In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
base_path = Path().resolve().parent
%cd {base_path}

/Users/moritz/Documents/Master/AILS-MICCAI-UWF4DR-Challenge


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
# setup
#!apt-get update
#!pip install python-dotenv
#!pip install loguru
#!pip install gdown
#!pip install typer
#!pip install imbalanced-learn

In [4]:
# load data and unzip data

#!python tools/download_data_and_chkpts.py

In [7]:
# imports

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import wandb

from sklearn.metrics import roc_auc_score, average_precision_score

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import time

# data
from ails_miccai_uwf4dr_challenge.dataset_strategy import CustomDataset, DatasetStrategy, CombinedDatasetStrategy, \
    Task1Strategy, Task2Strategy, Task3Strategy, DatasetBuilder

# augmentation
from ails_miccai_uwf4dr_challenge.preprocess_augmentations import ResidualGaussBlur, MultiplyMask

# trainer
from ails_miccai_uwf4dr_challenge.models.trainer import DefaultMetricsEvaluationStrategy, Metric, MetricCalculatedHook, \
    NumBatches, Trainer, EpochTrainingStrategy, EpochValidationStrategy, DefaultEpochTrainingStrategy, \
    DefaultBatchTrainingStrategy, TrainingContext, PersistBestModelOnEpochEndHook, \
        OversamplingResamplingStrategy, UndersamplingResamplingStrategy

from ails_miccai_uwf4dr_challenge.models.metrics import sensitivity_score, specificity_score
from ails_miccai_uwf4dr_challenge.config import Config

# models
from ails_miccai_uwf4dr_challenge.models.architectures.task1_automorph_plain import AutoMorphModel
from ails_miccai_uwf4dr_challenge.models.architectures.task1_efficientnet_plain import Task1EfficientNetB4
from ails_miccai_uwf4dr_challenge.models.architectures.task2_efficientnetb0_plain import Task2EfficientNetB0 
from ails_miccai_uwf4dr_challenge.models.architectures.task1_convnext import Task1ConvNeXt 
from ails_miccai_uwf4dr_challenge.models.architectures.dinov2 import DinoV2Classifier, ModelSize
from ails_miccai_uwf4dr_challenge.models.architectures.ResNets import ResNet, ResNetVariant
from ails_miccai_uwf4dr_challenge.models.architectures.shufflenet import ShuffleNet

from ails_miccai_uwf4dr_challenge.config import WANDB_API_KEY, PROJ_ROOT

wandb.login(key=WANDB_API_KEY)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/moritz/.netrc


True

In [6]:
# select device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: " + str(device))

Device: cpu


## Train some model

In [17]:
def train_model(cfg=None):

    WANDB_HTTP_TIMEOUT=300
    WANDB_INIT_TIMEOUT =600
    WANDB_DEBUG=True

    wandb.init(project="task2", config=cfg)
    cfg = wandb.config

    transforms_train = A.Compose([
        A.Resize(800, 1016, p=1),
        MultiplyMask(p=1),
        ResidualGaussBlur(p=cfg.p_gaussblur),
        A.Equalize(p=cfg.p_equalize),
        A.CLAHE(clip_limit=5., p=cfg.p_clahe),
        A.HorizontalFlip(p=cfg.p_horizontalflip),
        A.Affine(rotate=cfg.rotation, rotate_method='ellipse', p=cfg.p_affine),
        A.Normalize(mean=[0.406, 0.485, 0.456], std=[0.225, 0.229, 0.224], p=1),
        #A.Resize(770, 1022, p=1), # comment whenever not using DinoV2
        ToTensorV2(p=1)
    ])

    transforms_val = A.Compose([
            A.Resize(800, 1016, p=1),
            MultiplyMask(p=1),
            A.Normalize(mean=[0.406, 0.485, 0.456], std=[0.225, 0.229, 0.224], p=1),
            #A.Resize(770, 1022, p=1), # comment whenever not using DinoV2
            ToTensorV2(p=1)
        ])

    dataset_strategy = CombinedDatasetStrategy()
    task_strategy = Task2Strategy()

    builder = DatasetBuilder(dataset_strategy, task_strategy, split_ratio=0.8)
    train_data, val_data = builder.build()

    train_dataset = CustomDataset(train_data, transform=transforms_train)
    val_dataset = CustomDataset(val_data, transform=transforms_val)

    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.backends.mps.is_available() else "mps")
    print(f"Using device: {device}")

    assert cfg.model is not None, "Model type must be specified in the config"

    if cfg.model == "Task2EfficientNetB0":
        model = Task2EfficientNetB0(num_classes=1)
    elif cfg.model == "AutoMorphModel":
        model = AutoMorphModel(enc_frozen=True)
    elif cfg.model == "DinoV2Classifier":
        model = DinoV2Classifier(ModelSize.SMALL)
    elif cfg.model == "ShuffleNet":
        model = ShuffleNet()
    elif cfg.model == "ResNet":
        model = ResNet(ResNetVariant.RESNET18)
    else:
        raise ValueError(f"Model type {cfg.model} not recognized")

    model.to(device)

    metrics = [
        Metric('auroc', roc_auc_score),
        Metric('auprc', average_precision_score),
        Metric('accuracy', lambda y_true, y_pred: (y_pred.round() == y_true).mean()),
        Metric('sensitivity', sensitivity_score),
        Metric('specificity', specificity_score)
    ]

    class WandbLoggingHook(MetricCalculatedHook):
        def on_metric_calculated(self, training_context: TrainingContext, metric: Metric, result, last_metric_for_epoch: bool):
            import wandb
            wandb.log(data={metric.name: result}, commit=last_metric_for_epoch)

    metrics_eval_strategy = DefaultMetricsEvaluationStrategy(metrics).register_metric_calculated_hook(WandbLoggingHook())

    def combined_losses(pred, target):
        bce = F.binary_cross_entropy_with_logits(pred, target) * cfg.loss_weight
        smooth_l1 = F.smooth_l1_loss(pred, target) * (1 - cfg.loss_weight)
        return bce + smooth_l1

    criterion = combined_losses
    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=cfg.lr_schedule_factor, patience=cfg.lr_schedule_patience, verbose=True)

    trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, lr_scheduler, device, 
                        metrics_eval_strategy=metrics_eval_strategy,
                        val_dataloader_adapter=UndersamplingResamplingStrategy(),
                        train_dataloader_adapter=UndersamplingResamplingStrategy())
    
    # build a file name for the model weights containing current timestamp and the model class
    wandb_run_name = wandb.run.name

    training_timestamp = time.strftime("%Y-%m-%d_%H-%M")
    persist_model_hook = PersistBestModelOnEpochEndHook(f"models/{wandb_run_name}_{training_timestamp}.pth")
    trainer.add_epoch_end_hook(persist_model_hook) # TODO uncomment this line to save the best model

    #print("First train 2 epochs 2 batches to check if everything works - you can comment these two lines after the code has stabilized...")
    #trainer.train(num_epochs=2, num_batches=NumBatches.TWO_FOR_INITIAL_TESTING)
    
    print("Now train train train")
    trainer.train(num_epochs=cfg.epochs)
    wandb.finish()
    print("Finished training")

