In [None]:
# Mount Google Drive to access files
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# Navigate to DDI working directory
%cd drive/MyDrive/DDI
!ls

In [None]:
# Install dependencies
!pip install --quiet pytorch_lightning wandb

In [None]:
# Imports
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torchvision import transforms as T

wandb.login()

In [None]:
from datasets import DDI_DataModule
from models import DDI_DeepDerm

In [None]:
# RUN EXPERIMENT

# Constants and transforms copied from DDI-Code
means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]
test_transform = T.Compose([
    lambda x: x.convert('RGB'),
    T.Resize(299),
    T.CenterCrop(299),
    T.ToTensor(),
    T.Normalize(mean=means, std=stds)
])

# Config parameters
annotation_file = './data/ddi_metadata.csv'
img_dir = './data'
batch_size = 256
num_workers = 2  # recommended by Colab
random_seed = 0
classify_malignant = True
finetune_mode = 'first_conv'
transform = test_transform  # TODO: match DDI experiments
skin_tone = None
malignant = None
diseases = None

# Set random seed
pl.seed_everything(random_seed)

# Initialize LightningDataModule, LightningModule, Logger, Trainer
data_module = DDI_DataModule(random_seed,
                              batch_size,
                              num_workers,
                              annotation_file,
                              img_dir,
                              transform,
                              skin_tone,
                              malignant,
                              diseases)
model = DDI_DeepDerm(classify_malignant=classify_malignant,
                     mode=finetune_mode)
trainer = pl.Trainer(deterministic=True, # for reproducibility
                     max_epochs=500, # to match DDI experiments
                     logger=wandb_logger,
                     log_every_n_steps=2) # 2 batches per epoch

# Train model
trainer.fit(model, data_module)