In [1]:
!pip install -q medpy

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m156.3/156.3 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for medpy (setup.py) ... [?25l[?25hdone


In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
torch.cuda.empty_cache()
from torch.utils.tensorboard import SummaryWriter
import os
from medpy.metric import binary
from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [3]:
def roi_mask_func(x):
    t1ce, t2, flair = x[0], x[1], x[2]
    flair_bin = flair > 0.7
    t2_bin = t2 > 0.7
    t1ce_bin = t1ce > 0.9
    flair_t2_mask = np.logical_and(flair_bin, t2_bin)

    from skimage.measure import label, regionprops
    roi = np.zeros_like(t1ce_bin, dtype=bool)
    labeled = label(t1ce_bin)
    for region in regionprops(labeled):
        if region.solidity > 0.7 and region.area > 500 and region.major_axis_length > 35:
            tumor_candidate = labeled == region.label
            overlap = np.logical_and(tumor_candidate, flair_t2_mask)
            if np.sum(overlap) > 20:
                roi = np.logical_or(roi, tumor_candidate)
    return roi[None, :, :]


class BrainPatchesDataset(Dataset):
    def __init__(self, X_paths, y_paths, roi_mask_func=None):
        self.X_paths = X_paths
        self.y_paths = y_paths
        self.roi_mask_func = roi_mask_func

    def __len__(self):
        return len(self.X_paths)

    def __getitem__(self, idx):
        # Wczytaj pliki .npz
        X_npz = np.load(self.X_paths[idx])
        y_npz = np.load(self.y_paths[idx])

        # Oryginalne dane: [4, H, W, D]
        X = X_npz["X"].astype(np.float32)
        y = y_npz["y"].astype(np.float32)  # [3, H, W, D]

        # Wybieramy slice środkowy
        z = X.shape[-1] // 2
        X_slice = X[..., z]     # [4, H, W]
        y_slice = y[..., z]     # [3, H, W]

        # Z-score normalizacja kanał po kanale
        X_zscore = (X_slice - X_slice.mean(axis=(1, 2), keepdims=True)) / \
                   (X_slice.std(axis=(1, 2), keepdims=True) + 1e-8)

        # Połącz 4 oryginalne + 4 znormalizowane kanały → [8, H, W]
        X_combined = np.concatenate([X_slice, X_zscore], axis=0)

        # ROI (np. z FLAIR + T2 + T1ce), ale nie maskuje y
        if self.roi_mask_func:
            roi = self.roi_mask_func(X_slice[[1, 2, 3]])  # T1ce, T2, FLAIR
        else:
            roi = np.ones_like(X_combined[0:1])  # [1, H, W]

        return (
            torch.tensor(X_combined).float(),  # [8, H, W]
            torch.tensor(y_slice).float(),     # [3, H, W]
            torch.tensor(roi).float()          # [1, H, W]
        )


In [4]:
class DistanceWiseAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, feature_map, expected_mask):
        B, C, H, W = feature_map.shape
        attn = torch.zeros((B, 1, H, W), device=feature_map.device)
        for b in range(B):
            mask = expected_mask[b, 0] > 0
            if mask.sum() == 0: continue
            coords = mask.nonzero(as_tuple=False).float()
            yc, xc = coords.mean(dim=0)
            y = torch.arange(H, device=feature_map.device).view(H, 1).repeat(1, W)
            x = torch.arange(W, device=feature_map.device).view(1, W).repeat(H, 1)
            dist = torch.sqrt((x - xc)**2 + (y - yc)**2) / H
            attn[b, 0] = (1.0 - dist).clamp(0, 1)
        return feature_map * attn.repeat(1, C, 1, 1)


In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class GlobalPath(nn.Module):
    def __init__(self, in_channels, out_channels=32):
        super().__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Linear(128, out_channels)

    def forward(self, x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.proj(x)
        return x

class AttentionFusion(nn.Module):
    def __init__(self, local_dim, global_dim, fusion_dim):
        super().__init__()
        self.local_proj = nn.Conv2d(local_dim, fusion_dim, kernel_size=1)
        self.global_proj = nn.Linear(global_dim, fusion_dim)
        self.fusion = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(fusion_dim, fusion_dim, kernel_size=1),
            nn.BatchNorm2d(fusion_dim)
        )

    def forward(self, local_feat, global_feat):
        B, C, H, W = local_feat.shape
        local = self.local_proj(local_feat)
        global_ = self.global_proj(global_feat).unsqueeze(2).unsqueeze(3)
        global_ = global_.expand(-1, -1, H, W)
        fused = self.fusion(local * global_)
        return fused

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.decode = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 2, out_channels, kernel_size=1)
        )

    def forward(self, x):
        return self.decode(x)


