In [1]:
import numpy as np
import pandas as pd
import polars as pl
import os
import gc
import json
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import zarr
import napari
from scipy.spatial import KDTree
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import torch.cuda.amp as amp  # ✅ Import automatic mixed precision (AMP)

gc.enable()

pd.options.display.max_columns = None
#pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', None)

#pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_fmt_str_lengths(10000)

polars.config.Config

In [2]:
import sys
sys.path.append("/home/max1024/projects/MedicalNet")  # Adjust path as needed

# Test import
from model import generate_model
print("✅ MedicalNet imported successfully!")

✅ MedicalNet imported successfully!


In [3]:
from models import resnet

In [4]:
path = '/media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/'

In [5]:
train_data_experiment_folders_path = path + 'train/static/ExperimentRuns/'
train_data_experiment_folders_path

'/media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/'

In [6]:
test_data_experiment_folders_path = path + 'test/static/ExperimentRuns/'
test_data_experiment_folders_path

'/media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/test/static/ExperimentRuns/'

In [7]:
train_data_experiments = os.listdir(train_data_experiment_folders_path)
train_data_experiments

['TS_5_4', 'TS_69_2', 'TS_6_4', 'TS_6_6', 'TS_73_6', 'TS_86_3', 'TS_99_9']

In [8]:
test_data_experiments = os.listdir(test_data_experiment_folders_path)
test_data_experiments

['TS_5_4', 'TS_69_2', 'TS_6_4']

In [9]:
data_dict = {}
for experiment in tqdm(train_data_experiments):
    image_types_dict = {}    
    image_types_dict['denoised'] = zarr.open(train_data_experiment_folders_path + f'{experiment}/VoxelSpacing10.000/denoised.zarr', mode='r')
    image_types_dict['iso'] = zarr.open(train_data_experiment_folders_path + f'{experiment}/VoxelSpacing10.000/isonetcorrected.zarr', mode='r')
    image_types_dict['dcon'] = zarr.open(train_data_experiment_folders_path + f'{experiment}/VoxelSpacing10.000/ctfdeconvolved.zarr', mode='r')
    image_types_dict['wbp'] = zarr.open(train_data_experiment_folders_path + f'{experiment}/VoxelSpacing10.000/wbp.zarr', mode='r')
    data_dict[experiment] = image_types_dict

100%|████████████████████████████████████████████| 7/7 [00:00<00:00, 206.93it/s]


In [10]:
data_dict

