In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
## Metadata from manifest + compact_reads.
import pandas as pd

metadata_df = pq.merge(
    compact_reads[['name', 'ICH-majority']],
    on='name',
    how='left',
    validate='many_to_one',
    indicator=True
)
metadata_df.to_parquet("cq500ct_metadata.parquet", index=False)

In [None]:
## Split the B1 category only into train and validation sets
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedGroupKFold # Cross-Validation phase

meta = pd.read_parquet("/kaggle/input/metadata/cq500ct_metadata.parquet") # metadata
meta = meta.drop("_merge", axis=1)
cat = pd.read_csv("/kaggle/input/b1-cat-reads/b1_cat_reads.csv") # categories

df = meta.merge(
    cat[["name", "Category"]],
    on="name",
    how="left",
    indicator=True
)

In [None]:
## Create B1-only metadata
b1_df = df[df["Category"] == "B1"].copy()
b1_df.to_parquet("b1_metadata.parquet", index=False)

In [None]:
## Split the data for training.
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold

meta_df = pd.read_parquet("/kaggle/input/b1-cat-meta/b1_metadata.parquet") # load files
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(
    sgkf.split(X=meta_df[["name"]],
    y=meta_df["ICH-majority"], groups=meta_df["name"])
):
    if fold == 0:
        train_patients = meta_df.iloc[train_idx]["name"].tolist()
        val_patients = meta_df.iloc[val_idx]["name"].tolist()
        break
print(f"train = {len(train_patients)}, val = {len(val_patients)}")

train_meta = meta_df[meta_df["name"].isin(train_patients)]
val_meta = meta_df[meta_df["name"].isin(val_patients)]
train_meta[["name"]].to_parquet("train_patients.parquet", index=False)
val_meta[["name"]].to_parquet("val_patients.parquet", index=False)

In [None]:
## Check leakage between train and validation indecies
assert set(meta_df.iloc[train_idx].name).isdisjoint(meta_df.iloc[val_idx].name), "Leakage"

In [22]:
# DataLoader Class for 2.5D CNN
# Imports
from pathlib import Path
from typing import List, Sequence, Tuple, Dict, Optional

import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import cv2
import pydicom
import pydicom.dataset
from pydicom.pixel_data_handlers.util import _apply_modality_lut

