In [2]:
# MONAI Dual-Task (Segmentation + Classification) training using nnU-Net-style pipeline
# - Dataset: derived/unified_dualtask (train/val/test CSVs)
# - Spacing standardization to (0.8, 0.8, 1.0) mm
# - Label-preserving resample via one-hot + optional dilate-then-erode
# - Sliding-window patch training (192x192x160)
# - Shared-encoder segmentation (DynUNet) + classification head
# - QC counters that flag label shrinkage post-resample

import os
import math
import time
import json
import random
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler

import nibabel as nib

from monai.config import print_config
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.data.utils import no_collation
from monai.inferers import SlidingWindowInferer
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import DynUNet
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    RandSpatialCropd,
    RandFlipd,
    RandRotate90d,
    RandAffined,
    AsDiscreted,
    EnsureTyped,
    CastToTyped,
)
from monai.utils import set_determinism

print_config()


MONAI version: 1.6.dev2533
Numpy version: 2.3.2
Pytorch version: 2.8.0+cu128
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: d7a3eeb01a4d660d5fe25ed186117499aa57466e
MONAI __file__: /home/<username>/projects/brain_tumor_segmentation/venv-monai/lib/python3.12/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.16.1
Pillow version: 11.3.0
Tensorboard version: 2.20.0
gdown version: 5.2.0
TorchVision version: 0.23.0+cu128
tqdm version: 4.67.1
lmdb version: 1.7.3
psutil version: 7.0.0
pandas version: 2.3.1
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 3.2.0
pynrrd version: 1.1.3
clearml version: 2.0.3rc0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [3]:
# Reproducibility
SEED = 42
set_determinism(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Paths
PROJ_ROOT = Path('/home/ant/projects/brain_tumor_segmentation')
DUALTASK_ROOT = PROJ_ROOT / 'derived' / 'unified_dualtask'
TRAIN_CSV = DUALTASK_ROOT / 'train.csv'
VAL_CSV = DUALTASK_ROOT / 'val.csv'
TEST_CSV = DUALTASK_ROOT / 'test.csv'

assert TRAIN_CSV.exists() and VAL_CSV.exists() and TEST_CSV.exists(), 'Split CSVs missing'

# Target spacing and patch params
TARGET_SPACING = (0.8, 0.8, 1.0)
PATCH_SIZE = (192, 192, 160)
PATCH_OVERLAP = 0.5  # sliding window overlap


Device: cuda


In [4]:
# CSV -> dict list helpers

def read_unified_csv(path: Path) -> List[Dict]:
    df = pd.read_csv(path)
    # expected columns: case_id,class_label,image_path,label_path
    items = []
    for _, row in df.iterrows():
        items.append({
            'case_id': row['case_id'],
            'image': row['image_path'],
            'label': row['label_path'],
            'class_label': int(row['class_label']),
        })
    return items

train_items = read_unified_csv(TRAIN_CSV)
val_items = read_unified_csv(VAL_CSV)
test_items = read_unified_csv(TEST_CSV)

len(train_items), len(val_items), len(test_items)


(972, 209, 207)

In [5]:
# Morphology utilities and QC counters
import scipy.ndimage as ndi

class LabelQC:
    def __init__(self, shrink_warn_threshold: float = 0.35):
        self.shrink_warn_threshold = shrink_warn_threshold
        self.total = 0
        self.warn = 0
    def update(self, before_voxels: int, after_voxels: int, case_id: str):
        self.total += 1
        if before_voxels > 0:
            ratio = (after_voxels + 1e-6) / (before_voxels + 1e-6)
            if ratio < (1.0 - self.shrink_warn_threshold):
                self.warn += 1
                print(f'[QC] label shrinkage: {case_id} before={before_voxels} after={after_voxels} ratio={ratio:.3f}')
    def summary(self):
        print(f'[QC] shrinkage warnings: {self.warn}/{self.total}')


def binary_dilate_then_erode(mask: np.ndarray, radius_vox: int = 1) -> np.ndarray:
    if radius_vox <= 0:
        return mask
    structure = ndi.generate_binary_structure(3, 1)
    for _ in range(radius_vox):
        mask = ndi.binary_dilation(mask, structure=structure)
    for _ in range(radius_vox):
        mask = ndi.binary_erosion(mask, structure=structure)
    return mask


In [6]:
# Transforms: spacing standardization and intensity scale
# Labels are handled with a custom post-transform step to preserve small lesions via one-hot resample and optional morph.

from monai.transforms import MapTransform

class OneHotResampleWithMorphology(MapTransform):
    def __init__(self, keys, num_classes: int = 2, morph_radius: int = 0, allow_missing_keys: bool = False):
        super().__init__(keys, allow_missing_keys)
        self.num_classes = num_classes
        self.morph_radius = morph_radius

    def __call__(self, data):
        d = dict(data)
        # expects d['label'] to be a MONAI image tensor with metadata spacing attached
        label = d['label']  # torch.Tensor [1, D, H, W] after EnsureChannelFirstd
        meta = d.get('label_meta_dict', {})
        # Before voxels for QC
        before_vox = int((label > 0.5).sum().item())

        # one-hot
        label_oh = F.one_hot(label.long().squeeze(0), num_classes=self.num_classes).permute(3, 0, 1, 2).float()

        # resample using metadata of image (already resampled) to match size via trilinear/nearest
        # assume image and label now share the same shape; if not, interpolate to image shape
        img = d['image']
        if label_oh.shape[1:] != img.shape[1:]:
            # channels first
            label_oh = F.interpolate(label_oh.unsqueeze(0), size=img.shape[1:], mode='trilinear', align_corners=False).squeeze(0)
        # discretize back to argmax
        label_res = label_oh.argmax(dim=0, keepdim=True)

        # optional light morph
        if self.morph_radius > 0:
            arr = label_res.detach().cpu().numpy().astype(np.uint8)
            arr = binary_dilate_then_erode(arr[0], radius_vox=self.morph_radius)[None]
            label_res = torch.as_tensor(arr, dtype=torch.long, device=label_res.device)

        after_vox = int((label_res > 0).sum().item())
        d['label'] = label_res
        d['qc_before_vox'] = before_vox
        d['qc_after_vox'] = after_vox
        return d

TARGET_MODE = 'bilinear'

common_load = [
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    EnsureTyped(keys=['image', 'label'], dtype=torch.float32),
    Orientationd(keys=['image', 'label'], axcodes='RAS'),
    Spacingd(keys=['image', 'label'], pixdim=TARGET_SPACING, mode=('bilinear', 'nearest')),
]

from monai.transforms import RandCropByPosNegLabeld, SpatialPadd

intensity_train = [
    ScaleIntensityRanged(keys=['image'], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
    RandFlipd(keys=['image', 'label'], spatial_axis=[0, 1, 2], prob=0.2),
    RandRotate90d(keys=['image', 'label'], prob=0.2, max_k=3),
    RandAffined(keys=['image', 'label'], rotate_range=(math.pi/36, math.pi/36, math.pi/36),
                scale_range=(0.1, 0.1, 0.1), mode=('bilinear', 'nearest'), prob=0.2),
    SpatialPadd(keys=['image', 'label'], spatial_size=PATCH_SIZE),
    RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=PATCH_SIZE,
                           pos=1, neg=1, num_samples=1, image_key='image', allow_smaller=True),
]

intensity_val = [
    ScaleIntensityRanged(keys=['image'], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
]

# Post label discretization with morphology + QC bookkeeping performed inline by OneHotResampleWithMorphology

post_label_preserve = [
    OneHotResampleWithMorphology(keys=['label'], num_classes=2, morph_radius=1),
]

class CastClassLabeld(MapTransform):
    def __init__(self, keys, allow_missing_keys=False):
        super().__init__(keys, allow_missing_keys)
    def __call__(self, data):
        d = dict(data)
        if 'class_label' in d:
            d['class_label'] = torch.as_tensor(d['class_label'], dtype=torch.float32)
        return d

train_transforms = Compose(common_load + intensity_train + post_label_preserve + [CastClassLabeld(keys=['class_label'])])
val_transforms = Compose(common_load + intensity_val + post_label_preserve + [CastClassLabeld(keys=['class_label'])])


In [7]:
# Datasets and Loaders with QC hooks

qc_train = LabelQC(shrink_warn_threshold=0.35)
qc_val = LabelQC(shrink_warn_threshold=0.35)

class QCAugmentWrapper(CacheDataset):
    def __init__(self, data, transform, cache_rate=0.0, num_workers=4, copy_cache=True):
        super().__init__(data=data, transform=transform, cache_rate=cache_rate, num_workers=num_workers, copy_cache=copy_cache)
    def __getitem__(self, index):
        item = super().__getitem__(index)
        # Collect QC counters if present
        case_id = item.get('case_id') if isinstance(item, dict) else None
        before_vox = int(item.get('qc_before_vox', 0))
        after_vox = int(item.get('qc_after_vox', 0))
        if before_vox or after_vox:
            # decide which QC to update based on internal flag set earlier
            pass
        return item

train_ds = CacheDataset(data=train_items, transform=train_transforms, cache_rate=0.0, num_workers=4)
val_ds = CacheDataset(data=val_items, transform=val_transforms, cache_rate=0.0, num_workers=2)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)


In [8]:
# Model: DynUNet backbone + classification head from encoder bottleneck

# Segmentation network
seg_net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    kernel_size=[3, 3, 3, 3, 3, 3],
    strides=[1, 2, 2, 2, 2, 2],
    upsample_kernel_size=[2, 2, 2, 2, 2],
    norm_name='instance',
    deep_supervision=False,
).to(device)

