In [1]:
import torch
import torch.utils.data as torch_split
import torch.nn as nn
from torch.nn import MSELoss
from torch.nn import BCELoss
import numpy as np
import dataset
import heatmap_dataset
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from monai.networks.nets import UNet
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/postprocessing')
import visual

import metrics

In [2]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim112-gaussian-heatmap-1700"
# data = dataset.UNetDataset(path=path)

# tv_split = 0.8
# trn = int(len(data) * tv_split)
# val = len(data) - trn

# train_dataset, val_dataset = torch_split.random_split(data, [trn, val])

train_dataset = dataset.UNetDataset(path=path, train=True, in_mem=True)
val_dataset = dataset.UNetDataset(path=path, val=True, in_mem=True)

labels = [
"background",
"apo-ferritin(E)",
"beta-amylase(NS)",
"beta-galactosidase(H)",
"ribosome(E)",
"thyroglobulin(H)",
"virus-like-particle(E)"
]

In [4]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.networks.nets import UNet
from optuna.pruners import MedianPruner

# Define the objective function for Optuna
def objective(trial):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # ------------------------------ #
    #        HYPERPARAMETERS         #
    # ------------------------------ #
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    decay = trial.suggest_float('decay', 0.7, 1.0)
    gamma = trial.suggest_float('gamma', 2.5, 5.0)
    # alpha = trial.suggest_float('alpha', 0.05, 0.125)
    alpha = 0.0
    dropout = 0.35
    
    
    regularization_strength = trial.suggest_float("regularization_strength", 1e-4, 1e-2, log=True)
    
    # Model initialization
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=7,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        dropout=dropout,
    ).to(device)
    
    num_epochs = 10
    batch_size = 16

    # ------------------------------ #
    #        TRAINING METHODS        #
    # ------------------------------ #
    weights = torch.tensor([0.0434743, 1.16546, 1.1661, 1.16513, 1.14281, 1.15554, 1.16149]).to(device)  # Example weights for classes
    weights = torch.tensor([1.0,1.0,1.0,1.0,1.0,1.0,1.0]).to(device)
    dice_loss = DiceLoss(to_onehot_y=False, softmax=False, weight=weights).to(device)
    focal_loss = FocalLoss(to_onehot_y=False, use_softmax=False, weight=weights, gamma=gamma ).to(device)
    bce_loss = BCELoss()
    
    optimizer = Adam(model.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=heatmap_dataset.collate_fn, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=heatmap_dataset.collate_fn, shuffle=False, num_workers=4)

    def add_regularization_loss(model, regularization_strength):
        reg_loss = 0
        for param in model.parameters():
            reg_loss += torch.sum(param ** 2)
        return regularization_strength * reg_loss

    # ------------------------------ #
    #             TRAIN              #
    # ------------------------------ #
    print(alpha)
    for epoch in range(num_epochs):
        model.train()
        trn_loss = 0
        for batch in train_loader:
            input, target = batch['src'].to(device), batch['tgt'].to(device)
            optimizer.zero_grad()
            output = torch.softmax(model(input),dim=1)
            
            dice = dice_loss(output, target)
            focal = focal_loss(output, target)
            # print(f"BCE: {bce.item()}, Focal: {focal.item()}")
            
            loss = alpha * dice + (1 - alpha) * focal
            
            reg_loss = add_regularization_loss(model, regularization_strength)
            loss += reg_loss
            
            trn_loss += loss.item()
            loss.backward()
            optimizer.step()
            
        trn_loss /= len(train_loader)
        print(f"Epoch {epoch} train loss: {trn_loss} | ", end="")

        scheduler.step()
            
        # ---------- #
        # VALIDATION #
        # ---------- #
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                input, target = batch['src'].float().to(device), batch['tgt'].long().to(device)
                output = torch.softmax(model(input),dim=1)
                dice = dice_loss(output, target)
                focal = focal_loss(output, target)
                # print(f"BCE: {bce.item()}, Focal: {focal.item()}")
                
                loss = alpha * dice + (1 - alpha) * focal
                reg_loss = add_regularization_loss(model, regularization_strength)
                loss += reg_loss
                
                val_loss += loss.item()
                
        # print("batch done")
        val_loss /= len(val_loader)
        
        trial.report(val_loss, epoch)
        print(f"val loss: {val_loss}")
        
        if trial.should_prune():
            
            raise optuna.exceptions.TrialPruned()

    return val_loss

