In [5]:
import torch
import torch.utils.data as torch_split
import numpy as np
import dataset
import test
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.data import DataLoader
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')


In [2]:
path = "H:/Projects/Kaggle/CZII-CryoET-Object-Identification/datasets/3D/dim96-no-corner"
data = dataset.UNetDataset(path=path)

tv_split = 0.8
trn = int(len(data) * tv_split)
val = len(data) - trn

train_dataset, val_dataset = torch_split.random_split(data, [trn, val])
# train_dataset = dataset.UNetDataset(path=path, train=True)
# val_dataset = dataset.UNetDataset(path=path, val=True)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model initialization
model = UNet(
    spatial_dims=3,
    in_channels=1,  # Assuming single-channel input (adjust as needed)
    out_channels=7,  # Assuming 7 classes for segmentation (adjust as needed)
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

# model = test.UNet3D(in_channels = 1, out_channels = 7)

# Loss function and optimizer
criterion = DiceLoss(to_onehot_y=True, softmax=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers = 4, collate_fn=dataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers = 4, collate_fn=dataset.collate_fn)

n_batch = 5

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    n = 0
    c_loss = 0
    for batch in train_loader:
        n += 1
        inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)
        # print(inputs.shape)
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # c_loss += loss.item()
        # if n % n_batch == 0:
        #     print(f"batch {n-n_batch}-{n} train loss: {c_loss}")
        #     c_loss = 0
        
        train_loss += loss.item()

    # Validation loop
    model.eval()
    val_loss = 0
    n = 0
    with torch.no_grad():
        for batch in val_loader:
            n += 1
            inputs, targets = batch['src'].float().to(device), batch['tgt'].long().to(device)

            outputs = model(inputs)

            loss = criterion(outputs, targets)

            # print(f"validation batch {n} done")
            
            val_loss += loss.item()

    # Print loss for this epoch
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")



Epoch 1/20 - Train Loss: 0.9324, Val Loss: 0.9155
Epoch 2/20 - Train Loss: 0.9022, Val Loss: 0.8855
Epoch 3/20 - Train Loss: 0.8724, Val Loss: 0.8632
Epoch 4/20 - Train Loss: 0.8407, Val Loss: 0.8252
Epoch 5/20 - Train Loss: 0.8110, Val Loss: 0.8009
Epoch 6/20 - Train Loss: 0.7848, Val Loss: 0.7800
Epoch 7/20 - Train Loss: 0.7673, Val Loss: 0.7715
Epoch 8/20 - Train Loss: 0.7522, Val Loss: 0.7602
Epoch 9/20 - Train Loss: 0.7394, Val Loss: 0.7544
Epoch 10/20 - Train Loss: 0.7275, Val Loss: 0.7447
Epoch 11/20 - Train Loss: 0.7183, Val Loss: 0.7360
Epoch 12/20 - Train Loss: 0.7012, Val Loss: 0.7247
Epoch 13/20 - Train Loss: 0.6870, Val Loss: 0.7243
Epoch 14/20 - Train Loss: 0.6742, Val Loss: 0.7162
Epoch 15/20 - Train Loss: 0.6617, Val Loss: 0.6999
Epoch 16/20 - Train Loss: 0.6460, Val Loss: 0.6945
Epoch 17/20 - Train Loss: 0.6377, Val Loss: 0.6918
Epoch 18/20 - Train Loss: 0.6265, Val Loss: 0.6927
Epoch 19/20 - Train Loss: 0.6224, Val Loss: 0.6925
Epoch 20/20 - Train Loss: 0.6053, Val Lo

In [None]:
import numpy as np
from ipywidgets import interact
import augment
import load
import matplotlib.pyplot as plt


root = load.get_root()

picks = load.get_picks_dict(root)

vol, coords, scales = load.get_run_volume_picks(root, level=0)

mask = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]))

In [26]:
params = augment.aug_params
params['patch_size'] = (96,96,96)
params['final_size'] = (96,96,96)
params['flip_prob'] - 0.0
params['rot_prob'] = 0.0
params['rot_range'] = 0.0



samples = augment.random_augmentation(vol, mask, num_samples=1, aug_params=params)

model.eval()

inp = np.array(samples[0]["source"].unsqueeze(0).unsqueeze(0))
inp = torch.from_numpy(inp).to(device)
pred_mask = model(inp)

src = samples[0]['source']
tgt = samples[0]['target']  # Mask with interest points (non-zero values)

pred_tgt = pred_mask.squeeze().cpu().detach()


pred_tgt = torch.argmax(pred_tgt, dim = 0).numpy()


print(f'# Particles Types Represented: {len(np.unique(tgt)) - 1}')
print(f'# Particles Types Predicted: {len(np.unique(pred_tgt)) - 1}')


def plot_cross_section(i):
    plot_vol = tgt
    plot_mask = pred_tgt
    
    plt.figure(figsize=(15, 5))
    alpha = 0.3

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(plot_vol[i, :, :], cmap="viridis")
    plt.imshow(plot_mask[i, :, :], cmap="Reds", alpha=alpha)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(plot_vol[:, i, :], cmap="viridis")
    plt.imshow(plot_mask[:, i, :], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(plot_vol[:, :, i], cmap="viridis")
    plt.imshow(plot_mask[:, :, i], cmap="Reds", alpha=alpha)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, tgt.shape[0] - 1))

# Particles Types Represented: 4
# Particles Types Predicted: 6


interactive(children=(IntSlider(value=47, description='i', max=95), Output()), _dom_classes=('widget-interact'…

<function __main__.plot_cross_section(i)>