class CascadeNet(nn.Module):
    def __init__(self, in_channels=8, num_classes=3):
        super().__init__()
        self.local_path = nn.Sequential(
            ConvBlock(in_channels, 32),
            ConvBlock(32, 64),
            ConvBlock(64, 128)
        )
        self.global_path = GlobalPath(in_channels, out_channels=128)
        self.att_fusion = AttentionFusion(local_dim=128, global_dim=128, fusion_dim=128)
        self.distance_attn = DistanceWiseAttention()
        self.decoder = Decoder(in_channels=128, out_channels=3)

    def forward(self, x, roi):
        local = self.local_path(x)
        global_ = self.global_path(x)
        fused = self.att_fusion(local, global_)
        fused = self.distance_attn(fused, roi)
        out = self.decoder(fused)
        return out


In [6]:
def dice_score(pred, target, eps=1e-6):
    pred = torch.softmax(pred, dim=1)
    return [(2 * (pred[:, i] * target[:, i]).sum() + eps) /
            ((pred[:, i] + target[:, i]).sum() + eps) for i in range(target.shape[1])]

In [7]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001, mode='max', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif self._is_improvement(current_score):
            if self.verbose:
                print(f"✅ Metric improved: {self.best_score:.5f} → {current_score:.5f}")
            self.best_score = current_score
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"⏳ No improvement: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

    def _is_improvement(self, score):
        if self.mode == 'max':
            return score > self.best_score + self.min_delta
        elif self.mode == 'min':
            return score < self.best_score - self.min_delta
        else:
            raise ValueError("mode must be 'max' or 'min'")


In [8]:
def train_model(model, dataloaders, optimizer, criterion, device, epochs=40, save_path="/kaggle/working"):
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(os.path.join(save_path, "pred_masks"), exist_ok=True)
    run_name = f"CascadeNet_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir=os.path.join(save_path, "logs", run_name))

    early_stopping = EarlyStopping(patience=10, min_delta=0.001, mode='max', verbose=True)
    best_dice = 0.0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        model.train()
        epoch_loss = 0
        train_dice_scores = []

        for x, y, roi in dataloaders['train']:
            x, y, roi = x.to(device), y.to(device), roi.to(device)
            target = torch.argmax(y, dim=1).long()  # [B, H, W]

            outputs = model(x, roi)
            loss = criterion(outputs, target)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            epoch_loss += loss.item()
            train_dice_scores.append([d.item() for d in dice_score(outputs, y)])

        avg_train_loss = epoch_loss / len(dataloaders['train'])
        train_dice_mean = np.mean(train_dice_scores, axis=0)

        # ======= VALIDATION ========
        model.eval()
        val_loss = 0
        val_dice_scores = []
        val_hd = []
        val_precision, val_recall, val_spec = [], [], []

        with torch.no_grad():
            for x, y, roi in dataloaders['val']:
                x, y, roi = x.to(device), y.to(device), roi.to(device)
                target = torch.argmax(y, dim=1).long()

                outputs = model(x, roi)
                loss = criterion(outputs, target)
                val_loss += loss.item()
                val_dice_scores.append([d.item() for d in dice_score(outputs, y)])

                # Precision/Recall/Specificity
                pred_bin = torch.softmax(outputs, dim=1) > 0.5
                for i in range(y.shape[1]):
                    p = pred_bin[:, i]
                    t = y[:, i].bool()
                    TP = (p & t).sum().item()
                    FP = (p & ~t).sum().item()
                    FN = (~p & t).sum().item()
                    TN = (~p & ~t).sum().item()
                    val_precision.append(TP / (TP + FP + 1e-6))
                    val_recall.append(TP / (TP + FN + 1e-6))
                    val_spec.append(TN / (TN + FP + 1e-6))


                # Hausdorff 95
                pred_np = (pred_bin.cpu().numpy()).astype(np.uint8)
                y_np = y.cpu().numpy().astype(np.uint8)
                hd_list = []
                for i in range(y.shape[1]):
                    try:
                        hd = binary.hd95(pred_np[:, i], y_np[:, i])
                    except:
                        hd = np.nan
                    hd_list.append(hd)
                val_hd.append(hd_list)

        val_dice_mean = np.mean(val_dice_scores, axis=0)
        val_dice_avg = np.mean(val_dice_mean)
        hd_mean = np.nanmean(val_hd, axis=0)
        prec_mean = np.mean(np.array(val_precision).reshape(-1, 3), axis=0)
        rec_mean = np.mean(np.array(val_recall).reshape(-1, 3), axis=0)
        spec_mean = np.mean(np.array(val_spec).reshape(-1, 3), axis=0)

        print(f"Train Loss: {avg_train_loss:.4f} | Val Dice Avg: {val_dice_avg:.4f}")
        print(f"Val Dice: WT={val_dice_mean[0]:.4f}, TC={val_dice_mean[1]:.4f}, EC={val_dice_mean[2]:.4f}")
        print(f"Val Precision: {prec_mean}, Recall: {rec_mean}, Specificity: {spec_mean}")
        print(f"Hausdorff95: {hd_mean}")

        # TensorBoard log
        writer.add_scalar("Loss/train", avg_train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss / len(dataloaders['val']), epoch)
        writer.add_scalar("Dice/Val_Mean", val_dice_avg, epoch)
        for i, cls in enumerate(["WT", "TC", "EC"]):
            writer.add_scalar(f"Dice/Val_{cls}", val_dice_mean[i], epoch)
            writer.add_scalar(f"Precision/Val_{cls}", prec_mean[i], epoch)
            writer.add_scalar(f"Recall/Val_{cls}", rec_mean[i], epoch)
            writer.add_scalar(f"Specificity/Val_{cls}", spec_mean[i], epoch)
            writer.add_scalar(f"Hausdorff95/Val_{cls}", hd_mean[i], epoch)

        # Save last model
        torch.save(model.state_dict(), os.path.join(save_path, "cascade_last_model.pth"))

        # Save best model + predicted masks
        if val_dice_avg > best_dice:
            best_dice = val_dice_avg
            torch.save(model.state_dict(), os.path.join(save_path, "cascade_best_model.pth"))
            print(f"✅ Best model saved (Val Dice Avg: {val_dice_avg:.4f})")

            # Save predicted masks
            model.eval()
            with torch.no_grad():
                for i, (x, _, roi) in enumerate(dataloaders['val']):
                    x, roi = x.to(device), roi.to(device)
                    outputs = model(x, roi)
                    preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1).cpu().numpy()
                    for j, pred in enumerate(preds):
                        name = f"val_{i * x.size(0) + j}"
                        np.save(os.path.join(save_path, "pred_masks", f"{name}.npy"), pred)

        # EarlyStopping
        early_stopping(val_dice_avg)
        if early_stopping.early_stop:
            print(f"⛔ Early stopping triggered at epoch {epoch + 1}")
            break

    writer.close()
    print("✅ Training complete.")



