# ACDC MRI Disease Classification Pipeline



In [38]:
import os
import json
import random
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import nibabel as nib  # noqa: F401 (used in visualization helpers)
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from dataset import ACDCDataset

plt.style.use("seaborn-v0_8")

def get_processed_dir() -> Path:
    external = Path("/Volumes/Crucial X6/medical_ai_extra/processed")
    if external.exists():
        return external
    project_root = Path.cwd()
    if (project_root / "database" / "processed").exists():
        return project_root / "database" / "processed"
    raise FileNotFoundError("Could not locate processed directory.")


def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
seed_everything(42)

PROCESSED_DIR = get_processed_dir()
META_CSV = PROCESSED_DIR / "meta.csv"
SPLITS_JSON = PROCESSED_DIR / "splits.json"
BEST_MODEL_PATH = Path("best_model.pt")
EVAL_RESULTS_PATH = Path("evaluation_results.json")
CONF_MATRIX_PATH = Path("confusion_matrix.png")
LABEL_NAMES = {0: "NOR", 1: "MINF", 2: "DCM", 3: "HCM", 4: "RV"}



Using device: cpu


In [39]:
metadata_df = pd.read_csv(META_CSV)
with SPLITS_JSON.open("r", encoding="utf-8") as f:
    splits = json.load(f)

print("Processed directory:", PROCESSED_DIR)
display(metadata_df.head())

split_counts = {split: len(ids) for split, ids in splits.items()}
print("Samples per split:")
for split, count in split_counts.items():
    print(f"  {split}: {count}")



Processed directory: /Volumes/Crucial X6/medical_ai_extra/processed


Unnamed: 0,patient_id,filepath,height,weight,ED,ES,label
0,patient001,/Volumes/Crucial X6/medical_ai_extra/processed...,184.0,95.0,1,12,2
1,patient002,/Volumes/Crucial X6/medical_ai_extra/processed...,160.0,70.0,1,12,2
2,patient003,/Volumes/Crucial X6/medical_ai_extra/processed...,165.0,77.0,1,15,2
3,patient004,/Volumes/Crucial X6/medical_ai_extra/processed...,159.0,46.0,1,15,2
4,patient005,/Volumes/Crucial X6/medical_ai_extra/processed...,165.0,77.0,1,13,2


Samples per split:
  train: 80
  val: 20
  test: 50


In [40]:
class Conv3DBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        padding = kernel_size // 2
        self.block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class MRIEncoder3D(nn.Module):
    def __init__(self, in_channels: int = 30, base_channels: int = 32):
        super().__init__()
        self.layers = nn.Sequential(
            Conv3DBlock(in_channels, base_channels),
            nn.MaxPool3d((2, 2, 1)),
            Conv3DBlock(base_channels, base_channels * 2),
            nn.MaxPool3d((2, 2, 1)),
            Conv3DBlock(base_channels * 2, base_channels * 4),
            nn.MaxPool3d((2, 2, 1)),
            Conv3DBlock(base_channels * 4, base_channels * 8),
        )
        self.global_pool = nn.AdaptiveAvgPool3d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layers(x)
        x = self.global_pool(x)
        return torch.flatten(x, 1)


class MetaEncoder(nn.Module):
    def __init__(self, input_dim: int = 4, hidden_dim: int = 32, output_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class FusionClassifier(nn.Module):
    def __init__(self, mri_dim: int, meta_dim: int, num_classes: int = 5, dropout: float = 0.3):
        super().__init__()
        fusion_dim = mri_dim + meta_dim
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes),
        )

    def forward(self, mri_feat: torch.Tensor, meta_feat: torch.Tensor) -> torch.Tensor:
        fused = torch.cat([mri_feat, meta_feat], dim=1)
        return self.classifier(fused)


class ACDCNet(nn.Module):
    def __init__(self, num_classes: int = 5):
        super().__init__()
        self.mri_encoder = MRIEncoder3D()
        self.meta_encoder = MetaEncoder()
        self.classifier = FusionClassifier(mri_dim=256, meta_dim=64, num_classes=num_classes)

    def forward(self, image: torch.Tensor, metadata: torch.Tensor) -> torch.Tensor:
        mri_feat = self.mri_encoder(image)
        meta_feat = self.meta_encoder(metadata)
        return self.classifier(mri_feat, meta_feat)


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