n_epochs = 25

study = optuna.create_study(direction="minimize", pruner=MedianPruner(n_startup_trials=4, n_warmup_steps=25))
study.optimize(objective, n_trials=15)

print("Best hyperparameters:", study.best_params)

[I 2025-01-06 20:18:13,420] A new study created in memory with name: no-name-7c0ad501-550a-436b-9a8a-2e657c37a7e7


0.0
Epoch 0 train loss: 0.35756284571610963 | val loss: 0.14566134568303823
Epoch 1 train loss: 0.09305984256686745 | val loss: 0.07682784972712398
Epoch 2 train loss: 0.04936106204167827 | val loss: 0.05677360272966325
Epoch 3 train loss: 0.03464291199714273 | val loss: 0.04796311864629388
Epoch 4 train loss: 0.028567430655379873 | val loss: 0.044609803007915616
Epoch 5 train loss: 0.025769271591043735 | val loss: 0.0425678095780313
Epoch 6 train loss: 0.024038283747958612 | val loss: 0.0414244313724339
Epoch 7 train loss: 0.022976479196286464 | val loss: 0.040582116693258286
Epoch 8 train loss: 0.022344149055553007 | val loss: 0.04030047380365431
Epoch 9 train loss: 0.021963281597901178 | 

[I 2025-01-06 20:37:52,302] Trial 0 finished with value: 0.039869413478299975 and parameters: {'lr': 0.00020624778640571535, 'decay': 0.7312970427827094, 'gamma': 4.971919416167764, 'regularization_strength': 0.0010761395256515213}. Best is trial 0 with value: 0.039869413478299975.


val loss: 0.039869413478299975
0.0
Epoch 0 train loss: 0.08948711817572405 | val loss: 0.07975284662097692
Epoch 1 train loss: 0.04985466497121276 | val loss: 0.0764064253307879
Epoch 2 train loss: 0.04808145790145947 | val loss: 0.07576812990009785
Epoch 3 train loss: 0.04765818341747745 | val loss: 0.07564633758738637
Epoch 4 train loss: 0.04750629588142856 | val loss: 0.07554898504167795
Epoch 5 train loss: 0.04741794288485915 | val loss: 0.07550577307119966


[W 2025-01-06 20:47:17,886] Trial 1 failed with parameters: {'lr': 0.0007493643515328585, 'decay': 0.9338778269871351, 'gamma': 3.700358843138672, 'regularization_strength': 0.0002945581454425103} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\optuna\study\_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\hmhor\AppData\Local\Temp\ipykernel_15940\1285635154.py", line 83, in objective
    loss.backward()
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "c:\Users\hmhor\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\autograd\__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt
[W 2025-01-06 20:47:17,925] Tria

KeyboardInterrupt: 

In [3]:
import optuna
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from optuna.pruners import MedianPruner

vis = visual.loss_precision_recall(20, labels, 2.0)
vis.start()
vis.new_trial()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------ #
#        HYPERPARAMETERS         #
# ------------------------------ #
lr = 5.0e-4
decay = 0.9
# dropout = trial.suggest_float("dropout", 0.25, 0.5)
dropout = 0.3
regularization_strength = 1e-3
# alpha = trial.suggest_float("alpha", 0.25, 1.0)
# theta = trial.suggest_float("theta", 0.1, 0.9)
theta = 0.5
# gamma = trial.suggest_float("gamma", 2.0, 5.0)
gamma = 4.0
# Model initialization
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=dropout,
).to(device)

num_epochs = 15
batch_size = 16

# ------------------------------ #
#        TRAINING METHODS        #
# ------------------------------ #
weights = torch.tensor([0.0434743, 1.16546, 1.1661, 1.16513, 1.14281, 1.15554, 1.16149]).to(device)  # Example weights for classes
# weights = torch.tensor([1.0,1.0,1.0,1.0,1.0,1.0,1.0])
dice_loss = DiceLoss(to_onehot_y=False, softmax=True, weight=weights).to(device)
focal_loss = FocalLoss(to_onehot_y=False, use_softmax=True, weight=weights, gamma=gamma ).to(device)
# mse_loss = nn.MSELoss(reduction='none')

