In [None]:
# Download data
!pip install -q kaggle
from google.colab import files
files.upload()
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d 'andrewmvd/medical-mnist'
!unzip -q medical-mnist.zip -d data
!rm medical-mnist.zip

In [None]:
# Installations
!pip install pytorch-lightning
!pip install wandb

In [None]:
# Imports
import os
import glob

from torchvision.datasets import ImageFolder
import torch
import torchmetrics
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
import optuna
from optuna.integration import PyTorchLightningPruningCallback

In [None]:
# Global variables
BEST_F1 = 0

In [None]:
# Data

class MedicalDataMNIST(pl.LightningDataModule):
    def __init__(self, batch_size=64, num_workers=1):
      super().__init__()
      self.num_workers = num_workers
      self.batch_size = batch_size
      self.labels_map = {0 : "AbdomenCT",
                         1 : "BreastMRI",
                         2 : "CXR",
                         3 : "ChestCT",
                         4 : "Hand",
                         5 : "HeadCT"}
      self.train_transform = transforms.Compose(
          [transforms.ColorJitter(hue=.20, saturation=.20),
           transforms.RandomHorizontalFlip(),
           transforms.RandomVerticalFlip(),
           transforms.RandomRotation(10),
           transforms.ToTensor(),
           transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])]
      )
      
      self.val_test_transform = transforms.Compose(
          [transforms.ToTensor(),
           transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])]
      )

    def prepare_data(self):
      pass
      # called only on 1 GPU
      # ONLY DOWNLOAD!!!
      
    def setup(self, stage=None):

      self.dataset = ImageFolder("./data/")
      train_size = int(0.7 * len(self.dataset)) # take 70% for training
      val_size = int(0.2 * len(self.dataset)) # take 20% for validation
      test_size = len(self.dataset) - (train_size + val_size) # take 10% for test
      
      self.train_set, self.val_set, self.test_set = \
      torch.utils.data.random_split(self.dataset, 
                                    [train_size, val_size, test_size])

      self.train_set.dataset.transform = self.train_transform
      self.val_set.dataset.transform = self.val_test_transform
      self.test_set.dataset.transform = self.val_test_transform

    def train_dataloader(self):
      return DataLoader(self.train_set, 
                        batch_size=self.batch_size, 
                        shuffle=True, 
                        num_workers=self.num_workers) 

    def val_dataloader(self):
      return DataLoader(self.val_set, 
                        batch_size=self.batch_size, 
                        num_workers=self.num_workers)

    def test_dataloader(self):
      print("TEST DATALOADER")
      return DataLoader(self.test_set, 
                        batch_size=self.batch_size, 
                        num_workers=self.num_workers)

    def visualize_dataset(self):
      # Visualizes dataset
      figure = plt.figure(figsize=(8, 8))
      cols, rows = 3, 3
      for i in range(1, cols * rows + 1):
          sample_idx = torch.randint(len(self.train_set), size=(1,)).item()
          norm_img, label = self.train_set[sample_idx]
          mean = torch.tensor([0.485, 0.456, 0.406])
          std = torch.tensor([0.229, 0.224, 0.225])
          img = norm_img * std[:, None, None] + mean[:, None, None] 
          figure.add_subplot(rows, cols, i)
          plt.title(self.labels_map[label])
          plt.axis("off")
          plt.imshow(img.permute(1, 2, 0))
      plt.show()

    def visualize_dataloader(self):
      # Display image and label
      train_dataloader = self.train_dataloader()
      train_features, train_labels = next(iter(train_dataloader))
      print(f"Feature batch shape: {train_features.size()}")
      print(f"Labels batch shape: {train_labels.size()}")
      norm_img = train_features[0]
      mean = torch.tensor([0.485, 0.456, 0.406])
      std = torch.tensor([0.229, 0.224, 0.225])
      img = norm_img * std[:, None, None] + mean[:, None, None]
      label = train_labels[0]
      plt.imshow(img.permute(1, 2, 0))
      plt.show()
      print(f"Label: {self.labels_map[label.item()]}")

In [None]:
dm = MedicalDataMNIST()
dm.setup()
dm.visualize_dataset()
dm.visualize_dataloader()

In [None]:
# Model