# Helper
## DICOM to Tensor
def dcm_to_tensor(
    path: str,
    windows: list[tuple[int, int]] = None,
    out_size: tuple[int, int] = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    Helper: convert one DICOM slice into a torch.Tensor [C x H x W]
    ds			= single CQ500 DICOM slice.
    windows		= CT window to apply. Each tuple becomes one output channel.
                    Default = Brain, Subdural, Bone.
    out_size	= (H, W) to resize the slice.
    dtype		= final tensor percision (float32 recommended).

    Returns		= torch.tensor (shape = C x H x W with normalized values)
    """
    ## Handle None args
    if windows is None:
        windows = [(40, 80), (80, 200), (600, 2800)]
    if out_size is None:
        out_size = (256, 256)

    ds: pydicom.dataset.FileDataset = pydicom.dcmread(str(path))

    ## Raw values > Hounsfield units (HU)
    hu: np.ndarray = _apply_modality_lut(ds.pixel_array, ds).astype(np.int16)
    if out_size is not None and hu.shape != out_size:
        hu = cv2.resize(hu, out_size[::-1], interpolation=cv2.INTER_LINEAR)

    ## Window / Level > 0-1 float per channel
    ## 3 channels to feed into model
    chans: list[np.ndarray] = []
    for level, width in windows:
        level: int
        width: int
        lower: int = level - (width // 2)
        upper: int = level + (width // 2)
        img_clipped: np.ndarray = np.clip(hu, lower, upper)
        img_norm: np.ndarray = (img_clipped - lower) / float(width)  # 0 - 1
        chans.append(img_norm.astype(np.float32))

    ## Stack and convert to tensor
    arr: np.ndarray = np.stack(chans, axis=0)  # C x H x W
    tensor: torch.Tensor = torch.from_numpy(arr).type(dtype)

    return tensor


# Class
class CQ500DataLoader25D(Dataset):
    """
    Iterable dataset that yields (x, y) pairs.
    x = (3, H, W) windows for slice i, with slice context \
    [i-1, i, i+1].
    y = series-level ICH label.
    Optionally caches complete series volumes in RAM.
    """
    def __init__(
        self,
        metadata_path: str | Path,
        indices: Optional[Sequence[int]] = None,
        transform = None,
        cache: bool = False,
        replicate_edge: bool = True,
    ) -> None:
        super().__init__()
        self.df = pd.read_parquet(metadata_path)

        ## Restric to desired subset of patients (int idx list)
        grouped = self.df.groupby("name", sort = False)
        self.patients: List[Tuple[str, pd.DataFrame]] = list(grouped)
        if indices is not None:
            if isinstance(indices[0], int):
                self.patients = [self.patients[i] for i in indices]
            elif isinstance(indices[0], str):
                name_to_idx = {name: (name, pdf) for name, pdf in self.patients}
                self.patients = [name_to_idx[name] for name in indices if name in name_to_idx]
            else:
                raise ValueError(
                    "The argument for indices must be a list of int or str (patient names)"
                )

        ## Flatten into sample idx list (patient_idx, slice_idx)
        self.sample_index: List[Tuple[int, int]] = []
        for p_idx, (_, pdf) in enumerate(self.patients):
            pdf_sorted = pdf.sort_values("instance_num")
            n_slices = len(pdf_sorted)
            for s_idx in range(n_slices):
                self.sample_index.append((p_idx, s_idx))

        self.transform = transform
        self.cache_enabled = cache
        self.replicate_edge = replicate_edge
        self._series_cache: Dict[str, torch.Tensor] = {}

        ## Cache if caching is set to True
        if self.cache_enabled:
            self._populate_cache()

    def __len__(self) -> int:
        return len(self.sample_index)

    def _populate_cache(self) -> None:
        """
        Preload all series volumes into RAM (lazy-converted to torch.Tensor).
        """
        print("[CQ500DataLoader25D] Caching series data …")
        for patient_name, pdf in self.patients:
            pdf_sorted = pdf.sort_values("instance_num")
            slices = [dcm_to_tensor(path = p) for p in pdf_sorted["path"].tolist()]
            vol_np = np.stack(slices, axis=0)  # (S, 3, H, W)
            self._series_cache[patient_name] = torch.from_numpy(vol_np)
        print(f"→ Cached {len(self._series_cache)} series.")

    ## Public API
    def enable_cache(self) -> None:
        """Enable caching automatically."""
        if not self.cache_enabled:
            self.cache_enabled = True
            self._populate_cache()
    def disable_cache(self) -> None:
        """Disable caching automatically."""
        if self.cache_enabled:
            self.cache_enabled = False
            self._series_cache.clear()

    def __getitem__(
        self, idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        patient_idx, slice_idx = self.sample_index[idx]
        patient_name, pdf = self.patients[patient_idx]
        label = torch.tensor(pdf["ICH-majority"].iloc[0], dtype=torch.float32)

        # Load the volume form cache or disk on-the-fly
        if self.cache_enabled and patient_name in self._series_cache:
            volume = self._series_cache[patient_name]   # (S, 3, H, W) torch.Tensor
        else:
            pdf_sorted = pdf.sort_values("instance_num")
            paths = pdf_sorted["path"].tolist()
            slices = [dcm_to_tensor(path = p) for p in paths]
            volume = torch.from_numpy(np.stack(slices, axis = 0))   # (S, 3, H, W)

        n_slices = volume.shape[0]

        # Determine neighboring indices with optional edge replication
        def safe_index(i: int) -> int:
            if self.replicate_edge:
                return max(0, min(i, n_slices - 1))
            return i
        prev_idx = safe_index(slice_idx - 1)
        next_idx = safe_index(slice_idx + 1)

        x_stack = torch.stack(
            [volume[prev_idx], volume[slice_idx], volume[next_idx]], dim = 0
        )   # (3, 3, H, W)
        ## Merge HU channel and slice context dims (slice_ctx, hu_ch, H, W).\
        ## The convention is to keep the slice context and drop the HU channel.\
        ## Because we windowed into HU spaces, each slice context acts as one channel.\
        x = x_stack.mean(dim = 1)   # (3, H, W) - dim=1 for HU channel
        # x = x_stack.permute(1, 0, 2, 3).reshape(-1, x_stack.shape[2], x_stack.shape[3]) # (9, H, W)

        if self.transform:
            x = self.transform(x)

        return x, label

    # Preprocess to .pt files
    def preprocess(
        self, output_dir: str,
        chunk_size: int = 100,
        max_ram_gb: int = 12
    ) -> None:
        """
        Preprocess DICOMs: convert to tensors and save to disk \
        in small chunks to avoid exceeding RAM.
        Args:
            output_dir: Directory to save tensors (will be created if not exists).
            chunk_size: Number of slices to process and save at once.
            max_ram_gb: Maximum RAM (in GB) to use for holding tensors in memory at once.
        """
        os.makedirs(output_dir, exist_ok=True)
        tensor_buffer = []
        meta_buffer = []
        buffer_bytes = 0
        max_bytes = max_ram_gb * 1024 ** 3
        chunk_idx = 0
        print(f"[Preprocess] Saving tensors to {output_dir} \
            in chunks of {chunk_size} slices, max RAM {max_ram_gb} GB...")
        for patient_name, pdf in self.patients:
            pdf_sorted = pdf.sort_values("instance_num")
            for _, row in pdf_sorted.iterrows():
                dcm_path = row["path"]
                instance_num = row["instance_num"]
                tensor = dcm_to_tensor(dcm_path)
                tensor_buffer.append(tensor)
                meta_buffer.append((patient_name, instance_num, dcm_path))
                buffer_bytes += tensor.element_size() * tensor.nelement()
                # Save chunk if buffer is large enough
                if len(tensor_buffer) >= chunk_size or buffer_bytes >= max_bytes:
                    chunk_file = os.path.join(output_dir, f"chunk_{chunk_idx:05d}.pt")
                    meta_file = os.path.join(output_dir, f"chunk_{chunk_idx:05d}_meta.csv")
                    torch.save(torch.stack(tensor_buffer), chunk_file)
                    pd.DataFrame(
                        meta_buffer, columns=["patient_name", "instance_num", "dcm_path"]
                    ).to_csv(meta_file, index=False)
                    print(f"  Saved {len(tensor_buffer)} slices to {chunk_file}")
                    tensor_buffer.clear()
                    meta_buffer.clear()
                    buffer_bytes = 0
                    chunk_idx += 1
        # Save any remaining tensors
        if tensor_buffer:
            chunk_file = os.path.join(output_dir, f"chunk_{chunk_idx:05d}.pt")
            meta_file = os.path.join(output_dir, f"chunk_{chunk_idx:05d}_meta.csv")
            torch.save(torch.stack(tensor_buffer), chunk_file)
            pd.DataFrame(
                meta_buffer, columns=["patient_name", "instance_num", "dcm_path"]
            ).to_csv(meta_file, index=False)
            print(f"  Saved {len(tensor_buffer)} slices to {chunk_file}")
        print("[Preprocess] Done.")

In [23]:
# Model Class for 2.5 CNN (WIP)
"""
Model pipeline
"""
# 2p5d_cnn_train.py
import torch.nn as nn
import torch, torchvision as tv
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from torchmetrics.classification import (
    BinaryAUROC, BinaryAveragePrecision,
    BinaryRecall, BinarySpecificity
)
from tqdm.notebook import tqdm


# ---------- 1. Model ----------------------------------------------------------
def _replace_first_conv(m: nn.Module, in_ch: int) -> None:
    """Replace the first conv to accept `in_ch` channels; keeps pretrained weights."""
    old = m.conv1
    new = nn.Conv2d(in_ch, old.out_channels,
                    kernel_size=old.kernel_size,
                    stride=old.stride,
                    padding=old.padding,
                    bias=old.bias is not None)
    # repeat / average weights to new conv (simple heuristic)
    with torch.no_grad():
        repeat = in_ch // old.in_channels
        new.weight.copy_(old.weight.repeat(1, repeat, 1, 1) / repeat)
    m.conv1 = new


_BACKBONES = {
    "resnet18": lambda ic: tv.models.resnet18(weights="IMAGENET1K_V1"),
    "resnet34": lambda ic: tv.models.resnet34(weights="IMAGENET1K_V1"),
    "densenet121": lambda ic: tv.models.densenet121(weights="IMAGENET1K_V1"),
    # add more here...
}


class TwoPointFiveD(nn.Module):
    """ Backbone + Linear head for binary ICH classification """
    def __init__(self, backbone_name: str = "resnet18",
                 in_channels: int = 3):            # 3 slices (HU channel is reduced)
        super().__init__()
        backbone = _BACKBONES[backbone_name](in_channels)
        _replace_first_conv(backbone, in_channels)

        if hasattr(backbone, "fc"):               # ResNet-style
            feat_dim = backbone.fc.in_features
            backbone.fc = nn.Identity()
        elif hasattr(backbone, "classifier"):     # DenseNet-style
            feat_dim = backbone.classifier.in_features
            backbone.classifier = nn.Identity()
        else:
            raise ValueError("Add support for this backbone")

        self.backbone = backbone
        self.classifier = nn.Linear(feat_dim, 1)  # logits

    def forward(self, x):
        """ Set the model and return Classifier """
        x = self.backbone(x)
        return self.classifier(x).squeeze(1)      # (N,) logits


# ---------- 2. Metrics helpers -----------------------------------------------
def make_metric_dict(device):
    """ Make metrics dictionary """
    return {
        "auroc": BinaryAUROC().to(device),
        "prauc": BinaryAveragePrecision().to(device),
        "sens":  BinaryRecall(threshold=0.5).to(device),        # sensitivity
        "spec":  BinarySpecificity(threshold=0.5).to(device)    # specificity
    }


def update_metrics(metrics, preds, targets):
    """ Update the metrics """
    for m in metrics.values():
        m.update(preds, targets.int())


def compute_and_reset(metrics):
    """ Compute Metrics """
    out = {k: float(v.compute()) for k, v in metrics.items()}
    for v in metrics.values():
        v.reset()
    return out


# ---------- 3. Train / Val loops ---------------------------------------------
@torch.no_grad()
def validate(model, loader, loss_fn, device, metrics):
    """ Validate """
    model.eval()
    loop = tqdm(loader, desc="val", leave=False)
    total_loss = 0.0
    for x, y in loop:
        x, y = x.to(device, non_blocking=True), y.float().to(device, non_blocking=True)
        logits = model(x)
        loss = loss_fn(logits, y)
        total_loss += loss.item() * y.size(0)

        probs = torch.sigmoid(logits)
        update_metrics(metrics, probs, y)
    stats = compute_and_reset(metrics)
    stats["loss"] = total_loss / len(loader.dataset)
    return stats


def train_one_epoch(model, loader, optimizer, scaler, loss_fn,
                    device, metrics, epoch):
    """ Train one Epoch """
    model.train()
    loop = tqdm(loader, desc=f"train {epoch}")
    for x, y in loop:
        x, y = x.to(device, non_blocking=True), y.float().to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type=device.type):
            logits = model(x)
            loss = loss_fn(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        probs = torch.sigmoid(logits.detach())
        update_metrics(metrics, probs, y)

        loop.set_postfix(loss=loss.item())
    return compute_and_reset(metrics)


# ---------- 4. Fit routine with early-stopping -------------------------------
def fit(model, train_loader, val_loader,
        epochs=20, patience=3, lr=3e-4, weight_decay=1e-4,
        save_path="best_auc.pt"):
    """ Fit the model """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    loss_fn = nn.BCEWithLogitsLoss()

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = GradScaler()

    best_auc, epochs_no_improve = 0.0, 0
    for ep in range(1, epochs + 1):
        train_metrics = train_one_epoch(model, train_loader, opt, scaler,
                                        loss_fn, device,
                                        make_metric_dict(device), ep)
        val_metrics = validate(model, val_loader, loss_fn, device,
                               make_metric_dict(device))

        print(f"\nEpoch {ep}: "
              f"AUROC {val_metrics['auroc']:.4f}  "
              f"PRAUC {val_metrics['prauc']:.4f}  "
              f"Sens {val_metrics['sens']:.4f}  "
              f"Spec {val_metrics['spec']:.4f}")

        cur_auc = val_metrics["auroc"]
        if cur_auc > best_auc:
            best_auc = cur_auc
            torch.save(model.state_dict(), save_path)
            epochs_no_improve = 0
            print("  ↑ saved new best weights")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early-stopping (no AUROC improvement for {patience} epochs)")
                break


# ---------- 5. Example usage --------------------------------------------------
if __name__ == "__main__":
    # Assume you already have:
    #   train_idx.parquet  val_idx.parquet  full_metadata.parquet
    # And a DataLoader class `CQ500DataLoader25D` returning
    #   x: (9, H, W) float32  |  y: binary label 0/1  (ICH-majority)

    #from data_loader_25d import CQ500DataLoader25D   # ← adjust import

    train_idx = list(set(pd.read_parquet("/kaggle/working/train_patients.parquet")["name"].tolist()))
    val_idx = list(set(pd.read_parquet("/kaggle/working/val_patients.parquet")["name"].tolist()))
    
    train_set = CQ500DataLoader25D("/kaggle/working/b1_metadata.parquet",
                                   indices=train_idx)
    val_set   = CQ500DataLoader25D("/kaggle/working/b1_metadata.parquet",
                                   indices=val_idx)
    train_set.preprocess(output_dir="pp_train_tensors", chunk_size=100, max_ram_gb=20)

    #current_train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    #current_val_loader   = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

    #current_model = TwoPointFiveD(backbone_name="resnet18", in_channels=3)
    #fit(current_model, current_train_loader, current_val_loader, epochs=30, patience=3)


[Preprocess] Saving tensors to pp_train_tensors             in chunks of 100 slices, max RAM 20 GB...
  Saved 100 slices to pp_train_tensors/chunk_00000.pt
  Saved 100 slices to pp_train_tensors/chunk_00001.pt
  Saved 100 slices to pp_train_tensors/chunk_00002.pt
  Saved 100 slices to pp_train_tensors/chunk_00003.pt
  Saved 100 slices to pp_train_tensors/chunk_00004.pt
  Saved 100 slices to pp_train_tensors/chunk_00005.pt
  Saved 100 slices to pp_train_tensors/chunk_00006.pt
  Saved 100 slices to pp_train_tensors/chunk_00007.pt
  Saved 100 slices to pp_train_tensors/chunk_00008.pt
  Saved 100 slices to pp_train_tensors/chunk_00009.pt
  Saved 100 slices to pp_train_tensors/chunk_00010.pt
  Saved 100 slices to pp_train_tensors/chunk_00011.pt
  Saved 100 slices to pp_train_tensors/chunk_00012.pt
  Saved 100 slices to pp_train_tensors/chunk_00013.pt
  Saved 100 slices to pp_train_tensors/chunk_00014.pt
  Saved 100 slices to pp_train_tensors/chunk_00015.pt
  Saved 100 slices to pp_train_ten

RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 448 vs 342