In [1]:
import torch
import torch.utils.data as torch_split
import numpy as np
import dataset
import test
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.data import DataLoader
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')


In [2]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim96-no-corner"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

train_dataset, val_dataset = torch_split.random_split(data, [trn, val])
# train_dataset = dataset.UNetDataset(path=path, train=True)
# val_dataset = dataset.UNetDataset(path=path, val=True)

In [3]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet

# Define the objective function for Optuna
def objective(trial):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    warmup_epochs = 5
    cosine_epochs = 25
    batch_size = 16
    
    # Hyperparameters to optimize
    lr = trial.suggest_float("lr", 1e-5, 5e-3, log=True)
    dropout = trial.suggest_float("dropout", 0.0, 0.6)
    regularization_strength = trial.suggest_float("regularization_strength", 1e-5, 1e-2, log=True)
    t_max = trial.suggest_int("t_max", np.ceil(0.1 * cosine_epochs), np.ceil(0.5 * cosine_epochs))
    
    # regularization_type = trial.suggest_categorical("regularization_type", ["none", "L1", "L2"])
    # Suggest regularization strength only if regularization is used
    # if regularization_type != "none":
    #     regularization_strength = trial.suggest_float("regularization_strength", 1e-5, 1e-3, log=True)
    # else:
    #     regularization_strength = 0

    # Model initialization
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=7,
        channels=(64, 128, 256, 512),
        strides=(2, 2, 2),
        num_res_units=2,
        dropout=dropout,
    ).to(device)

    # Loss function and optimizer
    criterion = DiceLoss(to_onehot_y=True, softmax=True).to(device)
    optimizer = Adam(model.parameters(), lr=lr)

    # Learning rate warmup (LinearLR) and then cosine annealing (CosineAnnealingLR)
    
    scheduler_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(cosine_epochs*t_max))

    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

    # Regularization setup
    def add_regularization_loss(model, regularization_type, regularization_strength):
        reg_loss = 0
        if regularization_type == "L1":
            for param in model.parameters():
                reg_loss += torch.sum(torch.abs(param))
        elif regularization_type == "L2":
            for param in model.parameters():
                reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss

    num_epochs = warmup_epochs + cosine_epochs
    reached = 0
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Add regularization loss if applicable
            # if regularization_type != "none":
            reg_loss = add_regularization_loss(model, "L2", regularization_strength)
            loss += reg_loss

            loss.backward()
            optimizer.step()

        # Scheduler step after each epoch
        if epoch < warmup_epochs:
            scheduler_warmup.step()  # Warmup phase
        else:
            scheduler_cosine.step()  # Cosine annealing phase

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

        val_loss /= len(val_loader)

        # Report intermediate results to Optuna
        trial.report(val_loss, epoch)

        # Prune trial if necessary
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    print(f"Final validation loss: {val_loss:.6g}")
    return val_loss

# Run the Optuna optimization with Median Pruner
study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=4, n_warmup_steps=8))
study.optimize(objective, n_trials=25)

# Print the best parameters
print("Best hyperparameters:", study.best_params)

[I 2024-12-22 16:20:47,145] A new study created in memory with name: no-name-dd7c313b-bac0-448a-91d0-3342c18246e0


Trial 1: LR: 0.002257 | Decay: 0.8412 | batch size: 16 | dropout: 0.5468 | reg type: none | reg str: 0 | 

[I 2024-12-22 16:40:40,541] Trial 0 finished with value: 0.671778692648961 and parameters: {'lr': 0.0022568182210017046, 'lr_decay': 0.8411740729181514, 'batch_size': 16, 'dropout': 0.5467517131478999, 'regularization_type': 'none'}. Best is trial 0 with value: 0.671778692648961.


loss: 0.671779 | epoch 30
Trial 2: LR: 1.916e-05 | Decay: 0.8028 | batch size: 16 | dropout: 0.332 | reg type: L2 | reg str: 3.656e-05 | 

