In [1]:
import os
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

class BratsDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.root_dir = root_dir
        self.samples = sorted(os.listdir(root_dir))
        self.patch_size = patch_size

    def __len__(self):
        return len(self.samples)

    def load_nii(self, path):
        return nib.load(path).get_fdata().astype(np.float32)

    def random_crop(self, img, mask, size):
        _, D, H, W = img.shape

        # đảm bảo D/H/W > size
        d = np.random.randint(0, max(1, D - size))
        h = np.random.randint(0, max(1, H - size))
        w = np.random.randint(0, max(1, W - size))

        return (
            img[:, d:d+size, h:h+size, w:w+size],
            mask[d:d+size, h:h+size, w:w+size]
        )

    def __getitem__(self, idx):
        case = self.samples[idx]
        folder = os.path.join(self.root_dir, case)

        # flair = self.load_nii(os.path.join(folder, case + "_flair.nii.gz"))
        t1 = self.load_nii(os.path.join(folder, case + "_t1.nii.gz"))
        # t1ce = self.load_nii(os.path.join(folder, case + "_t1ce.nii.gz"))
        # t2 = self.load_nii(os.path.join(folder, case + "_t2.nii.gz"))
        mask = self.load_nii(os.path.join(folder, case + "_seg.nii.gz"))
        mask[mask == 4] = 3  # convert ET label 4 → 3

        # stack modal
        image = np.stack([ t1], axis=0)
        # image = np.stack([flair, t1, t1ce, t2], axis=0)
        # normalize
        image = (image - image.mean()) / (image.std() + 1e-6)

        # crop patch 128³
        image, mask = self.random_crop(image, mask, self.patch_size)

        return (
            torch.tensor(image, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.long)
        )


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNet3D(nn.Module):
    def __init__(self, n_channels=4, n_classes=4):
        super().__init__()

        self.enc1 = DoubleConv(n_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)

        self.pool = nn.MaxPool3d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up4 = nn.ConvTranspose3d(512, 256, 2, stride=2)
        self.dec4 = DoubleConv(512, 256)

        self.up3 = nn.ConvTranspose3d(256, 128, 2, stride=2)
        self.dec3 = DoubleConv(256, 128)

        self.up2 = nn.ConvTranspose3d(128, 64, 2, stride=2)
        self.dec2 = DoubleConv(128, 64)

        self.up1 = nn.ConvTranspose3d(64, 32, 2, stride=2)
        self.dec1 = DoubleConv(64, 32)

        self.out_conv = nn.Conv3d(32, n_classes, kernel_size=1)

    def forward(self, x):
        c1 = self.enc1(x)
        p1 = self.pool(c1)

        c2 = self.enc2(p1)
        p2 = self.pool(c2)

        c3 = self.enc3(p2)
        p3 = self.pool(c3)

        c4 = self.enc4(p3)
        p4 = self.pool(c4)

        bn = self.bottleneck(p4)

        u4 = self.up4(bn)
        u4 = torch.cat([u4, c4], dim=1)
        c5 = self.dec4(u4)

        u3 = self.up3(c5)
        u3 = torch.cat([u3, c3], dim=1)
        c6 = self.dec3(u3)

        u2 = self.up2(c6)
        u2 = torch.cat([u2, c2], dim=1)
        c7 = self.dec2(u2)

        u1 = self.up1(c7)
        u1 = torch.cat([u1, c1], dim=1)
        c8 = self.dec1(u1)

        return self.out_conv(c8)


In [3]:
def dice_loss(pred, target, eps=1e-6):
    pred = torch.softmax(pred, dim=1)
    target_1hot = F.one_hot(target, pred.shape[1]).permute(0,4,1,2,3)

    intersection = (pred * target_1hot).sum(dim=(0,2,3,4))
    union = pred.sum(dim=(0,2,3,4)) + target_1hot.sum(dim=(0,2,3,4))

    dice = (2 * intersection + eps) / (union + eps)
    return 1 - dice.mean()

def combined_loss(pred, target):
    ce = F.cross_entropy(pred, target)
    dl = dice_loss(pred, target)
    return ce + dl


In [4]:
import torch
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import nibabel as nib

# Giả sử bạn đã import BratsDataset và UNet3D
# from dataset import BratsDataset
# from model import UNet3D

# 1. Khởi tạo Dataset gốc
full_ds = BratsDataset("BraTS2021_Training_Data")

