In [1]:
import torch
import monai
import dataloader
import os
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.nn import MSELoss
from torch.nn import BCEWithLogitsLoss
import numpy as np
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from monai.networks.nets import UNet
import time

In [2]:
data_path = '../data/hm30rad/'
df = pd.read_csv('../data/train_labels.csv')
names = df['tomo_id'].astype(str).unique().tolist()

In [3]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from optuna.pruners import MedianPruner
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the objective function for Optuna
def objective(trial):
    # ------------------------------ #
    #        HYPERPARAMETERS         #
    # ------------------------------ #
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    decay = trial.suggest_float('decay', 0.3, 1.0)
    reg_str = trial.suggest_float("regularization_strength", 1e-4, 1e-2, log=True)
    dropout = 0.3
    pos_weight_val = 10000.0
    pos_weight = torch.tensor([pos_weight_val], device=device)  # Replace 'value' with your desired pos_weight

    num_epochs = 7
    batch_size = 32
    ''' Unused Hyperparameters
    # dropout = trial.suggest_float("dropout", 0.25, 0.5)
    
    # alpha = trial.suggest_float("alpha", 0.25, 1.0)
    # theta = trial.suggest_float("theta", 0.1, 0.9)
    # theta = 0.6
    # gamma = trial.suggest_float("gamma", 2.0, 5.0)
    '''
    
    # ------------------------------ #
    #              DATA              #
    # ------------------------------ #
    aug_params = {
        "patch_size": (80,80,80),
        "final_size":   (80,80,80),
        "flip_prob":  0.5,
        "rot_prob":   0.5,
        "scale_prob": 1.0,
        "rot_range":  np.pi,
        "scale_range": 0.2
    }
    
    train_names, val_names = train_test_split(names, test_size=0.2)

    train_dataset = dataloader.MMapDataset(names=train_names, path=data_path, gpu=True, aug_params=aug_params)
    val_dataset = dataloader.MMapDataset(names=val_names, path=data_path, gpu=True, aug_params=aug_params)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, collate_fn=dataloader.custom_collate, shuffle=True, pin_memory=True) # can put data aug in collate func later to have optional aug
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, collate_fn=dataloader.custom_collate, shuffle=False, pin_memory=True)
    
    # ------------------------------ #
    #             MODEL              #
    # ------------------------------ #
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(64, 128, 256, 512),
        strides=(2, 2, 2),
        num_res_units=2,
        dropout=dropout,
    ).to(device)
    

    # ------------------------------ #
    #        TRAINING METHODS        #
    # ------------------------------ #
    bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    def add_regularization_loss(model, regularization_strength):
        reg_loss = 0
        for param in model.parameters():
            reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss
    
    ''' Unused Loss Functions
    # dice_loss = DiceLoss(to_onehot_y=False, softmax=True, weight=weights).to(device)
    # focal_loss = FocalLoss(to_onehot_y=False, use_softmax=True, weight=weights, gamma=gamma ).to(device)
    # mse_loss = nn.MSELoss(reduction='none')
    '''
    
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)

    for epoch in range(num_epochs):
    # ------------------------------ #
    #             TRAIN              #
    # ------------------------------ #
        model.train()
        i=0
        batch_loss = 0
        train_data_time = 0
        train_forward_time = 0
        train_backward_time = 0
        total_start = time.perf_counter()
        for batch in train_loader:
            start_data = time.perf_counter()
            input, target = batch['src'].to(device), batch['tgt'].to(device)
            train_data_time += time.perf_counter() - start_data

            optimizer.zero_grad()
            
            start_forward = time.perf_counter()
            output = model(input)
            loss = bce_loss(output, target)
            reg_loss = add_regularization_loss(model, reg_str)
            loss+=reg_loss
            train_forward_time += time.perf_counter() - start_forward

            start_backward = time.perf_counter()
            loss.backward()
            optimizer.step()
            train_backward_time += time.perf_counter() - start_backward
            
            print(f'epoch {epoch} batch {i} loss: {loss.item()}')
            i+=1
            batch_loss+=loss.item()
        total_time = time.perf_counter() - total_start
        print(f'Batch Loss: {batch_loss / len(train_loader)}')
        print(f'Total Time: {total_time:.2f}s, Data Load Time: {(total_time - train_forward_time - train_backward_time):.2f}s, Forward Time: {train_forward_time:.2f}s, Backward Time: {train_backward_time:.2f}s')
        
        scheduler.step()
        
        # ---------- #
        # VALIDATION #
        # ---------- #
        model.eval()
        val_loss = 0
        i=0
        with torch.no_grad():
            for batch in val_loader:
                input, target = batch['src'].to(device), batch['tgt'].to(device)
                output = model(input)
                # loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
                loss = bce_loss(output, target)
                print(f'epoch {epoch} batch {i} loss: {loss.item()}')
                i+=1
                val_loss+=loss.item()
                
        val_loss /= len(val_loader)
        
        trial.report(val_loss, epoch)
        print(f"Epoch {epoch} loss: {val_loss}")
        
        if trial.should_prune():
            
            raise optuna.exceptions.TrialPruned()

    return val_loss

n_epochs = 25

study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=2, n_warmup_steps=4))
study.optimize(objective, n_trials=15)

print("Best hyperparameters:", study.best_params)

[I 2025-03-26 21:30:51,165] A new study created in memory with name: no-name-513ffb9b-4c1b-4dc2-8616-79c9333a1661
[W 2025-03-26 21:31:01,056] Trial 0 failed with parameters: {'lr': 2.198939685901193e-05, 'decay': 0.8216810617292174, 'regularization_strength': 0.0009543806812201595} because of the following error: RuntimeError('Caught RuntimeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n  File "c:\\Users\\hmhor\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\monai\\transforms\\transform.py", line 141, in apply_transform\n    return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)\n  File "c:\\Users\\hmhor\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\monai\\transforms\\transform.py", line 98, in _apply_transform\n    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)\n  File "h:\\Projects\\Kaggle\\BYU-Locating-Bacterial-Flagellar-Motors\\preproce

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\monai\transforms\transform.py", line 141, in apply_transform
    return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\monai\transforms\transform.py", line 98, in _apply_transform
    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
  File "h:\Projects\Kaggle\BYU-Locating-Bacterial-Flagellar-Motors\preprocess\augment.py", line 43, in __call__
    shape = data['src'].shape
AttributeError: 'tuple' object has no attribute 'shape'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "h:\Projects\Kaggle\BYU-Locating-Bacterial-Flagellar-Motors\preprocess\dataloader.py", line 36, in __getitem__
    out_src, out_tgt = augment.rand_aug(
  File "h:\Projects\Kaggle\BYU-Locating-Bacterial-Flagellar-Motors\preprocess\augment.py", line 155, in rand_aug
    result = augment(sample)
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\monai\transforms\compose.py", line 335, in __call__
    result = execute_compose(
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\monai\transforms\compose.py", line 111, in execute_compose
    data = apply_transform(
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\monai\transforms\transform.py", line 171, in apply_transform
    raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <augment.RandCropMMapd object at 0x00000224A4BE5510>
