In [1]:
import torch
import torch.utils.data as torch_split
import numpy as np
import dataset
import test
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.data import DataLoader
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')


In [2]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim96-no-corner"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

train_dataset, val_dataset = torch_split.random_split(data, [trn, val])
# train_dataset = dataset.UNetDataset(path=path, train=True)
# val_dataset = dataset.UNetDataset(path=path, val=True)

In [3]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet

# Define the objective function for Optuna
def objective(trial):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyperparameters to optimize
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    batch_size = trial.suggest_categorical("batch_size", [4, 8, 16])
    dropout = trial.suggest_uniform("dropout", 0.0, 0.5)
    regularization_type = trial.suggest_categorical("regularization_type", ["none", "L1", "L2"])

    # Suggest regularization strength only if regularization is used
    if regularization_type != "none":
        regularization_strength = trial.suggest_loguniform("regularization_strength", 1e-5, 1e-3)
    else:
        regularization_strength = 0  # No regularization if "none"

    print(f"LR: {lr}\tbatch size: {batch_size}\tdropout: {dropout}\treg type: {regularization_type}\treg str: {regularization_strength}")
    
    # Model initialization
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=7,
        channels=(64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        dropout=dropout,
    ).to(device)

    # Loss function and optimizer
    criterion = DiceLoss(to_onehot_y=True, softmax=True).to(device)
    optimizer = Adam(model.parameters(), lr=lr)

    # DataLoader (update batch size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

    # Regularization setup
    def add_regularization_loss(model, regularization_type, regularization_strength):
        reg_loss = 0
        if regularization_type == "L1":
            for param in model.parameters():
                reg_loss += torch.sum(torch.abs(param))
        elif regularization_type == "L2":
            for param in model.parameters():
                reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss

    # Early stopping parameters
    patience = 2
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    elapsed = 0

    # Training loop
    num_epochs = 15  # Shorten for quicker optimization
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            elapsed += 1
            
            inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Add regularization loss if applicable
            if regularization_type != "none":
                reg_loss = add_regularization_loss(model, regularization_type, regularization_strength)
                loss += reg_loss

            loss.backward()
            optimizer.step()

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

        val_loss /= len(val_loader)

        # Early stopping check
        if val_loss <= best_val_loss * 0.99:
            best_val_loss = val_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch + 1}")
                break
    print(f"Validation loss: {best_val_loss} after {elapsed} epochs")
    return best_val_loss

# Run the Optuna optimization
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=25)

# Print the best parameters
print("Best hyperparameters:", study.best_params)

[I 2024-12-22 11:30:42,695] A new study created in memory with name: no-name-6e8b94f9-f607-498b-a3fb-39dc953982af
  lr = trial.suggest_loguniform("lr", 1e-5, 1e-3)
  dropout = trial.suggest_uniform("dropout", 0.0, 0.5)


LR: 0.0007690852832005317	batch size: 16	dropout: 0.20627530165636548	reg type: none	reg str: 0


[I 2024-12-22 11:37:38,925] Trial 0 finished with value: 0.7081862275417035 and parameters: {'lr': 0.0007690852832005317, 'batch_size': 16, 'dropout': 0.20627530165636548, 'regularization_type': 'none'}. Best is trial 0 with value: 0.7081862275417035.


Early stopping triggered at epoch 8
Validation loss: 0.7081862275417035 after 400 epochs
LR: 8.02200311056887e-05	batch size: 8	dropout: 0.38585196148834683	reg type: L1	reg str: 0.000561752035750931


  regularization_strength = trial.suggest_loguniform("regularization_strength", 1e-5, 1e-3)
[I 2024-12-22 11:43:47,207] Trial 1 finished with value: 0.8221715784072876 and parameters: {'lr': 8.02200311056887e-05, 'batch_size': 8, 'dropout': 0.38585196148834683, 'regularization_type': 'L1', 'regularization_strength': 0.000561752035750931}. Best is trial 0 with value: 0.7081862275417035.


Early stopping triggered at epoch 7
Validation loss: 0.8221715784072876 after 700 epochs
LR: 2.5814446343747655e-05	batch size: 16	dropout: 0.3886212905293643	reg type: none	reg str: 0


[I 2024-12-22 11:56:45,774] Trial 2 finished with value: 0.8416277399429908 and parameters: {'lr': 2.5814446343747655e-05, 'batch_size': 16, 'dropout': 0.3886212905293643, 'regularization_type': 'none'}. Best is trial 0 with value: 0.7081862275417035.


Validation loss: 0.8416277399429908 after 750 epochs
LR: 2.160379804543174e-05	batch size: 4	dropout: 0.38926445786989194	reg type: L1	reg str: 0.00010899501786911792