[I 2024-12-22 17:06:03,229] Trial 1 finished with value: 0.9137653616758493 and parameters: {'lr': 1.915656953031922e-05, 'lr_decay': 0.8027895876789398, 'batch_size': 16, 'dropout': 0.3319746917382432, 'regularization_type': 'L2', 'regularization_strength': 3.65616278612233e-05}. Best is trial 0 with value: 0.671778692648961.


loss: 0.913765 | epoch 30
Trial 3: LR: 0.0002571 | Decay: 0.8122 | batch size: 8 | dropout: 0.2157 | reg type: L2 | reg str: 0.0001786 | 

[I 2024-12-22 17:27:19,138] Trial 2 finished with value: 0.6554909825325013 and parameters: {'lr': 0.0002571366097468487, 'lr_decay': 0.8122059055367643, 'batch_size': 8, 'dropout': 0.21572053433527905, 'regularization_type': 'L2', 'regularization_strength': 0.00017863285757000076}. Best is trial 2 with value: 0.6554909825325013.


loss: 0.655491 | epoch 30
Trial 4: LR: 0.0001804 | Decay: 0.9858 | batch size: 16 | dropout: 0.1537 | reg type: L2 | reg str: 0.000255 | 

[I 2024-12-22 17:53:20,198] Trial 3 finished with value: 0.6666678786277771 and parameters: {'lr': 0.00018037415504867577, 'lr_decay': 0.9858063251996845, 'batch_size': 16, 'dropout': 0.1537436371119879, 'regularization_type': 'L2', 'regularization_strength': 0.0002549863953333922}. Best is trial 2 with value: 0.6554909825325013.


loss: 0.666668 | epoch 30
Trial 5: LR: 0.001035 | Decay: 0.9917 | batch size: 16 | dropout: 0.03217 | reg type: L1 | reg str: 0.0004499 | 

[I 2024-12-22 17:55:56,357] Trial 4 pruned. 


Trial 6: LR: 0.001287 | Decay: 0.8175 | batch size: 8 | dropout: 0.1341 | reg type: none | reg str: 0 | 

[I 2024-12-22 18:18:02,493] Trial 5 finished with value: 0.6580322718620301 and parameters: {'lr': 0.0012868692405141073, 'lr_decay': 0.8175013256611177, 'batch_size': 8, 'dropout': 0.13412583613540544, 'regularization_type': 'none'}. Best is trial 2 with value: 0.6554909825325013.


loss: 0.658032 | epoch 30
Trial 7: LR: 0.001363 | Decay: 0.8971 | batch size: 16 | dropout: 0.05967 | reg type: L1 | reg str: 0.0001089 | 

[I 2024-12-22 18:19:54,940] Trial 6 pruned. 


Trial 8: LR: 0.001544 | Decay: 0.9977 | batch size: 8 | dropout: 0.5024 | reg type: L2 | reg str: 0.0004908 | 

[I 2024-12-22 18:21:50,767] Trial 7 pruned. 


Trial 9: LR: 5.277e-05 | Decay: 0.9638 | batch size: 8 | dropout: 0.2418 | reg type: none | reg str: 0 | 

[I 2024-12-22 18:23:56,426] Trial 8 pruned. 


Trial 10: LR: 1.65e-05 | Decay: 0.9505 | batch size: 16 | dropout: 0.4929 | reg type: none | reg str: 0 | 

[I 2024-12-22 18:26:32,216] Trial 9 pruned. 


Trial 11: LR: 0.0002841 | Decay: 0.8714 | batch size: 8 | dropout: 0.3576 | reg type: L2 | reg str: 1.019e-05 | 

[I 2024-12-22 18:52:08,110] Trial 10 pruned. 


Trial 12: LR: 0.000368 | Decay: 0.8015 | batch size: 8 | dropout: 0.1861 | reg type: none | reg str: 0 | 

[I 2024-12-22 19:11:48,000] Trial 11 pruned. 


Trial 13: LR: 0.004215 | Decay: 0.8408 | batch size: 8 | dropout: 0.1217 | reg type: L2 | reg str: 0.0001008 | 