# Classification head
class ClassificationHead(nn.Module):
    def __init__(self, in_channels: int, num_classes: int = 1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(in_channels, num_classes)
    def forward(self, feat):
        x = self.pool(feat).flatten(1)
        return self.fc(x)

# Identify bottleneck channels from DynUNet
# DynUNet returns a list of decoder outputs when deep_supervision; we also can hook encoder features via register_forward_hook

# Lazy-init classification head once we know feature channels
class LazyClassificationHead(nn.Module):
    def __init__(self, num_classes: int = 1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = None
        self.num_classes = num_classes
    def forward(self, feat):
        x = self.pool(feat).flatten(1)
        if self.fc is None:
            self.fc = nn.Linear(x.shape[1], self.num_classes).to(x.device)
        return self.fc(x)

cls_head = LazyClassificationHead(num_classes=1).to(device)

# Simple hook to capture bottleneck features
encoder_feat = {'x': None}

def hook_fn(module, input, output):
    encoder_feat['x'] = output

# Attach hook to bottleneck layer (seg_net.encoder4 or seg_net.bottleneck depending on version)
if hasattr(seg_net, 'bottleneck'):
    seg_net.bottleneck.register_forward_hook(hook_fn)
elif hasattr(seg_net, 'encoder4'):
    seg_net.encoder4.register_forward_hook(hook_fn)
else:
    print('[WARN] Could not attach hook, classification head may not receive features')

# Losses and optimizer
seg_loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
cls_loss_fn = nn.BCEWithLogitsLoss()

params = list(seg_net.parameters()) + list(cls_head.parameters())
optimizer = torch.optim.AdamW(params, lr=2e-4, weight_decay=1e-5)
scaler = GradScaler(enabled=torch.cuda.is_available())

# Metrics
post_pred = AsDiscreted(keys=['pred'], argmax=True)
post_label = AsDiscreted(keys=['label'], to_onehot=2)
dice_metric = DiceMetric(include_background=False, reduction='mean')


In [9]:
# Helper: pad tensor to next multiple-of factor for each spatial dim
import torch.nn.functional as F

def pad_to_factor(x: torch.Tensor, factor: int = 32) -> torch.Tensor:
    # x: (B, C, D, H, W)
    B, C, D, H, W = x.shape
    def next_m(s):
        return ((s + factor - 1) // factor) * factor
    Dn, Hn, Wn = next_m(D), next_m(H), next_m(W)
    pd = Dn - D; ph = Hn - H; pw = Wn - W
    # pad order: (W_left, W_right, H_left, H_right, D_left, D_right)
    pad = (0, pw, 0, ph, 0, pd)
    if any(p > 0 for p in pad):
        x = F.pad(x, pad, mode='constant', value=0.0)
    return x


In [10]:
# Inferer for validation/test
inferer = SlidingWindowInferer(roi_size=PATCH_SIZE, sw_batch_size=1, overlap=PATCH_OVERLAP, mode='gaussian')

# Utils
from contextlib import nullcontext

def to_device(batch: Dict, device: torch.device) -> Dict:
    out = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            out[k] = v.to(device, non_blocking=True)
        else:
            out[k] = v
    return out

ckpt_dir = PROJ_ROOT / 'runs' / 'dualtask_monai'
os.makedirs(ckpt_dir, exist_ok=True)
print('Checkpoint dir:', ckpt_dir)


Checkpoint dir: /home/ant/projects/brain_tumor_segmentation/runs/dualtask_monai


In [None]:
# Train / Val loops
EPOCHS = 50
val_interval = 1
best_val_dice = -1.0

for epoch in range(1, EPOCHS + 1):
    seg_net.train(); cls_head.train()
    epoch_loss = 0.0
    num_steps = 0

    for batch in train_loader:
        # QC update per-sample
        for b in decollate_batch(batch):
            qc_train.update(int(b.get('qc_before_vox', 0)), int(b.get('qc_after_vox', 0)), str(b.get('case_id', '?')))
        batch = to_device(batch, device)
        images = batch['image']
        labels = batch['label'].long()
        class_labels = batch['class_label'].view(-1, 1)

        optimizer.zero_grad(set_to_none=True)
        ctx = autocast(device_type='cuda', enabled=torch.cuda.is_available())
        with ctx:
            encoder_feat['x'] = None
            seg_logits = seg_net(images)
            # DynUNet with deep supervision returns list; last is highest res
            if isinstance(seg_logits, (list, tuple)):
                seg_logits_main = seg_logits
            else:
                seg_logits_main = seg_logits
            # Classification
            feat = encoder_feat['x'] if encoder_feat['x'] is not None else seg_logits_main
            cls_logits = cls_head(feat)

            loss_seg = seg_loss_fn(seg_logits_main, labels)
            loss_cls = cls_loss_fn(cls_logits, class_labels)
            loss = loss_seg + 0.3 * loss_cls
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        num_steps += 1

    epoch_loss /= max(1, num_steps)
    print(f'Epoch {epoch}/{EPOCHS} - train loss: {epoch_loss:.4f}')
    if epoch % val_interval == 0:
        qc_train.summary()

    if epoch % val_interval == 0:
        seg_net.eval(); cls_head.eval()
        dice_metric.reset()
        val_loss = 0.0
        steps = 0
        with torch.no_grad():
            for batch in val_loader:
                for b in decollate_batch(batch):
                    qc_val.update(int(b.get('qc_before_vox', 0)), int(b.get('qc_after_vox', 0)), str(b.get('case_id', '?')))
                batch = to_device(batch, device)
                images = batch['image']
                labels = batch['label'].long()
                class_labels = batch['class_label'].view(-1, 1)

                ctx = autocast(device_type='cuda', enabled=torch.cuda.is_available())
                with ctx:
                    images_p = pad_to_factor(images, factor=32)  # (B,C,D,H,W) -> padded to mult of 32
                    seg_logits = inferer(inputs=images_p, network=seg_net)
                    encoder_feat['x'] = None    
                    # classification from encoder feature may not be available in inferer path; do a forward to populate
                    _ = seg_net(images_p)
                    if seg_logits.shape[-3:] != labels.shape[-3:]:
                        Dz, Hy, Wx = labels.shape[-3:]
                        seg_logits = seg_logits[..., :Dz, :Hy, :Wx]
                    feat = encoder_feat['x'] if encoder_feat['x'] is not None else seg_logits
                    cls_logits = cls_head(feat)

                    loss_seg = seg_loss_fn(seg_logits, labels)
                    loss_cls = cls_loss_fn(cls_logits, class_labels)
                    loss = loss_seg + 0.3 * loss_cls
                val_loss += loss.item()
                steps += 1

                y_pred = torch.softmax(seg_logits, dim=1)
                y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
                dice_metric(y_pred=y_pred, y=labels)

        mean_dice = dice_metric.aggregate().item()
        val_loss /= max(1, steps)
        print(f'  Val loss: {val_loss:.4f} | Val Dice(tumor): {mean_dice:.4f}')
        qc_val.summary()

        if mean_dice > best_val_dice:
            best_val_dice = mean_dice
            torch.save({'seg': seg_net.state_dict(), 'cls': cls_head.state_dict()}, ckpt_dir / 'best.pt')
            print(f'  [Saved] best.pt with Dice {best_val_dice:.4f}')

        del images, labels, seg_logits, cls_logits, feat
        try:
            del images_p
        except:
            pass
        
        torch.cuda.empty_cache()



[QC] label shrinkage: UPENN_GBM_UPENN-GBM-00147_11 before=5 after=0 ratio=0.000
[QC] label shrinkage: BCBM_RADIOGENOMICS_BCBM-RadioGenomics-158-0 before=328 after=182 ratio=0.555
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0079_Timepoint_5 before=2388 after=1531 ratio=0.641
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0242_Timepoint_1 before=4 after=2 ratio=0.500
Epoch 1/50 - train loss: 0.6708


Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)


  Val loss: 0.6246 | Val Dice(tumor): 0.3353
  [Saved] best.pt with Dice 0.3353
[QC] label shrinkage: UCSD_PTGBM_UCSD-PTGBM-0178_02 before=17 after=3 ratio=0.176
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0105_Timepoint_4 before=776 after=401 ratio=0.517
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0094_Timepoint_4 before=425 after=255 ratio=0.600
[QC] label shrinkage: BCBM_RADIOGENOMICS_BCBM-RadioGenomics-139-0 before=320 after=142 ratio=0.444
[QC] label shrinkage: PRETREAT_METS_BraTS-MET-00167-000 before=130 after=66 ratio=0.508
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0038_Timepoint_1 before=15 after=0 ratio=0.000
[QC] label shrinkage: UPENN_GBM_UPENN-GBM-00274_11 before=94 after=25 ratio=0.266
Epoch 2/50 - train loss: 0.5491
  Val loss: 0.5137 | Val Dice(tumor): 0.4562
  [Saved] best.pt with Dice 0.4562
[QC] label shrinkage: BCBM_RADIOGENOMICS_BCBM-RadioGenomics-103-0 before=247 after=140 ratio=0.567
[QC] label shrinkage: UPENN_GBM_UPENN-GBM-00147_11 before=5 after=0 ra

Num foregrounds 0, Num backgrounds 2141705, unable to generate class balanced samples, setting `pos_ratio` to 0.


Epoch 24/50 - train loss: 0.2718
  Val loss: 0.2984 | Val Dice(tumor): 0.6074
[QC] label shrinkage: PRETREAT_METS_BraTS-MET-00104-000 before=302 after=181 ratio=0.599
[QC] label shrinkage: BCBM_RADIOGENOMICS_BCBM-RadioGenomics-98-0 before=216 after=129 ratio=0.597
[QC] label shrinkage: UPENN_GBM_UPENN-GBM-00274_11 before=467 after=284 ratio=0.608
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0037_Timepoint_4 before=740 after=443 ratio=0.599
[QC] label shrinkage: MU_GLIOMA_POST_PatientID_0038_Timepoint_1 before=15 after=0 ratio=0.000
[QC] label shrinkage: UCSD_PTGBM_UCSD-PTGBM-0076_01 before=515 after=266 ratio=0.517
[QC] label shrinkage: UCSD_PTGBM_UCSD-PTGBM-0036_01 before=23 after=0 ratio=0.000
[QC] label shrinkage: PRETREAT_METS_BraTS-MET-00175-000 before=282 after=180 ratio=0.638
[QC] label shrinkage: UCSD_PTGBM_UCSD-PTGBM-0137_01 before=906 after=529 ratio=0.584
Epoch 25/50 - train loss: 0.2697
  Val loss: 0.3820 | Val Dice(tumor): 0.6069
[QC] label shrinkage: UCSD_PTGBM_UCSD-PTG

In [None]:
# Test evaluation (seg Dice + cls AUC/ACC)
from sklearn.metrics import roc_auc_score, accuracy_score

# Build test loader
test_ds = CacheDataset(data=test_items, transform=val_transforms, cache_rate=0.0, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

# Load best
ckpt = torch.load(ckpt_dir / 'best.pt', map_location=device)
seg_net.load_state_dict(ckpt['seg'])
cls_head.load_state_dict(ckpt['cls'])
seg_net.eval(); cls_head.eval()

dice_metric.reset()
y_true, y_prob = [], []

with torch.no_grad():
    for batch in test_loader:
        batch = to_device(batch, device)
        images = batch['image']
        labels = batch['label'].long()
        class_labels = batch['class_label'].view(-1, 1)

        ctx = autocast(device_type='cuda', enabled=torch.cuda.is_available())
        with ctx:
            seg_logits = inferer(inputs=images, network=seg_net)
            _ = seg_net(images)
            feat = encoder_feat['x'] if encoder_feat['x'] is not None else seg_logits
            cls_logits = cls_head(feat)

            y_pred = torch.softmax(seg_logits, dim=1)
            y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
            dice_metric(y_pred=y_pred, y=labels)

            y_true.append(int(class_labels.item()))
            y_prob.append(torch.sigmoid(cls_logits).item())

mean_dice = dice_metric.aggregate().item()
y_pred_cls = [1 if p >= 0.5 else 0 for p in y_prob]
acc = accuracy_score(y_true, y_pred_cls)
try:
    auc = roc_auc_score(y_true, y_prob)
except Exception:
    auc = float('nan')

print({'test_dice': mean_dice, 'test_acc': acc, 'test_auc': auc})