optimizer = Adam(model.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

def add_regularization_loss(model, regularization_strength):
    reg_loss = 0
    for param in model.parameters():
        reg_loss += torch.sum(param ** 2)
    return regularization_strength * reg_loss

# ------------------------------ #
#             TRAIN              #
# ------------------------------ #
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        input, target = batch['src'].to(device), batch['tgt'].to(device)
        optimizer.zero_grad()
        output = model(input)
        loss = (theta) * dice_loss(output, target) + (1 - theta) * focal_loss(output, target)
        reg_loss = add_regularization_loss(model, regularization_strength)
        loss += reg_loss
        loss.backward()
        optimizer.step()

    scheduler.step()
        
    # ---------- #
    # VALIDATION #
    # ---------- #
    model.eval()
    val_loss = 0
    precision = torch.zeros((7))
    recall = torch.zeros((7))
    with torch.no_grad():
        for batch in val_loader:
            input, target = batch['src'].float().to(device), batch['tgt'].long().to(device)
            output = model(input)
            loss = (theta) * dice_loss(output, target) + (1 - theta) * focal_loss(output, target)
            reg_loss = add_regularization_loss(model, regularization_strength)
            loss += reg_loss
            val_loss += loss.item()
            p, r = metrics.continuous_precision_recall(target.to('cpu'), torch.softmax(output.to('cpu'), dim=1))
            precision += p
            recall += r
    val_loss /= len(val_loader)
    pr = torch.stack([precision, recall], dim=0)
    vis.report(val_loss, pr)
            
    print(f"Epoch {epoch} loss: {val_loss}")
    

Epoch 0 loss: 0.6432971689436171
Epoch 1 loss: 0.5598176187939115
Epoch 2 loss: 0.5357193417019315
Epoch 3 loss: 0.5257072117593553
Epoch 4 loss: 0.525143735938602


KeyboardInterrupt: 

In [4]:
torch.save(model.state_dict(), "HeatNet-1-0.pth")


In [1]:
# Inference
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
import load
import augment
import os
import torch
from monai.networks.nets import UNet
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
run = 'TS_6_4'
root = load.get_root()
picks = load.get_picks_dict(root)
vol, coords, scales = load.get_run_volume_picks(root, run=run, level=0)
for key in coords.keys():
    coords[key] = np.array(coords[key], dtype=np.int16)
coord_list = []
for key in coords.keys():
    coord_list.append(coords[key])
radii = [ 6,
          6,
          9,
          15,
          13,
          14 ]
params = augment.aug_params
params["final_size"] = (104,104,104)
params["flip_prob"] = 0.0
params["patch_size"] = (104,104,104)
params["rot_prob"] = 0.0

In [5]:
mask = load.create_exponential_heatmap_gpu(6, vol.shape, coord_list, radii).cpu().numpy()

In [35]:
sample = augment.random_augmentation_gpu(vol, 
                                mask,
                                num_samples=1, 
                                aug_params=params
                                )

src = sample[0]["source"].unsqueeze(0).to(device)
tgt = sample[0]["target"].unsqueeze(0).to(device)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=0.1,
).to(device)
model.load_state_dict(torch.load("HeatNet-1-0.pth"))

model.eval()
pred = torch.softmax(model(src), dim=1).to('cpu')
pred = pred.squeeze().to('cpu').detach().numpy()
src = src.to('cpu').squeeze()
tgt = tgt.squeeze(0).to('cpu').detach().numpy()
print(f"src shape {src.shape}")
print(f"tgt shape {tgt.shape}")
print(f"pred shape {pred.shape}")


src shape torch.Size([104, 104, 104])
tgt shape (7, 104, 104, 104)
pred shape (7, 104, 104, 104)


In [36]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact

def plot_cross_section(i):
    vol1 = 1.0 - src
    vol1 = np.zeros(pred[0].shape)
    vol2 = pred[1]
    
    plt.figure(figsize=(15, 5))
    alpha = 0.3

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(vol1[i, :, :], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[i, :, :], cmap="Reds", alpha=alpha)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(vol1[:, i, :], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[:, i, :], cmap="Blues", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(vol1[:, :, i], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[:, :, i], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, src.shape[0] - 1))

interactive(children=(IntSlider(value=51, description='i', max=103), Output()), _dom_classes=('widget-interact…

<function __main__.plot_cross_section(i)>