In [1]:
import torch
import torch.utils.data as torch_split
import numpy as np
import dataset
import test
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from torchmetrics.classification import F1Score
from monai.networks.nets import UNet
from monai.data import DataLoader
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
import visual

In [30]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim104-3000sample"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

train_dataset, val_dataset = torch_split.random_split(data, [trn, val])
# train_dataset = dataset.UNetDataset(path=path, train=True)
# val_dataset = dataset.UNetDataset(path=path, val=True)
labels = [
"background",
"apo-ferritin(E)",
"beta-amylase(NS)",
"beta-galactosidase(H)",
"ribosome(E)",
"thyroglobulin(H)",
"virus-like-particle(E)"
]

In [None]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from optuna.pruners import MedianPruner

# Define the objective function for Optuna
def objective(trial):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vis.new_trial()
    
    warmup_epochs = 0
    
    # Hyperparameters to optimize
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    decay = trial.suggest_float('decay', 0.3, 1.0)
    # dropout = trial.suggest_float("dropout", 0.25, 0.5)
    dropout = 0.3
    regularization_strength = trial.suggest_float("regularization_strength", 5e-5, 1e-2, log=True)
    # alpha = trial.suggest_float("alpha", 0.25, 1.0)
    # theta = trial.suggest_float("theta", 0.1, 0.9)
    theta = 0.6
    gamma = trial.suggest_float("gamma", 2.0, 5.0)
    
    # t_max = trial.suggest_int("t_max", np.ceil(0.3 * cosine_epochs), np.ceil(0.5 * cosine_epochs))
    
    # regularization_type = trial.suggest_categorical("regularization_type", ["none", "L1", "L2"])
    # Suggest regularization strength only if regularization is used
    # if regularization_type != "none":
    #     regularization_strength = trial.suggest_float("regularization_strength", 1e-5, 1e-3, log=True)
    # else:
    #     regularization_strength = 0

    # Model initialization
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=7,
        channels=(64, 128, 256, 512),
        strides=(2, 2, 2),
        num_res_units=2,
        dropout=dropout,
    ).to(device)

    # Loss function and optimizer
    
    weights = torch.tensor([0.0434743, 1.16546, 1.1661, 1.16513, 1.14281, 1.15554, 1.16149]).to(device)  # Example weights for classes

    dice_loss = DiceLoss(to_onehot_y=True, softmax=True, weight=weights).to(device)
    focal_loss = FocalLoss(to_onehot_y=True, use_softmax=True, weight=torch.tensor([0.0] + list(weights[1:])), gamma=gamma ).to(device)
    
    optimizer = Adam(model.parameters(), lr=lr)

    # Learning rate warmup (LinearLR) and then cosine annealing (CosineAnnealingLR)
    
    scheduler_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)
    
    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

    # Regularization setup
    def add_regularization_loss(model, regularization_type, regularization_strength):
        reg_loss = 0
        if regularization_type == "L1":
            for param in model.parameters():
                reg_loss += torch.sum(torch.abs(param))
        elif regularization_type == "L2":
            for param in model.parameters():
                reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss

    # num_epochs = warmup_epochs + cosine_epochs
    num_epochs = 25
    prec, rec = 0, 0
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)

            # Add regularization loss if applicable
            # if regularization_type != "none":
            reg_loss = add_regularization_loss(model, "L2", regularization_strength)
            loss += reg_loss

            loss.backward()
            optimizer.step()

        # Scheduler step after each epoch
        if epoch < warmup_epochs:
            scheduler_warmup.step()
        else:
            # scheduler_cosine.step()
            scheduler.step()
            

        # Validation loop
        model.eval()
        val_loss = 0
        precision_metric = MulticlassPrecision(num_classes=7, average='none').to(device)
        recall_metric = MulticlassRecall(num_classes=7, average='none').to(device)
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
                outputs = model(inputs)
                loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
                val_loss += loss.item()
                precision_metric.update(outputs.argmax(dim=1).flatten(), targets.flatten())
                recall_metric.update(outputs.argmax(dim=1).flatten(), targets.flatten())
        # print("batch done")
        val_loss /= len(val_loader)
        precision = precision_metric.compute().cpu()
        recall = recall_metric.compute().cpu()
        pr = torch.stack([precision, recall], dim=0)
        vis.report(val_loss, pr)
        
        class_weights = np.array([0.00621062, 0.16649395, 0.16658561, 0.16644739, 0.16325901, 0.16507644, 0.16592698])
        
        # print(precision.shape)
        # print(recall.shape)
        precision = precision.detach().numpy()
        recall = recall.detach().numpy()
        
        weighted_prec = np.sum(precision * class_weights)
        weighted_rec = np.sum(recall * class_weights)
        
        obj = val_loss / 0.6 + (1 - (0.75 * weighted_rec + 0.25 * weighted_prec) )
        
        # Report weighted loss between loss, prec, recall results to Optuna
        trial.report(obj, epoch)

        # Prune trial if necessary
        if trial.should_prune():
            
            raise optuna.exceptions.TrialPruned()

    # print(f"Final validation loss: {val_loss:.6g}")
    
    return val_loss

n_epochs = 25
vis = visual.loss_precision_recall(n_epochs, labels, 1.2)
vis.start()

# Run the Optuna optimization with Median Pruner
study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=4, n_warmup_steps=8))
study.optimize(objective, n_trials=10)

# Print the best parameters
print("Best hyperparameters:", study.best_params)

In [5]:
import pickle
with open("param_search_results.pkl", "wb") as f:
    pickle.dump(vis.data, f)

