In [3]:
CFG = {
    "dataset": "mrnet",
    "plane": "coronal",
    "protocol": "TSE",
    "backbone": "tf_efficientnetv2_s_in21k", #"tf_efficientnetv2_s_in21k" or "tf_mobilenetv3_small_minimal_100"
    "n_epochs": 15,
    "n_trials": 1,
    "wandb_project": "mrknee"
}

In [4]:
%%capture
import os

!git clone https://github.com/nclibz/MRKnee/
os.chdir('/kaggle/working/MRKnee/')
!git checkout v3
dataset_name = os.listdir('/kaggle/input')[0]
DATADIR = f"/kaggle/input/{dataset_name}/"

# INSTALL pyodbc driver
!sudo curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add -
!sudo curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list > /etc/apt/sources.list.d/mssql-release.list
!sudo apt-get update
!sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18
# Packages
!pip install -U torchmetrics timm optuna albumentations scikit-image madgrad wandb
%conda install -y pyodbc

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
ENV = {}
ENV["RDB_URL"] = user_secrets.get_secret("RDB_URL")
ENV['WANDB_API_KEY'] = user_secrets.get_secret("WANDB_API_KEY")

## DIRTY FIX FOR OAI IMGDIR

if CFG['dataset'] == 'oai':
    IMGDIR = "imgs/imgs"
else:
    IMGDIR = "imgs"


In [None]:
import torch
import wandb
from madgrad import MADGRAD
import optuna
from src.augmentations import Augmentations
from src.data import OAI, MRNet, get_dataloader
from src.metrics import AUC, Loss, MetricLogger
from src.model import VanillaMRKnee
from src.model_checkpoint import SaveModelCheckpoint
from src.trainer import Trainer
from src.utils import seed_everything

seed_everything(123)

wandb.login(key=ENV["WANDB_API_KEY"])

In [6]:

def objective(trial, CFG = CFG):
    
    augs = Augmentations(
        ssr_p=trial.suggest_int('ssr_p', 30, 80, 5) / 100,
        shift_limit=trial.suggest_int('shift_limit', 0, 30, 5) / 100,
        scale_limit=trial.suggest_int('scale_limit', 0, 30,5) / 100,
        rotate_limit=trial.suggest_int('rotate_limit', 0, 30,5) / 100,
        bc_p=0.00,
        brigthness_limit=0.10,
        contrast_limit=0.10,
        re_p=trial.suggest_int('re_p', 0, 80, 10) / 100,
        clahe_p=trial.suggest_int('clahe_p', 0, 80, 10) / 100,
        trim_p=0.0,
    )

    if CFG["dataset"] == "oai":
        DATAREADER = OAI
    elif CFG["dataset"] == "mrnet":
        DATAREADER = MRNet

    # TODO: flytte dr loading ind i get_dataloader

    train_dr = DATAREADER(
        stage="train",
        diagnosis="meniscus",
        plane=CFG["plane"],
        protocol=CFG["protocol"],
        clean=True,
        datadir = DATADIR,
        img_dir = IMGDIR
    )

    val_dr = DATAREADER(
        stage="valid",
        diagnosis="meniscus",
        plane=CFG["plane"],
        protocol=CFG["protocol"],
        clean=False,
        datadir = DATADIR,
        img_dir = IMGDIR

    )

    train_dl = get_dataloader(train_dr, augs)
    val_dl = get_dataloader(val_dr, augs)

    model = VanillaMRKnee(CFG["backbone"], pretrained=True, drop_rate=trial.suggest_int('drop_rate', 40, 90, 5) / 100)

### OPTIMIZERS

    LR = trial.suggest_loguniform('lr', 1e-5, 1e-3)
    WD = trial.suggest_loguniform('weigth_decay', 1e-5, 1e-2)
    OPTIM_NAME = trial.suggest_categorical("optimizer", ["madgrad", "adamw"])

    if OPTIM_NAME == "madgrad":
        optimizer = MADGRAD(model.parameters(),lr=LR, weight_decay=WD)
    elif OPTIM_NAME == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr= LR, weight_decay=WD)

### SCHEDULER
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        patience=4,
    )

## HELPER CLASSES
    metriclogger = MetricLogger(
        train_metrics={"train_loss": Loss(), "train_auc": AUC()},
        val_metrics={"val_loss": Loss(), "val_auc": AUC()},
    )

    chpkt = SaveModelCheckpoint("checkpoint")

    trainer = Trainer(
        model,
        optimizer,
        scheduler,
        metriclogger,
        label_smoothing=trial.suggest_int('lbl_smoothing', 0, 10, 5) / 100,
        progressbar=True,
    )

### LOGGING

    CFG = {**CFG, **trial.params}
    wandb.init(project=CFG['wandb_project'], entity="nclibz", config=CFG)

### TRAINING

    for epoch in range(CFG["n_epochs"]):
        trainer.train(train_dl)
        trainer.validate(val_dl)

        metrics = {k: metriclogger.get_metric(k, epoch) for k in metriclogger.all_metrics}

        wandb.log(metrics)

        is_best = chpkt.check(metrics["val_loss"], model, optimizer, scheduler, epoch)
        if is_best:
            wandb.save(chpkt.get_checkpoint_path())

        # TODO: Flytte print af metrics ind i metriclogger?
        print(
            f"EPOCH: {epoch} / {CFG['n_epochs']} \n train_loss: {metrics['train_loss']:.3f} val_loss: {metrics['val_loss']:.3f} \n train_auc: {metrics['train_auc']:.3f} val_auc: {metrics['val_auc']:.3f} "
        )


    wandb.finish()
               
    return metriclogger.get_min("val_loss")


In [None]:
storage = optuna.storages.RDBStorage(
            url=ENV['RDB_URL'],
            heartbeat_interval=360,
        )
sampler = optuna.samplers.TPESampler(multivariate=True)

study = optuna.create_study(
    storage = storage,
    study_name=f"{CFG['dataset']}_{CFG['plane']}_{CFG['backbone']}",
    sampler = sampler,
    load_if_exists=True,
)

In [None]:
study.optimize(objective, n_trials=CFG['n_trials']) 