[I 2024-12-22 12:10:18,949] Trial 3 finished with value: 0.7470451152324676 and parameters: {'lr': 2.160379804543174e-05, 'batch_size': 4, 'dropout': 0.38926445786989194, 'regularization_type': 'L1', 'regularization_strength': 0.00010899501786911792}. Best is trial 0 with value: 0.7081862275417035.


Validation loss: 0.7470451152324676 after 2985 epochs
LR: 1.4122494303203815e-05	batch size: 4	dropout: 0.31329520642911346	reg type: L1	reg str: 3.7326680106871646e-05


[I 2024-12-22 12:19:06,155] Trial 4 finished with value: 0.8379334139823914 and parameters: {'lr': 1.4122494303203815e-05, 'batch_size': 4, 'dropout': 0.31329520642911346, 'regularization_type': 'L1', 'regularization_strength': 3.7326680106871646e-05}. Best is trial 0 with value: 0.7081862275417035.


Early stopping triggered at epoch 10
Validation loss: 0.8379334139823914 after 1990 epochs
LR: 4.908826331340077e-05	batch size: 16	dropout: 0.04782889900625786	reg type: L2	reg str: 0.00010664495004807392


[I 2024-12-22 12:31:20,458] Trial 5 finished with value: 0.7503093297664936 and parameters: {'lr': 4.908826331340077e-05, 'batch_size': 16, 'dropout': 0.04782889900625786, 'regularization_type': 'L2', 'regularization_strength': 0.00010664495004807392}. Best is trial 0 with value: 0.7081862275417035.


Validation loss: 0.7503093297664936 after 750 epochs
LR: 8.98074633410313e-05	batch size: 8	dropout: 0.3329289582515136	reg type: L2	reg str: 0.00010828447370348184


[I 2024-12-22 12:41:26,241] Trial 6 finished with value: 0.6713182854652405 and parameters: {'lr': 8.98074633410313e-05, 'batch_size': 8, 'dropout': 0.3329289582515136, 'regularization_type': 'L2', 'regularization_strength': 0.00010828447370348184}. Best is trial 6 with value: 0.6713182854652405.


Validation loss: 0.6713182854652405 after 1500 epochs
LR: 3.0585090990070275e-05	batch size: 4	dropout: 0.446570343495282	reg type: L2	reg str: 1.138249597272998e-05


[I 2024-12-22 12:51:24,421] Trial 7 finished with value: 0.7294552326202393 and parameters: {'lr': 3.0585090990070275e-05, 'batch_size': 4, 'dropout': 0.446570343495282, 'regularization_type': 'L2', 'regularization_strength': 1.138249597272998e-05}. Best is trial 6 with value: 0.6713182854652405.


Validation loss: 0.7294552326202393 after 2985 epochs
LR: 1.6818119268303073e-05	batch size: 16	dropout: 0.20991155717642673	reg type: L2	reg str: 0.0004180013351355291


[I 2024-12-22 13:04:16,890] Trial 8 finished with value: 0.8550590322567866 and parameters: {'lr': 1.6818119268303073e-05, 'batch_size': 16, 'dropout': 0.20991155717642673, 'regularization_type': 'L2', 'regularization_strength': 0.0004180013351355291}. Best is trial 6 with value: 0.6713182854652405.


Validation loss: 0.8550590322567866 after 750 epochs
LR: 0.0004836871643785142	batch size: 8	dropout: 0.047232011802912255	reg type: L2	reg str: 0.0003674062788900352


[I 2024-12-22 13:13:46,712] Trial 9 finished with value: 0.6730712532997132 and parameters: {'lr': 0.0004836871643785142, 'batch_size': 8, 'dropout': 0.047232011802912255, 'regularization_type': 'L2', 'regularization_strength': 0.0003674062788900352}. Best is trial 6 with value: 0.6713182854652405.


Early stopping triggered at epoch 11
Validation loss: 0.6730712532997132 after 1100 epochs
LR: 0.00020851804861614929	batch size: 8	dropout: 0.15916864328583366	reg type: L2	reg str: 3.534330338279551e-05


[I 2024-12-22 13:24:52,511] Trial 10 finished with value: 0.6503513836860657 and parameters: {'lr': 0.00020851804861614929, 'batch_size': 8, 'dropout': 0.15916864328583366, 'regularization_type': 'L2', 'regularization_strength': 3.534330338279551e-05}. Best is trial 10 with value: 0.6503513836860657.


Validation loss: 0.6503513836860657 after 1500 epochs
LR: 0.00023769047886156396	batch size: 8	dropout: 0.18175362635232617	reg type: L2	reg str: 4.3382030915381155e-05


