In [1]:
# Brain Tumor Segmentation for MRI on Kaggle
import torch
torch.cuda.empty_cache()

import os
import sys
from pathlib import Path
import numpy as np
import nibabel as nib
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
from sklearn.metrics import jaccard_score
import optuna
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# ==========================================
# Dataset
# ==========================================
class BrainPatchesDataset(Dataset):
    def __init__(self, data_dir_X, data_dir_y, transform=None):
        self.X_paths = sorted(list(Path(data_dir_X).glob("*.npz")))
        self.y_paths = sorted(list(Path(data_dir_y).glob("*.npz")))
        assert len(self.X_paths) == len(self.y_paths), "Mismatch in number of X and y patches"
        self.transform = transform

    def __len__(self):
        return len(self.X_paths)

    def __getitem__(self, idx):
        # Wczytanie danych z plików .npz
        X_npz = np.load(self.X_paths[idx])
        y_npz = np.load(self.y_paths[idx])

        X = X_npz["X"].astype(np.float32)
        y = y_npz["y"].astype(np.float32)

        # Zamiana na tensory
        X_tensor = torch.from_numpy(X)
        y_tensor = torch.from_numpy(y)

        # Dodaj kanał jeśli trzeba (np. [C, H, W, D])
        if X_tensor.ndim == 3:
            X_tensor = X_tensor.unsqueeze(0)
        if y_tensor.ndim == 3:
            y_tensor = y_tensor.unsqueeze(0)

        if self.transform:
            X_tensor, y_tensor = self.transform(X_tensor, y_tensor)

        return X_tensor, y_tensor


In [3]:
# ==========================================
# Model
# ==========================================

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv3d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv3d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv3d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# ======================================
# === Attention U-Net 3D ===
# ======================================

class AttentionUNet3D(nn.Module):
    def __init__(self, in_channels=4, out_channels=3, base_channels=32):
        super().__init__()

        self.enc1 = DoubleConv(in_channels, base_channels)
        self.pool1 = nn.MaxPool3d(2)

        self.enc2 = DoubleConv(base_channels, base_channels*2)
        self.pool2 = nn.MaxPool3d(2)

        self.enc3 = DoubleConv(base_channels*2, base_channels*4)
        self.pool3 = nn.MaxPool3d(2)

        self.bottleneck = DoubleConv(base_channels*4, base_channels*8)

        self.up3 = nn.ConvTranspose3d(base_channels*8, base_channels*4, kernel_size=2, stride=2)
        self.att3 = AttentionGate(F_g=base_channels*4, F_l=base_channels*4, F_int=base_channels*2)
        self.dec3 = DoubleConv(base_channels*8, base_channels*4)

        self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=base_channels*2, F_l=base_channels*2, F_int=base_channels)
        self.dec2 = DoubleConv(base_channels*4, base_channels*2)

        self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2)
        self.att1 = AttentionGate(F_g=base_channels, F_l=base_channels, F_int=base_channels//2)
        self.dec1 = DoubleConv(base_channels*2, base_channels)

        self.out_conv = nn.Conv3d(base_channels, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        bottleneck = self.bottleneck(self.pool3(enc3))

        up3 = self.up3(bottleneck)
        att3 = self.att3(up3, enc3)
        dec3 = self.dec3(torch.cat([up3, att3], dim=1))

        up2 = self.up2(dec3)
        att2 = self.att2(up2, enc2)
        dec2 = self.dec2(torch.cat([up2, att2], dim=1))

        up1 = self.up1(dec2)
        att1 = self.att1(up1, enc1)
        dec1 = self.dec1(torch.cat([up1, att1], dim=1))

        return self.out_conv(dec1)

In [4]:
# ==========================================
# Loss function
# ==========================================
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        assert preds.shape == targets.shape, "Shape mismatch between preds and targets"

        intersection = (preds * targets).sum(dim=(2, 3, 4))
        union = preds.sum(dim=(2, 3, 4)) + targets.sum(dim=(2, 3, 4))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)

        return 1 - dice.mean()  # Średnia po klasach i batchu

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice = DiceLoss()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, preds, targets):
        return 0.5 * self.dice(preds, targets) + 0.5 * self.bce(preds, targets)



In [5]:
# ==========================================
# Evaluation
# ==========================================
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import recall_score

def dice_score(pred, target):
    smooth = 1e-5
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    return (2 * intersection + smooth) / (union + smooth)

def evaluate_model_multiclass(model, val_loader, device):
    model.eval()

    dice = {"WT": [], "TC": [], "EC": []}
    sensitivity = {"WT": [], "TC": [], "EC": []}

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)

            outputs = model(imgs)  # (B, C=3, D, H, W)
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()

            for i in range(imgs.shape[0]):
                pred = preds[i].cpu().numpy()
                target = masks[i].cpu().numpy()

                # === Binarize categories ===
                pred_WT = (pred.sum(axis=0) > 0).astype(int)          # Whole tumor = any of class 1-3
                pred_TC = ((pred[0] + pred[2]) > 0).astype(int)       # TC = class 1 (necrosis) + 3 (enhancing)
                pred_EC = (pred[2] > 0).astype(int)                   # EC = class 3 (enhancing)

                true_WT = (target.sum(axis=0) > 0).astype(int)
                true_TC = ((target[0] + target[2]) > 0).astype(int)
                true_EC = (target[2] > 0).astype(int)

                # === Dice ===
                dice["WT"].append(dice_score(pred_WT, true_WT))
                dice["TC"].append(dice_score(pred_TC, true_TC))
                dice["EC"].append(dice_score(pred_EC, true_EC))

                # === Sensitivity ===
                for label, pred_bin, true_bin in zip(["WT", "TC", "EC"],
                                                     [pred_WT, pred_TC, pred_EC],
                                                     [true_WT, true_TC, true_EC]):
                    if true_bin.sum() == 0:
                        sensitivity[label].append(np.nan)
                    else:
                        recall = recall_score(true_bin.flatten(), pred_bin.flatten(), zero_division=0)
                        sensitivity[label].append(recall)

    # Final results
    results = {
        "dice": {k: np.nanmean(v) for k, v in dice.items()},
        "sensitivity": {k: np.nanmean(v) for k, v in sensitivity.items()}
    }
    return results

