In [1]:
import torch
import torch.utils.data as torch_split
import numpy as np
import dataset
import test
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from torchmetrics.classification import F1Score
from monai.networks.nets import UNet
from monai.data import DataLoader
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from cstm_model import MaskToPointUNet
import optuna
import cstm_model
from optuna.pruners import MedianPruner

import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/postprocessing')

import visual
import metrics

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim104-with-pts"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

# train_dataset, val_dataset = torch_split.random_split(data, [trn, val])
train_dataset = dataset.UNetDataset(path=path, train=True, fold=1)
val_dataset = dataset.UNetDataset(path=path, val=True, fold=1)
labels = [
"background",
"apo-ferritin (easy)",
"beta-amylase (impossible, NS)",
"beta-galactosidase (hard)",
"ribosome (easy)",
"thyroglobulin (hard)",
"virus-like-particle (easy)"
]
radii = { 1:60,
          2:65,
          3:90,
          4:150,
          5:130,
          6:135 }

In [3]:
batch_size = 16
num_epochs = 10
weights = torch.tensor([0.0000001, 1.1666, 1.1666, 1.1666, 1.1666, 1.1666, 1.1666]).to(device)  # Example weights for classes

In [4]:
def objective(trial):
    
    # ---------------------------------------------
    # DEFINE HYPERPARAMETERS
    # ---------------------------------------------
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    decay = trial.suggest_float("decay", 0.5, 1.0)
    dropout = trial.suggest_float("dropout", 0.2, 0.5)
    regularization_strength = trial.suggest_float("regularization_strength", 1e-4, 1e-2, log=True)
    gamma = trial.suggest_float("gamma", 3.0, 5.0)
    
    confidence = 0.9
    theta = 0.6
    
    # ---------------------------------------------
    # DEFINE TRAINING TOOLS
    # ---------------------------------------------
    model = MaskToPointUNet(dropout=dropout).to(device)
    model.load_seg('UNet_v1-2.pth')

    dice_loss = DiceLoss(to_onehot_y=True, softmax=False, weight=weights).to(device)
    focal_loss = FocalLoss(to_onehot_y=True, use_softmax=False, weight=weights, gamma=gamma ).to(device)

    optimizer = Adam(model.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

    def add_regularization_loss(model, regularization_strength):
        reg_loss = 0
        for param in model.parameters():
            reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss

    # ---------------------------------------------
    # TRAIN
    # ---------------------------------------------
    for epoch in range(num_epochs):
        model.train_pts()
        for batch in train_loader:
            inputs, targets = batch['src'].float().to(device), batch['pts'].long().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)

            reg_loss = add_regularization_loss(model, regularization_strength)
            loss += reg_loss

            loss.backward()
            optimizer.step()

        scheduler.step()

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch['src'].float().to(device), batch['pts'].long().to(device)
                outputs = model(inputs)
                loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        trial.report(val_loss, epoch)
        print(val_loss)
        # Prune trial if necessary
        if trial.should_prune():
            
            raise optuna.exceptions.TrialPruned()
        
        
study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=4, n_warmup_steps=8))
study.optimize(objective, n_trials=4)

# Print the best parameters
print("Best hyperparameters:", study.best_params)

[I 2024-12-28 21:21:14,145] A new study created in memory with name: no-name-19f574f1-266c-4a1d-8779-b5281c56e847


0.623422904809316
0.6234559469752842
0.6233324276076423
0.6233419895172119
0.6233389443821378
0.6233372370402018
0.6233179105652703
0.6234946860207452


[W 2024-12-28 22:08:45,903] Trial 0 failed with parameters: {'lr': 0.002143120706777237, 'decay': 0.9495611682876461, 'dropout': 0.4360854021712893, 'regularization_strength': 0.00041585374190832676, 'gamma': 3.5715903639018607} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\optuna\study\_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\hmhor\AppData\Local\Temp\ipykernel_22228\2318281938.py", line 46, in objective
    loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    

KeyboardInterrupt: 