# Install pkgs

**Note:** This is training notebook only. Inference ain't included in . 
Anybody who wants to use this notebook for inference purposes is most welcome.

In [None]:
import os
import shutil
import numpy as np
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
import zarr, copick
from tqdm import tqdm
import napari
import mlflow
import mlflow.pytorch
from copick_utils.segmentation import segmentation_from_picks
import copick_utils.io.writers as write
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
import gc
import concurrent.futures
import optuna, optunahub
import json
import copy

gc.enable()

In [None]:
torch._dynamo.config.cache_size_limit = 64

In [None]:
path = '/media/max1024/Extreme SSD1/Kaggle/czii-cryo-et-object-identification/'
output_path = path + 'output/'

In [None]:
# Make a copick project

config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 2,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 3,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 4,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 5,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        }
    ],

    "overlay_root": "/media/max1024/Extreme SSD1/Kaggle/czii-cryo-et-object-identification/output/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "/media/max1024/Extreme SSD1/Kaggle/czii-cryo-et-object-identification/train/static"
}"""

copick_config_path = path + "output/copick.config"
output_overlay = path + "output/overlay"

with open(copick_config_path, "w") as f:
    f.write(config_blob)
    
# Update the overlay
# Define source and destination directories
source_dir = path + 'train/overlay'
destination_dir = path + 'output/overlay'

# Walk through the source directory
for root, dirs, files in os.walk(source_dir):
    # Create corresponding subdirectories in the destination
    relative_path = os.path.relpath(root, source_dir)
    target_dir = os.path.join(destination_dir, relative_path)
    os.makedirs(target_dir, exist_ok=True)
    
    # Copy and rename each file
    for file in files:
        if file.startswith("curation_0_"):
            new_filename = file
        else:
            new_filename = f"curation_0_{file}"
            
        
        # Define full paths for the source and destination files
        source_file = os.path.join(root, file)
        destination_file = os.path.join(target_dir, new_filename)
        
        # Copy the file with the new name
        shutil.copy2(source_file, destination_file)
        print(f"Copied {source_file} to {destination_file}")

# Prepare the dataset
## 1. Get copick root

In [None]:
root = copick.from_file(copick_config_path)

copick_user_name = "copickUtils"
copick_segmentation_name = "paintedPicks"
voxel_size = 10
#tomo_type = "denoised"

## 2. Generate multi-class segmentation masks from picks, and saved them to the copick overlay directory (one-time)

In [None]:
root

In [None]:
# Just do this once
generate_masks = True

if generate_masks:
    target_objects = defaultdict(dict)
    for object in root.pickable_objects:
        if object.is_particle:
            target_objects[object.name]['label'] = object.label
            target_objects[object.name]['radius'] = object.radius


    for run in tqdm(root.runs):
        tomo = run.get_voxel_spacing(voxel_size)
        for tomogram in tomo.tomograms:
            tomo_type = tomogram.tomo_type
            image = tomogram.numpy()
            target = np.zeros(image.shape, dtype=np.uint8)
            for pickable_object in root.pickable_objects:
                pick = run.get_picks(object_name=pickable_object.name, user_id='curation')
                if len(pick):
                    target = segmentation_from_picks.from_picks(pick[0],
                                                                target,
                                                                target_objects[pickable_object.name]['radius'],# * 0.5,
                                                                target_objects[pickable_object.name]['label']
                                                               )
            write.segmentation(run, target, copick_user_name, name=copick_segmentation_name)

## 3. Get tomograms and their segmentaion masks (from picks) arrays

In [None]:
train_label_experiment_folders_path = '/media/max1024/Extreme SSD1/Kaggle/czii-cryo-et-object-identification/' + 'train/overlay/ExperimentRuns/'

In [None]:
class_ids = {
    'apo-ferritin': 1,
    'beta-galactosidase': 2,
    'ribosome': 3,
    'thyroglobulin': 4,
    'virus-like-particle': 5,
}

In [None]:
particle_radius = {
    'apo-ferritin': 60,
    'beta-galactosidase': 90,
    'ribosome': 150,
    'thyroglobulin': 130,
    'virus-like-particle': 135,
}

In [None]:
def create_labels_df(experiment):
    labels_dict = {}
    
    particle_types_dict = {}
    
    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/apo-ferritin.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['apo-ferritin'] = loaded_json
    
    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/beta-galactosidase.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['beta-galactosidase'] = loaded_json
    
    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/ribosome.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['ribosome'] = loaded_json
    
    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/thyroglobulin.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['thyroglobulin'] = loaded_json
    
    with open(f'{train_label_experiment_folders_path}{experiment}/Picks/virus-like-particle.json') as f:
        loaded_json = json.loads(f.read())
    particle_types_dict['virus-like-particle'] = loaded_json
    
    labels_dict[experiment] = particle_types_dict

    experiment_list = []
    particle_type_list = []
    x_list = []
    y_list = []
    z_list = []
    r_list = []
    class_id_list = []
    #print(experiment)
    #print(len(labels_dict[experiment]['apo-ferritin']['points']))
    #print(type(labels_dict[experiment]['apo-ferritin']['points']))
    #print(labels_dict[experiment]['apo-ferritin']['points'][0])

    for key in labels_dict[experiment].keys():
        #print(labels_dict[experiment][key])
        #print(labels_dict[experiment][key]['pickable_object_name'])
        for i in range(len(labels_dict[experiment][key]['points'])):
            experiment_list.append(labels_dict[experiment][key]['run_name'])
            particle_type_list.append(labels_dict[experiment][key]['pickable_object_name'])
            x_list.append(labels_dict[experiment][key]['points'][i]['location']['x']/10.012444537618887)
            y_list.append(labels_dict[experiment][key]['points'][i]['location']['y']/10.012444196428572)
            z_list.append(labels_dict[experiment][key]['points'][i]['location']['z']/10.012444196428572)
            r_list.append(particle_radius[key]/10)
            class_id_list.append(class_ids[key])

    labels_df = pd.DataFrame({'experiment':experiment_list, 'particle_type':particle_type_list, 'x':x_list, 'y':y_list, 'z':z_list, 'radius':r_list, 'label':class_id_list})
    
    return labels_df

In [None]:
data_dicts = []
for run in tqdm(root.runs):
    tomo = run.get_voxel_spacing(voxel_size)#.get_tomograms(tomo_type)[0].numpy()
    labels_df = create_labels_df(run.name)
    for tomogram in tomo.tomograms:
        tomo_type = tomogram.tomo_type
        image = tomogram.numpy()
        segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=voxel_size, is_multilabel=True)[0].numpy()
        data_dicts.append({"tomo_type":tomo_type, "image": image, "label": segmentation, "label_df": labels_df})
    
print(np.unique(data_dicts[0]['label']))

In [None]:
data_dicts[0].keys()

In [None]:
data_dicts[0]['label'].shape

In [None]:
len(data_dicts)

In [None]:
data_dicts[0].keys()

In [None]:
data_dicts[0]['label_df']

## 4. Visualize the tomogram and painted segmentation from ground-truth picks

In [None]:
# Plot the images
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.title('Tomogram')
plt.imshow(data_dicts[0]['image'][100],cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Painted Segmentation from Picks')
plt.imshow(data_dicts[0]['label'][100], cmap='viridis')
plt.axis('off')

plt.tight_layout()
plt.show()

# Simulated data prep

In [None]:
simulated_data_path = '../../czii_downloaded_data/simulated_training_data/'

In [None]:
simulated_experiments_list = [f for f in os.listdir(simulated_data_path) if 'TS_' in f]
simulated_experiments_list

In [None]:
particle_info = {
    "apo-ferritin": {"label": 1, "radius": 60},
    "beta-galactosidase": {"label": 2, "radius": 90},
    "ribosome": {"label": 3, "radius": 150},
    "thyroglobulin": {"label": 4, "radius": 130},
    "virus-like-particle": {"label": 5, "radius": 135}
}

In [None]:
def create_masks(image, locations_df):
    shape = image.shape
    volume = np.zeros(shape, dtype=np.uint8)

    for i in range(locations_df.shape[0]):
        center = (locations_df.loc[i, 'z'], locations_df.loc[i, 'y'], locations_df.loc[i, 'x'])
        radius = locations_df.loc[i, 'radius']
        intensity = locations_df.loc[i, 'label']

        z, y, x = np.ogrid[:shape[0], :shape[1], :shape[2]]
        
        mask = (x - center[2])**2 + (y - center[1])**2 + (z - center[0])**2 <= radius**2

        volume[mask] = intensity

    return volume

In [None]:
particle_name_dict = {
    'ferritin':'apo-ferritin',
    'galacto':'beta-galactosidase',
    'ribosome':'ribosome',
    'thyro':'thyroglobulin',
    'vlp':'virus-like-particle'
}

In [None]:
def process_experiment(experiment):
    """Process a single experiment in parallel"""
    # Load tomogram data
    tomogram_path = f'{simulated_data_path}{experiment}/Reconstructions/VoxelSpacing10.000/Tomograms/100/{experiment}.zarr'
    tomogram = zarr.open(tomogram_path, mode='r')[0].__array__()
    
    location_dfs_list = []
    annotations_base = f'{simulated_data_path}{experiment}/Reconstructions/VoxelSpacing10.000/Annotations/'
    
    # Process location files
    for location_folder in os.listdir(annotations_base):
        location_dir = os.path.join(annotations_base, location_folder)
        for location_file in os.listdir(location_dir):
            if 'orientedpoint.ndjson' not in location_file:
                continue
                
            # Determine particle type
            particle_type = None
            for key in particle_name_dict.keys():
                if key in location_file:
                    particle_type = particle_name_dict[key]
                    break
            if not particle_type:
                continue
                
            # Process NDJSON file
            label = particle_info[particle_type]['label']
            radius = particle_info[particle_type]['radius']
            file_path = os.path.join(location_dir, location_file)
            
            location_df = pd.read_json(file_path, lines=True)
            location_df['x'] = location_df['location'].map(lambda x: x['x'])
            location_df['y'] = location_df['location'].map(lambda x: x['y'])
            location_df['z'] = location_df['location'].map(lambda x: x['z'])
            location_df['label'] = label
            location_df['radius'] = radius / 10# * 0.5
            location_df['particle_type'] = particle_type
            
            location_dfs_list.append(location_df)
    
    # Create mask and return result
    if location_dfs_list:
        all_particle_locations_df = pd.concat(location_dfs_list, ignore_index=True)
        mask_image = create_masks(tomogram, all_particle_locations_df)
        all_particle_locations_df['experiment'] = experiment
        label_df = all_particle_locations_df[['experiment', 'particle_type', 'x', 'y', 'z', 'radius', 'label']]
        return {"tomo_type": 'Unknown', "image": tomogram, "label": mask_image, "label_df": label_df}
    return None

def append_simulation_data():
    global data_dicts
    max_workers = os.cpu_count() // 2  # Use half the available cores to prevent memory issues
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        # Submit all experiments for processing
        futures = [executor.submit(process_experiment, exp) for exp in simulated_experiments_list]
        
        # Collect results as they complete
        for future in tqdm(concurrent.futures.as_completed(futures), 
                         total=len(simulated_experiments_list),
                         desc="Processing experiments"):
            result = future.result()
            if result:
                data_dicts.append(result)

In [None]:
append_simulation_data()

In [None]:
len(data_dicts)

In [None]:
temp_image = data_dicts[-1]['image']
temp_image.shape

In [None]:
temp_label = data_dicts[-1]['label']
temp_label.shape

In [None]:
plt.imshow(temp_image[100])

In [None]:
plt.imshow(temp_label[100])

In [None]:
data_dicts[-1]['label_df']

## 5. Prepare dataloaders

In [None]:
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
from monai.transforms import (
    Compose,
    EnsureChannelFirstd, 
    Orientationd,  
    AsDiscrete,  
    RandFlipd, 
    RandRotate90d, 
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss, FocalLoss, TverskyLoss
from monai.metrics import DiceMetric, ConfusionMatrixMetric

In [None]:
import random

numbers = list(range(len(data_dicts)))
random_numbers = random.sample(numbers, int(len(data_dicts)/5))

print(random_numbers)

In [None]:
val_files = []
train_files = []
for i in range(len(data_dicts)):
    if i in random_numbers:
        val_files.append(data_dicts[i])
    else:
        train_files.append(data_dicts[i])

In [None]:
print(f"Number of training samples: {len(train_files)}")
print(f"Number of validation samples: {len(val_files)}")

In [None]:
num_classes = 6 # 5 particles + 1 background

In [None]:
train_files_copy = copy.deepcopy(train_files)
val_files_copy = copy.deepcopy(val_files)

In [None]:
for i in range(len(train_files)):
    del train_files[i]['tomo_type'], train_files[i]['label_df']

for i in range(len(val_files)):
    del val_files[i]['tomo_type'], val_files[i]['label_df']

In [None]:
my_num_samples = 16
train_batch_size = 1
val_batch_size = 1

In [None]:
# Non-random transforms to be cached
non_random_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image", "label"], axcodes="RAS")
])

# Random transforms to be applied during training
random_transforms = Compose([
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],
        num_classes=num_classes,
        num_samples=my_num_samples,
        allow_missing_keys=True
    ),
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),    
])

# Create the cached dataset with non-random transforms
train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)

# Wrap the cached dataset to apply random transforms during iteration
train_ds = Dataset(data=train_ds, transform=random_transforms)

# DataLoader remains the same
train_loader = DataLoader(
    train_ds,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=torch.cuda.is_available()
)

# Validation transforms
val_transforms = Compose([
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[96, 96, 96],
        num_classes=num_classes,
        num_samples=my_num_samples,  # Use 1 to get a single, consistent crop per image
        allow_missing_keys=True
    ),
])

# Create validation dataset
val_ds = CacheDataset(data=val_files, transform=non_random_transforms, cache_rate=1.0)

# Wrap the cached dataset to apply random transforms during iteration
val_ds = Dataset(data=val_ds, transform=random_transforms)

# Create validation DataLoader
val_loader = DataLoader(
    val_ds,
    batch_size=val_batch_size,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
    shuffle=False,  # Ensure the data order remains consistent
)

## Model setup

In [None]:
"""
Derived from:
https://github.com/cellcanvas/album-catalog/blob/main/solutions/copick/compare-picks/solution.py
"""

import numpy as np
import pandas as pd

from scipy.spatial import KDTree


class ParticipantVisibleError(Exception):
    pass


def compute_metrics(reference_points, reference_radius, candidate_points):
    num_reference_particles = len(reference_points)
    num_candidate_particles = len(candidate_points)

    if len(reference_points) == 0:
        return 0, num_candidate_particles, 0

    if len(candidate_points) == 0:
        return 0, 0, num_reference_particles

    ref_tree = KDTree(reference_points)
    candidate_tree = KDTree(candidate_points)
    raw_matches = candidate_tree.query_ball_tree(ref_tree, r=reference_radius)
    matches_within_threshold = []
    for match in raw_matches:
        matches_within_threshold.extend(match)
    # Prevent submitting multiple matches per particle.
    # This won't be be strictly correct in the (extremely rare) case where true particles
    # are very close to each other.
    matches_within_threshold = set(matches_within_threshold)
    tp = int(len(matches_within_threshold))
    fp = int(num_candidate_particles - tp)
    fn = int(num_reference_particles - tp)
    return tp, fp, fn


def score(
        solution: pd.DataFrame,
        submission: pd.DataFrame,
        row_id_column_name: str,
        distance_multiplier: float,
        beta: int) -> float:
    '''
    F_beta
      - a true positive occurs when
         - (a) the predicted location is within a threshold of the particle radius, and
         - (b) the correct `particle_type` is specified
      - raw results (TP, FP, FN) are aggregated across all experiments for each particle type
      - f_beta is calculated for each particle type
      - individual f_beta scores are weighted by particle type for final score
    '''

    particle_radius = {
        'apo-ferritin': 60,
        'beta-amylase': 65,
        'beta-galactosidase': 90,
        'ribosome': 150,
        'thyroglobulin': 130,
        'virus-like-particle': 135,
    }

    weights = {
        'apo-ferritin': 1,
        'beta-amylase': 0,
        'beta-galactosidase': 2,
        'ribosome': 1,
        'thyroglobulin': 2,
        'virus-like-particle': 1,
    }

    particle_radius = {k: v * distance_multiplier for k, v in particle_radius.items()}

    # Filter submission to only contain experiments found in the solution split
    split_experiments = set(solution['experiment'].unique())
    submission = submission.loc[submission['experiment'].isin(split_experiments)]

    # Only allow known particle types
    if not set(submission['particle_type'].unique()).issubset(set(weights.keys())):
        raise ParticipantVisibleError('Unrecognized `particle_type`.')

    assert solution.duplicated(subset=['experiment', 'x', 'y', 'z']).sum() == 0
    assert particle_radius.keys() == weights.keys()

    results = {}
    for particle_type in solution['particle_type'].unique():
        results[particle_type] = {
            'total_tp': 0,
            'total_fp': 0,
            'total_fn': 0,
        }

    for experiment in split_experiments:
        for particle_type in solution['particle_type'].unique():
            reference_radius = particle_radius[particle_type]
            select = (solution['experiment'] == experiment) & (solution['particle_type'] == particle_type)
            reference_points = solution.loc[select, ['x', 'y', 'z']].values

            select = (submission['experiment'] == experiment) & (submission['particle_type'] == particle_type)
            candidate_points = submission.loc[select, ['x', 'y', 'z']].values

            if len(reference_points) == 0:
                reference_points = np.array([])
                reference_radius = 1

            if len(candidate_points) == 0:
                candidate_points = np.array([])

            tp, fp, fn = compute_metrics(reference_points, reference_radius, candidate_points)

            results[particle_type]['total_tp'] += tp
            results[particle_type]['total_fp'] += fp
            results[particle_type]['total_fn'] += fn

    aggregate_fbeta = 0.0
    for particle_type, totals in results.items():
        tp = totals['total_tp']
        fp = totals['total_fp']
        fn = totals['total_fn']

        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0
        fbeta = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) if (precision + recall) > 0 else 0.0
        aggregate_fbeta += fbeta * weights.get(particle_type, 1.0)

    if weights:
        aggregate_fbeta = aggregate_fbeta / sum(weights.values())
    else:
        aggregate_fbeta = aggregate_fbeta / len(results)
    return aggregate_fbeta

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class FBetaLoss(nn.Module):
    def __init__(self, beta=4, eps=1e-6, class_weights=None, enzyme_classes=5):
        super().__init__()
        self.beta = beta
        self.eps = eps
        self.enzyme_classes = enzyme_classes
        
        self.class_weights = class_weights or {
            'apo-ferritin': 1,
            'beta-galactosidase': 2,
            'ribosome': 1,
            'thyroglobulin': 2,
            'virus-like-particle': 1,
        }
        self.weight_tensor = torch.tensor(
            [self.class_weights[k] for k in self.class_weights.keys()],
            dtype=torch.float32
        )

    def forward(self, y_pred, y_true):
        # Convert y_true to proper format
        y_true = y_true.squeeze(1).long()  # Remove channel dim & ensure integer type
        
        y_pred_probs = F.softmax(y_pred, dim=1)
        y_pred_enzymes = y_pred_probs[:, 1:self.enzyme_classes+1, ...]
        
        # One-hot encoding with proper type handling
        y_true_onehot = F.one_hot(y_true, num_classes=6).permute(0, 4, 1, 2, 3).float()
        y_true_enzymes = y_true_onehot[:, 1:self.enzyme_classes+1, ...]

        tp = (y_pred_enzymes * y_true_enzymes).sum(dim=(2, 3, 4))
        fp = (y_pred_enzymes * (1 - y_true_enzymes)).sum(dim=(2, 3, 4))
        fn = ((1 - y_pred_enzymes) * y_true_enzymes).sum(dim=(2, 3, 4))

        precision = tp / (tp + fp + self.eps)
        recall = tp / (tp + fn + self.eps)
        beta2 = self.beta ** 2
        f_beta = (1 + beta2) * (precision * recall) / (beta2 * precision + recall + self.eps)

        weight_tensor = self.weight_tensor.to(y_pred.device)
        weighted_f_beta = f_beta * weight_tensor
        aggregate_f_beta = weighted_f_beta.sum(dim=1) / weight_tensor.sum()
        
        return 1 - aggregate_f_beta.mean()

In [None]:
fbeta_loss_function = FBetaLoss()

In [None]:
class ParticleTverskyCrossEntropyLoss(nn.Module):
    def __init__(self, particle_weights, alpha=16/17, beta=1/17, ce_weight=1.0, tversky_weight=1.0, smooth=1e-6):
        """
        Args:
            particle_weights (torch.Tensor): Weights for 5 particles + background [0.0, 1.0, 2.0, 1.0, 2.0, 1.0]
            alpha: Tversky FP weight (prioritizes recall for beta=4)
            beta: Tversky FN weight
        """
        super().__init__()
        self.ce_weight = ce_weight
        self.tversky_weight = tversky_weight
        self.smooth = smooth
        self.alpha = alpha
        self.beta = beta
        
        # Cross-Entropy with particle weights (background weight=0)
        self.ce = nn.CrossEntropyLoss(weight=particle_weights)
        self.register_buffer('class_weights', particle_weights)  # [6] tensor

    def forward(self, inputs, targets):
        # Ensure targets have the correct shape and dtype
        ce_loss = self.ce(inputs, targets.squeeze(1).long())
    
        # Tversky Loss (3D-compatible)
        num_classes = inputs.shape[1]
        probs = F.softmax(inputs, dim=1)
        targets_onehot = F.one_hot(targets.squeeze(1).long(), num_classes).permute(0, 4, 1, 2, 3).float()  # BCDHW
    
        # Flatten spatial dimensions (3D)
        probs_flat = probs.view(probs.size(0), num_classes, -1)  # [B,6,D*H*W]
        targets_flat = targets_onehot.view(targets_onehot.size(0), num_classes, -1)
    
        # Calculate TP/FP/FN (broadcasted across classes)
        tp = (probs_flat * targets_flat).sum(dim=2)  # [B,6]
        fp = (probs_flat * (1 - targets_flat)).sum(dim=2)
        fn = ((1 - probs_flat) * targets_flat).sum(dim=2)
    
        # Tversky index per class
        tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)  # [B,6]
        tversky_loss = 1 - tversky
    
        # Apply class weights
        weighted_tversky = tversky_loss * self.class_weights  # [B,6]
        tversky_loss = weighted_tversky.mean()
    
        # Combined loss
        return self.ce_weight * ce_loss + self.tversky_weight * tversky_loss

In [None]:
# Class order: [background, apo-ferritin, beta-galactosidase, ribosome, thyroglobulin, virus-like-particle]
particle_weights = torch.tensor([
    0.0,  # Background (weight=0)
    1.0,  # apo-ferritin
    2.0,  # beta-galactosidase (higher weight)
    1.0,  # ribosome
    2.0,  # thyroglobulin (higher weight)
    1.0   # virus-like-particle
], device='cuda')

wtversky_ce_loss_function = ParticleTverskyCrossEntropyLoss(
    particle_weights=particle_weights,
    ce_weight=1.0,  # Balance with Tversky
    tversky_weight=1.0  # Emphasize recall
)

In [None]:
#loss_function = DiceLoss(include_background=True, to_onehot_y=True, softmax=True)  # softmax=True for multiclass
tversky_loss_function = TverskyLoss(include_background=False, to_onehot_y=True, softmax=True, alpha=16/17, beta=1/17)  # softmax=True for multiclass
dice_metric = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)  # must use onehot for multiclass
recall_metric = ConfusionMatrixMetric(include_background=False, metric_name="recall", reduction="None")

In [None]:
post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
post_label = AsDiscrete(to_onehot=num_classes)

In [None]:
from typing import List, Tuple, Union
import cc3d


def extract_3d_patches_minimal_overlap(arrays: List[np.ndarray], patch_size: int) -> Tuple[List[np.ndarray], List[Tuple[int, int, int]]]:
    if not arrays or not isinstance(arrays, list):
        raise ValueError("Input must be a non-empty list of arrays")
    
    # Verify all arrays have the same shape
    shape = arrays[0].shape
    if not all(arr.shape == shape for arr in arrays):
        raise ValueError("All input arrays must have the same shape")
    
    if patch_size > min(shape):
        raise ValueError(f"patch_size ({patch_size}) must be smaller than smallest dimension {min(shape)}")
    
    m, n, l = shape
    patches = []
    coordinates = []
    
    # Calculate starting positions for each dimension
    x_starts = calculate_patch_starts(m, patch_size)
    y_starts = calculate_patch_starts(n, patch_size)
    z_starts = calculate_patch_starts(l, patch_size)
    
    # Extract patches from each array
    for arr in arrays:
        for x in x_starts:
            for y in y_starts:
                for z in z_starts:
                    patch = arr[
                        x:x + patch_size,
                        y:y + patch_size,
                        z:z + patch_size
                    ]
                    patches.append(patch)
                    coordinates.append((x, y, z))
    
    return patches, coordinates

    
def reconstruct_array(patches: List[np.ndarray], 
                     coordinates: List[Tuple[int, int, int]], 
                     original_shape: Tuple[int, int, int]) -> np.ndarray:
    reconstructed = np.zeros(original_shape, dtype=np.int64)  # To track overlapping regions
    
    patch_size = patches[0].shape[0]
    
    for patch, (x, y, z) in zip(patches, coordinates):
        reconstructed[
            x:x + patch_size,
            y:y + patch_size,
            z:z + patch_size
        ] = patch
        
    
    return reconstructed

    
def calculate_patch_starts(dimension_size: int, patch_size: int) -> List[int]:
    if dimension_size <= patch_size:
        return [0]
        
    # Calculate number of patches needed
    n_patches = np.ceil(dimension_size / patch_size)
    
    if n_patches == 1:
        return [0]
    
    # Calculate overlap
    total_overlap = (n_patches * patch_size - dimension_size) / (n_patches - 1)
    
    # Generate starting positions
    positions = []
    for i in range(int(n_patches)):
        pos = int(i * (patch_size - total_overlap))
        if pos + patch_size > dimension_size:
            pos = dimension_size - patch_size
        if pos not in positions:  # Avoid duplicates
            positions.append(pos)
    
    return positions
    

def dict_to_df(coord_dict, experiment_name):
    # Create lists to store data
    all_coords = []
    all_labels = []
    
    # Process each label and its coordinates
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    # Concatenate all coordinates
    all_coords = np.vstack(all_coords)
    
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })

    return df

In [None]:
id_to_name = {1: "apo-ferritin", 
              #2: "beta-amylase",
              2: "beta-galactosidase", 
              3: "ribosome", 
              4: "thyroglobulin", 
              5: "virus-like-particle"}

In [None]:
classes = [1, 2, 3, 4, 5]

In [None]:
def ensemble_prediction_tta(model, input_tensor, threshold=0.05):
    probs_list = []
    
    data_copy0 = input_tensor.clone()
    data_copy0 = torch.flip(data_copy0, dims=[2])
    data_copy1 = input_tensor.clone()
    data_copy1 = torch.flip(data_copy1, dims=[3])
    data_copy2 = input_tensor.clone()
    data_copy2 = torch.flip(data_copy2, dims=[4])
    data_copy3 = input_tensor.clone()
    data_copy3 = data_copy3.rot90(1, dims=[3, 4])
    
    model_output0 = model(input_tensor)
    model_output1 = model(data_copy0)
    model_output1 = torch.flip(model_output1, dims=[2])
    model_output2 = model(data_copy1)
    model_output2 = torch.flip(model_output2, dims=[3])
    model_output3 = model(data_copy2)
    model_output3 = torch.flip(model_output3, dims=[4])
    
    probs0 = torch.softmax(model_output0[0], dim=0)
    probs1 = torch.softmax(model_output1[0], dim=0)
    probs2 = torch.softmax(model_output2[0], dim=0)
    probs3 = torch.softmax(model_output3[0], dim=0)
    
    probs_list.append(probs0)
    probs_list.append(probs1)
    probs_list.append(probs2)
    probs_list.append(probs3)
    
    avg_probs = torch.mean(torch.stack(probs_list), dim=0)
    thresh_probs = avg_probs > threshold
    _, max_classes = thresh_probs.max(dim=0)
    
    return max_classes

In [None]:
inference_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS")
])

In [None]:
loss_function_dict = {
    'fbeta': fbeta_loss_function,
    'wtversky_ce': wtversky_ce_loss_function,
    'tversky': tversky_loss_function
}

In [None]:
def objective(trial: optuna.Trial) -> float:
    learning_rate = trial.suggest_float('learning_rate', 3e-4, 3e-3)
    third_channel = trial.suggest_int('third_channel', 80, 128)
    fourth_channel = trial.suggest_int('fourth_channel', 80, 128)
    last_stride = trial.suggest_int('last_stride', 1, 2)
    #batch_size = trial.suggest_int('batch_size', 1, 2)
    num_res_units = trial.suggest_int('num_res_units', 1, 4)
    loss_function_name = trial.suggest_categorical('loss_function', ['fbeta', 'wtversky_ce', 'tversky'])
    loss_function = loss_function_dict[loss_function_name]

    # Create UNet, DiceLoss and Adam optimizer
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=num_classes,
        channels=(48, 64, third_channel, fourth_channel),
        strides=(2, 2, last_stride),
        num_res_units=num_res_units,
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), learning_rate)

    model.to('cuda')
    model = torch.compile(model)

    torch.set_float32_matmul_precision('high')
    
    scaler = torch.amp.GradScaler('cuda')
    autocast_dtype = torch.float16

    best_val_score = 0

    max_epochs = 15
    
    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
    
            optimizer.zero_grad()

            with torch.autocast(device_type='cuda', dtype=autocast_dtype):
                outputs = model(inputs)
                loss = tversky_loss_function(outputs, labels)
                #fbeta_loss = fbeta_loss_function(outputs, labels)
                #tversky_ce_loss = tversky_ce_loss_function(outputs, labels)
                #total_loss = fbeta_loss * 0.7 + tversky_ce_loss * 0.3

            scaler.scale(loss).backward()  
            scaler.step(optimizer)  
            scaler.update() 
            
            epoch_loss += loss.item()
            #print(f"batch {step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
            
        epoch_loss /= step
        #epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        #mlflow.log_metric("train_loss", epoch_loss, step=epoch+1)
        
        model.eval()
        with torch.no_grad():
            with torch.autocast(device_type='cuda', dtype=autocast_dtype):
                
                solution_dfs_list = []
                submission_dfs_list = []
                exp_names_list = []
    
                for i in range(len(val_files_copy)):
                    val_data = val_files_copy[i]

                    exp_name = val_data['label_df']['experiment'][0]
                    if exp_name in exp_names_list:
                        continue
                    else:
                        exp_names_list.append(exp_name)
        
                    tomo = val_data['image']
                    tomo_patches, coordinates = extract_3d_patches_minimal_overlap([tomo], 96)
                    tomo_patched_data = [{"image": img} for img in tomo_patches]
                    tomo_ds = CacheDataset(data=tomo_patched_data, transform=inference_transforms, cache_rate=1.0, progress=False)
                    pred_masks = []
                    for i in range(len(tomo_ds)):
                        input_tensor = tomo_ds[i]['image'].unsqueeze(0).to("cuda")
                        max_classes = ensemble_prediction_tta(model, input_tensor, threshold=0.05)
                        pred_masks.append(max_classes.cpu().numpy())
        
                    reconstructed_mask = reconstruct_array(pred_masks, coordinates, tomo.shape)
                    location = {}
                    for c in classes:
                        cc = cc3d.connected_components(reconstructed_mask == c)
                        stats = cc3d.statistics(cc)
                        zyx = stats['centroids'][1:]
                        zyx_large = zyx[stats['voxel_counts'][1:] > 255]
                        xyz = np.ascontiguousarray(zyx_large[:, ::-1])
                        location[id_to_name[c]] = xyz
                    df = dict_to_df(location, val_data['label_df']['experiment'][0])
        
                    solution_df = val_data['label_df']
    
                    solution_dfs_list.append(solution_df)
                    submission_dfs_list.append(df)

                solution_concat_df = pd.concat(solution_dfs_list, ignore_index=True).reset_index(drop=True).reset_index().rename(columns={'index':'id'})[['id', 'experiment', 'particle_type', 'x', 'y', 'z']]
                submission_concat_df = pd.concat(submission_dfs_list, ignore_index=True).reset_index(drop=True).reset_index().rename(columns={'index':'id'})

                
        
                val_fbeta_score = score(solution_concat_df, submission_concat_df, 'id', 0.1*0.5, 4)
    
                print(f'Epoch {epoch+1} validation F beta score: ', val_fbeta_score)

                if val_fbeta_score > best_val_score:
                    best_val_score = val_fbeta_score

    return best_val_score

In [None]:
notebook_number = '20250130_03'
optuna_n_trials = 50

with tqdm(total=optuna_n_trials, desc="Optimizing", unit="trial") as pbar:
    
    # Define a callback function to update the progress bar
    def progress_bar_callback(study, trial):
        pbar.update(1)

    study = optuna.create_study(
        direction="maximize",
        sampler=optunahub.load_module("samplers/auto_sampler").AutoSampler(),
        storage="sqlite:////home/max1024/Python Notebooks/czii/optuna_study/db.sqlite3",
        study_name=f"czii_3D_UNet_param_tune_{notebook_number}"
    )
    study.optimize(objective, n_trials=optuna_n_trials, callbacks=[progress_bar_callback])

# References

1. https://www.kaggle.com/code/hideyukizushi/czii-yolo11-unet3d-monai-lb-707