In [None]:
!pip install torch numpy pandas zarr tqdm scikit-learn matplotlib json

In [1]:
import os
import json
import random
import math
import shutil
import logging
from typing import List, Tuple, Dict, Optional
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.cuda.amp import GradScaler, autocast
import torch.utils.checkpoint
from torch.nn.parallel import DataParallel

from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    OneCycleLR
)

import zarr
from zarr.storage import DirectoryStore

from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_notebook

import warnings
warnings.filterwarnings('ignore')

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

np.set_printoptions(precision=4, suppress=True)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
pd.set_option('display.float_format', lambda x: '%.3f' % x)

from typing import (
    Dict,
    List,
    Tuple,
    Optional,
    Union,
    Any,
    Callable,
    Iterator,
    Sequence,
    TypeVar,
    Generic,
    Protocol,
    runtime_checkable
)

from dataclasses import dataclass
from pathlib import Path

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()

PathLike = Union[str, Path]
JsonDict = Dict[str, Any]
ModelOutput = Dict[str, torch.Tensor]

def set_seed(seed=42):
    """Fix random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def process_json_files(base_dir, experiments, particle_types, output_dir):
    """Process all JSON files and save to output directory"""
    os.makedirs(output_dir, exist_ok=True)
    
    for experiment in experiments:
        for particle_type in particle_types:
            if particle_type == "beta-amylase":
                continue
                
            input_path = os.path.join(
                base_dir, 
                f"train/overlay/ExperimentRuns/{experiment}/Picks/{particle_type}.json"
            )
            output_path = os.path.join(
                output_dir, 
                f"{experiment}_{particle_type}_processed.json"
            )
            
            if not os.path.exists(input_path):
                print(f"File not found: {input_path}")
                continue
                
            with open(input_path, 'r') as f:
                data = json.load(f)
                
            if "points" not in data:
                print(f"No points found in {input_path}")
                continue
                
            processed_points = [
                {
                    "x": p["location"]["x"],
                    "y": p["location"]["y"],
                    "z": p["location"]["z"]
                }
                for p in data["points"]
            ]
            
            with open(output_path, 'w') as f:
                json.dump(processed_points, f, indent=2)
            print(f"Processed {input_path} -> {output_path}")

class CryoET3DDataset(Dataset):
    """Dataset class for 3D CryoET data"""
    def __init__(
        self,
        zarr_path,
        labels=None,
        particle_type=None, 
        patch_size=(64, 128, 128),
        stride_3d=(32, 64, 64),
        skip_negatives_ratio=0.7,
        max_patches=4000,
        normalize_coords=True,
        voxel_size=10.0,
        origin_offset=(0, 0, 0),
        augment=True
    ):
        super().__init__()
        self.zarr_path = zarr_path
        self.labels = labels if labels else []
        self.particle_type = particle_type
        self.patch_d, self.patch_h, self.patch_w = patch_size
        self.stride_d, self.stride_h, self.stride_w = stride_3d
        self.skip_negatives_ratio = skip_negatives_ratio
        self.max_patches = max_patches
        self.normalize_coords = normalize_coords
        self.voxel_size = voxel_size
        self.origin_offset = origin_offset
        self.augment = augment

        
        self.particle_types = [
            "apo-ferritin",
            "beta-galactosidase",
            "ribosome",
            "thyroglobulin",
            "virus-like-particle"
        ]
        self.particle_type_to_idx = {
            ptype: idx for idx, ptype in enumerate(self.particle_types)
        }

        self.volume = self._load_and_normalize_zarr()
        self.D, self.H, self.W = self.volume.shape
        
        self._normalize_labels()
        self.patch_positions = self._build_balanced_patches()
        
    def _load_and_normalize_zarr(self):
        """Load and normalize the volume data"""
        store = DirectoryStore(self.zarr_path)
        zf = zarr.open(store, mode='r')
        volume = zf["0"][:] 

        mean = np.mean(volume)
        std = np.std(volume)
        volume = (volume - mean) / (std + 1e-6)

        p2, p98 = np.percentile(volume, (2, 98))
        volume = np.clip(volume, p2, p98)
        volume = (volume - p2) / (p98 - p2 + 1e-6)

        return volume
        
    def _normalize_labels(self):
        """Convert raw coordinates to voxel space"""
        scaled = []
        for lbl in self.labels:
            rx, ry, rz = lbl["x"], lbl["y"], lbl["z"]
            vx = (rx - self.origin_offset[0]) / self.voxel_size
            vy = (ry - self.origin_offset[1]) / self.voxel_size
            vz = (rz - self.origin_offset[2]) / self.voxel_size
            
            if 0 <= vx < self.W and 0 <= vy < self.H and 0 <= vz < self.D:
                scaled.append({"x": vx, "y": vy, "z": vz})
                
        self.labels = scaled
        
    def _build_balanced_patches(self):
        """Create balanced positive/negative patches"""
        positions = []
        positive_positions = []
        negative_positions = []

        for z in range(0, self.D - self.patch_d + 1, self.stride_d):
            for y in range(0, self.H - self.patch_h + 1, self.stride_h):
                for x in range(0, self.W - self.patch_w + 1, self.stride_w):
                    pos = (z, y, x)
                    if self._has_particle_in_patch(pos):
                        positive_positions.append(pos)
                    else:
                        negative_positions.append(pos)

        num_positives = len(positive_positions)
        if num_positives == 0:
            num_samples = min(self.max_patches, len(negative_positions))
            positions = random.sample(negative_positions, num_samples)
        else:
            neg_factor = (1 / max(self.skip_negatives_ratio, 1e-6) - 1)
            num_negatives = min(
                int(num_positives * neg_factor),
                len(negative_positions)
            )
            sampled_negatives = random.sample(negative_positions, num_negatives)
            positions = positive_positions + sampled_negatives

        random.shuffle(positions)
        if self.max_patches:
            positions = positions[:self.max_patches]

        return positions
        
    def _has_particle_in_patch(self, pos):
        """Check if patch contains a particle"""
        z0, y0, x0 = pos
        z1, y1, x1 = z0 + self.patch_d, y0 + self.patch_h, x0 + self.patch_w
        
        for lbl in self.labels:
            if (x0 <= lbl["x"] < x1 and 
                y0 <= lbl["y"] < y1 and 
                z0 <= lbl["z"] < z1):
                return True
        return False
        
    def _create_particle_type_tensor(self):
        """Create one-hot encoded tensor for particle type"""
        encoding = torch.zeros(len(self.particle_types))
        if self.particle_type in self.particle_type_to_idx:
            encoding[self.particle_type_to_idx[self.particle_type]] = 1.0
        return encoding

    def __len__(self):
        return len(self.patch_positions)
        
    def __getitem__(self, idx):
        z0, y0, x0 = self.patch_positions[idx]
        z1, y1, x1 = z0 + self.patch_d, y0 + self.patch_h, x0 + self.patch_w
    
        patch_data = self.volume[z0:z1, y0:y1, x0:x1]
    
        if self.augment:
            if random.random() < 0.5:
                patch_data = patch_data[::-1, :, :]   
            if random.random() < 0.5:
                patch_data = patch_data[:, ::-1, :]  
            if random.random() < 0.5:
                patch_data = patch_data[:, :, ::-1] 
        
        patch_data = patch_data.copy()
    
        patch_tensor = torch.tensor(patch_data, dtype=torch.float32).unsqueeze(0)
    
        cls_label = 0
        coord_label = torch.zeros(3, dtype=torch.float32)
        particle_type_label = self._create_particle_type_tensor()
    
        picks_in_patch = [
            {"x": lbl["x"] - x0, "y": lbl["y"] - y0, "z": lbl["z"] - z0}
            for lbl in self.labels
            if x0 <= lbl["x"] < x1 and y0 <= lbl["y"] < y1 and z0 <= lbl["z"] < z1
        ]
    
        if len(picks_in_patch) == 1:
            cls_label = 1
            px, py, pz = picks_in_patch[0]["x"], picks_in_patch[0]["y"], picks_in_patch[0]["z"]
            if self.normalize_coords:
                coord_label = torch.tensor([px / self.patch_w, py / self.patch_h, pz / self.patch_d], dtype=torch.float32)
            else:
                coord_label = torch.tensor([px, py, pz], dtype=torch.float32)
    
        return patch_tensor, {
            "class": torch.tensor(cls_label, dtype=torch.long),
            "coords": coord_label,
            "particle_type": particle_type_label,
            "offset": torch.tensor([z0, y0, x0], dtype=torch.long)
        }


def build_datasets(base_dir, particle_types, experiments, patch_size, stride_3d, 
                  max_patches, output_dir, val_split=0.2):
    """Build training and validation datasets"""
    
    num_val = int(len(experiments) * val_split)
    train_experiments = experiments[:-num_val]
    val_experiments = experiments[-num_val:]
    
    def build_dataset(exps):
        datasets = []
        for experiment in exps:
            for ptype in particle_types:
                if ptype == "beta-amylase":
                    continue

                zarr_path = os.path.join(
                    base_dir,
                    f"train/static/ExperimentRuns/{experiment}/VoxelSpacing10.000/denoised.zarr"
                )
                json_file = os.path.join(output_dir, f"{experiment}_{ptype}_processed.json")

                if not os.path.exists(json_file):
                    continue

                with open(json_file, 'r') as f:
                    labels = json.load(f)

                try:
                    dataset = CryoET3DDataset(
                        zarr_path=zarr_path,
                        labels=labels,
                        particle_type=ptype,  
                        patch_size=patch_size,
                        stride_3d=stride_3d,
                        skip_negatives_ratio=0.7,
                        max_patches=max_patches,
                        normalize_coords=True,
                        augment=True
                    )
                    datasets.append(dataset)
                except Exception as e:
                    print(f"Error building dataset for {experiment}-{ptype}: {e}")

        if not datasets:
            raise ValueError("No valid datasets found")
        return ConcatDataset(datasets)
    
    train_dataset = build_dataset(train_experiments)
    val_dataset = build_dataset(val_experiments)
    
    return train_dataset, val_dataset

class SqueezeExcitation(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, channels, reduction_ratio=16):
        super().__init__()
        reduced_channels = max(channels // reduction_ratio, 8)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Conv3d(channels, reduced_channels, 1)
        self.fc2 = nn.Conv3d(reduced_channels, channels, 1)
        
    def forward(self, x):
        squeeze = self.avg_pool(x)
        excitation = F.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation))
        return x * excitation

class ResidualBlock3D(nn.Module):
    """3D Residual block with squeeze-excitation"""
    def __init__(self, channels, se_ratio=16):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(channels)
        self.se = SqueezeExcitation(channels, se_ratio)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out += residual
        return F.relu(out)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AxialAttention(nn.Module):
    """Memory-efficient axial attention for 3D data"""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv_d = nn.Linear(dim, dim * 3)
        self.qkv_h = nn.Linear(dim, dim * 3)
        self.qkv_w = nn.Linear(dim, dim * 3)

        self.proj_d = nn.Linear(dim, dim)
        self.proj_h = nn.Linear(dim, dim)
        self.proj_w = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        assert len(x.shape) == 5, f"Expected 5D input (B, D, H, W, C), but got {x.shape}"
        B, D, H, W, C = x.shape

        def attention(x, qkv_proj, proj):
            shape = x.shape
            x = x.reshape(-1, shape[-2], C) 
            qkv = qkv_proj(x).reshape(-1, shape[-2], 3, self.num_heads, self.head_dim)
            qkv = qkv.permute(2, 0, 3, 1, 4) 
            q, k, v = qkv[0], qkv[1], qkv[2]

            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.dropout(attn)

            x = (attn @ v).transpose(1, 2).reshape(-1, shape[-2], C)
            x = proj(x).reshape(shape)
            return x

        x_d = attention(x.permute(0, 2, 3, 1, 4), self.qkv_d, self.proj_d)
        x_h = attention(x.permute(0, 1, 3, 2, 4), self.qkv_h, self.proj_h)
        x_w = attention(x.permute(0, 1, 2, 3, 4), self.qkv_w, self.proj_w)

        return x_d.permute(0, 3, 1, 2, 4) + x_h.permute(0, 1, 3, 2, 4) + x_w


class SqueezeExcitation(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, channels, reduction_ratio=16):
        super().__init__()
        reduced_channels = max(channels // reduction_ratio, 8)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Conv3d(channels, reduced_channels, kernel_size=1)
        self.fc2 = nn.Conv3d(reduced_channels, channels, kernel_size=1)

    def forward(self, x):
        squeeze = self.avg_pool(x)  
        excitation = F.relu(self.fc1(squeeze))  
        excitation = torch.sigmoid(self.fc2(excitation))  
        return x * excitation 


class ResidualBlock3D(nn.Module):
    """3D Residual block with squeeze-excitation"""
    def __init__(self, channels, se_ratio=16):
        super().__init__()
        self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(channels)
        self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(channels)
        self.se = SqueezeExcitation(channels, se_ratio)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))  
        out = self.bn2(self.conv2(out))  
        out = self.se(out) 
        out += residual 
        return F.relu(out)  

class ImprovedFeaturePyramid(nn.Module):
    """Feature Pyramid with residual connections"""
    def __init__(self, in_dim=1, dim=512):
        super().__init__()
        self.init_conv = nn.Conv3d(in_dim, dim, kernel_size=1)

        self.down = nn.ModuleList([
            nn.Sequential(
                ResidualBlock3D(dim),
                nn.Conv3d(dim, dim, kernel_size=2, stride=2)
            ) for _ in range(3)
        ])

        self.lateral = nn.ModuleList([
            nn.Conv3d(dim, dim, kernel_size=1)
            for _ in range(3)
        ])

        self.up = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose3d(dim, dim, kernel_size=2, stride=2),
                ResidualBlock3D(dim)
            ) for _ in range(3)
        ])

        self.se = nn.ModuleList([
            SqueezeExcitation(dim)
            for _ in range(4)
        ])

    def forward(self, x):
        x = self.init_conv(x)
        features = [self.se[0](x)]  

        for i, down in enumerate(self.down):
            x = down(x)
            features.append(self.se[i + 1](x))

        laterals = [lateral(feat) for feat, lateral in zip(features, self.lateral)]

        results = [features[-1]]
        x = features[-1]

        for up, skip in zip(self.up, laterals[::-1]):
            x = up(x)
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='trilinear', align_corners=False)
            x = x + skip
            results.append(x)

        return results  


class TransformerBlock(nn.Module):
    """Transformer block with axial attention"""
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, drop_path=0.0, use_checkpointing=True):
        super().__init__()
        self.use_checkpointing = use_checkpointing
        self.drop_path = drop_path
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = AxialAttention(dim, num_heads, dropout)
        
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        if self.use_checkpointing and self.training:
            x = x + self._drop_path(torch.utils.checkpoint.checkpoint(self.attn, self.norm1(x)))
            x = x + self._drop_path(torch.utils.checkpoint.checkpoint(self.mlp, self.norm2(x)))
        else:
            x = x + self._drop_path(self.attn(self.norm1(x)))
            x = x + self._drop_path(self.mlp(self.norm2(x)))
        return x
        
    def _drop_path(self, x):
        if self.drop_path > 0.0 and self.training:
            keep_prob = 1 - self.drop_path
            mask = torch.zeros_like(x[0, 0]).bernoulli_(keep_prob)
            mask = mask / keep_prob
            mask = mask.expand_as(x)
            return x * mask
        return x

class ImprovedViT3D(nn.Module):
    """Improved Vision Transformer 3D"""
    def __init__(
        self,
        patch_size=(8, 16, 16),
        in_chans=1,
        embed_dim=512,
        depth=12,
        num_heads=16,
        mlp_ratio=4.0,
        num_classes=2,
        num_particle_types=5,
        dropout=0.1,
        stochastic_depth_prob=0.1,
        use_checkpointing=True
    ):
        super().__init__()

        self.feature_pyramid = ImprovedFeaturePyramid(in_chans, embed_dim // 4)

        self.patch_embed = nn.Conv3d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        self.norm = nn.LayerNorm(embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 100, embed_dim)) 
        self.pos_drop = nn.Dropout(dropout)

        dpr = [x.item() for x in torch.linspace(0, stochastic_depth_prob, depth)]

        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                drop_path=dpr[i],
                use_checkpointing=use_checkpointing
            )
            for i in range(depth)
        ])

        self.final_norm = nn.LayerNorm(embed_dim)

        self.class_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes)
        )

        self.coord_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, 3),
            nn.Sigmoid()
        )

        self.particle_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_particle_types),
            nn.Sigmoid()
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        """Initialize weights"""
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        """
        Forward pass for the ViT3D model.
        Expected input: [B, C, D, H, W]
        AxialAttention expects [B, D, H, W, C].
        """
        _ = self.feature_pyramid(x)  
    
        x = self.patch_embed(x)
        B, C, D, H, W = x.shape  
    
        x = x.permute(0, 2, 3, 4, 1)
    
        x = self.norm(x)  
    
        
        for block in self.blocks:
            x = block(x)  
    
        x = self.final_norm(x)
    
        x = x.mean(dim=(1, 2, 3))
    
        return {
            'logits': self.class_head(x),
            'coords': self.coord_head(x),
            'particle_types': self.particle_head(x)
        }




class FocalLoss(torch.nn.Module):
    """Focal loss for better handling of class imbalance"""
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            focal_loss = self.alpha[targets] * focal_loss
        return focal_loss.mean()

def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=15,
    save_dir='checkpoints',
    model_id=0
):
    """Training function with proper device handling and mixed precision"""
    os.makedirs(save_dir, exist_ok=True)
    
    model = model.to('cuda:0')
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        print(f"Using {torch.cuda.device_count()} GPUs")
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=2e-4,
        weight_decay=0.05,
        betas=(0.9, 0.999)
    )
    
    scheduler = OneCycleLR(
        optimizer,
        max_lr=2e-4,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        div_factor=25,
        final_div_factor=1000
    )
    
    scaler = GradScaler()
    
    class_weights = torch.tensor([1.0, 10.0]).to('cuda:0')
    focal_loss = FocalLoss(alpha=class_weights, gamma=2.0)
    coord_loss = nn.SmoothL1Loss()
    particle_loss = nn.BCEWithLogitsLoss()
    
    best_val_loss = float('inf')
    patience = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        
        for batch_idx, (volumes, meta) in enumerate(train_loader):
            volumes = volumes.to('cuda:0')
            class_labels = meta['class'].to('cuda:0')
            coord_labels = meta['coords'].to('cuda:0')
            particle_type_labels = meta['particle_type'].to('cuda:0')
            optimizer.zero_grad()
            
            with autocast():
                outputs = model(volumes)
                
                class_loss = focal_loss(outputs['logits'], class_labels)
                
                pos_mask = (class_labels == 1)
                if pos_mask.any():
                    coord_loss_val = coord_loss(outputs['coords'][pos_mask], coord_labels[pos_mask])
                    particle_loss_val = particle_loss(outputs['particle_types'][pos_mask], particle_type_labels[pos_mask])
                else:
                    coord_loss_val = torch.tensor(0.0, device='cuda:0')
                    particle_loss_val = torch.tensor(0.0, device='cuda:0')
                
                loss = class_loss + 2.0 * coord_loss_val + 0.5 * particle_loss_val
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            scheduler.step()
            
            total_loss += loss.item()
            
            if batch_idx % 50 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs} | Batch: {batch_idx}/{len(train_loader)} | '
                      f'Loss: {loss.item():.4f}')
        
        avg_train_loss = total_loss / len(train_loader)
        
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for volumes, meta in val_loader:
                volumes = volumes.to('cuda:0')
                class_labels = meta['class'].to('cuda:0')
                coord_labels = meta['coords'].to('cuda:0')
                particle_type_labels = meta['particle_type'].to('cuda:0')
                
                with autocast():
                    outputs = model(volumes)
                    class_loss = focal_loss(outputs['logits'], class_labels)
                    
                    pos_mask = (class_labels == 1)
                    if pos_mask.any():
                        coord_loss_val = coord_loss(outputs['coords'][pos_mask], coord_labels[pos_mask])
                        particle_loss_val = particle_loss(outputs['particle_types'][pos_mask], particle_type_labels[pos_mask])
                    else:
                        coord_loss_val = torch.tensor(0.0, device='cuda:0')
                        particle_loss_val = torch.tensor(0.0, device='cuda:0')
                    
                    loss = class_loss + 2.0 * coord_loss_val + 0.5 * particle_loss_val
                    val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        
        print(f'\nEpoch {epoch+1} Summary:')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience = 0
            save_path = os.path.join(save_dir, f'model_{model_id}_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': best_val_loss
            }, save_path)
            print(f'Saved best model to {save_path}')
        else:
            patience += 1
            if patience >= 5:
                print('Early stopping triggered')
                break
    
    return model

def validate_model(model, val_loader, device, focal_loss, coord_loss, particle_loss):
    """Validation function with autocast-safe losses"""
    model.eval()
    val_losses = {'total': 0.0, 'class': 0.0, 'coord': 0.0, 'particle': 0.0}
    
    with torch.no_grad():
        for volumes, meta in val_loader:
            volumes = volumes.to(device)
            class_labels = meta["class"].to(device)
            coord_labels = meta["coords"].to(device)
            particle_type_labels = meta["particle_type"].to(device)
            
            with autocast():
                outputs = model(volumes)
                loss_class = focal_loss(outputs['logits'], class_labels)
                
                pos_mask = (class_labels == 1)
                if pos_mask.any():
                    loss_coords = coord_loss(
                        outputs['coords'][pos_mask],
                        coord_labels[pos_mask]
                    )
                    loss_particle = particle_loss(
                        outputs['particle_types'][pos_mask],
                        particle_type_labels[pos_mask]
                    )
                else:
                    loss_coords = torch.tensor(0.0, device=device)
                    loss_particle = torch.tensor(0.0, device=device)
                
                total_loss = loss_class + 2.0 * loss_coords + 0.5 * loss_particle
            
            val_losses['total'] += total_loss.item()
            val_losses['class'] += loss_class.item()
            val_losses['coord'] += loss_coords.item()
            val_losses['particle'] += loss_particle.item()
    
    return val_losses['total'] / len(val_loader)

def validate_model(model, val_loader, device, focal_loss, coord_loss, particle_loss):
    """Validation function with detailed metrics"""
    model.eval()
    val_losses = {'total': 0.0, 'class': 0.0, 'coord': 0.0, 'particle': 0.0}
    
    with torch.no_grad():
        for volumes, meta in val_loader:
            volumes = volumes.to(device)
            class_labels = meta["class"].to(device)
            coord_labels = meta["coords"].to(device)
            
            outputs = model(volumes)
            
            loss_class = focal_loss(outputs['logits'], class_labels)
            
            pos_mask = (class_labels == 1)
            if pos_mask.any():
                loss_coords = coord_loss(
                    outputs['coords'][pos_mask],
                    coord_labels[pos_mask]
                )
                loss_particle = particle_loss(
                    outputs['particle_types'][pos_mask],
                    meta['particle_type'].to(device)[pos_mask]
                )
            else:
                loss_coords = torch.tensor(0.0, device=device)
                loss_particle = torch.tensor(0.0, device=device)
            
            total_loss = loss_class + 2.0 * loss_coords + 0.5 * loss_particle
            
            val_losses['total'] += total_loss.item()
            val_losses['class'] += loss_class.item()
            val_losses['coord'] += loss_coords.item()
            val_losses['particle'] += loss_particle.item()
    
    return val_losses['total'] / len(val_loader)

def predict_ensemble(models, zarr_path, patch_size, stride_3d, device="cuda", threshold=0.85):
    """Ensemble prediction with multi-GPU and test-time augmentation"""
    for model in models:
        if torch.cuda.device_count() > 1:
            model = DataParallel(model)
        model.to(device)
        model.eval()

    dataset = CryoET3DDataset(
        zarr_path=zarr_path,
        labels=[],
        patch_size=patch_size,
        stride_3d=stride_3d,
        skip_negatives_ratio=0.0,
        max_patches=None,
        normalize_coords=True,
        augment=False
    )

    loader = DataLoader(dataset, batch_size=8, num_workers=4, pin_memory=True)
    predictions = []

    with torch.no_grad():
        for volumes, meta in loader:
            volumes = volumes.to(device)
            offsets = meta["offset"]

            ensemble_probs = []
            ensemble_coords = []

            for model in models:
                class_logits, coords = model(volumes)
                probs = F.softmax(class_logits, dim=-1)[:, 1]
                ensemble_probs.append(probs.unsqueeze(0))
                ensemble_coords.append(coords.unsqueeze(0))

            mean_probs = torch.cat(ensemble_probs, dim=0).mean(0)
            mean_coords = torch.cat(ensemble_coords, dim=0).mean(0)

            for i in range(len(volumes)):
                if mean_probs[i] >= threshold:
                    coord = mean_coords[i].cpu().numpy()
                    z0, y0, x0 = offsets[i].tolist()

                    px = coord[0] * patch_size[2]
                    py = coord[1] * patch_size[1]
                    pz = coord[2] * patch_size[0]

                    gx = float(px + x0) * 10.0
                    gy = float(py + y0) * 10.0
                    gz = float(pz + z0) * 10.0

                    predictions.append({
                        "x": round(gx, 2),
                        "y": round(gy, 2),
                        "z": round(gz, 2),
                        "prob": float(mean_probs[i].cpu().item())
                    })

    return predictions

def nms_3d(predictions, iou_threshold=0.3, radius=15):
    """
    Apply Non-Maximum Suppression to 3D predictions
    Args:
        predictions: List of dictionaries containing predictions with x, y, z coordinates and prob scores
        iou_threshold: Distance threshold for suppression
        radius: Radius around each point to consider for suppression
    Returns:
        List of filtered predictions
    """
    if not predictions:
        return []
    
    predictions = sorted(predictions, key=lambda x: x['prob'], reverse=True)
    kept_predictions = []
    
    for pred in predictions:
        should_keep = True
        for kept in kept_predictions:
            dist = ((pred['x'] - kept['x'])**2 + 
                   (pred['y'] - kept['y'])**2 + 
                   (pred['z'] - kept['z'])**2)**0.5
            
            if dist < radius:
                should_keep = False
                break
        
        if should_keep:
            kept_predictions.append(pred)
    
    return kept_predictions

def generate_submission(
    models,
    test_experiments,
    particle_types,
    base_dir,
    patch_size,
    stride_3d,
    submission_file,
    device="cuda"
):
    """Generate submission with ensemble predictions and NMS"""
    all_predictions = []
    prediction_id = 0

    for experiment in test_experiments:
        print(f"\nProcessing experiment: {experiment}")
        zarr_path = os.path.join(
            base_dir, 
            f"test/static/ExperimentRuns/{experiment}/VoxelSpacing10.000/denoised.zarr"
        )
        
        if not os.path.exists(zarr_path):
            continue
        
        predictions = predict_ensemble(
            models=models,
            zarr_path=zarr_path,
            patch_size=patch_size,
            stride_3d=stride_3d,
            device=device,
            threshold=0.75 
        )
        
        filtered_preds = nms_3d(
            predictions, 
            iou_threshold=0.3, 
            radius=15
        )
        
        for pred in filtered_preds:
            particle_probs = pred["particle_type_probs"]
            
            for i, particle_type in enumerate(particle_types):
                if particle_probs[i] >= 0.5:  
                    all_predictions.append({
                        "id": prediction_id,
                        "experiment": experiment,
                        "particle_type": particle_type,
                        "x": round(float(pred["x"]), 5),
                        "y": round(float(pred["y"]), 5),
                        "z": round(float(pred["z"]), 5)
                    })
                    prediction_id += 1

    df = pd.DataFrame(all_predictions)
    df.to_csv(submission_file, index=False)
    print(f"Saved {len(all_predictions)} predictions to {submission_file}")    