# 2. Tính toán kích thước cho từng tập (Tỉ lệ 8:1:1)
total_size = len(full_ds)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size # Lấy phần còn lại để đảm bảo tổng không bị lệch do làm tròn

print(f"Tổng số mẫu: {total_size}")
print(f"Train: {train_size}, Val: {val_size}, Test: {test_size}")

# 3. Thực hiện chia ngẫu nhiên (Dùng generator để cố định seed giúp kết quả lặp lại được)
generator = torch.Generator().manual_seed(42) 
train_set, val_set, test_set = random_split(full_ds, [train_size, val_size, test_size], generator=generator)

# 4. Tạo DataLoader cho từng tập
# Lưu ý: batch_size=16 cho 3D là RẤT LỚN, dễ bị tràn VRAM (OOM). 
# Với 3D UNet thường chỉ để batch_size=1 hoặc 2 tùy GPU.
batch_size = 24

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=4) # Val không cần shuffle
test_loader  = DataLoader(test_set,  batch_size=1,          shuffle=False, num_workers=4) # Test thường batch=1 để đánh giá từng ca

# 5. Khởi tạo Model và Optimizer
model = UNet3D(n_channels=1, n_classes=4).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Tổng số mẫu: 1251
Train: 1000, Val: 125, Test: 126


In [5]:
def calculate_dice(preds, targets, num_classes=4):
    """
    preds: Output của model (Logits) [Batch, C, D, H, W]
    targets: Ground Truth [Batch, D, H, W]
    """
    # Chuyển logits thành xác suất rồi lấy class có xác suất cao nhất
    preds = torch.argmax(torch.softmax(preds, dim=1), dim=1) # [B, D, H, W]
    
    dice_per_class = []
    # Bỏ qua class 0 (Background) vì nó chiếm đa số, tính vào sẽ làm ảo chỉ số
    for c in range(1, num_classes):
        pred_c = (preds == c)
        target_c = (targets == c)
        
        intersection = (pred_c & target_c).float().sum()
        union = pred_c.float().sum() + target_c.float().sum()
        
        if union == 0:
            dice = 1.0 # Cả 2 đều không có class này => dự đoán đúng
        else:
            dice = (2.0 * intersection) / (union + 1e-8) # +epsilon để tránh chia cho 0
        dice_per_class.append(dice.item())
        
    return sum(dice_per_class) / len(dice_per_class) # Trả về Dice trung bình của 3 class (1, 2, 3)

# --- 2. Vòng lặp Training & Validation ---
best_dice = 0.0 # Biến để theo dõi kết quả tốt nhất

In [None]:
for epoch in range(20):
    model.train()
    losses = []
    for img, mask in tqdm(train_loader):
        img = img.cuda()
        mask = mask.cuda()

        pred = model(img)
        loss = combined_loss(pred, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    print(f"Epoch {epoch + 1}: Loss = {sum(losses)/len(losses):.4f}")
    if epoch%3 == 1:
        model.eval()
        val_losses = []
        val_dices = []
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/20 [Valid]")
            for img, mask in val_pbar:
                img = img.cuda()
                mask = mask.cuda()

                pred = model(img)
                
                # 1. Tính Loss
                loss = combined_loss(pred, mask)
                val_losses.append(loss.item())
                
                # 2. Tính Dice Score (Metric đánh giá thực tế)
                dice = calculate_dice(pred, mask, num_classes=4)
                val_dices.append(dice)
                
                val_pbar.set_postfix({'val_loss': loss.item(), 'dice': dice})

        avg_val_loss = sum(val_losses) / len(val_losses)
        avg_val_dice = sum(val_dices) / len(val_dices)

        # ================= LOGGING & SAVE =================
        print(f"\nEND EPOCH {epoch+1}:")
        # print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {avg_val_loss:.4f} | Valid Dice: {avg_val_dice:.4f}")

        # Chỉ lưu model nếu Dice score cải thiện
        if avg_val_dice > best_dice:
            print(f"  >>> Model Improved (Dice: {best_dice:.4f} -> {avg_val_dice:.4f}). Saving...")
            torch.save(model.state_dict(), "best_unet3d_brats.pth")
            best_dice = avg_val_dice
        
        print("-" * 50)

  0%|          | 0/42 [00:10<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 139.72 GiB of which 3.09 GiB is free. Process 211501 has 123.35 GiB memory in use. Process 232482 has 13.26 GiB memory in use. Of the allocated memory 12.65 GiB is allocated by PyTorch, and 13.06 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [1]:
import torch
torch.cuda.empty_cache()