In [1]:
import os
import numpy as np
import nibabel as nib
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


In [2]:


DATA_PATH = "/kaggle/input/notebook73fe287098/BraTS2021_Training_Data"

all_files = []
for root, dirs, files in os.walk(DATA_PATH):
    for file in files:
        all_files.append(os.path.join(root, file))

print(f"Total files: {len(all_files)}")
all_files[:10]  # preview first 10


Total files: 6256


['/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/.DS_Store',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01030/BraTS2021_01030_flair.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01030/BraTS2021_01030_t1ce.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01030/BraTS2021_01030_t2.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01030/BraTS2021_01030_t1.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01030/BraTS2021_01030_seg.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01358/BraTS2021_01358_t2.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01358/BraTS2021_01358_seg.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01358/BraTS2021_01358_t1.nii.gz',
 '/kaggle/input/notebook73fe287098/BraTS2021_Training_Data/BraTS2021_01358/BraTS2021_01358_t1ce.ni

In [3]:
def load_nifti(path):
    return nib.load(path).get_fdata()
def normalize(img):
    return (img - img.min()) / (img.max() - img.min() + 1e-8)
def resize_image(img, size=(256, 256)):
    img = torch.tensor(img, dtype=torch.float32)
    img = F.interpolate(img.unsqueeze(0).unsqueeze(0),
                        size=size, mode="bilinear",
                        align_corners=False)
    return img.squeeze().numpy()
def resize_mask(mask, size=(256, 256)):
    mask = torch.tensor(mask, dtype=torch.float32)
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0),
                         size=size, mode="nearest")
    return mask.squeeze().numpy()


In [4]:
patients = sorted([
    p for p in os.listdir(DATA_PATH)
    if p.startswith("BraTS2021_")
])

print("Total patients:", len(patients))
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

