In [40]:
from code.dataset import H5Dataset
from torchvision import transforms
import torch
from code.MyPytorchModel import MyPytorchModel
import torchvision
import optuna
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Callback
import pytorch_lightning as pl

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
#https://stackoverflow.com/questions/57310333/can-we-disable-h5py-file-locking-for-python-file-like-object
!export HDF5_USE_FILE_LOCKING='FALSE'

from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/My Drive/diganes')
sys.path.append('/content/drive/My Drive/diganes/code')

In [76]:
transform = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(), 
     transforms.Normalize(mean=[0.49191375, 0.48235852, 0.44673872], 
                          std=[0.24706447, 0.24346213, 0.26147554])])

In [77]:
train_file = 'diganes_train_dataset.h5'
val_file = 'diganes_val_dataset.h5'
test_file = 'diganes_test_dataset.h5'

dataset = {"train": H5Dataset(train_file, transform=transform), 
           "val": H5Dataset(val_file, transform=transform),
           "test": H5Dataset(test_file, transform=transform)}

In [83]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [82]:
from pytorch_lightning import Callback

class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""
    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

In [85]:
def objective(trial):  
    metrics_callback = MetricsCallback()
    
    early_stop_callback = EarlyStopping(
       monitor='val_loss',
       patience=5,
       verbose=True,
       mode='min'
    )
    
    # create a trainer
    trainer = pl.Trainer(
        logger=True,                                                                  
        max_epochs=20,                                                                
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback],
        early_stop_callback=early_stop_callback,
    )
    
    hparams = {
        "batch_size": 64,
        "lr": 3e-4,
        "layers_to_freeze": 13
    }
    
    # create model from these hyper params and train it
    model = MyPytorchModel(hparams, dataset, torchvision.models.mobilenet_v2(pretrained=True))
    model.to(device)
    trainer.fit(model)

    # save model
    save_model(model, '{}.p'.format(trial.number), "checkpoints")

    return metrics_callback.metrics[-1]["val_loss"]

In [None]:
pruner = optuna.pruners.NopPruner()
study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=1)

GPU available: False, used: False
No environment variable for node rank defined. Set as 0.

    | Name                                | Type                 | Params
-------------------------------------------------------------------------
0   | model                               | PretrainedClassifier | 2 M   
1   | model.feature_extractor             | Sequential           | 2 M   
2   | model.feature_extractor.0           | ConvBNReLU           | 928   
3   | model.feature_extractor.0.0         | Conv2d               | 864   
4   | model.feature_extractor.0.1         | BatchNorm2d          | 64    
5   | model.feature_extractor.0.2         | ReLU6                | 0     
6   | model.feature_extractor.1           | InvertedResidual     | 896   
7   | model.feature_extractor.1.conv      | Sequential           | 896   
8   | model.feature_extractor.1.conv.0    | ConvBNReLU           | 352   
9   | model.feature_extractor.1.conv.0.0  | Conv2d               | 288   
10  | model.feature_

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Val-F1=0.13, Val-Loss:0.69


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…