class MedicalMNIST(LightningModule):
    def __init__(self, model="EfficientNetb0", optimaizer="Adam", lr=1e-4,
                 betas=(0.9, 0.999), eps=1e-08, weight_decay=0, momentum=0,
                 alpha=0.99, lambd=1e-4, asgd_alpha=0.75, dropout=0.2): 
        super().__init__()

        # Hyperparameters
        # Model
        if model == "EfficientNetb0":
          # Fine tuning EfficientNetb0
          self.name = "EfficientNetb0"
          self.model = models.efficientnet_b0(pretrained=True)
          self.model.classifier = torch.nn.Sequential(
              torch.nn.Dropout(p=dropout, inplace=False),
              torch.nn.Linear(in_features=self.model.classifier[1].in_features,
                              out_features=6)
          )

        elif model == "VGG16":
          # Fine tuning VGG16
          self.name = "VGG16"
          self.model = models.vgg16(pretrained=True)
          self.model.classifier[-1] = nn.Linear(in_features=4096, 
                                                out_features=6)

        elif model == "InceptionV3":
          # Fine tuning InceptionV3
          self.name = "InceptionV3"
          self.model = models.inception_v3(pretrained=True)
          # Handle the auxilary net
          in_features = self.model.AuxLogits.fc.in_features
          self.model.AuxLogits.fc = nn.Linear(in_features=in_features,
                                              out_features=6)
          # Handle the primary net
          in_features = self.model.fc.in_features
          self.model.fc = nn.Linear(in_features=in_features,
                                    out_features=6)

        else model == "ResNet18":
          # Fine tuning ResNet18
          self.name = "ResNet18"
          self.model = models.resnet18(pretrained=True)
          self.model.fc = nn.Linear(in_features=self.model.fc.in_features,
                                    out_features=6)

        # Optimizer
        if optimaizer == "Adam":
          self.optimaizer = torch.optim.Adam(
              self.parameters(),
              lr=lr,
              betas=betas,
              eps=eps,
              weight_decay=weight_decay
          )

        elif optimaizer == "SGD":
          self.optimaizer = torch.optim.SGD(
              self.parameters(),
              lr=lr,
              momentum=momentum,
              weight_decay=weight_decay
          )
          
        elif optimaizer == "RMSprop":
          self.optimaizer = torch.optim.RMSprop(
              self.parameters(),
              lr=lr,
              alpha=alpha,
              eps=eps,
              weight_decay=weight_decay,
              momentum=momentum
          )

        else:
          # ASGD
          self.optimaizer = torch.optim.ASGD(
              self.parameters(),
              lr=lr,
              lambd=lambd,
              alpha=asgd_alpha,
              weight_decay=weight_decay,
          )
        
        # Metrics
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.train_f1 = torchmetrics.F1(num_classes=6)
        self.val_f1 = torchmetrics.F1(num_classes=6)
        
        
    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = torch.nn.functional.cross_entropy(logits, y)

        self.log("loss/train", loss, on_step=True, on_epoch=True)
      
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        val_loss = torch.nn.functional.cross_entropy(y_hat, y)
        
        self.log("loss/val", val_loss)
        self.log("accuaracy/val", self.val_acc(y_hat, y), prog_bar=True,
                 logger=True)
        self.log("f1/val", self.val_f1(y_hat, y), prog_bar=True,
                 logger=True)
        
        return val_loss

    def configure_optimizers(self):
        return self.optimazer

In [None]:
# # Custom callbaks

# class TensorCallback(pl.Callback):
#   """
#   To fill...
#   """
#   def on_train_epoch_start(self, trainer, _):
#         """ Check if we should save a checkpoint after every train epoch """
#         epoch = trainer.current_epoch
#         if epoch == 1:
#           try:
#             from tensorboard import notebook

#           except:

#             # Load Tensorboard
#             %load_ext tensorboard

#             # Adjust the hight of the tensorboard
#             from tensorboard import notebook
#             notebook.display(height=2000)

#             # Show the logs 
#             %tensorboard --logdir=lightning_logs/