In [50]:
with open("param_search_results.pkl", "rb") as f:
    data = pickle.load(f)
print("\t\t\t", end="")
for label in labels: print(f"{label}   |   ", end="")
print()
for trial in range(9):
    epochs = len(data["losses"][trial])
    print(f"Trial {trial} loss {data['losses'][trial][epochs-1]:.3f}")
    print("\t\t\t", end="")
    for i in range(7):
        print(f"{data['label_pr'][trial][i][0][epochs-1].item():.3f}", end="\t|\t   ")
    print()
    print("\t\t\t",end="")
    for i in range(7):
        print(f"{data['label_pr'][trial][i][1][epochs-1].item():.3f}", end="\t|\t   ")
    print()
        

			background   |   apo-ferritin(E)   |   beta-amylase(NS)   |   beta-galactosidase(H)   |   ribosome(E)   |   thyroglobulin(H)   |   virus-like-particle(E)   |   
Trial 0 loss 0.575
			0.998	|	   0.001	|	   0.000	|	   0.000	|	   0.187	|	   0.021	|	   0.024	|	   
			0.312	|	   0.030	|	   0.041	|	   0.012	|	   0.961	|	   0.537	|	   0.439	|	   
Trial 1 loss 0.591
			0.980	|	   0.000	|	   0.000	|	   0.001	|	   0.068	|	   0.006	|	   0.003	|	   
			0.109	|	   0.024	|	   0.138	|	   0.077	|	   0.674	|	   0.073	|	   0.090	|	   
Trial 2 loss 0.577
			0.996	|	   0.000	|	   0.000	|	   0.001	|	   0.180	|	   0.040	|	   0.022	|	   
			0.189	|	   0.060	|	   0.031	|	   0.051	|	   0.922	|	   0.582	|	   0.377	|	   
Trial 3 loss 0.429
			0.991	|	   0.482	|	   0.000	|	   0.277	|	   0.637	|	   0.361	|	   0.693	|	   
			0.489	|	   0.792	|	   0.393	|	   0.542	|	   0.812	|	   0.599	|	   0.874	|	   
Trial 4 loss 0.586
			0.993	|	   0.000	|	   0.000	|	   0.001	|	   0.112	|	   0.021	|	   0.001	|	   
			0.233	|	 

In [None]:
# TAKING TRAIL 6: lr 0.002378749151980637 decay 0.6383495211595349 regularization_strength 0.0014973716879159318 gamma 3.461783291951009

In [None]:
# 'losses': [],   trials x epochs (per trial)
# 'label_pr': []  trials x labels x 2 (p & r) x epochs

# [I 2024-12-24 13:38:05,114] A new study created in memory with name: no-name-3d26efe3-d09b-41e6-9859-6e780bb52cf4
# [I 2024-12-24 14:39:11,594] Trial 0 finished with value: 0.575233225759707 and parameters: {'lr': 8.895982362405156e-05, 'decay': 0.4973760593275419, 'regularization_strength': 0.0025172052804279532, 'gamma': 4.996031337088363}. Best is trial 0 with value: 0.575233225759707.
# [I 2024-12-24 15:37:22,894] Trial 1 finished with value: 0.591125853751835 and parameters: {'lr': 1.2990040043526965e-05, 'decay': 0.6813198310736835, 'regularization_strength': 0.0024965757066016288, 'gamma': 4.369881718243024}. Best is trial 0 with value: 0.575233225759707.
# [I 2024-12-24 16:35:33,212] Trial 2 finished with value: 0.5770129514367957 and parameters: {'lr': 4.4440170274243755e-05, 'decay': 0.31695145632882915, 'regularization_strength': 8.74331159347397e-05, 'gamma': 4.2913558753361665}. Best is trial 0 with value: 0.575233225759707.
# [I 2024-12-24 17:49:47,555] Trial 3 finished with value: 0.42899969769151586 and parameters: {'lr': 0.0007172248192508939, 'decay': 0.8544127180269929, 'regularization_strength': 0.0051075629067560855, 'gamma': 4.62092221228817}. Best is trial 3 with value: 0.42899969769151586.
# [I 2024-12-24 18:25:46,168] Trial 4 pruned. 
# [I 2024-12-24 19:45:32,846] Trial 5 finished with value: 0.41076544240901347 and parameters: {'lr': 0.009262442723082055, 'decay': 0.8135586839953626, 'regularization_strength': 0.0001678451998607637, 'gamma': 3.6492772738208084}. Best is trial 5 with value: 0.41076544240901347.
# [I 2024-12-24 20:43:49,142] Trial 6 finished with value: 0.4059352302237561 and parameters: {'lr': 0.002378749151980637, 'decay': 0.6383495211595349, 'regularization_strength': 0.0014973716879159318, 'gamma': 3.461783291951009}. Best is trial 6 with value: 0.4059352302237561.
# [I 2024-12-24 21:41:53,587] Trial 7 finished with value: 0.460155421181729 and parameters: {'lr': 0.0011120877059222305, 'decay': 0.7041753205677597, 'regularization_strength': 0.00032375245652757075, 'gamma': 3.589166472143006}. Best is trial 6 with value: 0.4059352302237561.
# [W 2024-12-24 22:18:59,469] Trial 8 failed with parameters: {'lr': 0.0006692026907199952, 'decay': 0.6625227751501456, 'regularization_strength': 0.00014299575697903301, 'gamma': 4.666629649772913} because of the following error: KeyboardInterrupt().
