In [None]:
import torch
import timm
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
import numpy as np
from pathlib import Path
from datetime import datetime
import joblib
import pandas as pd
from uwf_data import get_dataloaders
from timm.utils import ModelEmaV2

# Configuration dictionary
config = {
    "dataset": "tsk1",
    "model_name": 'mobilenetv3_rw',
    "batch_size": 32,
    "epochs": 5,
    "learning_rate": 1e-4,
    "weight_decay": 1e-2,
    "optimizer": "AdamW",
    "targets": ['image quality level'],
    "mixed_precision": False,
    "dtype": torch.bfloat16,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "resolution": (254, 200),
    "aug_type": "quality_rev3",
    "aug_prob": 0.95,
    "grad_clip_norm": 1.0,
    "label_smoothing_min_max": (0.001, 0.999),
    "model_to_load": None,
    "uwf_task1_extra_augs": True,
    "uwf_task1_synthetic_bad_only": False,
    "use_ema": True,
    "ema_decay": 0.99,
    "ema_start_epoch": 1,
    "n_folds": 30,
    "cv_index": 0,
}
config["num_classes"] = len(config["targets"])

df = pd.read_csv('uwf_task1_labels.csv')
skf = StratifiedKFold(n_splits=config["n_folds"], shuffle=True, random_state=42)

for fold, (train_index, val_index) in enumerate(skf.split(df, df['image quality level'])):
    if fold == config["cv_index"]:
        train_ids = df.iloc[train_index]['image'].values
        val_ids = df.iloc[val_index]['image'].values
        break

train_df = df[df.image.isin(train_ids)].copy()
if config['uwf_task1_synthetic_bad_only']:
    train_df_to_use = train_df[train_df['image quality level']==1].copy()
else:
    train_df_to_use = train_df
val_df = df[df.image.isin(val_ids)]

train_loader, val_loader = get_dataloaders(train_df_to_use, val_df, target_cols=config["targets"], batch_size=config['batch_size'],
                                           uwf_task1_extra_augs=config['uwf_task1_extra_augs'],
                                           res=config["resolution"], aug_type=config["aug_type"], aug_prob=config["aug_prob"])
print(len(train_loader.dataset))

current_time = datetime.now().strftime("%m%d%H%M%S")
res_str = f"{config['resolution'][0]}x{config['resolution'][1]}"
save_name = f"miccai24/uwf_runs/{config['dataset']}_{config['model_name']}_{res_str}_{current_time}_fold{config['cv_index']}"
if config['uwf_task1_extra_augs']:
    save_name+='extraaug'
    if config['uwf_task1_synthetic_bad_only']:
        save_name+='synth_only'
save_name = save_name.replace('efficientnet', 'efn').replace('resnet', 'rn').replace('mobilenet', 'mn').replace('densenet', 'dn')
save_path = Path(save_name)
save_path.mkdir(exist_ok=True, parents=True)
print(f"Training started. Save name: {save_name}")
joblib.dump(config, save_path/f'config.joblib')

model = timm.create_model(config["model_name"], pretrained=True, num_classes=config["num_classes"])
if config['model_to_load']:
    state_dict = torch.load('miccai24/uwf_runs/model_to_load/final_model.pth')
    try:
        model = timm.create_model(config["model_name"], pretrained=False, num_classes=config["num_classes"])
        config["model_name"] = state_dict["model_name"]
    except:
        pass
    if state_dict.get('conv1.weight', torch.zeros([1,1])).shape[1] == 2:
        with torch.no_grad():
            state_dict['conv1.weight'] = torch.concat([state_dict['conv1.weight'], torch.zeros((64,1,7,7))], dim=1)    
    _keys_to_del = [k for k in state_dict.keys() if any([k.startswith('fc'), k.startswith('classifier')])]
    for k in _keys_to_del:
        del state_dict[k]
    model.load_state_dict(state_dict, strict=False)
    del state_dict

with torch.no_grad():
    per_target_means = torch.special.logit(torch.tensor(train_df[config['targets']].mean(axis=0).values))
    try:
        model.fc.weight.fill_(0.)
        model.fc.bias.data = torch.nn.Parameter(per_target_means.float())
    except:
        model.classifier.weight.fill_(0.)
        model.classifier.bias.data = torch.nn.Parameter(per_target_means.float())

model = model.to(config["device"])

ema_model = None

criterion = torch.nn.BCEWithLogitsLoss()
if config["optimizer"] == "Adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
elif config["optimizer"] == "AdamW":
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
else:
    raise ValueError(f"Unsupported optimizer: {config['optimizer']}")

scaler = torch.cuda.amp.GradScaler(enabled=config["mixed_precision"])
ls_min, ls_max = config['label_smoothing_min_max']