In [None]:
def objective(trial: optuna.trial.Trial) -> float:
  # OPTUNA objective function

  # Hyperparameters
  model_name = trial.suggest_categorical(
      "model_name", ["EfficientNetb0", "VGG16", "InceptionV3", "ResNet18"]
  )
  optimizer_name = trial.suggest_categorical(
      "optimaizer_name", ["Adam", "SGD", "RMSprop", "ASGD"]
  )
  batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128])
  learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)
  betas = (trial.suggest_uniform("beta_1", 0.8, 0.95),
            trial.suggest_uniform("beta_2", 0.995, 0.9999))
  eps = trial.suggest_loguniform("eps", 1e-09, 1e-07)
  weight_decay = trial.suggest_float("weight_decay", 1e-10, 1e-3, log=True)
  momentum = trial.suggest_float("momentum", 1e-5, 1e-1, log=True)
  alpha = trial.suggest_uniform("alpha", 0.9, 1)
  lambd = trial.suggest_float("lambd", 1e-5, 1e-2, log=True)
  asgd_alpha = trial.suggest_uniform("asgd_alpha", 0.7, 0.8)
  dropout = trial.suggest_categorical("dropout", [0.1, 0.2, 0.3, 0.4, 0.5])

  # Model and data
  model = MedicalMNIST(
      model=model_name,
      optimizer=optimizer_name,
      lr=learning_rate,
      betas=betas,
      eps=eps,
      weight_decay=weight_decay,
      momentum=momentum,
      alpha=alpha,
      lambd=lambd,
      asgd_alpha=asgd_alpha,
      dropout=dropout
  )
  datamodule = MedicalDataMNIST(
      batch_size=batch_size, # 128 batch_size is max
      num_workers=os.cpu_count()
  )

  # Logger
  logger = pl.loggers.TensorBoardLogger(
      "logs", 
      name=None,
      version=f"trial_{trial.number}_{model_name}_{learning_rate}"
  )

  # Early stopping
  earl.callbacks.early_stopping import EarlyStopping

  # Trainer
  trainer = pl.Trainer(
        logger=logger,
        checkpoint_callback=False,
        max_epochs=20,
        gpus=torch.cuda.device_count() if torch.cuda.is_available() else None,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="f1/val"),
                   pl.callbacks.early_stopping.EarlyStopping(monitor="f1/val",
                                                             patience=5)]
  )

  hyperparameters = dict(
      model=model_name,
      optimizer=optimizer_name,
      lr=learning_rate,
      betas=betas,
      eps=eps,
      weight_decay=weight_decay,
      momentum=momentum,
      alpha=alpha,
      lambd=lambd,
      asgd_alpha=asgd_alpha,
      batch_size=batch_size,
      dropout=dropout
  )
  trainer.logger.log_hyperparams(hyperparameters)
  trainer.fit(model, datamodule=datamodule)

  # Save model
  current_f1 = trainer.callback_metrics["f1/val"].item()
  if trial.number == 0:
    BEST_F1 = current_f1
    dir = "./best_model/"
    file_name = f"trial_{trial.number}_{model_name}_f1={current_f1}.ckpt"
    ckpt_path = os.path.join(dir, file_name)
    trainer.save_checkpoint(ckpt_path)
  elif current_f1 > BEST_F1:
    BEST_F1 = current_f1
    dir = "./best_model/"
    files = glob.glob(dir + '*')
    for f in files:
      os.remove(f)
    file_name = f"trial_{trial.number}_{model_name}_f1={current_f1}.ckpt"
    ckpt_path = os.path.join(dir, file_name)
    trainer.save_checkpoint(ckpt_path)

  return current_f1

In [None]:
# Main

if __name__ == "__main__":

  # Accelerators
  print(f"Number of CPUS: {os.cpu_count()}")
  print(f"Number of GPUS: \\
  {torch.cuda.device_count() if torch.cuda.is_available() else None}")

  # Tensorboard
  %load_ext tensorboard
  %tensorboard --logdir=logs/

  # Pruner
  pruner = optuna.pruners.MedianPruner()

  # Study
  study = optuna.create_study(direction="maximize", pruner=pruner)
  study.optimize(objective, n_trials=100, timeout=36000) # 10h

  print("Number of finished trials: {}".format(len(study.trials)))

  print("Best trial:")
  trial = study.best_trial

  print("  Value: {}".format(trial.value))

  print("  Params: ")
  for key, value in trial.params.items():
      print("    {}: {}".format(key, value))
