In [24]:
!pip install optuna
!pip install efficientnet_pytorch




In [25]:

import torch
import pytorch_lightning as pl
import optuna
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet50
from efficientnet_pytorch import EfficientNet
import wandb

In [26]:
# Initialize WandB
wandb.init()

VBox(children=(Label(value='0.001 MB of 0.029 MB uploaded\r'), FloatProgress(value=0.025984993666157793, max=1…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01116753888911464, max=1.0)…

In [27]:
class ImageClassificationModel(pl.LightningModule):
    def __init__(self, backbone, learning_rate):
        super().__init__()
        self.backbone = backbone
        self.learning_rate = learning_rate
        
        # Define backbone architecture
        if backbone == 'ResNet':
            self.model = resnet18(pretrained=True)
            num_ftrs = self.model.fc.in_features
            self.model.fc = torch.nn.Linear(num_ftrs, 6)
        elif backbone == 'EfficientNet':
            self.model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=6)
            
        self.criterion = torch.nn.CrossEntropyLoss()    
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        self.log('val_loss', loss)
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        self.log('test_loss', loss)
        

In [28]:
# Load data
train_data = ImageFolder(root='archive/seg_train/seg_train/', transform=transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
]))
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

test_data = ImageFolder(root="archive/seg_test/seg_test/", transform=transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
]))
test_loader = DataLoader(test_data, batch_size=64)



In [29]:
# Define objective function for Optuna
def objective(trial):
    backbone = trial.suggest_categorical('backbone', ['ResNet', 'EfficientNet'])
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2)
    
    model = ImageClassificationModel(backbone=backbone, learning_rate=learning_rate)
    trainer = pl.Trainer(max_epochs=1)  # Change max_epochs to 1
    trainer.fit(model, train_loader, test_loader)
    
    # return validation loss as the metric to minimize
    return trainer.logged_metrics['val_loss'].item()

In [30]:
# Optuna hyperparameter optimization
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=1)

# Get best hyperparameters
best_params = study.best_params
best_backbone = best_params['backbone']
best_learning_rate = best_params['learning_rate']

# Train model with best hyperparameters
best_model = ImageClassificationModel(backbone=best_backbone, learning_rate=best_learning_rate)
trainer = pl.Trainer(max_epochs=1)  # Change max_epochs to 1
trainer.fit(best_model, train_loader, test_loader)

# Evaluate on test set
trainer.test(best_model, test_loader)

[I 2024-04-28 19:36:22,663] A new study created in memory with name: no-name-f50e9157-6d51-4dff-ae11-9529e20877b9
  learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2)
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /Users/maryamsoftdev/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████████████████████████████████| 20.4M/20.4M [00:12<00:00, 1.73MB/s]


Loaded pretrained weights for efficientnet-b0


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /Users/maryamsoftdev/Downloads/lightning_logs

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | EfficientNet     | 4.0 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.061    Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/maryamsoftdev/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/maryamsoftdev/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

python(72470) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72471) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72472) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72473) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72475) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72478) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72484) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72489) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72496) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72499) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72504) Malloc

python(72797) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72799) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72800) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72801) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72802) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72806) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72810) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72811) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72812) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72813) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(72814) Malloc

python(73067) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73069) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73071) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73073) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73076) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73082) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73089) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73093) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73096) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73097) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Validation: |                                             | 0/? [00:00<?, ?it/s]

python(73098) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73100) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73104) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73106) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73107) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73109) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73111) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73115) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73117) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73122) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73123) Malloc

Loaded pretrained weights for efficientnet-b0


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

python(73134) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73139) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73141) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Training: |                                               | 0/? [00:00<?, ?it/s]

python(73145) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73147) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73148) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73149) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73150) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73155) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73156) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73157) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73164) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73168) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73170) Malloc

python(73423) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73427) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73433) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73435) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73439) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73453) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73463) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73465) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73467) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73468) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73470) Malloc

Validation: |                                             | 0/? [00:00<?, ?it/s]

python(73695) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73697) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73698) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73701) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73705) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73706) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73708) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73709) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73716) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
`Trainer.fit` stopped: `max_epochs=1` reached.
/Users/maryamsoftdev/anaconda3/lib/python3.11/site-packages/pytorch_lig

Testing: |                                                | 0/? [00:00<?, ?it/s]

python(73723) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73728) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73731) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73734) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73738) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73741) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73744) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(73748) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.2132299393415451
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


python(73754) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


[{'test_loss': 0.2132299393415451}]