model = ACDCNet(num_classes=len(LABEL_NAMES)).to(device)
print(model)
print(f"Total trainable parameters: {count_parameters(model):,}")



ACDCNet(
  (mri_encoder): MRIEncoder3D(
    (layers): Sequential(
      (0): Conv3DBlock(
        (block): Sequential(
          (0): Conv3d(30, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (1): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
      (2): Conv3DBlock(
        (block): Sequential(
          (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (3): MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
      (4): Conv3DBlock(
        (block): Sequential(
          (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      

In [41]:
TARGET_DEPTH = 216

def adjust_depth(volume: torch.Tensor, target_depth: int = TARGET_DEPTH) -> torch.Tensor:
    depth = volume.shape[0]
    if depth == target_depth:
        return volume
    if depth > target_depth:
        start = (depth - target_depth) // 2
        end = start + target_depth
        return volume[start:end, :, :, :]
    pad_before = (target_depth - depth) // 2
    pad_after = target_depth - depth - pad_before
    padding = (0, 0, 0, 0, 0, 0, pad_before, pad_after)
    return F.pad(volume, padding)


def image_transform(volume: torch.Tensor) -> torch.Tensor:
    if volume.ndim != 4:
        raise ValueError("Expected 4D volume (D, H, W, T)")
    volume = adjust_depth(volume)
    # Permute to (T, D, H, W) to align with Conv3d input ordering (C, D, H, W)
    return volume.permute(3, 0, 1, 2)


def create_dataset(split: str) -> ACDCDataset:
    return ACDCDataset(
        meta_csv=META_CSV,
        splits_json=SPLITS_JSON,
        subset=split,
        image_transform=image_transform,
    )


TARGET_CHANNELS = 30


def pad_trim_channels(image: torch.Tensor, target_channels: int = TARGET_CHANNELS) -> torch.Tensor:
    c, d, h, w = image.shape
    if c == target_channels:
        return image
    if c > target_channels:
        start = (c - target_channels) // 2
        end = start + target_channels
        return image[start:end, :, :, :]
    pad_before = (target_channels - c) // 2
    pad_after = target_channels - c - pad_before
    padding = (0, 0, 0, 0, 0, 0, pad_before, pad_after)
    return F.pad(image, padding)


def pad_image(image: torch.Tensor, target_shape: tuple[int, int, int, int]) -> torch.Tensor:
    image = pad_trim_channels(image)
    c, d, h, w = image.shape
    tc, td, th, tw = target_shape
    output = torch.zeros(target_shape, dtype=image.dtype)
    output[:c, :d, :h, :w] = image[:tc, :td, :th, :tw]
    return output


def pad_collate(batch):
    images = [item[0][0] for item in batch]
    metadata = [item[0][1] for item in batch]
    labels = [item[1] for item in batch]

    max_dims = [
        max(img.shape[dim] for img in images) for dim in range(4)
    ]
    max_dims[0] = TARGET_CHANNELS
    padded_images = [pad_image(img, tuple(max_dims)) for img in images]

    images_tensor = torch.stack(padded_images)
    metadata_tensor = torch.stack(metadata).float()
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    return (images_tensor, metadata_tensor), labels_tensor


train_dataset = create_dataset("train")
val_dataset = create_dataset("val")
test_dataset = create_dataset("test")

BATCH_SIZE = 2
# Use single-process loading inside notebooks to avoid fork/pickle issues.
NUM_WORKERS = 0

def make_loader(dataset: ACDCDataset, shuffle: bool) -> DataLoader:
    return DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=pad_collate,
    )

train_loader = make_loader(train_dataset, shuffle=True)
val_loader = make_loader(val_dataset, shuffle=False)
test_loader = make_loader(test_dataset, shuffle=False)

print(
    f"Train/Val/Test sizes: {len(train_dataset)} / {len(val_dataset)} / {len(test_dataset)}"
)



Train/Val/Test sizes: 80 / 20 / 50


# üîç DIAGNOSTIC CHECKLIST - Run this before training to debug 0% validation accuracy



In [42]:
from collections import Counter

print("=" * 60)
print("üîç DIAGNOSTIC CHECKLIST")
print("=" * 60)

# 1. CHECK LABEL DISTRIBUTION
print("\n1Ô∏è‚É£  LABEL DISTRIBUTION")
train_labels = [int(train_dataset[i][1]) for i in range(len(train_dataset))]
val_labels = [int(val_dataset[i][1]) for i in range(len(val_dataset))]
test_labels = [int(test_dataset[i][1]) for i in range(len(test_dataset))]

print("Train labels:", Counter(train_labels))
print("Val labels  :", Counter(val_labels))
print("Test labels :", Counter(test_labels))

if len(set(val_labels)) < len(LABEL_NAMES):
    print("‚ö†Ô∏è  WARNING: Val set missing some classes!")
else:
    print("‚úÖ Val contains all classes")

# 2. CHECK MODEL OUTPUT DIMENSION
print("\n2Ô∏è‚É£  MODEL OUTPUT DIMENSION")
for name, module in model.named_modules():
    if isinstance(module, nn.Linear) and name.startswith("classifier"):
        print(f"Final layer: {module}")
        if module.out_features == len(LABEL_NAMES):
            print(f"‚úÖ Output dimension correct: {module.out_features}")
        else:
            print(f"‚ùå Wrong output dimension: {module.out_features}, expected {len(LABEL_NAMES)}")

# 3. CHECK DATASET OUTPUT SHAPES
print("\n3Ô∏è‚É£  DATASET OUTPUT SHAPES")
(img, meta), label = train_dataset[0]
print(f"Image shape: {img.shape}")
print(f"Meta shape : {meta.shape}")
print(f"Label      : {label.item()} ({LABEL_NAMES[label.item()]})")

if img.ndim not in [4, 5]:
    print(f"‚ö†Ô∏è  WARNING: Image should be 4D or 5D, got {img.ndim}D")
if meta.shape[0] == 0:
    print("‚ùå ERROR: Metadata is empty!")

# 4. CHECK BATCH UNPACKING
print("\n4Ô∏è‚É£  BATCH UNPACKING")
for batch in val_loader:
    (images, metas), labels = batch
    print(f"‚úÖ Batch shapes - Images: {images.shape}, Metas: {metas.shape}, Labels: {labels.shape}")
    break

# 5. CHECK NORMALIZATION
print("\n5Ô∏è‚É£  NORMALIZATION CHECK")
img_train = train_dataset[0][0][0]
img_val = val_dataset[0][0][0] if len(val_dataset) > 0 else None

print(f"Train mean: {img_train.mean().item():.4f}, std: {img_train.std().item():.4f}")
if img_val is not None:
    print(f"Val   mean: {img_val.mean().item():.4f}, std: {img_val.std().item():.4f}")
    if abs(img_train.mean() - img_val.mean()) > 0.5:
        print("‚ö†Ô∏è  WARNING: Train/val normalization differs significantly")

# 6. CHECK METADATA IN BATCHES
print("\n6Ô∏è‚É£  METADATA IN BATCHES")
for batch in train_loader:
    (i, m), l = batch
    print(f"‚úÖ Train batch meta shape: {m.shape}, mean: {m.mean().item():.4f}")
    break

for batch in val_loader:
    (i, m), l = batch
    print(f"‚úÖ Val batch meta shape  : {m.shape}, mean: {m.mean().item():.4f}")
    break

# 7. OVERFIT TEST
print("\n7Ô∏è‚É£  OVERFIT TEST (40 steps on single batch)")
model_copy = ACDCNet(num_classes=len(LABEL_NAMES)).to(device)
optimizer_test = torch.optim.Adam(model_copy.parameters(), lr=1e-3)
criterion_test = nn.CrossEntropyLoss()

test_batch = next(iter(train_loader))
(images_test, metas_test), labels_test = test_batch
images_test = images_test.float().to(device)
metas_test = metas_test.float().to(device)
labels_test = labels_test.to(device)

initial_loss = None
for step in range(40):
    optimizer_test.zero_grad()
    out = model_copy(images_test, metas_test)
    loss = criterion_test(out, labels_test)
    loss.backward()
    optimizer_test.step()
    if step == 0:
        initial_loss = loss.item()
    if step % 10 == 0 or step == 39:
        print(f"  Step {step:2d}: loss = {loss.item():.4f}")

if initial_loss is not None and loss.item() < initial_loss * 0.5:
    print("‚úÖ Loss decreased - architecture is correct")
else:
    print("‚ùå Loss did not decrease significantly - check architecture!")

# 8. CHECK PATIENT IDS MATCH SPLITS
print("\n8Ô∏è‚É£  PATIENT IDS MATCH SPLITS")
meta_df = pd.read_csv(META_CSV)
with SPLITS_JSON.open("r") as f:
    splits_check = json.load(f)

train_ids_set = set(splits_check['train'])
val_ids_set = set(splits_check['val'])
meta_ids_set = set(meta_df['patient_id'])

missing_in_meta = val_ids_set - meta_ids_set
if missing_in_meta:
    print(f"‚ùå Val patients not in meta.csv: {missing_in_meta}")
else:
    print("‚úÖ All val patients found in meta.csv")

duplicates = meta_df['patient_id'].value_counts()
if (duplicates > 1).any():
    print(f"‚ö†Ô∏è  Duplicate patient IDs: {duplicates[duplicates > 1]}")
else:
    print("‚úÖ No duplicate patient IDs")

print("\n" + "=" * 60)
print("Diagnostic complete!")
print("=" * 60)



üîç DIAGNOSTIC CHECKLIST

1Ô∏è‚É£  LABEL DISTRIBUTION
Train labels: Counter({2: 16, 3: 16, 1: 16, 0: 16, 4: 16})
Val labels  : Counter({2: 4, 3: 4, 1: 4, 0: 4, 4: 4})
Test labels : Counter({2: 10, 0: 10, 1: 10, 3: 10, 4: 10})
‚úÖ Val contains all classes

2Ô∏è‚É£  MODEL OUTPUT DIMENSION
Final layer: Linear(in_features=320, out_features=256, bias=True)
‚ùå Wrong output dimension: 256, expected 5
Final layer: Linear(in_features=256, out_features=5, bias=True)
‚úÖ Output dimension correct: 5

3Ô∏è‚É£  DATASET OUTPUT SHAPES
Image shape: torch.Size([30, 216, 256, 10])
Meta shape : torch.Size([4])
Label      : 2 (DCM)

4Ô∏è‚É£  BATCH UNPACKING
‚úÖ Batch shapes - Images: torch.Size([2, 30, 216, 256, 10]), Metas: torch.Size([2, 4]), Labels: torch.Size([2])

5Ô∏è‚É£  NORMALIZATION CHECK
Train mean: 0.0389, std: 0.9915
Val   mean: -0.0000, std: 1.0000

6Ô∏è‚É£  METADATA IN BATCHES
‚úÖ Train batch meta shape: torch.Size([2, 4]), mean: 68.5000
‚úÖ Val batch meta shape  : torch.Size([2, 4]), mean:

KeyboardInterrupt: 

In [None]:
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss

loss_fn = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)
scaler = GradScaler(enabled=device.type == "cuda")

NUM_EPOCHS = 15
history: List[Dict] = []
best_val_acc = 0.0


def run_epoch(loader: DataLoader, train: bool, description: str) -> Dict[str, float]:
    epoch_loss = 0.0
    correct = 0
    total = 0

    if train:
        model.train()
    else:
        model.eval()

    iterator = tqdm(loader, desc=description, leave=False)
    for (images, metadata), labels in iterator:
        images = images.float().to(device, non_blocking=True)
        metadata = metadata.float().to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=device.type == "cuda"):
            logits = model(images, metadata)
            loss = loss_fn(logits, labels)

        if train:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        epoch_loss += loss.item() * labels.size(0)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return {
        "loss": epoch_loss / max(total, 1),
        "acc": correct / max(total, 1),
    }


for epoch in range(1, NUM_EPOCHS + 1):
    train_metrics = run_epoch(train_loader, train=True, description=f"Train {epoch}")
    with torch.no_grad():
        val_metrics = run_epoch(val_loader, train=False, description=f"Val {epoch}")

    scheduler.step(val_metrics["loss"])

    history.append(
        {
            "epoch": epoch,
            "train_loss": train_metrics["loss"],
            "train_acc": train_metrics["acc"],
            "val_loss": val_metrics["loss"],
            "val_acc": val_metrics["acc"],
            "lr": optimizer.param_groups[0]["lr"],
        }
    )

    if val_metrics["acc"] > best_val_acc:
        best_val_acc = val_metrics["acc"]
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_tag = " (new best)"
    else:
        best_tag = ""

    current_lr = optimizer.param_groups[0]["lr"]
    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS}\n"
        f"  Train -> loss: {train_metrics['loss']:.4f}, acc: {train_metrics['acc']:.3f}\n"
        f"  Val   -> loss: {val_metrics['loss']:.4f}, acc: {val_metrics['acc']:.3f}{best_tag}\n"
        f"  LR: {current_lr:.6f}"
    )

metrics_df = pd.DataFrame(history)
display(metrics_df.tail())

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
metrics_df.plot(x="epoch", y=["train_loss", "val_loss"], ax=axes[0])
axes[0].set_title("Loss")
metrics_df.plot(x="epoch", y=["train_acc", "val_acc"], ax=axes[1])
axes[1].set_ylim(0, 1)
axes[1].set_title("Accuracy")
plt.tight_layout()
plt.show()



                                                        

Epoch 01/15
  Train -> loss: 1.6146, acc: 0.150
  Val   -> loss: 1.5558, acc: 0.300 (new best)
  LR: 0.000100


                                                        

Epoch 02/15
  Train -> loss: 1.5874, acc: 0.325
  Val   -> loss: 1.5490, acc: 0.250
  LR: 0.000100


                                                        

Epoch 03/15
  Train -> loss: 1.5539, acc: 0.325
  Val   -> loss: 1.5497, acc: 0.250
  LR: 0.000100


                                                        

Epoch 04/15
  Train -> loss: 1.5220, acc: 0.362
  Val   -> loss: 1.5551, acc: 0.200
  LR: 0.000100


                                                        

Epoch 05/15
  Train -> loss: 1.5507, acc: 0.312
  Val   -> loss: 1.4949, acc: 0.200
  LR: 0.000100


                                                        

Epoch 06/15
  Train -> loss: 1.4947, acc: 0.275
  Val   -> loss: 1.5471, acc: 0.400 (new best)
  LR: 0.000100


                                                        

Epoch 07/15
  Train -> loss: 1.5229, acc: 0.350
  Val   -> loss: 1.5471, acc: 0.250
  LR: 0.000100


                                                        

Epoch 08/15
  Train -> loss: 1.4577, acc: 0.425
  Val   -> loss: 1.5071, acc: 0.250
  LR: 0.000050


                                                        

Epoch 09/15
  Train -> loss: 1.4524, acc: 0.375
  Val   -> loss: 1.7247, acc: 0.200
  LR: 0.000050


                                                         

Epoch 10/15
  Train -> loss: 1.4355, acc: 0.388
  Val   -> loss: 1.5634, acc: 0.300
  LR: 0.000050


Val 11:  30%|‚ñà‚ñà‚ñà       | 3/10 [00:32<01:13, 10.51s/it]   

In [None]:
def load_best_model(model: nn.Module, path: Path) -> None:
    state_dict = torch.load(path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()


def evaluate(model: nn.Module, loader: DataLoader):
    all_preds: List[int] = []
    all_labels: List[int] = []
    metas: List[np.ndarray] = []

    with torch.no_grad():
        for (images, metadata), labels in loader:
            images = images.float().to(device)
            metadata = metadata.float().to(device)
            logits = model(images, metadata)
            preds = torch.argmax(logits, dim=1)

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.tolist())
            metas.extend(metadata.cpu().numpy())

    return np.array(all_preds), np.array(all_labels), metas


load_best_model(model, BEST_MODEL_PATH)
preds, labels_np, metas = evaluate(model, test_loader)

acc = accuracy_score(labels_np, preds)
macro_f1 = f1_score(labels_np, preds, average="macro")
micro_f1 = f1_score(labels_np, preds, average="micro")

cm = confusion_matrix(labels_np, preds, labels=list(LABEL_NAMES.keys()))
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(cm, cmap="Blues")
ax.set_xticks(range(len(LABEL_NAMES)))
ax.set_yticks(range(len(LABEL_NAMES)))
ax.set_xticklabels(LABEL_NAMES.values())
ax.set_yticklabels(LABEL_NAMES.values())
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ax.text(j, i, cm[i, j], ha="center", va="center", color="black")
fig.colorbar(im)
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(CONF_MATRIX_PATH)
plt.show()

per_class_acc = {}
for label_idx, label_name in LABEL_NAMES.items():
    mask = labels_np == label_idx
    if mask.any():
        per_class_acc[label_name] = float((preds[mask] == labels_np[mask]).mean())
    else:
        per_class_acc[label_name] = float("nan")

evaluation_results = {
    "accuracy": float(acc),
    "macro_f1": float(macro_f1),
    "micro_f1": float(micro_f1),
    "per_class_accuracy": per_class_acc,
    "best_model_path": str(BEST_MODEL_PATH.resolve()),
}

with EVAL_RESULTS_PATH.open("w", encoding="utf-8") as f:
    json.dump(evaluation_results, f, indent=2)

print(json.dumps(evaluation_results, indent=2))

print("\nRandom sample predictions:")
sample_count = min(5, len(preds))
if sample_count == 0:
    print("No test samples available.")
else:
    indices = random.sample(range(len(preds)), k=sample_count)
    for idx in indices:
        meta_vals = metas[idx]
        print(
            f"Sample {idx}: metadata={meta_vals.tolist()} | "
            f"pred={LABEL_NAMES[preds[idx]]} vs actual={LABEL_NAMES[labels_np[idx]]}"
        )



In [None]:
def show_volume_slices(volume: np.ndarray, num_slices: int = 6) -> None:
    if volume.ndim != 4:
        raise ValueError("Volume must be 4D (D, H, W, T)")
    depth = volume.shape[0]
    indices = np.linspace(0, depth - 1, num_slices, dtype=int)
    fig, axes = plt.subplots(1, num_slices, figsize=(15, 3))
    for ax, idx in zip(axes, indices):
        ax.imshow(volume[idx, :, :, 0], cmap="gray")
        ax.set_title(f"Slice {idx}")
        ax.axis("off")
    plt.suptitle("Volume slices (first time frame)")
    plt.tight_layout()
    plt.show()


def show_single_slice(volume: np.ndarray, slice_index: int = 0, time_index: int = 0) -> None:
    if volume.ndim != 4:
        raise ValueError("Volume must be 4D (D, H, W, T)")
    slice_index = np.clip(slice_index, 0, volume.shape[0] - 1)
    time_index = np.clip(time_index, 0, volume.shape[3] - 1)
    plt.figure(figsize=(4, 4))
    plt.imshow(volume[slice_index, :, :, time_index], cmap="gray")
    plt.title(f"Slice {slice_index} | Time {time_index}")
    plt.axis("off")
    plt.show()


def show_metadata(meta_vector: np.ndarray) -> None:
    labels = ["Height", "Weight", "ED", "ES"]
    plt.figure(figsize=(4, 3))
    plt.bar(labels, meta_vector)
    plt.title("Metadata values")
    plt.show()



In [None]:
with EVAL_RESULTS_PATH.open("r", encoding="utf-8") as f:
    final_metrics = json.load(f)

print("Pipeline completed successfully.")
print(f"Best model saved at: {BEST_MODEL_PATH.resolve()}")
print("Final test metrics:")
print(json.dumps(final_metrics, indent=2))