class BraTSDataset(Dataset):
    def __init__(self, data_path, indices, slice_start=10, slice_end=145):
        self.data_path = data_path
        self.patients = sorted([
            p for p in os.listdir(data_path)
            if p.startswith("BraTS2021_")
        ])
        self.indices = indices
        self.slice_start = slice_start
        self.slice_end = slice_end

    def __len__(self):
        return len(self.indices) * (self.slice_end - self.slice_start)

    def load_nifti(self, path):
        return nib.load(path).get_fdata()

    def normalize(self, img):
        return (img - img.min()) / (img.max() - img.min() + 1e-8)

    def resize_image(self, img, size=(256, 256)):
        img = torch.tensor(img, dtype=torch.float32)
        img = F.interpolate(img.unsqueeze(0).unsqueeze(0),
                            size=size, mode="bilinear",
                            align_corners=False)
        return img.squeeze()

    def resize_mask(self, mask, size=(256, 256)):
        mask = torch.tensor(mask, dtype=torch.float32)
        mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0),
                             size=size, mode="nearest")
        return mask.squeeze()

    def __getitem__(self, idx):
        patient_idx = self.indices[idx // (self.slice_end - self.slice_start)]
        slice_idx = self.slice_start + idx % (self.slice_end - self.slice_start)

        patient_id = self.patients[patient_idx]
        base = os.path.join(self.data_path, patient_id)

        t1    = self.load_nifti(os.path.join(base, f"{patient_id}_t1.nii.gz"))
        t1ce  = self.load_nifti(os.path.join(base, f"{patient_id}_t1ce.nii.gz"))
        t2    = self.load_nifti(os.path.join(base, f"{patient_id}_t2.nii.gz"))
        flair = self.load_nifti(os.path.join(base, f"{patient_id}_flair.nii.gz"))
        seg   = self.load_nifti(os.path.join(base, f"{patient_id}_seg.nii.gz"))

        t1    = self.normalize(t1[:, :, slice_idx])
        t1ce  = self.normalize(t1ce[:, :, slice_idx])
        t2    = self.normalize(t2[:, :, slice_idx])
        flair = self.normalize(flair[:, :, slice_idx])

        mask = (seg[:, :, slice_idx] > 0).astype(np.float32)

        t1    = self.resize_image(t1)
        t1ce  = self.resize_image(t1ce)
        t2    = self.resize_image(t2)
        flair = self.resize_image(flair)
        mask  = self.resize_mask(mask)

        x = torch.stack([t1, t1ce, t2, flair], dim=0)

        return x, mask




Total patients: 1251


In [5]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [6]:
class Unet(nn.Module):
    def __init__(self, in_channels=4, out_channels=1):
        super().__init__()

        self.enc1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = conv_block(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        e4 = self.enc4(p3)
        p4 = self.pool4(e4)

        b = self.bottleneck(p4)

        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.final(d1)



In [7]:
def dice_loss(preds,targets, smooth=1e-6):
    preds = torch.sigmoid(preds) 
    preds = preds.view(-1)
    targets = targets.view(-1)
    intersection = (preds * targets).sum()
    dice = (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
    return 1 - dice 

def bce_dice_loss(preds, targets):
    bce = F.binary_cross_entropy_with_logits(preds, targets)
    dice = dice_loss(preds, targets)
    return bce + dice

In [8]:
def dice_score(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    preds = preds.view(-1)
    targets = targets.view(-1)
    intersection = (preds * targets).sum()
    return (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)


def iou_score(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    preds = preds.view(-1)
    targets = targets.view(-1)
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    return (intersection + smooth) / (union + smooth)

In [9]:
import os
import random

DATA_PATH = "/kaggle/input/notebook73fe287098/BraTS2021_Training_Data"

patients = sorted([
    p for p in os.listdir(DATA_PATH)
    if p.startswith("BraTS2021_")
])

random.seed(42)

num_patients = len(patients)
indices = list(range(num_patients))
random.shuffle(indices)

split_idx = int(0.8 * num_patients)

train_indices = indices[:split_idx]
val_indices   = indices[split_idx:]

print(f"Total patients : {num_patients}")
print(f"Train patients : {len(train_indices)}")
print(f"Val patients   : {len(val_indices)}")


Total patients : 1251
Train patients : 1000
Val patients   : 251


In [10]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Unet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [11]:
from tqdm import tqdm

def train_one_epoch(model, loader, optimizer, max_batches=3000):
    model.train()
    total_loss = 0
    num_batches = 0

    for i, (images, masks) in enumerate(tqdm(loader, desc="Training")):
        if i >= max_batches:
            break

        images = images.to(device)
        masks = masks.unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = bce_dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / num_batches


In [12]:
@torch.no_grad()
def validate(model, loader, max_batches=500):
    model.eval()
    dice_total = 0
    iou_total = 0
    num_batches = 0

    for i, (images, masks) in enumerate(loader):
        if i >= max_batches:
            break

        images = images.to(device)
        masks = masks.unsqueeze(1).to(device)

        outputs = model(images)

        dice_total += dice_score(outputs, masks)
        iou_total += iou_score(outputs, masks)
        num_batches += 1

    return dice_total / num_batches, iou_total / num_batches



In [13]:
from torch.utils.data import DataLoader

train_dataset = BraTSDataset(DATA_PATH, indices=train_indices)
val_dataset   = BraTSDataset(DATA_PATH, indices=val_indices)


train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [14]:
epochs = 10

for epoch in range(epochs):
    train_loss = train_one_epoch(
        model,
        train_loader,
        optimizer,
        max_batches=3000
    )

    val_dice, val_iou = validate(
        model,
        val_loader,
        max_batches=500
    )

    print(f"Epoch [{epoch+1}/{epochs}]")
    print(f"Train Loss : {train_loss:.4f}")
    print(f"Val Dice   : {val_dice:.4f}")
    print(f"Val IoU    : {val_iou:.4f}")
    print("-" * 40)



Training:   9%|▉         | 3000/33750 [50:10<8:34:20,  1.00s/it]


Epoch [1/10]
Train Loss : 0.4954
Val Dice   : 0.7381
Val IoU    : 0.7004
----------------------------------------


Training:   9%|▉         | 3000/33750 [56:58<9:44:01,  1.14s/it]


Epoch [2/10]
Train Loss : 0.3214
Val Dice   : 0.7052
Val IoU    : 0.6659
----------------------------------------


Training:   9%|▉         | 3000/33750 [57:07<9:45:28,  1.14s/it]


Epoch [3/10]
Train Loss : 0.2983
Val Dice   : 0.7692
Val IoU    : 0.7338
----------------------------------------


Training:   9%|▉         | 3000/33750 [55:20<9:27:19,  1.11s/it] 


Epoch [4/10]
Train Loss : 0.2864
Val Dice   : 0.7921
Val IoU    : 0.7569
----------------------------------------


Training:   9%|▉         | 3000/33750 [56:12<9:36:12,  1.12s/it] 


Epoch [5/10]
Train Loss : 0.2784
Val Dice   : 0.7519
Val IoU    : 0.7170
----------------------------------------


Training:   9%|▉         | 3000/33750 [54:47<9:21:40,  1.10s/it] 


Epoch [6/10]
Train Loss : 0.2628
Val Dice   : 0.7069
Val IoU    : 0.6696
----------------------------------------


Training:   9%|▉         | 3000/33750 [58:02<9:54:55,  1.16s/it] 


Epoch [7/10]
Train Loss : 0.2542
Val Dice   : 0.7676
Val IoU    : 0.7294
----------------------------------------


Training:   9%|▉         | 3000/33750 [56:13<9:36:15,  1.12s/it] 


Epoch [8/10]
Train Loss : 0.2598
Val Dice   : 0.8093
Val IoU    : 0.7724
----------------------------------------


Training:   9%|▉         | 3000/33750 [52:44<9:00:35,  1.05s/it] 


Epoch [9/10]
Train Loss : 0.2525
Val Dice   : 0.8058
Val IoU    : 0.7682
----------------------------------------


Training:   9%|▉         | 3000/33750 [57:15<9:46:52,  1.15s/it] 


Epoch [10/10]
Train Loss : 0.2388
Val Dice   : 0.8165
Val IoU    : 0.7819
----------------------------------------


In [15]:
import torch

torch.save(model.state_dict(), "unet_brats.pth")
print("Model saved as unet_brats.pth")


Model saved as unet_brats.pth