best_val_loss = np.inf
train_metrics = []
val_metrics = []
main_pbar = tqdm(range(config["epochs"]), desc="Epochs")
for epoch in main_pbar:
    if config["use_ema"] and epoch >= config["ema_start_epoch"]:
        if ema_model is None:
            print('Initialising EMA')
            ema_model = ModelEmaV2(model, decay=config["ema_decay"])
        
    model.train()
    epoch_loss = 0
    all_targets = []
    all_predictions = []
    
    pbar = tqdm(train_loader, desc="Train", leave=False)
    for batch in pbar:
        inputs = batch['image']
        targets = batch['target'].clamp(ls_min, ls_max)
        inputs, targets = inputs.to(config["device"]), targets.to(config["device"])
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=config["mixed_precision"], dtype=config["dtype"]):
            outputs = model(inputs)
            loss = criterion(outputs, targets.float())
        
        scaler.scale(loss).backward()
        if config["grad_clip_norm"] > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip_norm"])
        
        scaler.step(optimizer)
        scaler.update()
        
        if ema_model:
            ema_model.update(model)
        
        epoch_loss += loss.item() * inputs.size(0)
        all_targets.extend(targets.cpu().numpy())
        all_predictions.extend(outputs.float().squeeze().detach().cpu().numpy())
        
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "mem": f"{torch.cuda.max_memory_allocated() / 1e9:.2f}GB"
        })
    
    epoch_loss /= len(train_loader.dataset)
    epoch_auc = roc_auc_score(np.hstack(all_targets)>0.5, all_predictions)
    train_metrics.append({"loss": epoch_loss, "auc": epoch_auc})
        
    model.eval()
    if ema_model:
        ema_model.eval()
    val_loss = 0
    ema_val_loss = 0
    all_targets = []
    all_predictions = []
    ema_all_predictions = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Val", leave=False):
            inputs = batch['image']
            targets = batch['target']
            inputs, targets = inputs.to(config["device"]), targets.to(config["device"])
            
            with torch.cuda.amp.autocast(enabled=config["mixed_precision"], dtype=config["dtype"]):
                outputs = model(inputs)
                loss = criterion(outputs, targets.float())
                
                if ema_model:
                    ema_outputs = ema_model.module(inputs)
                    ema_loss = criterion(ema_outputs, targets.float())
            
            val_loss += loss.item() * inputs.size(0)
            if ema_model:
                ema_val_loss += ema_loss.item() * inputs.size(0)
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(outputs.float().squeeze().detach().cpu().numpy())
            if ema_model:
                ema_all_predictions.extend(ema_outputs.float().squeeze().detach().cpu().numpy())
    
    val_loss /= len(val_loader.dataset)
    val_auc = roc_auc_score(np.hstack(all_targets)>0.5, all_predictions)
    val_metrics.append({"loss": val_loss, "auc": val_auc})
    
    ep_str = f"Ep {epoch} Train L: {epoch_loss:.4f} AUC: {epoch_auc:.4f} Val L: {val_loss:.4f} AUC: {val_auc:.4f}"

    if ema_model:
        ema_val_loss /= len(val_loader.dataset)
        ema_val_auc = roc_auc_score(np.hstack(all_targets)>0.5, ema_all_predictions)
        ep_str += f" EMA L: {ema_val_loss:.4f} AUC: {ema_val_auc:.4f}"

    print(ep_str)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        sd = model.state_dict()
        sd["model_name"] = config["model_name"]
        sd["resolution"] = config['resolution']
        torch.save(sd, save_path/f'best_model.pth')
    
    if ema_model and ema_val_loss < best_val_loss:
        best_val_loss = ema_val_loss
        sd = ema_model.module.state_dict()
        sd["model_name"] = config["model_name"]
        sd["resolution"] = config['resolution']
        torch.save(sd, save_path/f'best_ema_model.pth')

    main_pbar.set_postfix({
        "loss": f"{epoch_loss:.4f}",
        "mem": f"{torch.cuda.max_memory_allocated() / 1e9:.2f}GB"
    })

sd = model.state_dict()
sd["model_name"] = config["model_name"]
sd["resolution"] = config['resolution']
torch.save(sd, save_path/f'final_model.pth')

if ema_model:
    sd = ema_model.module.state_dict()
    sd["model_name"] = config["model_name"]
    sd["resolution"] = config['resolution']
    torch.save(sd, save_path/f'final_ema_model.pth')

joblib.dump({
    "train_metrics": train_metrics,
    "val_metrics": val_metrics
}, save_path/f'metrics.joblib')

print(f"Training completed. Final model and metrics saved with prefix: {save_name}")

import shutil

# Regular model submissions
shutil.copy(save_path/f'final_model.pth', 'TemplateSubmission/')
shutil.make_archive('submissions/T1'+save_path.name, 'zip', 'TemplateSubmission/')

shutil.copy(save_path/f'best_model.pth', 'TemplateSubmission/final_model.pth')
shutil.make_archive('submissions/bT1'+save_path.name, 'zip', 'TemplateSubmission/')

# EMA model submissions
if config["use_ema"]:
    shutil.copy(save_path/f'final_ema_model.pth', 'TemplateSubmission/final_model.pth')
    shutil.make_archive('submissions/T1EMA'+save_path.name, 'zip', 'TemplateSubmission/')
    try:
        shutil.copy(save_path/f'best_ema_model.pth', 'TemplateSubmission/final_model.pth')
        shutil.make_archive('submissions/bT1EMA'+save_path.name, 'zip', 'TemplateSubmission/')
    except:
        pass
        
print("All submissions prepared, including EMA models if applicable.")