In [9]:
data_dir = Path("/kaggle/input/braintumor-dataset-patches")

train_X = sorted(list((data_dir / "train" / "images").glob("*.npz")))
train_y = sorted(list((data_dir / "train" / "masks").glob("*.npz")))
val_X = sorted(list((data_dir / "val" / "images").glob("*.npz")))
val_y = sorted(list((data_dir / "val" / "masks").glob("*.npz")))

train_ds = BrainPatchesDataset(train_X, train_y, roi_mask_func)
val_ds = BrainPatchesDataset(val_X, val_y, roi_mask_func)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=2)
dataloaders = {'train': train_loader, 'val': val_loader}

model = CascadeNet(in_channels=8, num_classes=3).to(device)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs.")
    model = nn.DataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

train_model(
    model=model,
    dataloaders=dataloaders,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=100,
    save_path="/kaggle/working/CascadeNet"
)

Using 2 GPUs.

Epoch 1/100
Train Loss: 0.7421 | Val Dice Avg: 0.0908
Val Dice: WT=0.1085, TC=0.0518, EC=0.1119
Val Precision: [0.06068678 0.19563478 0.27530171], Recall: [0.98378675 0.04431727 0.18165256], Specificity: [0.00863501 0.99942782 0.99797314]
Hausdorff95: [58.75606959 40.99220812 37.82341429]
✅ Best model saved (Val Dice Avg: 0.0908)

Epoch 2/100
Train Loss: 0.3429 | Val Dice Avg: 0.1094
Val Dice: WT=0.1082, TC=0.0691, EC=0.1510
Val Precision: [0.05927296 0.17175187 0.26256501], Recall: [0.95862116 0.04932598 0.20976309], Specificity: [0.01068035 0.99932736 0.99663161]
Hausdorff95: [58.78961356 43.02377433 39.00866201]
✅ Best model saved (Val Dice Avg: 0.1094)
✅ Metric improved: 0.09075 → 0.10945

Epoch 3/100
Train Loss: 0.1887 | Val Dice Avg: 0.1084
Val Dice: WT=0.0937, TC=0.0861, EC=0.1454
Val Precision: [0.04945314 0.121473   0.16714626], Recall: [0.78212319 0.13790883 0.24853344], Specificity: [0.03122804 0.99086463 0.98398535]
Hausdorff95: [59.03371787 36.44548109 40.60