In [2]:
import torch
import torch.utils.data as torch_split
import numpy as np
import dataset
import test
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler
from monai.losses import DiceLoss
from monai.losses import FocalLoss
from torchmetrics.classification import F1Score
from monai.networks.nets import UNet
from monai.data import DataLoader
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
from torch.optim import Adam
from monai.losses import DiceLoss
from monai.networks.nets import UNet

import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
import visual

In [2]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim104-3000sample"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

# train_dataset, val_dataset = torch_split.random_split(data, [trn, val])
train_dataset = dataset.UNetDataset(path=path, train=True, fold=1)
val_dataset = dataset.UNetDataset(path=path, val=True, fold=1)
labels = [
"background",
"apo-ferritin (easy)",
"beta-amylase (impossible, NS)",
"beta-galactosidase (hard)",
"ribosome (easy)",
"thyroglobulin (hard)",
"virus-like-particle (easy)"
]

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#TRIAL 6 HYPERPARAMETERS
batch_size = 16
num_epochs = 25

lr = 0.002378749151980637
decay = 0.6383495211595349
dropout = 0.3
regularization_strength = 0.0014973716879159318
theta = 0.6
gamma = 3.461783291951009

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=dropout,
).to(device)

vis = visual.loss_precision_recall(num_epochs, labels, 1.2)
vis.start()
vis.new_trial()

weights = torch.tensor([0.0434743, 1.16546, 1.1661, 1.16513, 1.14281, 1.15554, 1.16149]).to(device)  # Example weights for classes

dice_loss = DiceLoss(to_onehot_y=True, softmax=True, weight=weights).to(device)
focal_loss = FocalLoss(to_onehot_y=True, use_softmax=True, weight=torch.tensor([0.0] + list(weights[1:])), gamma=gamma ).to(device)

optimizer = Adam(model.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=decay)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=False, num_workers=4)

# Regularization setup
def add_regularization_loss(model, regularization_type, regularization_strength):
    reg_loss = 0
    if regularization_type == "L1":
        for param in model.parameters():
            reg_loss += torch.sum(torch.abs(param))
    elif regularization_type == "L2":
        for param in model.parameters():
            reg_loss += torch.sum(param ** 2)
    return regularization_strength * reg_loss

# num_epochs = warmup_epochs + cosine_epochs
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)

        reg_loss = add_regularization_loss(model, "L2", regularization_strength)
        loss += reg_loss

        loss.backward()
        optimizer.step()

    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0
    precision_metric = MulticlassPrecision(num_classes=7, average='none').to(device)
    recall_metric = MulticlassRecall(num_classes=7, average='none').to(device)
    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
            outputs = model(inputs)
            loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
            val_loss += loss.item()
            precision_metric.update(outputs.argmax(dim=1).flatten(), targets.flatten())
            recall_metric.update(outputs.argmax(dim=1).flatten(), targets.flatten())

    val_loss /= len(val_loader)
    precision = precision_metric.compute().cpu()
    recall = recall_metric.compute().cpu()
    pr = torch.stack([precision, recall], dim=0)
    vis.report(val_loss, pr)
    
    class_weights = np.array([0.00621062, 0.16649395, 0.16658561, 0.16644739, 0.16325901, 0.16507644, 0.16592698])
    
    precision = precision.detach().numpy()
    recall = recall.detach().numpy()

In [5]:
torch.save(model.state_dict(), "UNet_v1-1.pth")
labels = [
"background",
"apo-ferritin(E)",
"beta-amylase(NS)",
"beta-galactosidase(H)",
"ribosome(E)",
"thyroglobulin(H)",
"virus-like-particle(E)"
]

In [6]:
print("\t\t\t", end="")
for label in labels: print(f"{label}   |   ", end="")