In [6]:
# ==========================================
# Training
# ==========================================
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils import clip_grad_norm_

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_score):
        if self.best_score is None or current_score > self.best_score + self.min_delta:
            self.best_score = current_score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=10, save_path="/kaggle/working/results"):
    os.makedirs(save_path, exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(save_path, "logs"))

    train_losses, val_dices = [], []
    best_dice = 0.0
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0

        for imgs, masks in train_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)

            if imgs.dim() == 4:
                imgs = imgs.unsqueeze(1)
            if masks.dim() == 4:
                masks = masks.unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(imgs)

            if masks.shape != outputs.shape:
                masks = F.interpolate(masks, size=outputs.shape[2:], mode='nearest')

            loss = criterion(outputs, masks)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)

        metrics = evaluate_model_multiclass(model, val_loader, device)
        val_dices.append(metrics["dice"]["WT"])
        val_sens = metrics["sensitivity"]["WT"]

        scheduler.step(avg_loss)

        writer.add_scalar("Loss/train", avg_loss, epoch)
        writer.add_scalar("Val/Dice_WT", metrics["dice"]["WT"], epoch)
        writer.add_scalar("Val/Sens_WT", val_sens, epoch)

        for cls in ["WT", "TC", "EC"]:
            writer.add_scalar(f"Val/Dice_{cls}", metrics["dice"][cls], epoch)
            writer.add_scalar(f"Val/Sens_{cls}", metrics["sensitivity"][cls], epoch)

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_loss:.4f} - Dice WT: {metrics['dice']['WT']:.4f} - Sens WT: {val_sens:.4f}")

        current_dice = metrics["dice"]["WT"]
        if current_dice > best_dice:
            best_dice = current_dice
            torch.save(model.state_dict(), os.path.join(save_path, "best_model.pth"))
            writer.add_text("Model", f"Best model saved at epoch {epoch+1} (Dice WT: {current_dice:.4f})", epoch)
            print(f"✅ Best model saved (Dice WT: {current_dice:.4f})")

        # ✅ Zapis metryk i ostatniego modelu co epokę
        np.save(os.path.join(save_path, "train_loss.npy"), train_losses)
        np.save(os.path.join(save_path, "val_dice.npy"), val_dices)
        torch.save(model.state_dict(), os.path.join(save_path, "last_model.pth"))
        writer.flush()

        early_stopping(current_dice)
        if early_stopping.early_stop:
            writer.add_text("EarlyStopping", f"Triggered at epoch {epoch+1}", epoch)
            print(f"⛔ Early stopping triggered at epoch {epoch+1}")
            break

    writer.close()

In [7]:
print("Devices:", torch.cuda.device_count(), torch.cuda.get_device_name(0))

Devices: 2 Tesla T4


In [8]:
# ==========================================
# Run Training
# ==========================================
data_dir = Path("/kaggle/input/braintumor-dataset-patches")

train_dataset = BrainPatchesDataset(data_dir / "train" / "images", data_dir / "train" / "masks")
val_dataset = BrainPatchesDataset(data_dir / "val" / "images", data_dir / "val" / "masks")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2, pin_memory=True)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Number of GPUs available: {torch.cuda.device_count()}")

# Model initialization
model = AttentionUNet3D(in_channels=4, out_channels=3)

# Wrap with DataParallel if multiple GPUs
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel.")
    model = torch.nn.DataParallel(model)

model = model.to(device)


#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = AttentionUNet3D(in_channels=4, out_channels=3).to(device)

criterion = CombinedLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=80,
    save_path="/kaggle/working/results/attention_unet"
)

Using device: cuda
Number of GPUs available: 2
Using 2 GPUs with DataParallel.
Epoch 1/80 - Train Loss: 0.4171 - Dice WT: 0.8495 - Sens WT: 0.8569
✅ Best model saved (Dice WT: 0.8495)
Epoch 2/80 - Train Loss: 0.2658 - Dice WT: 0.8486 - Sens WT: 0.8389
Epoch 3/80 - Train Loss: 0.2511 - Dice WT: 0.8196 - Sens WT: 0.7722
Epoch 4/80 - Train Loss: 0.2441 - Dice WT: 0.8446 - Sens WT: 0.8751
Epoch 5/80 - Train Loss: 0.2391 - Dice WT: 0.8622 - Sens WT: 0.8826
✅ Best model saved (Dice WT: 0.8622)
Epoch 6/80 - Train Loss: 0.2350 - Dice WT: 0.8654 - Sens WT: 0.8489
✅ Best model saved (Dice WT: 0.8654)
Epoch 7/80 - Train Loss: 0.2319 - Dice WT: 0.8682 - Sens WT: 0.9055
✅ Best model saved (Dice WT: 0.8682)
Epoch 8/80 - Train Loss: 0.2278 - Dice WT: 0.8711 - Sens WT: 0.8837
✅ Best model saved (Dice WT: 0.8711)
Epoch 9/80 - Train Loss: 0.2258 - Dice WT: 0.8563 - Sens WT: 0.8279
Epoch 10/80 - Train Loss: 0.2234 - Dice WT: 0.8597 - Sens WT: 0.8359
Epoch 11/80 - Train Loss: 0.2193 - Dice WT: 0.8698 - Se