In [18]:
cfg = Config(
    batch_size=32,
    epochs=25,
    lr=0.001,
    lr_schedule_factor=0.1,
    lr_schedule_patience=5,
    p_gaussblur=0,
    p_equalize=0,
    p_clahe=0.5,
    p_horizontalflip=0.3,
    rotation=15,
    p_affine=0.3,
    loss_weight=0.5,
    resampling_strategy='undersampling',
    model="ShuffleNet",
)

In [19]:
train_model(cfg)

Using cache found in /Users/moritz/.cache/torch/hub/pytorch_vision_v0.10.0


Using device: cpu
Number of output features in ShuffleNet encoder:  1
Now train train train


Epoch 1/25 - Avg train Loss: 0.474475:  12%|█▎        | 1/8 [00:40<04:43, 40.56s/it]


KeyboardInterrupt: 

## SWEEP

In [1]:

# Define the sweep configuration
sweep_config = {
    "method": "random",  # or "grid", or "bayes"
    "parameters": {
        "model": {
            "values": ["Task2EfficientNetB0"]
        },
        "lr": {
            "values": [1e-3]
        },
        "epochs": {
            "values": [30]
        },
        "batch_size": {
            "values": [16]
        },
        "p_gaussblur": {
            "values": [0, 0.3]
        },
        "p_equalize": {
            "values": [0]
        },
        "p_clahe": {
            "values": [0.3, 0.5]
        },
        "p_horizontalflip": {
            "values": [0.5]
        },
        "rotation": {
            "values": [10]
        },
        "p_affine": {
            "values": [0.3]
        },
        "loss_weight": {
            "values": [0.5]
        },
        "lr_schedule_factor": {
            "values": [0.1]
        },
        "lr_schedule_patience": {
            "values": [5]
        },
        "resampling_strategy": {
            "values": ['oversampling', 'undersampling']
        }
    }
}

# Initialize the sweep
#sweep_id = wandb.sweep(sweep=sweep_config, project="task2")

# Start the sweep
#wandb.agent(sweep_id, function=train_model)

'''
class ResamplingStrategy(ABC):
    @abstractmethod
    def apply(self, dataloader, epoch):
        pass


    def _calculate_weights(self, dataloader):
        dataset = dataloader.dataset
        labels = [item[1] for item in dataset]
        
        class_counts = Counter(labels)
        
        weights = [1.0 / class_counts[label] for label in labels]
        return weights
    

class DefaultResamplingStrategy(ResamplingStrategy):
    def apply(self, dataloader, epoch):
        # no resampling is applied by default
        return dataloader
    

class OversamplingResamplingStrategy(ResamplingStrategy):
    def apply(self, dataloader, epoch):

        weights = self._calculate_weights(dataloader)

        sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

        oversampled_loader = DataLoader(dataloader.dataset, 
                                        batch_size=dataloader.batch_size, 
                                        sampler=sampler, 
                                        num_workers=dataloader.num_workers)
        
        return oversampled_loader
    
    
'''