In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
#import matplotlib.pyplot as plt
import random
import torch.nn.functional as F


In [2]:
# Dataset Class
class BraTSDataset(Dataset):
    def __init__(self, img_dir, mask_dir):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_list = os.listdir(img_dir)
        self.mask_list = os.listdir(mask_dir)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_list[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_list[idx])

        img = np.load(img_path)
        mask = np.load(mask_path)
        mask = np.argmax(mask, axis=3)  # Convert one-hot encoded mask to single channel

        img = torch.tensor(img, dtype=torch.float32).permute(3, 0, 1, 2)  # Channels first
        mask = torch.tensor(mask, dtype=torch.long)

        return img, mask

# Define the 3D U-Net Model
class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()
        self.c1 = self.conv_block(in_channels, 16)
        self.c2 = self.conv_block(16, 32)
        self.c3 = self.conv_block(32, 64)
        self.c4 = self.conv_block(64, 128)
        self.c5 = self.conv_block(128, 256)

        self.u6 = self.upconv_block(256, 128)
        self.u7 = self.upconv_block(128, 64)
        self.u8 = self.upconv_block(64, 32)
        self.u9 = self.upconv_block(32, 16)

        self.output_conv = nn.Conv3d(16, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels, dropout_prob=0.1):
        layers = [
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout_prob),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]
        return nn.Sequential(*layers)

    def upconv_block(self, in_channels, out_channels, dropout_prob=0.2):
        layers = [
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout_prob),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]
        return nn.Sequential(*layers)

    def forward(self, x):
        enc1 = self.c1(x)
        enc2 = self.c2(F.max_pool3d(enc1, kernel_size=2, stride=2))
        enc3 = self.c3(F.max_pool3d(enc2, kernel_size=2, stride=2))
        enc4 = self.c4(F.max_pool3d(enc3, kernel_size=2, stride=2))
        enc5 = self.c5(F.max_pool3d(enc4, kernel_size=2, stride=2))

        dec6 = F.interpolate(enc5, scale_factor=2, mode='trilinear', align_corners=True)
        dec6 = torch.cat([dec6, enc4], dim=1)

        dec7 = F.interpolate(dec6, scale_factor=2, mode='trilinear', align_corners=True)
        dec7 = torch.cat([dec7, enc3], dim=1)

        dec8 = F.interpolate(dec7, scale_factor=2, mode='trilinear', align_corners=True)
        dec8 = torch.cat([dec8, enc2], dim=1)

        dec9 = F.interpolate(dec8, scale_factor=2, mode='trilinear', align_corners=True)
        dec9 = torch.cat([dec9, enc1], dim=1)

        output = self.output_conv(dec9)
        return output

# Directories
train_img_dir = "../data/BraTS2020_TrainingData/input_data_128/train/images/"
train_mask_dir = "../data/BraTS2020_TrainingData/input_data_128/train/masks/"
val_img_dir = "../data/BraTS2020_TrainingData/input_data_128/val/images/"
val_mask_dir = "../data/BraTS2020_TrainingData/input_data_128/val/masks/"

# Datasets and DataLoaders
batch_size = 1
train_dataset = BraTSDataset(train_img_dir, train_mask_dir)
val_dataset = BraTSDataset(val_img_dir, val_mask_dir)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3D(in_channels=4, out_channels=4).to(device)  # Update input channels if necessary

# Loss and Optimizer
class_weights = torch.tensor([0.25, 0.25, 0.25, 0.25], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training Loop
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Save the model
torch.save(model.state_dict(), "brats_3d_model.pth")

# Plot Sample Predictions
def plot_sample_prediction(data_loader, model, device):
    model.eval()
    imgs, masks = next(iter(data_loader))
    imgs, masks = imgs.to(device), masks.to(device)

    with torch.no_grad():
        outputs = model(imgs)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()

    n_slice = random.randint(0, masks.shape[2] - 1)

    plt.figure(figsize=(12, 8))
    plt.subplot(1, 3, 1)
    plt.imshow(imgs[0, 0, :, :, n_slice].cpu(), cmap='gray')
    plt.title('Input Image')

    plt.subplot(1, 3, 2)
    plt.imshow(masks[0, :, :, n_slice].cpu(), cmap='gray')
    plt.title('Ground Truth')

    plt.subplot(1, 3, 3)
    plt.imshow(preds[0, :, :, n_slice], cmap='gray')
    plt.title('Prediction')

    plt.show()

plot_sample_prediction(val_loader, model, device)


RuntimeError: Given groups=1, weight of size [16, 4, 3, 3, 3], expected input[1, 3, 128, 128, 128] to have 4 channels, but got 3 channels instead