{'TS_5_4': {'denoised': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_5_4/VoxelSpacing10.000/denoised.zarr>,
  'iso': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_5_4/VoxelSpacing10.000/isonetcorrected.zarr>,
  'dcon': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_5_4/VoxelSpacing10.000/ctfdeconvolved.zarr>,
  'wbp': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_5_4/VoxelSpacing10.000/wbp.zarr>},
 'TS_69_2': {'denoised': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_69_2/VoxelSpacing10.000/denoised.zarr>,
  'iso': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/static/ExperimentRuns/TS_69_2/VoxelSpacing10.000/i

In [11]:
test_data_dict = {}
for experiment in tqdm(test_data_experiments):
    image_types_dict = {}    
    image_types_dict['denoised'] = zarr.open(test_data_experiment_folders_path + f'{experiment}/VoxelSpacing10.000/denoised.zarr', mode='r')
    test_data_dict[experiment] = image_types_dict

100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 711.14it/s]


In [12]:
test_data_dict

{'TS_5_4': {'denoised': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/test/static/ExperimentRuns/TS_5_4/VoxelSpacing10.000/denoised.zarr>},
 'TS_69_2': {'denoised': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/test/static/ExperimentRuns/TS_69_2/VoxelSpacing10.000/denoised.zarr>},
 'TS_6_4': {'denoised': <Group file:///media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/test/static/ExperimentRuns/TS_6_4/VoxelSpacing10.000/denoised.zarr>}}

In [13]:
train_label_experiment_folders_path = path + 'train/overlay/ExperimentRuns/'
train_label_experiment_folders_path

'/media/max1024/Extreme SSD/Kaggle/czii-cryo-et-object-identification/train/overlay/ExperimentRuns/'

In [14]:
train_label_experiments = os.listdir(train_label_experiment_folders_path)
train_label_experiments

['TS_5_4', 'TS_69_2', 'TS_6_4', 'TS_6_6', 'TS_73_6', 'TS_86_3', 'TS_99_9']

In [15]:
labels_dict = {}
for experiment in tqdm(train_label_experiments):
    particle_types_dict = {}
    
    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/apo-ferritin.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['apo-ferritin'] = loaded_json

    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/beta-amylase.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['beta-amylase'] = loaded_json

    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/beta-galactosidase.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['beta-galactosidase'] = loaded_json

    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/ribosome.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['ribosome'] = loaded_json

    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/thyroglobulin.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['thyroglobulin'] = loaded_json

    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/virus-like-particle.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['virus-like-particle'] = loaded_json

    labels_dict[experiment] = particle_types_dict

100%|████████████████████████████████████████████| 7/7 [00:00<00:00, 401.57it/s]


In [16]:
labels_dict

{'TS_5_4': {'apo-ferritin': {'pickable_object_name': 'apo-ferritin',
   'user_id': 'curation',
   'session_id': '0',
   'run_name': 'TS_5_4',
   'voxel_spacing': None,
   'unit': 'angstrom',
   'points': [{'location': {'x': 468.514, 'y': 5915.906, 'z': 604.167},
     'transformation_': [[1.0, 0.0, 0.0, 0.0],
      [0.0, 1.0, 0.0, 0.0],
      [0.0, 0.0, 1.0, 0.0],
      [0.0, 0.0, 0.0, 1.0]],
     'instance_id': 0},
    {'location': {'x': 5674.694, 'y': 1114.354, 'z': 565.068},
     'transformation_': [[1.0, 0.0, 0.0, 0.0],
      [0.0, 1.0, 0.0, 0.0],
      [0.0, 0.0, 1.0, 0.0],
      [0.0, 0.0, 0.0, 1.0]],
     'instance_id': 0},
    {'location': {'x': 5744.509, 'y': 1049.172, 'z': 653.712},
     'transformation_': [[1.0, 0.0, 0.0, 0.0],
      [0.0, 1.0, 0.0, 0.0],
      [0.0, 0.0, 1.0, 0.0],
      [0.0, 0.0, 0.0, 1.0]],
     'instance_id': 0},
    {'location': {'x': 5880.769, 'y': 1125.348, 'z': 579.56},
     'transformation_': [[1.0, 0.0, 0.0, 0.0],
      [0.0, 1.0, 0.0, 0.0],
      

In [17]:
particle_radius = {
    'apo-ferritin': 60,
    'beta-amylase': 65,
    'beta-galactosidase': 90,
    'ribosome': 150,
    'thyroglobulin': 130,
    'virus-like-particle': 135,
}

In [18]:
class_ids = {
    'apo-ferritin': 0,
    'beta-amylase': 1,
    'beta-galactosidase': 2,
    'ribosome': 3,
    'thyroglobulin': 4,
    'virus-like-particle': 5,
}

In [19]:
weights_dict = {
    'apo-ferritin': 1,
    'beta-amylase': 0,
    'beta-galactosidase': 2,
    'ribosome': 1,
    'thyroglobulin': 2,
    'virus-like-particle': 1,
}

In [20]:
experiment_list = []
particle_type_list = []
x_list = []
y_list = []
z_list = []
r_list = []
class_id_list = []
for experiment in tqdm(train_data_experiments):
    #print(experiment)
    #print(len(labels_dict[experiment]['apo-ferritin']['points']))
    #print(type(labels_dict[experiment]['apo-ferritin']['points']))
    #print(labels_dict[experiment]['apo-ferritin']['points'][0])

    for key in labels_dict[experiment].keys():
        #print(labels_dict[experiment][key])
        #print(labels_dict[experiment][key]['pickable_object_name'])
        for i in range(len(labels_dict[experiment][key]['points'])):
            experiment_list.append(labels_dict[experiment][key]['run_name'])
            particle_type_list.append(labels_dict[experiment][key]['pickable_object_name'])
            x_list.append(labels_dict[experiment][key]['points'][i]['location']['x']/10)
            y_list.append(labels_dict[experiment][key]['points'][i]['location']['y']/10)
            z_list.append(labels_dict[experiment][key]['points'][i]['location']['z']/10)
            r_list.append(particle_radius[key]/10)
            class_id_list.append(class_ids[key])

100%|███████████████████████████████████████████| 7/7 [00:00<00:00, 3281.19it/s]


In [21]:
labels_df = pd.DataFrame({'experiment':experiment_list, 'particle_type':particle_type_list, 'x':x_list, 'y':y_list, 'z':z_list, 'r':r_list, 'class_id':class_id_list})

In [22]:
labels_df['experiment'].unique()

array(['TS_5_4', 'TS_69_2', 'TS_6_4', 'TS_6_6', 'TS_73_6', 'TS_86_3',
       'TS_99_9'], dtype=object)

In [23]:
for v in sorted(labels_df['class_id'].unique()):
    labels_df['class_id_' + str(v)] = (labels_df['class_id'] == v).astype(int)

In [24]:
labels_df = labels_df.drop(columns=['class_id'])
labels_df

Unnamed: 0,experiment,particle_type,x,y,z,r,class_id_0,class_id_1,class_id_2,class_id_3,class_id_4,class_id_5
0,TS_5_4,apo-ferritin,46.8514,591.5906,60.4167,6.0,1,0,0,0,0,0
1,TS_5_4,apo-ferritin,567.4694,111.4354,56.5068,6.0,1,0,0,0,0,0
2,TS_5_4,apo-ferritin,574.4509,104.9172,65.3712,6.0,1,0,0,0,0,0
3,TS_5_4,apo-ferritin,588.0769,112.5348,57.9560,6.0,1,0,0,0,0,0
4,TS_5_4,apo-ferritin,466.1667,126.9497,81.0409,6.0,1,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
1264,TS_99_9,virus-like-particle,201.0056,475.2618,105.7078,13.5,0,0,0,0,0,1
1265,TS_99_9,virus-like-particle,224.4068,431.0063,95.9548,13.5,0,0,0,0,0,1
1266,TS_99_9,virus-like-particle,80.4270,581.7135,57.9493,13.5,0,0,0,0,0,1
1267,TS_99_9,virus-like-particle,419.8228,553.4578,85.8169,13.5,0,0,0,0,0,1


In [25]:
first_df = labels_df[labels_df['experiment'] == 'TS_5_4']
first_df

Unnamed: 0,experiment,particle_type,x,y,z,r,class_id_0,class_id_1,class_id_2,class_id_3,class_id_4,class_id_5
0,TS_5_4,apo-ferritin,46.8514,591.5906,60.4167,6.0,1,0,0,0,0,0
1,TS_5_4,apo-ferritin,567.4694,111.4354,56.5068,6.0,1,0,0,0,0,0
2,TS_5_4,apo-ferritin,574.4509,104.9172,65.3712,6.0,1,0,0,0,0,0
3,TS_5_4,apo-ferritin,588.0769,112.5348,57.9560,6.0,1,0,0,0,0,0
4,TS_5_4,apo-ferritin,466.1667,126.9497,81.0409,6.0,1,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
135,TS_5_4,virus-like-particle,263.6539,421.4980,96.5410,13.5,0,0,0,0,0,1
136,TS_5_4,virus-like-particle,313.7396,357.2460,37.2914,13.5,0,0,0,0,0,1
137,TS_5_4,virus-like-particle,329.4133,302.7464,67.4070,13.5,0,0,0,0,0,1
138,TS_5_4,virus-like-particle,299.7686,494.8218,116.9375,13.5,0,0,0,0,0,1


In [26]:
image = data_dict['TS_5_4']['denoised']['0']
image.shape

(184, 630, 630)

In [27]:
np.min(image)

-0.00032409787

In [28]:
image = image - np.min(image)
image = image / np.max(image)

In [29]:
np.min(image)

0.0

In [30]:
np.max(image)

1.0

In [31]:
class YOLO3D_MedicalNet(nn.Module):
    def __init__(self, num_classes=6, num_anchors=3, pretrain_path=None):
        super(YOLO3D_MedicalNet, self).__init__()

        # ✅ Load MedicalNet (Pretrained ResNet-18 3D)
        self.backbone = resnet.resnet18(
            sample_input_W=128,
            sample_input_H=128,
            sample_input_D=64,
            shortcut_type='B',
            no_cuda=False,
            num_seg_classes=num_classes
        )

        # ✅ Load Pretrained Weights
        if pretrain_path:
            print(f"✅ Loading Pretrained MedicalNet Weights from: {pretrain_path}")
            pretrain = torch.load(pretrain_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
            self.backbone.load_state_dict(pretrain['state_dict'], strict=False)

        # ✅ Remove Fully Connected Layer from MedicalNet
        self.backbone.avgpool = nn.Identity()  # Remove global pooling
        self.backbone.fc = nn.Identity()  # Remove classification layer

        # ✅ Define YOLO3D Detection Head
        self.num_classes = num_classes
        self.num_anchors = num_anchors

        # ✅ Calculate Feature Size
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 64, 128, 128)  # [batch, channels, depth, height, width]
            dummy_output = self.backbone(dummy_input)
            feature_size = dummy_output.view(1, -1).shape[1]

        # ✅ Fully Connected Layer for YOLO3D
        self.fc = nn.Linear(feature_size, self.num_anchors * (4 + self.num_classes))

    def forward(self, x):
        x = self.backbone(x)  # Extract Features using MedicalNet
        x = torch.flatten(x, start_dim=1)  # Flatten for FC Layer
        x = self.fc(x)  # Predict Bounding Boxes

        return x.view(-1, self.num_anchors, 4 + self.num_classes)  # Reshape Output

In [32]:
model = YOLO3D_MedicalNet(pretrain_path="/home/max1024/models/MedicalNet/resnet_18_23dataset.pth")

  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')


✅ Loading Pretrained MedicalNet Weights from: /home/max1024/models/MedicalNet/resnet_18_23dataset.pth


  pretrain = torch.load(pretrain_path, map_location="cuda" if torch.cuda.is_available() else "cpu")


In [33]:
model

YOLO3D_MedicalNet(
  (backbone): ResNet(
    (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNor

In [34]:
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

backbone.conv1.weight: requires_grad=True
backbone.bn1.weight: requires_grad=True
backbone.bn1.bias: requires_grad=True
backbone.layer1.0.conv1.weight: requires_grad=True
backbone.layer1.0.bn1.weight: requires_grad=True
backbone.layer1.0.bn1.bias: requires_grad=True
backbone.layer1.0.conv2.weight: requires_grad=True
backbone.layer1.0.bn2.weight: requires_grad=True
backbone.layer1.0.bn2.bias: requires_grad=True
backbone.layer1.1.conv1.weight: requires_grad=True
backbone.layer1.1.bn1.weight: requires_grad=True
backbone.layer1.1.bn1.bias: requires_grad=True
backbone.layer1.1.conv2.weight: requires_grad=True
backbone.layer1.1.bn2.weight: requires_grad=True
backbone.layer1.1.bn2.bias: requires_grad=True
backbone.layer2.0.conv1.weight: requires_grad=True
backbone.layer2.0.bn1.weight: requires_grad=True
backbone.layer2.0.bn1.bias: requires_grad=True
backbone.layer2.0.conv2.weight: requires_grad=True
backbone.layer2.0.bn2.weight: requires_grad=True
backbone.layer2.0.bn2.bias: requires_grad=Tru

In [35]:
print(f"Model training mode: {model.training}")

Model training mode: True


In [36]:
class SphereIoULoss(nn.Module):
    def __init__(self):
        super(SphereIoULoss, self).__init__()

    def forward(self, pred_spheres, target_spheres):
        """
        Compute IoU loss between predicted and target spheres.

        pred_spheres: (batch_size, num_anchors, 4) -> [x, y, z, radius]
        target_spheres: (batch_size, num_targets, 4) -> [x, y, z, radius]
        """

        pred_centers = pred_spheres[..., :3]  # Extract (x, y, z)
        pred_radii = pred_spheres[..., 3]     # Extract radius

        target_centers = target_spheres[..., :3]  # Extract (x, y, z)
        target_radii = target_spheres[..., 3]     # Extract radius

        # Compute pairwise Euclidean distances
        dists = torch.cdist(pred_centers, target_centers, p=2)  # (batch_size, num_anchors, num_targets)

        # Find best match: Assign each predicted sphere to the closest ground truth
        best_target_idx = dists.argmin(dim=-1)  # (batch_size, num_anchors)
        
        # Gather matched target spheres
        matched_target_spheres = torch.gather(target_spheres, 1, best_target_idx.unsqueeze(-1).expand(-1, -1, 4))

        matched_target_centers = matched_target_spheres[..., :3]
        matched_target_radii = matched_target_spheres[..., 3]

        # Compute Euclidean distance for matched pairs
        d = torch.norm(pred_centers - matched_target_centers, dim=-1)

        # Compute Intersection Volume
        R = pred_radii + matched_target_radii  # Sum of radii
        inter_vol = (4/3) * torch.pi * ((R - d).clamp(min=0) ** 3)  # Avoid negative values

        # Compute Union Volume
        vol_pred = (4/3) * torch.pi * (pred_radii ** 3)
        vol_target = (4/3) * torch.pi * (matched_target_radii ** 3)
        union_vol = vol_pred + vol_target - inter_vol

        # Compute IoU
        iou = inter_vol / union_vol.clamp(min=1e-6)

        # Compute IoU Loss
        loss = 1 - iou.mean()  # Take mean over batch
        return loss

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn = SphereIoULoss().to(device)

In [38]:
num_classes = 6

In [39]:
class YOLO3DDataset_Cropped(Dataset):
    def __init__(self, image, labels_df, crop_size=(64, 128, 128), stride=(32, 64, 64), transform=None):
        self.image = image  # Load full 3D image
        self.labels = labels_df[['x', 'y', 'z', 'r', 'class_id_0', 'class_id_1', 'class_id_2', 'class_id_3', 'class_id_4', 'class_id_5']].to_numpy()  # Load bounding spheres with one-hot class labels
        self.crop_size = crop_size
        self.stride = stride
        self.transform = transform  # Optional augmentations
        self.num_classes = 6  # Number of classes

        self.patches, self.patch_bboxes = self.create_crops_with_labels()

    def create_crops_with_labels(self):
        d, h, w = self.image.shape
        crops, bboxes = [], []

        crop_size = tuple(map(int, self.crop_size))
        stride = tuple(map(int, self.stride))

        for z in range(0, d - crop_size[0] + 1, stride[0]):
            for y in range(0, h - crop_size[1] + 1, stride[1]):
                for x in range(0, w - crop_size[2] + 1, stride[2]):
                    
                    cropped_patch = self.image[z:z+crop_size[0], y:y+crop_size[1], x:x+crop_size[2]]
                    cropped_bboxes = []

                    for bbox in self.labels:
                        z_center, y_center, x_center, radius, *class_one_hot = bbox
                        
                        # Check if the bounding sphere center is within the crop
                        if (z <= z_center < z + crop_size[0] and
                            y <= y_center < y + crop_size[1] and
                            x <= x_center < x + crop_size[2]):

                            z_new = z_center - z
                            y_new = y_center - y
                            x_new = x_center - x

                            cropped_bboxes.append([z_new, y_new, x_new, radius, *class_one_hot])

                    crops.append(cropped_patch)
                    bboxes.append(cropped_bboxes if cropped_bboxes else [])  

        crops_tensor = torch.tensor(np.array(crops), dtype=torch.float32)

        bboxes_fixed = []
        for bbox in bboxes:
            if len(bbox) == 0:
                bboxes_fixed.append(torch.empty((0, 4 + self.num_classes)))  # [z, y, x, radius, class_one_hot]
            else:
                bboxes_fixed.append(torch.tensor(bbox, dtype=torch.float32))

        bboxes_tensor = torch.nn.utils.rnn.pad_sequence(bboxes_fixed, batch_first=True, padding_value=-1)
        return crops_tensor, bboxes_tensor

    def __len__(self):
        """Return the total number of 3D image patches"""
        return len(self.patches)

    def __getitem__(self, idx):
        """Return one sample"""
        x = self.patches[idx].unsqueeze(0)  # Add channel dimension (1, D, H, W)
        y = self.patch_bboxes[idx]  
        return x, y

In [40]:
dataset = YOLO3DDataset_Cropped(image, first_df)

In [41]:
print(f"Total dataset size: {len(dataset)}")  # Should print an integer

Total dataset size: 256


In [42]:
train_loader = DataLoader(
    dataset, batch_size=4, shuffle=True, num_workers=0, pin_memory=False
)

In [43]:
def train_yolo3d_medicalnet(model, dataloader, loss_fn, epochs=5, lr=0.0001, fine_tune=True):
    """
    Trains YOLO3D with the pre-trained MedicalNet backbone.
    Fine-tune the entire network or freeze backbone layers as needed.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(f"🚨 WARNING: {name} is frozen!")

    for param in model.parameters():
        param.requires_grad = True

    # ✅ Fine-tuning: Freeze MedicalNet backbone (if needed)
    if not fine_tune:
        print("🚀 Freezing MedicalNet Backbone!")
        for param in model.backbone.parameters():
            param.requires_grad = False

    # ✅ Optimizer: Separate learning rates for backbone vs. new layers
    backbone_params = model.backbone.parameters()
    new_params = [p for n, p in model.named_parameters() if "backbone" not in n]

    optimizer = optim.AdamW([
        {"params": backbone_params, "lr": lr * 0.1},  # Lower LR for backbone (pre-trained)
        {"params": new_params, "lr": lr}  # Higher LR for new layers
    ], weight_decay=1e-4)

    scaler = torch.amp.GradScaler(device="cuda")  # ✅ Enable mixed precision (FP16 training)
    
    torch.backends.cudnn.benchmark = True  # ✅ Optimized convolutions for varying input sizes
    torch.backends.cudnn.enabled = True  # ✅ Ensures CUDNN acceleration is active

    print(f"🚀 Training on {device} | Mixed Precision Enabled")
    print(f"📝 Fine-tuning: {fine_tune} | Backbone LR: {lr * 0.1} | New Layers LR: {lr}")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        num_batches = len(dataloader)

        for batch_idx, (x, y) in enumerate(dataloader):
            if torch.isnan(x).any() or torch.isinf(x).any():
                print(f"❌ NaN or Inf detected in x at batch {batch_idx}")
                raise ValueError("Input tensor x contains NaN or Inf values!")

            assert x.dtype in [torch.float32, torch.float16], f"❌ Unexpected dtype {x.dtype}"
            assert x.numel() > 0, "❌ `x` tensor is empty!"

            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            if torch.isnan(x).any() or torch.isinf(x).any():
                print("❌ NaN or Inf detected in input data!")
            if torch.isnan(y).any() or torch.isinf(y).any():
                print("❌ NaN or Inf detected in labels!")

            optimizer.zero_grad(set_to_none=True)  # ✅ Reduce memory usage

            '''
            # 🔍 **Log Parameters Before Update**
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(f"Before Update: {name} -> {param.data.mean()}")
            '''

            # ✅ Enable mixed precision (Float16) for faster training
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):

                '''
                predictions = model(x)
                pred_spheres = predictions[..., :4]
                loss = loss_fn(predictions, y)
                print(f"📉 Debug: Loss value = {loss.item()}")
                loss.retain_grad()  # Ensure gradient retention
                print(f"📊 Loss gradient: {loss.grad}")
                '''

                predictions = model(x)  
                # predictions.shape = [batch_size, 3, 10]  (for example)
            
                # -- Slice for sphere regression --
                pred_spheres = predictions[..., :4]     # shape [batch_size, 3, 4]
            
                # -- Slice for classification --
                pred_classes = predictions[..., 4:]     # shape [batch_size, 3, 6]
            
                # YOUR sphere + classification targets must match the above shapes!
                # e.g. target_spheres.shape = [batch_size, 3, 4]
                #      target_classes.shape = [batch_size, 3] or [batch_size, 3, 6] 
                # (Depending on your classification approach—see section 2)
            
                # -- Sphere IoU Loss --
                loss_sphere = sphere_loss_fn(pred_spheres, target_spheres)
            
                # -- Classification Loss (example: cross-entropy) --
                # If each anchor has a single class label in [0..5], 
                # then shape is [batch_size, 3] of integer labels
                # pred_classes has shape [batch_size, 3, 6]
                # We need to flatten appropriately or handle batch+anchor dims in CE
                # (see below for details)
                loss_class = classification_loss_fn(pred_classes, target_class_labels)
            
                # Combine them
                loss = loss_sphere + loss_class

            optimizer.zero_grad(set_to_none=True)  # ✅ Clears gradients before backward
            scaler.scale(loss).backward()

            for name, param in model.named_parameters():
                if param.grad is not None:
                    print(f"{name} grad mean: {param.grad.mean()}")
                else:
                    print(f"{name} has NO grad!")

            '''
            for name, param in model.named_parameters():
                if param.grad is None:
                    print(f"❌ {name} has NO gradient!")
                else:
                    print(f"✅ {name} gradient mean: {param.grad.mean()}")
            '''
            '''
            for param in model.parameters():
                param.data -= 0.001 * param.grad  # Simulated update
            '''

            '''
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(f"Gradient {name} -> {param.grad.mean() if param.grad is not None else 'None'}")
            '''

            '''
            for name, param in model.named_parameters():
                if param.grad is not None:
                    print(f"✅ {name} gradient mean: {param.grad.mean()}")
                else:
                    print(f"❌ {name} has NO gradient!")
            '''

            scaler.unscale_(optimizer)  # Unscales before stepping
            scaler.step(optimizer)  # ✅ Apply parameter update
            scaler.update()

            '''
            # 🔍 **Log Parameters After Update**
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(f"After Update: {name} -> {param.data.mean()}")
            '''

            total_loss += loss.item()

            # ✅ Dynamic logging for every 10% of dataset
            if batch_idx % max(1, num_batches // 10) == 0:
                print(f"[Epoch {epoch+1}/{epochs}] Batch {batch_idx}/{num_batches} | Loss: {loss.item():.4f}")

        avg_loss = total_loss / num_batches
        print(f"✅ Epoch [{epoch+1}/{epochs}] | Average Loss: {avg_loss:.6f}")

    print("🎯 Training Complete!")

In [45]:
def loss_fn(preds, gt_centers, gt_labels, num_classes=6):
    """
    preds.shape = [B, A, 3 + C]
       1) preds[..., :3] = predicted centers (x, y, z)
       2) preds[..., 3:] = predicted class logits

    gt_centers.shape = [B, A, 3]
    gt_labels.shape  = [B, A]  (integer labels in [0, num_classes-1])
    """

    pred_centers = preds[..., :3]          # [B, A, 3]
    pred_logits  = preds[..., 3:]         # [B, A, C]

    # -- REGRESSION Loss (e.g. L1 distance) --
    center_dist = torch.norm(pred_centers - gt_centers, dim=-1)  # [B, A]
    loss_center = center_dist.mean()  # Average across batch, anchors

    # -- CLASSIFICATION Loss (CrossEntropy) --
    B, A, C = pred_logits.shape
    pred_logits_2d = pred_logits.view(B*A, C)
    gt_labels_1d   = gt_labels.view(B*A)          # [B*A]
    loss_class = nn.CrossEntropyLoss()(pred_logits_2d, gt_labels_1d)

    # Weighted sum
    alpha, beta = 1.0, 1.0  # tune these
    total_loss = alpha * loss_center + beta * loss_class
    return total_loss

In [46]:
def train_yolo3d_medicalnet(model, dataloader, sphere_loss_fn, class_loss_fn, epochs=5, lr=1e-4, fine_tune=True):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Freeze backbone if needed
    for param in model.parameters():
        param.requires_grad = True
    if not fine_tune:
        for param in model.backbone.parameters():
            param.requires_grad = False

    # Build optimizer with two LR groups
    backbone_params = model.backbone.parameters()
    new_params = [p for n,p in model.named_parameters() if "backbone" not in n]
    optimizer = optim.AdamW([
        {"params": backbone_params, "lr": lr * 0.1},
        {"params": new_params, "lr": lr}
    ], weight_decay=1e-4)

    scaler = torch.amp.GradScaler()

    print(f"Using device={device}, backbone LR={lr*0.1}, head LR={lr}")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.
        for batch_idx, (x, y_spheres, y_classes) in enumerate(dataloader):
            # x.shape => [batch, channels=1 or 3, D, H, W]
            # y_spheres => [batch, 3, 4] or [batch, N, 4], etc.
            # y_classes => [batch, 3] or [batch, 3, 6], depending on how you store labels

            x = x.to(device)
            y_spheres = y_spheres.to(device)
            y_classes = y_classes.to(device)

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                preds = model(x)  # shape [batch, 3, 10]
                loss  = loss_fn(preds, gt_centers, gt_labels, num_classes=6)
                '''
                sphere_preds = preds[..., :4]   # [batch, 3, 4]
                class_preds  = preds[..., 4:]   # [batch, 3, 6]

                loss_sphere = sphere_loss_fn(sphere_preds, y_spheres)
                loss_class  = class_loss_fn(class_preds, y_classes)
                loss = loss_sphere + loss_class
                '''

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch {batch_idx}, Loss={loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{epochs}], Avg Loss={total_loss / len(dataloader):.4f}")
    print("Done.")

In [48]:
def train_yolo3d_medicalnet(model, dataloader, epochs=5, lr=1e-4, fine_tune=True):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Freeze backbone if needed
    for param in model.parameters():
        param.requires_grad = True
    if not fine_tune:
        for param in model.backbone.parameters():
            param.requires_grad = False

    # Build optimizer with two LR groups
    backbone_params = model.backbone.parameters()
    new_params = [p for n,p in model.named_parameters() if "backbone" not in n]
    optimizer = optim.AdamW([
        {"params": backbone_params, "lr": lr * 0.1},
        {"params": new_params, "lr": lr}
    ], weight_decay=1e-4)

    scaler = torch.amp.GradScaler()

    print(f"Using device={device}, backbone LR={lr*0.1}, head LR={lr}")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.
        for batch_idx, (x, y_spheres, y_classes) in enumerate(dataloader):
            # x.shape => [batch, channels=1 or 3, D, H, W]
            # y_spheres => [batch, 3, 4] or [batch, N, 4], etc.
            # y_classes => [batch, 3] or [batch, 3, 6], depending on how you store labels

            x = x.to(device)
            y_spheres = y_spheres.to(device)
            y_classes = y_classes.to(device)

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                preds = model(x)  # shape [batch, 3, 10]
                loss  = loss_fn(preds, gt_centers, gt_labels, num_classes=6)
                '''
                sphere_preds = preds[..., :4]   # [batch, 3, 4]
                class_preds  = preds[..., 4:]   # [batch, 3, 6]

                loss_sphere = sphere_loss_fn(sphere_preds, y_spheres)
                loss_class  = class_loss_fn(class_preds, y_classes)
                loss = loss_sphere + loss_class
                '''

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch {batch_idx}, Loss={loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{epochs}], Avg Loss={total_loss / len(dataloader):.4f}")
    print("Done.")

In [None]:
def train_yolo3d_medicalnet(model, dataloader, epochs=5, lr=1e-4, fine_tune=True):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Freeze backbone if needed
    for param in model.parameters():
        param.requires_grad = True
    if not fine_tune:
        for param in model.backbone.parameters():
            param.requires_grad = False

    # Build optimizer with two LR groups
    backbone_params = model.backbone.parameters()
    new_params = [p for n,p in model.named_parameters() if "backbone" not in n]
    optimizer = optim.AdamW([
        {"params": backbone_params, "lr": lr * 0.1},
        {"params": new_params, "lr": lr}
    ], weight_decay=1e-4)

    scaler = torch.amp.GradScaler()

    print(f"Using device={device}, backbone LR={lr*0.1}, head LR={lr}")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.
        for batch_idx, (x, y) in enumerate(dataloader):
            # x.shape => [batch, channels=1 or 3, D, H, W]
            # y_spheres => [batch, 3, 4] or [batch, N, 4], etc.
            # y_classes => [batch, 3] or [batch, 3, 6], depending on how you store labels

            x = x.to(device)
            y_spheres, y_classes = y  # if y is a tuple or list of two elements

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                preds = model(x)  # shape [batch, 3, 10]
                loss  = loss_fn(preds, gt_centers, gt_labels, num_classes=6)
                '''
                sphere_preds = preds[..., :4]   # [batch, 3, 4]
                class_preds  = preds[..., 4:]   # [batch, 3, 6]

                loss_sphere = sphere_loss_fn(sphere_preds, y_spheres)
                loss_class  = class_loss_fn(class_preds, y_classes)
                loss = loss_sphere + loss_class
                '''

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch {batch_idx}, Loss={loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{epochs}], Avg Loss={total_loss / len(dataloader):.4f}")
    print("Done.")

In [50]:
def train_yolo3d_medicalnet(model, dataloader, epochs=5, lr=1e-4, fine_tune=True):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Freeze backbone if needed
    for param in model.parameters():
        param.requires_grad = True
    if not fine_tune:
        for param in model.backbone.parameters():
            param.requires_grad = False

    # Build optimizer with two LR groups
    backbone_params = model.backbone.parameters()
    new_params = [p for n, p in model.named_parameters() if "backbone" not in n]
    optimizer = optim.AdamW([
        {"params": backbone_params, "lr": lr * 0.1},
        {"params": new_params, "lr": lr}
    ], weight_decay=1e-4)

    scaler = torch.amp.GradScaler()
    print(f"Using device={device}, backbone LR={lr*0.1}, head LR={lr}")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.

        for batch_idx, (x, y) in enumerate(dataloader):
            # x.shape => [batch_size, 1, D, H, W]
            # y.shape => [batch_size, max_bboxes, 4 + num_classes], 
            #            due to your internal pad_sequence

            x = x.to(device)              # [B, 1, D, H, W]
            y = y.to(device)              # [B, max_bboxes, 4 + num_classes]

            # ------------------------------------------------
            # STEP A: Split y into spheres vs. one-hot classes
            # ------------------------------------------------
            y_spheres = y[..., :4]        # [B, max_bboxes, 4]
            y_onehot = y[..., 4:]         # [B, max_bboxes, num_classes]

            # If you want integer class labels for CrossEntropy:
            # We'll do an argmax over the last dim of the one-hot:
            y_labels = y_onehot.argmax(dim=-1)  # [B, max_bboxes]

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                # Example: Suppose your model outputs => [B, anchors, 3 + num_classes]
                # e.g. [B, 3, 3+6=9] or [B, 3, 4+6=10], depending on your design
                preds = model(x)

                # If your model has 3 coords + 6 class logits => total 9 channels:
                #   preds[..., :3] => 3D center
                #   preds[..., 3:] => class logits
                # Then call your combined loss:
                loss = loss_fn(
                    preds,           # shape [B, anchors, 3 + C]
                    y_spheres[..., :3],  # we might only use the [z,y,x] or [x,y,z], depends on your model convention
                    y_labels         # shape [B, max_bboxes], int class IDs
                )

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch {batch_idx}, Loss={loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{epochs}], Avg Loss={total_loss / len(dataloader):.4f}")
    print("Done.")

In [51]:
def loss_fn(preds, gt_spheres, gt_labels, num_classes=6):
    """
    preds.shape = [B, A, 3 + C]  # e.g. [B, 3, 9]
    gt_spheres.shape = [B, N, 3] # we might only compare with matched anchors if you do anchor matching
    gt_labels.shape  = [B, N]    # integer labels [0..num_classes-1]
    """

    # Example: we ignore radius for now, or do separate code if you want radius
    # preds[..., :3] => [B, A, 3] for centers
    # preds[..., 3:] => [B, A, C] for class logits

    # -- Mismatch warning! --
    # If you have A anchors but N bounding spheres, you either need an anchor-matching step
    # or a simpler approach (like N == A == 1, or 1 anchor per object).
    # For a small toy example, you might do something naive like: min(A,N).
    # Real YOLO approaches do anchor-based matching, ignoring unmatched anchors with an "objectness" label.

    # For demonstration, let's assume A == N:
    pred_centers = preds[..., :3]       # [B, A, 3]
    gt_centers   = gt_spheres          # [B, A, 3]
    dist = torch.norm(pred_centers - gt_centers, dim=-1)  # [B, A]
    loss_center = dist.mean()

    # Classification
    pred_logits = preds[..., 3:]       # [B, A, C]
    B, A, C = pred_logits.shape
    pred_logits_2d = pred_logits.view(B*A, C)
    gt_labels_1d   = gt_labels.view(B*A)
    loss_class = nn.CrossEntropyLoss()(pred_logits_2d, gt_labels_1d)

    # Combine
    alpha, beta = 1.0, 1.0
    total_loss = alpha * loss_center + beta * loss_class
    return total_loss

In [52]:
# ✅ Run training
train_yolo3d_medicalnet(model, train_loader, epochs=10, lr=0.0001, fine_tune=True)

Using device=cuda, backbone LR=1e-05, head LR=0.0001


RuntimeError: The size of tensor a (3) must match the size of tensor b (15) at non-singleton dimension 1

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# References

1. https://www.kaggle.com/code/davidlist/experiment-ts-6-4-visualization
2. https://www.kaggle.com/code/nk35jk/3d-visualization-of-particles