In [None]:
!pip install nibabel
!pip install -U scikit-learn
!pip install torch torchvision

import os
import zipfile
import requests
import tarfile
import random
import numpy as np
import nibabel as nib
from itertools import combinations
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import subprocess


In [None]:
def load_data(data_path, modalities, fraction=1.0):
    data_list = []
    mask_list = []

    patient_folders = os.listdir(data_path)
    random.shuffle(patient_folders)
    num_patients = int(len(patient_folders) * fraction)
    selected_patients = patient_folders[:num_patients]

    for patient_folder in selected_patients:
        patient_path = os.path.join(data_path, patient_folder)

        image_data = []
        for modality in modalities:
            modality_file = os.path.join(patient_path, f"{patient_folder}_{modality}.nii.gz")
            modality_data = nib.load(modality_file).get_fdata(dtype=np.float32)  # Add dtype=np.float32
            image_data.append(modality_data)

        mask_file = os.path.join(patient_path, f"{patient_folder}_seg.nii.gz")
        mask_data = nib.load(mask_file).get_fdata()

        data_list.append(np.stack(image_data, axis=-1))
        mask_list.append(mask_data)

    return np.array(data_list), np.array(mask_list)

def generate_modality_combinations(modalities):
    combinations_list = []
    for i in range(1, len(modalities) + 1):
        for subset in combinations(modalities, i):
            combinations_list.append(list(subset))
    return combinations_list


class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()
        # Encoder
        self.enc1 = self.conv_block(in_channels, 32)
        self.enc2 = self.conv_block(32, 64)
        self.enc3 = self.conv_block(64, 128)
        self.enc4 = self.conv_block(128, 256)
        
        self.pool = nn.MaxPool3d(2)
        
        # Decoder
        self.upconv3 = self.upconv_block(256, 128)
        self.dec3 = self.conv_block(256, 128)
        self.upconv2 = self.upconv_block(128, 64)
        self.dec2 = self.conv_block(128, 64)
        self.upconv1 = self.upconv_block(64, 32)
        self.dec1 = self.conv_block(64, 32)
        
        self.out_conv = nn.Conv3d(32, out_channels, kernel_size=1)
    
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        dec3 = self.dec3(torch.cat((self.upconv3(enc4), enc3), dim=1))
        dec2 = self.dec2(torch.cat((self.upconv2(dec3), enc2), dim=1))
        dec1 = self.dec1(torch.cat((self.upconv1(dec2), enc1), dim=1))
        
        return self.out_conv(dec1)


In [None]:
current_directory = os.getcwd()
data_path = os.path.join(current_directory, "data")
os.makedirs(data_path, exist_ok=True)
!pip install kaggle

%env KAGGLE_USERNAME=ihindal
%env KAGGLE_KEY=549e8a0e9862683f6f255cb289ece9de
import kaggle
subprocess.run(["kaggle", "datasets", "download", "-d", "dschettler8845/brats-2021-task1", "-p", data_path])


In [None]:
# Import the necessary libraries
import os
import tarfile
import zipfile

# Set the data path
data_path = "/content/drive/MyDrive/brats"

# Unzip the dataset
zip_file_path = os.path.join(data_path, "brats-2021-task1.zip")
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(data_path)

# Create a new directory for the dataset
dataset_path = os.path.join(data_path, "dataset")
os.makedirs(dataset_path, exist_ok=True)

# Untar the dataset
tar_file_path = os.path.join(data_path, "BraTS2021_Training_Data.tar")
with tarfile.open(tar_file_path, 'r') as tar_ref:
    tar_ref.extractall(dataset_path)

# Set the data path
data_path = dataset_path


In [None]:
# Define the parameters
all_modalities = ['t1', 't1ce', 't2', 'flair']
modality_combinations = generate_modality_combinations(all_modalities)

batch_size_per_gpu = 4
num_gpus = 1
total_batch_size = batch_size_per_gpu * num_gpus

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loop through all possible combinations of modalities
for modalities in modality_combinations:
    print(f"Training with modalities: {modalities}")

    # Load the data
    fraction = 0.1
    X, y = load_data(data_path, modalities, fraction)

    # Split the data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Create the model
    model = UNet3D(len(modalities), 1).to(device)

    # Scale the learning rate
    lr = 1e-3 * (total_batch_size / 16)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Define the loss function
    criterion = nn.BCEWithLogitsLoss()

    # Set up the TensorBoard writer
    writer = SummaryWriter(log_dir=f'./logs/3DUNet_{"_".join(modalities)}')

    # Train the model
    num_epochs = 100
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for i in range(0, len(X_train), batch_size_per_gpu):
            optimizer.zero_grad()

            inputs = torch.from_numpy(X_train[i:i + batch_size_per_gpu]).to(device)
            labels = torch.from_numpy(y_train[i:i + batch_size_per_gpu]).unsqueeze(1).to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        epoch_loss /= len(X_train)
        writer.add_scalar('train_loss', epoch_loss, epoch)

        # Validate the model
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for i in range(0, len(X_val), batch_size_per_gpu):
                inputs = torch.from_numpy(X_val[i:i + batch_size_per_gpu]).to(device)
                labels = torch.from_numpy(y_val[i:i + batch_size_per_gpu]).unsqueeze(1).to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()

        val_loss /= len(X_val)
        writer.add_scalar('val_loss', val_loss, epoch)

    # Save the model
    torch.save(model.state_dict(), f'3DUNet_{"_".join(modalities)}.pth')
    writer.close()