[I 2024-12-22 13:33:55,123] Trial 11 finished with value: 0.6538152027130127 and parameters: {'lr': 0.00023769047886156396, 'batch_size': 8, 'dropout': 0.18175362635232617, 'regularization_type': 'L2', 'regularization_strength': 4.3382030915381155e-05}. Best is trial 10 with value: 0.6503513836860657.


Early stopping triggered at epoch 14
Validation loss: 0.6538152027130127 after 1400 epochs
LR: 0.0002568970689663098	batch size: 8	dropout: 0.1345357290524775	reg type: L2	reg str: 2.6229315472367273e-05


[I 2024-12-22 13:41:39,750] Trial 12 finished with value: 0.6608428263664246 and parameters: {'lr': 0.0002568970689663098, 'batch_size': 8, 'dropout': 0.1345357290524775, 'regularization_type': 'L2', 'regularization_strength': 2.6229315472367273e-05}. Best is trial 10 with value: 0.6503513836860657.


Early stopping triggered at epoch 12
Validation loss: 0.6608428263664246 after 1200 epochs
LR: 0.0002153282815777213	batch size: 8	dropout: 0.1364441202721979	reg type: L2	reg str: 3.8485051714922073e-05


[I 2024-12-22 13:48:07,035] Trial 13 finished with value: 0.6745794796943665 and parameters: {'lr': 0.0002153282815777213, 'batch_size': 8, 'dropout': 0.1364441202721979, 'regularization_type': 'L2', 'regularization_strength': 3.8485051714922073e-05}. Best is trial 10 with value: 0.6503513836860657.


Early stopping triggered at epoch 10
Validation loss: 0.6745794796943665 after 1000 epochs
LR: 0.00016787493248612858	batch size: 8	dropout: 0.14651442752120203	reg type: L2	reg str: 1.6389156017103606e-05


[I 2024-12-22 13:57:15,113] Trial 14 finished with value: 0.6887184810638428 and parameters: {'lr': 0.00016787493248612858, 'batch_size': 8, 'dropout': 0.14651442752120203, 'regularization_type': 'L2', 'regularization_strength': 1.6389156017103606e-05}. Best is trial 10 with value: 0.6503513836860657.


Early stopping triggered at epoch 13
Validation loss: 0.6887184810638428 after 1300 epochs
LR: 0.0004088366985777809	batch size: 8	dropout: 0.2621971393883426	reg type: none	reg str: 0


[I 2024-12-22 14:10:00,408] Trial 15 finished with value: 0.6498052048683166 and parameters: {'lr': 0.0004088366985777809, 'batch_size': 8, 'dropout': 0.2621971393883426, 'regularization_type': 'none'}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 15
Validation loss: 0.6498052048683166 after 1500 epochs
LR: 0.0005294217889192675	batch size: 8	dropout: 0.2809754233475264	reg type: none	reg str: 0


[I 2024-12-22 14:20:36,219] Trial 16 finished with value: 0.6547260951995849 and parameters: {'lr': 0.0005294217889192675, 'batch_size': 8, 'dropout': 0.2809754233475264, 'regularization_type': 'none'}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 13
Validation loss: 0.6547260951995849 after 1300 epochs
LR: 0.0009590308360720139	batch size: 8	dropout: 0.09738479566681933	reg type: none	reg str: 0


[I 2024-12-22 14:25:43,287] Trial 17 finished with value: 0.6857056212425232 and parameters: {'lr': 0.0009590308360720139, 'batch_size': 8, 'dropout': 0.09738479566681933, 'regularization_type': 'none'}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 8
Validation loss: 0.6857056212425232 after 800 epochs
LR: 0.0004052327448167833	batch size: 8	dropout: 0.24336470545748473	reg type: none	reg str: 0


[I 2024-12-22 14:33:25,377] Trial 18 finished with value: 0.6622527933120728 and parameters: {'lr': 0.0004052327448167833, 'batch_size': 8, 'dropout': 0.24336470545748473, 'regularization_type': 'none'}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 12
Validation loss: 0.6622527933120728 after 1200 epochs
LR: 0.0001240640157381763	batch size: 4	dropout: 0.2632929365448615	reg type: none	reg str: 0


[I 2024-12-22 14:44:03,051] Trial 19 finished with value: 0.6611132884025573 and parameters: {'lr': 0.0001240640157381763, 'batch_size': 4, 'dropout': 0.2632929365448615, 'regularization_type': 'none'}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 13
Validation loss: 0.6611132884025573 after 2587 epochs
LR: 0.00035369642211035966	batch size: 8	dropout: 0.007353364898175535	reg type: none	reg str: 0


