In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
!git clone https://github.com/nclibz/MRKnee/
os.chdir('/content/MRKnee/')
!git checkout v3

In [None]:
DIAGNOSIS = "acl"
PLANE = "sagittal"
BACKBONE = "tf_mobilenetv3_small_minimal_100"
DATADIR = "data"

# SETUP

In [1]:
!pip install --quiet "pytorch-lightning>=1.3" "torchmetrics>=0.3" "torch==1.8" "torchvision" "torchtext == 0.9" "timm" "neptune-client" "optuna" "PyMySql"
!pip install albumentations --upgrade --quiet

Collecting rfc3987
  Downloading rfc3987-1.3.8-py2.py3-none-any.whl (13 kB)
Collecting webcolors
  Downloading webcolors-1.11.1-py3-none-any.whl (9.9 kB)
Collecting jsonpointer>1.13
  Downloading jsonpointer-2.1-py2.py3-none-any.whl (7.4 kB)
Collecting strict-rfc3339
  Downloading strict-rfc3339-0.7.tar.gz (17 kB)
Building wheels for collected packages: strict-rfc3339
  Building wheel for strict-rfc3339 (setup.py) ... [?25ldone
[?25h  Created wheel for strict-rfc3339: filename=strict_rfc3339-0.7-py3-none-any.whl size=18149 sha256=42d3e8824a38d7deedf3c1150cd1739553762ebe3974e4fd6a7e5a66c6b0d5ff
  Stored in directory: /home/nicolai/.cache/pip/wheels/25/38/74/7ec7f77ec64b2907430120931ba588b40e6e26f02d4df5be35
Successfully built strict-rfc3339
Installing collected packages: webcolors, strict-rfc3339, rfc3987, jsonpointer
Successfully installed jsonpointer-2.1 rfc3987-1.3.8 strict-rfc3339-0.7 webcolors-1.11.1
Collecting albumentations
  Downloading albumentations-1.1.0-py3-none-any.whl (1

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

# MODEL

In [None]:
from src.model import MRKnee
from src.data import MRKneeDataModule
from src.augmentations import Augmentations
from src.callbacks import Callbacks
import pytorch_lightning as pl
import optuna

pl.seed_everything(123)

In [None]:

def objective(trial, diagnosis=DIAGNOSIS, plane=PLANE, backbone=BACKBONE, datadir=DATADIR):

    model = MRKnee(
        backbone=backbone,
        drop_rate=0.0,
        final_drop=0.0,
        learning_rate=0.0001,
        log_auc=True,
        log_ind_loss=False,
        adam_wd=0.01,
        max_epochs=20,
        precision=32,
    )

    augs = Augmentations(
        model,
        shift_limit=0.20,
        scale_limit=0.20,
        rotate_limit=30,
        reverse_p=0.5,
        same_range=True,
        indp_normalz=True,
    )

    dm = MRKneeDataModule(
        datadir=datadir,
        diagnosis=diagnosis,
        plane=plane,
        transforms=augs,
        clean=True,
        num_workers=1,
        pin_memory=True,
        trim_train=True,
    )

    # TODO: Lave cfg class?
    cfg = dict()
    cfg.update(model.__dict__)
    cfg.update(augs.__dict__)
    cfg.update(dm.__dict__)

    callbacks = Callbacks(cfg, trial, neptune_name="tester")

    trainer = pl.Trainer(
        gpus=1,
        precision=cfg["precision"],
        max_epochs=cfg["max_epochs"],
        logger=callbacks.get_neptune_logger(),
        log_every_n_steps=100,
        num_sanity_val_steps=0,
        callbacks=callbacks.get_callbacks(),
        progress_bar_refresh_rate=20,
        deterministic=True,
    )

    trainer.fit(model, dm)

    ## UPLOAD BEST CHECKPOINTS TO LOG
    callbacks.upload_best_checkpoints()

    return callbacks.metrics_callback.metrics[-1]["val_loss"].item()


In [None]:

pruner = optuna.pruners.HyperbandPruner(min_resource=10)
sampler = optuna.samplers.TPESampler(multivariate=True)
storage = optuna.storages.RDBStorage(
    url="mysql+pymysql://admin:Testuser1234@database-1.c17p2riuxscm.us-east-2.rds.amazonaws.com/optuna",
    heartbeat_interval=120,
    grace_period=360,
)
study_name = f"{DIAGNOSIS}_{PLANE}_{BACKBONE}"

study = optuna.create_study(
    storage=storage,
    study_name=study_name,
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner,
    direction="minimize",
)
#study.enqueue_trial({
#    'dropout': 55,
#    'lr': 3.e-4,
#    'rotate': 25,
#    'scale': 8,
#    'shift': 10,
#    'adam_wd': 0.0900
#    })


study.optimize(objective, n_trials=40, timeout=8 * 60 * 60)