[I 2024-12-22 19:16:59,564] Trial 12 pruned. 


Trial 14: LR: 9.926e-05 | Decay: 0.8398 | batch size: 8 | dropout: 0.2413 | reg type: none | reg str: 0 | 

[I 2024-12-22 19:19:34,548] Trial 13 pruned. 


Trial 15: LR: 0.0004468 | Decay: 0.8714 | batch size: 8 | dropout: 0.392 | reg type: L1 | reg str: 3.101e-05 | 

[I 2024-12-22 19:22:10,555] Trial 14 pruned. 


Trial 16: LR: 0.0007197 | Decay: 0.8232 | batch size: 8 | dropout: 0.2558 | reg type: L2 | reg str: 0.0009881 | 

[I 2024-12-22 19:27:21,825] Trial 15 pruned. 


Trial 17: LR: 0.0001243 | Decay: 0.9246 | batch size: 8 | dropout: 0.09955 | reg type: none | reg str: 0 | 

[I 2024-12-22 19:29:56,243] Trial 16 pruned. 


Trial 18: LR: 0.0006565 | Decay: 0.8711 | batch size: 8 | dropout: 0.1915 | reg type: L2 | reg str: 3.381e-05 | 

[I 2024-12-22 19:32:32,732] Trial 17 pruned. 


Trial 19: LR: 0.003892 | Decay: 0.8232 | batch size: 8 | dropout: 0.002092 | reg type: none | reg str: 0 | 

[W 2024-12-22 19:43:35,153] Trial 18 failed with parameters: {'lr': 0.003892073860719418, 'lr_decay': 0.823209857500322, 'batch_size': 8, 'dropout': 0.0020923062452283536, 'regularization_type': 'none'} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\optuna\study\_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\hmhor\AppData\Local\Temp\ipykernel_4508\2062051421.py", line 63, in objective
    for batch in train_loader:
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\dataloader.py", line 630, in __next__
    data = self._next_data()
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-package

KeyboardInterrupt: 

In [None]:
import numpy as np
from ipywidgets import interact
import augment
import load
import matplotlib.pyplot as plt


root = load.get_root()

picks = load.get_picks_dict(root)

vol, coords, scales = load.get_run_volume_picks(root, level=0)

mask = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]))

In [26]:
params = augment.aug_params
params['patch_size'] = (96,96,96)
params['final_size'] = (96,96,96)
params['flip_prob'] - 0.0
params['rot_prob'] = 0.0
params['rot_range'] = 0.0



samples = augment.random_augmentation(vol, mask, num_samples=1, aug_params=params)

model.eval()

inp = np.array(samples[0]["source"].unsqueeze(0).unsqueeze(0))
inp = torch.from_numpy(inp).to(device)
pred_mask = model(inp)

src = samples[0]['source']
tgt = samples[0]['target']  # Mask with interest points (non-zero values)

pred_tgt = pred_mask.squeeze().cpu().detach()


pred_tgt = torch.argmax(pred_tgt, dim = 0).numpy()


print(f'# Particles Types Represented: {len(np.unique(tgt)) - 1}')
print(f'# Particles Types Predicted: {len(np.unique(pred_tgt)) - 1}')


def plot_cross_section(i):
    plot_vol = tgt
    plot_mask = pred_tgt
    
    plt.figure(figsize=(15, 5))
    alpha = 0.3

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(plot_vol[i, :, :], cmap="viridis")
    plt.imshow(plot_mask[i, :, :], cmap="Reds", alpha=alpha)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(plot_vol[:, i, :], cmap="viridis")
    plt.imshow(plot_mask[:, i, :], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(plot_vol[:, :, i], cmap="viridis")
    plt.imshow(plot_mask[:, :, i], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, tgt.shape[0] - 1))

# Particles Types Represented: 4
# Particles Types Predicted: 6


interactive(children=(IntSlider(value=47, description='i', max=95), Output()), _dom_classes=('widget-interact'…

<function __main__.plot_cross_section(i)>