extra_epochs = 10
for epoch in range(extra_epochs):
    model.train()
    for batch in train_loader:
        inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)

        reg_loss = add_regularization_loss(model, "L2", regularization_strength)
        loss += reg_loss

        loss.backward()
        optimizer.step()

    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0
    precision_metric = MulticlassPrecision(num_classes=7, average='none').to(device)
    recall_metric = MulticlassRecall(num_classes=7, average='none').to(device)
    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
            outputs = model(inputs)
            loss = (theta) * dice_loss(outputs, targets) + (1 - theta) * focal_loss(outputs, targets)
            val_loss += loss.item()
            precision_metric.update(outputs.argmax(dim=1).flatten(), targets.flatten())
            recall_metric.update(outputs.argmax(dim=1).flatten(), targets.flatten())

    val_loss /= len(val_loader)
    precision = precision_metric.compute().cpu()
    recall = recall_metric.compute().cpu()
    pr = torch.stack([precision, recall], dim=0)
    vis.report(val_loss, pr)
    
    class_weights = np.array([0.00621062, 0.16649395, 0.16658561, 0.16644739, 0.16325901, 0.16507644, 0.16592698])
    
    precision = precision.detach().numpy()
    recall = recall.detach().numpy()
    
    print(f"Epoch {epoch + num_epochs} loss {val_loss:.3f}")
    print("\t\t\t", end="")
    for i in range(7):
        print(f"{precision[i]:.3f}", end="\t|\t   ")
    print()
    print("\t\t\t",end="")
    for i in range(7):
        print(f"{recall[i]:.3f}", end="\t|\t   ")
    print()

			background   |   apo-ferritin(E)   |   beta-amylase(NS)   |   beta-galactosidase(H)   |   ribosome(E)   |   thyroglobulin(H)   |   virus-like-particle(E)   |   Epoch 25 loss 0.432
			0.991	|	   0.547	|	   0.117	|	   0.216	|	   0.535	|	   0.326	|	   0.556	|	   
			0.967	|	   0.754	|	   0.370	|	   0.635	|	   0.794	|	   0.622	|	   0.772	|	   
Epoch 26 loss 0.434
			0.990	|	   0.674	|	   0.102	|	   0.211	|	   0.525	|	   0.360	|	   0.677	|	   
			0.970	|	   0.642	|	   0.726	|	   0.456	|	   0.798	|	   0.534	|	   0.709	|	   
Epoch 27 loss 0.429
			0.988	|	   0.536	|	   0.198	|	   0.353	|	   0.593	|	   0.351	|	   0.594	|	   
			0.977	|	   0.761	|	   0.574	|	   0.448	|	   0.705	|	   0.528	|	   0.784	|	   
Epoch 28 loss 0.424
			0.991	|	   0.607	|	   0.158	|	   0.353	|	   0.521	|	   0.335	|	   0.763	|	   
			0.969	|	   0.736	|	   0.647	|	   0.509	|	   0.817	|	   0.627	|	   0.766	|	   
Epoch 29 loss 0.425
			0.989	|	   0.599	|	   0.150	|	   0.234	|	   0.616	|	   0.329	|	   0.603	|	   
			0.974

In [7]:
torch.save(model.state_dict(), "UNet_v1-2.pth")

In [5]:
# Inference
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
import load
import augment
import os
import torch
from monai.networks.nets import UNet

root = load.get_root()

picks = load.get_picks_dict(root)

runs = os.listdir('H:/Projects/Kaggle/CZII-CryoET-Object-Identification/data/train/static/ExperimentRuns')
run = 'TS_6_4'

In [6]:
vol, coords, scales = load.get_run_volume_picks(root, run=run, level=0)
params = augment.aug_params
params["final_size"] = (104,104,104)
params["flip_prob"] = 0.0
params["patch_size"] = (104,104,104)
params["rot_prob"] = 0.0


mask = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]))

In [20]:
from monai.networks.nets import UNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sample = augment.random_augmentation(vol, 
                            mask, 
                            num_samples=1, 
                            aug_params=params,
                            save=False)
src = sample[0]["source"].unsqueeze(0).to(device)
tgt = sample[0]["target"]

print(src.shape)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=0.1,
).to(device)
model.load_state_dict(torch.load("UNet_v1-2.pth"))


model.eval()
prediction = model(src).argmax(1).squeeze().to('cpu')
prediction.shape
src = src.to('cpu').squeeze()
print(tgt.shape)

torch.Size([1, 1, 104, 104, 104])
torch.Size([1, 104, 104, 104])


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact
%matplotlib inline


print(f'# Particles Types Represented: {len(np.unique(tgt)) - 1}')
print(f'# Particles Types Predicted: {len(np.unique(prediction)) - 1}')



def plot_cross_section(i):
    vol2 = tgt[0]
    vol1 = prediction
    
    plt.figure(figsize=(15, 5))
    alpha = 0.3

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(vol1[i, :, :], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[i, :, :], cmap="Reds", alpha=alpha)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(vol1[:, i, :], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[:, i, :], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(vol1[:, :, i], cmap="viridis", alpha=alpha)
    plt.imshow(vol2[:, :, i], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, prediction.shape[0] - 1))

# Particles Types Represented: 4
# Particles Types Predicted: 5


interactive(children=(IntSlider(value=51, description='i', max=103), Output()), _dom_classes=('widget-interact…

<function __main__.plot_cross_section(i)>