In [1]:
DIAGNOSIS = "meniscus"
PLANE = "sagittal"
N_FOLDS = 5
N_EPOCHS = 15

In [2]:
import os

KAGGLE =  os.getenv("KAGGLE_URL_BASE") is not None
COLAB = os.getenv("COLAB_GPU") is not None
TPU = os.getenv("XRT_TPU_CONFIG") is not None
LOCAL = not KAGGLE and not COLAB

if not LOCAL:
    !git clone https://github.com/nclibz/MRKnee/

if COLAB:
    os.chdir('/content/MRKnee/')
    !git checkout v3
    from google.colab import drive
    drive.mount('/content/drive')
    DATADIR = "/content/drive/MyDrive/MRKnee/data"
    if TPU:
        !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

if KAGGLE:
    os.chdir('/kaggle/working/MRKnee/')
    !git checkout v3
    dataset_name = os.listdir('/kaggle/input')[0]
    DATADIR = f"/kaggle/input/{dataset_name}/"
    MODELDIR = DATADIR
    
    # INSTALL pyodbc driver
    !sudo curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add -
    !sudo curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list > /etc/apt/sources.list.d/mssql-release.list
    !sudo apt-get update
    !sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18
    
    if TPU:
        !pip install torchtext==0.9
        !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
        !python pytorch-xla-env-setup.py --version 1.8

if not LOCAL:
    !pip install -U torchmetrics timm optuna albumentations scikit-image
    %conda install -y pyodbc
    BACKBONE = "tf_efficientnetv2_s_in21k"

if LOCAL:
    DATADIR = "data"
    MODELDIR = "src/"
    BACKBONE = 'tf_mobilenetv3_small_minimal_100'
    %load_ext autoreload
    %autoreload 2



In [3]:
import optuna
import torch
from optuna.pruners import ThresholdPruner
from optuna.samplers import TPESampler
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from src.augmentations import Augmentations
from src.data import MRNet
from src.metrics import AUC, Loss, MetricLogger
from src.model import VanillaMRKnee
from src.rdb import get_rdb_string
from src.trainer import Trainer
from src.utils import seed_everything

seed_everything(123)
rdb_string = get_rdb_string()

In [4]:

def objective(trial):
    augs = Augmentations(
        train_imgsize=(256, 256),
        test_imgsize=(256, 256),
        shift_limit=trial.suggest_int("shift_limit", 0, 35, step = 5) / 100,
        scale_limit=trial.suggest_int("scale_limit", 0, 20, step = 5) / 100,
        rotate_limit=trial.suggest_int("rotate_limit", 0, 15, step = 5) / 100,
        ssr_p=trial.suggest_int("ShiftScaleRotate_p", 50, 80, step = 10) / 100,
        clahe_p=trial.suggest_int("clahe_p", 20, 70, step = 10) / 100,
        reverse_p=0.0,
        indp_normalz=True,
        trim_p=0.1,
    )

    ds = MRNet(
        stage="train",
        diagnosis=DIAGNOSIS,
        plane=PLANE,
        clean=True,
        transforms=augs,
        datadir = DATADIR + "MRNet"
    )
    # TODO: For OAI implement grouped strat kfold
    splits = list(StratifiedKFold(N_FOLDS, shuffle=True).split(ds.ids, ds.lbls))

    ## Start cv loop
    fold_losses = []
    for train_idxs, val_idxs in splits:
        train_fold = Subset(ds, train_idxs)
        train_dl = DataLoader(train_fold)
        val_fold = Subset(ds, val_idxs)
        val_dl = DataLoader(val_fold)

        model = VanillaMRKnee(
            BACKBONE,
            drop_rate=trial.suggest_int("drop_rate", 50, 90, step=10) / 100)
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=trial.suggest_loguniform('lr', 1e-4, 1e-2),
            weight_decay=trial.suggest_loguniform('adam_wd', 0.01, 0.1),
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               "min",
                                                               patience=4)

        metriclogger = MetricLogger(
            train_metrics={
                "train_loss": Loss(),
                "train_auc": AUC(),
            },
            val_metrics={
                "val_loss": Loss(),
                "val_auc": AUC()
            },
        )

        trainer = Trainer(model, optimizer, scheduler, metriclogger)

        for epoch in tqdm(range(N_EPOCHS), desc='Epochs', disable=True):
            trainer.train(train_dl)
            trainer.test(val_dl)
            # TODO: Den reporter vals til samme epoch step for alle folds -> så pruner virker kun for den første
            #trial.report(metriclogger.val_loss.epoch_values[-1], epoch)

            #if trial.should_prune():
                #raise optuna.TrialPruned()

        min_loss = torch.min(torch.Tensor(metriclogger.val_loss.epoch_values))
        fold_losses.append(min_loss.to("cpu"))

    avg_cv_loss = torch.mean(torch.Tensor(fold_losses)).item()

    return avg_cv_loss

In [None]:
if DIAGNOSIS == "meniscus":
    THRESHOLD = 1.4
else:
    THRESHOLD = 1.0


storage = optuna.storages.RDBStorage(
            url=rdb_string,
            heartbeat_interval=360,
        )
sampler = TPESampler(multivariate=True)

study = optuna.create_study(
    storage = storage,
    study_name=f"{DIAGNOSIS}_{PLANE}_{BACKBONE}",
    sampler = sampler,
    pruner=ThresholdPruner(upper=THRESHOLD, n_warmup_steps=5, interval_steps=1),
    load_if_exists=True,
)




In [None]:
## best params
men_sag = {
    "shift_limit": 25, 
    "rotate_limit": 10,
    "scale_limit": 10, 
    "ssr_p": 80,
    "adam_wd": 0.0539477, 
    "lr": 0.000220,
    "drop_rate": 90,
    "clahe_p": 20
    }

study.enqueue_trial(men_sag)

In [None]:
study.optimize(objective, n_trials=1) # 1 trial = 8 hours. Kaggle limit = 9hrs