[I 2024-12-22 14:47:53,531] Trial 20 finished with value: 0.7126597714424133 and parameters: {'lr': 0.00035369642211035966, 'batch_size': 8, 'dropout': 0.007353364898175535, 'regularization_type': 'none'}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 6
Validation loss: 0.7126597714424133 after 600 epochs
LR: 0.00025795459593388505	batch size: 8	dropout: 0.18800602411646955	reg type: L2	reg str: 5.3601336211126e-05


[I 2024-12-22 14:56:15,160] Trial 21 finished with value: 0.6547290349006653 and parameters: {'lr': 0.00025795459593388505, 'batch_size': 8, 'dropout': 0.18800602411646955, 'regularization_type': 'L2', 'regularization_strength': 5.3601336211126e-05}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 13
Validation loss: 0.6547290349006653 after 1300 epochs
LR: 0.00012770574121546735	batch size: 8	dropout: 0.1759972497951609	reg type: L2	reg str: 6.223502829884042e-05


[I 2024-12-22 15:07:45,942] Trial 22 finished with value: 0.671155903339386 and parameters: {'lr': 0.00012770574121546735, 'batch_size': 8, 'dropout': 0.1759972497951609, 'regularization_type': 'L2', 'regularization_strength': 6.223502829884042e-05}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 15
Validation loss: 0.671155903339386 after 1500 epochs
LR: 0.00018330211028207545	batch size: 8	dropout: 0.09018376947727601	reg type: L1	reg str: 0.0002595641269967224


[I 2024-12-22 15:15:31,993] Trial 23 finished with value: 0.7843618059158325 and parameters: {'lr': 0.00018330211028207545, 'batch_size': 8, 'dropout': 0.09018376947727601, 'regularization_type': 'L1', 'regularization_strength': 0.0002595641269967224}. Best is trial 15 with value: 0.6498052048683166.


Early stopping triggered at epoch 9
Validation loss: 0.7843618059158325 after 900 epochs
LR: 0.00032731036287932553	batch size: 8	dropout: 0.24536040493307815	reg type: L2	reg str: 2.0610968568687057e-05


[I 2024-12-22 15:28:26,822] Trial 24 finished with value: 0.6468523573875428 and parameters: {'lr': 0.00032731036287932553, 'batch_size': 8, 'dropout': 0.24536040493307815, 'regularization_type': 'L2', 'regularization_strength': 2.0610968568687057e-05}. Best is trial 24 with value: 0.6468523573875428.


Validation loss: 0.6468523573875428 after 1500 epochs
Best hyperparameters: {'lr': 0.00032731036287932553, 'batch_size': 8, 'dropout': 0.24536040493307815, 'regularization_type': 'L2', 'regularization_strength': 2.0610968568687057e-05}


In [None]:
import numpy as np
from ipywidgets import interact
import augment
import load
import matplotlib.pyplot as plt


root = load.get_root()

picks = load.get_picks_dict(root)

vol, coords, scales = load.get_run_volume_picks(root, level=0)

mask = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]))

In [26]:
params = augment.aug_params
params['patch_size'] = (96,96,96)
params['final_size'] = (96,96,96)
params['flip_prob'] - 0.0
params['rot_prob'] = 0.0
params['rot_range'] = 0.0



samples = augment.random_augmentation(vol, mask, num_samples=1, aug_params=params)

model.eval()

inp = np.array(samples[0]["source"].unsqueeze(0).unsqueeze(0))
inp = torch.from_numpy(inp).to(device)
pred_mask = model(inp)

src = samples[0]['source']
tgt = samples[0]['target']  # Mask with interest points (non-zero values)

pred_tgt = pred_mask.squeeze().cpu().detach()


pred_tgt = torch.argmax(pred_tgt, dim = 0).numpy()


print(f'# Particles Types Represented: {len(np.unique(tgt)) - 1}')
print(f'# Particles Types Predicted: {len(np.unique(pred_tgt)) - 1}')


def plot_cross_section(i):
    plot_vol = tgt
    plot_mask = pred_tgt
    
    plt.figure(figsize=(15, 5))
    alpha = 0.3

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(plot_vol[i, :, :], cmap="viridis")
    plt.imshow(plot_mask[i, :, :], cmap="Reds", alpha=alpha)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(plot_vol[:, i, :], cmap="viridis")
    plt.imshow(plot_mask[:, i, :], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(plot_vol[:, :, i], cmap="viridis")
    plt.imshow(plot_mask[:, :, i], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, tgt.shape[0] - 1))

# Particles Types Represented: 4
# Particles Types Predicted: 6


interactive(children=(IntSlider(value=47, description='i', max=95), Output()), _dom_classes=('widget-interact'…

<function __main__.plot_cross_section(i)>