In [None]:
# setup
!apt-get update
!apt-get install git
!pip install python-dotenv
!pip install loguru
!pip install efficientnet_pytorch
!pip install wandb

In [None]:
# load data and unzip data
!python ./tools/download_data_and_chkpts.py

In [None]:
# imports
from ails_miccai_uwf4dr_challenge.dataset import DatasetBuilder, DatasetOriginationType, ChallengeTaskType, CustomDataset

import torch.nn as nn
from sklearn.metrics import roc_auc_score, average_precision_score
from torch import optim
from torch.utils.data import DataLoader

import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from ails_miccai_uwf4dr_challenge.preprocess_augmentations import ResidualGaussBlur, MultiplyMask
from ails_miccai_uwf4dr_challenge.augmentations import rotate_affine_flip_choice, resize_only
from ails_miccai_uwf4dr_challenge.models.metrics import sensitivity_score, specificity_score
from ails_miccai_uwf4dr_challenge.models.trainer import Metric, DefaultMetricsEvaluationStrategy, Trainer, TrainingContext, MetricCalculatedHook

In [None]:
# select device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: " + str(device))

In [None]:
# login to wandb
use_wandb = True
if use_wandb:
    import wandb
    wandb.login()

In [None]:



preprocessing = A.Compose([
        A.Resize(800, 1016, p=1),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], p=1),
        #ResidualGaussBlur(p=1),
        MultiplyMask(p=1),
        A.Resize(800, 1016, p=1),
        #A.Equalize(p=0.1),
        #A.CLAHE(clip_limit=5., p=0.3)
    ])

augment_train = A.Compose([
        #A.VerticalFlip(p=0.5),
        #A.HorizontalFlip(p=0.5),
        #A.Affine(rotate=15, rotate_method='ellipse', p=0.5),
        ToTensorV2(p=1)
    ])

augment_val = A.Compose([
        ToTensorV2(p=1)
    ])


transforms_train = A.Compose([
    preprocessing,
    augment_train
])

transforms_val = A.Compose([
    preprocessing,
    augment_val
])

In [None]:
from ails_miccai_uwf4dr_challenge.models.architectures import task1_automorph_plain

model = task1_automorph_plain(enc_frozen=True)

model.to(device)
print("Training model: ", model.__class__.__name__)

metrics = [
        Metric('auroc', roc_auc_score),
        Metric('auprc', average_precision_score),
        Metric('accuracy', lambda y_true, y_pred: (y_pred.round() == y_true).mean()),
        Metric('sensitivity', sensitivity_score),
        Metric('specificity', specificity_score)
    ]

class WandbLoggingHook(MetricCalculatedHook):
        def on_metric_calculated(self, training_context: TrainingContext, metric: Metric, result, last_metric_for_epoch: bool):
            import wandb
            wandb.log(data={metric.name: result}, commit=last_metric_for_epoch)

metrics_eval_strategy = DefaultMetricsEvaluationStrategy(metrics)
if(use_wandb):
    metrics_eval_strategy.register_metric_calculated_hook(WandbLoggingHook())

config = {
    "learning_rate": 1e-3,
    "dataset": "Original",
    "epochs": 10,
    "batch_size": 16,
    "model_type": model.__class__.__name__
}

loss1 = nn.SmoothL1Loss()
loss2 = nn.BCEWithLogitsLoss()
criterion = 0.5*loss1 + 0.5*loss2

optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"])
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)


dataset_builder = DatasetBuilder(DatasetOriginationType.ALL, ChallengeTaskType.TASK1)

train_data, val_data = dataset_builder.get_train_val()
train_dataset = CustomDataset(train_data, transform=transforms_train)
val_dataset = CustomDataset(val_data, transform=transforms_val)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)


trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, lr_scheduler, device,
                        metrics_eval_strategy=metrics_eval_strategy)

if use_wandb:
    wandb.init(entity='miccai-challenge-2024' ,project='task2', config=config)

print(f"Start training [{config['model_type']}] on [{config['dataset']}] dataset for [{config['epochs']}] epochs with batch size [{config['batch_size']}]")

trainer.train(num_epochs=config["epochs"])

print("Training finished.")

