In [None]:
!unzip /content/sample_data/Data.zip.ZIP -d /content/sample_data

unzip:  cannot find or open /content/sample_data/Data.zip.ZIP, /content/sample_data/Data.zip.ZIP.zip or /content/sample_data/Data.zip.ZIP.ZIP.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Apply CLAHE
def apply_clahe(image):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(image)

# Custom Dataset Class for 3D Brain MRI
class Metastasis3DDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_filenames = []
        self.mask_filenames = []

        # Traverse through subdirectories and collect all .tif files
        for root, _, files in os.walk(data_dir):
            for f in files:
                if f.endswith('.tif'):
                    if '_mask' in f:  # Check if the file is a mask
                        self.mask_filenames.append(os.path.join(root, f))
                    else:  # Otherwise, it's an image
                        self.image_filenames.append(os.path.join(root, f))

        # Keep only the matched pairs
        matched_images = set()
        matched_masks = set()

        for mask in self.mask_filenames:
            mask_name = os.path.basename(mask).replace('_mask.tif', '.tif')
            image_path = os.path.join(os.path.dirname(mask), mask_name)
            if image_path in self.image_filenames:
                matched_images.add(image_path)
                matched_masks.add(mask)

        self.image_filenames = list(matched_images)
        self.mask_filenames = list(matched_masks)

        # Ensure that the number of images and masks match
        if len(self.image_filenames) != len(self.mask_filenames):
            print(f"Warning: {len(self.image_filenames)} images and {len(self.mask_filenames)} masks found.")

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = self.image_filenames[idx]
        mask_path = self.mask_filenames[idx]  # Get the corresponding mask

        # Load and preprocess the image
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # Load as grayscale
        image = apply_clahe(image)  # Apply CLAHE

        # Normalize the image
        image = image / 255.0  # Normalize pixel values

        # Load the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Load as grayscale
        mask = mask / 255.0  # Normalize mask pixel values

        # Convert to a PIL Image for transformation
        image = transforms.ToPILImage()(image)
        mask = transforms.ToPILImage()(mask)

        # Data Augmentation and Transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Define Data Augmentation and Normalization
def get_transform(phase):
    if phase == 'train':
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
        ])

# Example usage
if __name__ == "__main__":
    data_dir = "/content/drive/MyDrive/Data"  # Update this to your actual data directory
    train_dataset = Metastasis3DDataset(data_dir=data_dir, transform=get_transform('train'))

    # Create DataLoader
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    # Test DataLoader
    for images, masks in train_loader:
        print(images.shape, masks.shape)  # Verify the shape of the images and masks
        break


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Basic Convolution Block
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

# Up-sampling block for decoding path
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat((x, skip), dim=1)  # Concatenate with the skip connection
        x = self.conv(x)
        return x

# U-Net++ Architecture
class NestedUNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=1):
        super(NestedUNet, self).__init__()

        # Encoder Path
        self.conv1_0 = ConvBlock(in_channels, 64)
        self.conv2_0 = ConvBlock(64, 128)
        self.conv3_0 = ConvBlock(128, 256)
        self.conv4_0 = ConvBlock(256, 512)
        self.conv5_0 = ConvBlock(512, 1024)

        # Decoder Path (with nested blocks)
        self.up4_1 = UpBlock(1024, 512)
        self.up3_2 = UpBlock(512, 256)
        self.up2_3 = UpBlock(256, 128)
        self.up1_4 = UpBlock(128, 64)

        self.up3_1 = UpBlock(512, 256)
        self.up2_2 = UpBlock(256, 128)
        self.up1_3 = UpBlock(128, 64)

        self.up2_1 = UpBlock(256, 128)
        self.up1_2 = UpBlock(128, 64)

        self.up1_1 = UpBlock(128, 64)

        # Final Convolution: Adjust output channels according to num_classes
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1_0 = self.conv1_0(x)   # Encoder level 1
        x2_0 = self.conv2_0(F.max_pool2d(x1_0, 2))  # Encoder level 2
        x3_0 = self.conv3_0(F.max_pool2d(x2_0, 2))  # Encoder level 3
        x4_0 = self.conv4_0(F.max_pool2d(x3_0, 2))  # Encoder level 4
        x5_0 = self.conv5_0(F.max_pool2d(x4_0, 2))  # Encoder level 5

        # Decoder with Nested Skip Connections
        x4_1 = self.up4_1(x5_0, x4_0)  # First level decoding
        x3_2 = self.up3_2(x4_1, x3_0)  # Second level decoding
        x2_3 = self.up2_3(x3_2, x2_0)  # Third level decoding
        x1_4 = self.up1_4(x2_3, x1_0)  # Fourth level decoding

        x3_1 = self.up3_1(x4_0, x3_0)
        x2_2 = self.up2_2(x3_1, x2_0)
        x1_3 = self.up1_3(x2_2, x1_0)

        x2_1 = self.up2_1(x3_0, x2_0)
        x1_2 = self.up1_2(x2_1, x1_0)

        x1_1 = self.up1_1(x2_0, x1_0)

        # Output
        output = self.final_conv(x1_4)  # Apply final 1x1 convolution to get segmentation mask

        return output


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import os


# Loss Function: BCE + Dice
def bce_dice_loss(outputs, masks):
    bce_loss = nn.BCEWithLogitsLoss()(outputs, masks)
    dice_loss = 1 - dice_coeff(outputs, masks)
    return bce_loss + dice_loss

# Dice Coefficient Calculation
def dice_coeff(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred > 0.5).float()

    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))

    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice.mean()

# Training Function
def train_model(model, train_loader, optimizer, device, num_epochs=3):
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_dice = 0.0
        total_batches = len(train_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}]")

        for batch in tqdm(train_loader):
            inputs, masks = batch
            inputs, masks = inputs.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = bce_dice_loss(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_dice += dice_coeff(outputs, masks).item()

        avg_loss = epoch_loss / total_batches
        avg_dice = epoch_dice / total_batches

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Dice Score: {avg_dice:.4f}")

# Evaluation Function
def evaluate_model(model, val_loader, device):
    model.eval()
    total_dice = 0.0
    total_batches = len(val_loader)

    with torch.no_grad():
        for batch in tqdm(val_loader):
            inputs, masks = batch
            inputs, masks = inputs.to(device), masks.to(device)
            outputs = model(inputs)
            total_dice += dice_coeff(outputs, masks).item()

    avg_dice = total_dice / total_batches
    print(f"Validation Dice Score: {avg_dice:.4f}")
    return avg_dice

# Example Usage
if __name__ == "__main__":
    data_dir = "/content/drive/MyDrive/Data"  # Replace with your actual path

    # Load dataset
    full_dataset = Metastasis3DDataset(data_dir=data_dir, transform=get_transform('train'))

    # Split into training and validation sets
    train_size = int(0.8 * len(full_dataset))  # 80% for training
    val_size = len(full_dataset) - train_size  # 20% for validation
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # Model (assuming NestedUNet as your model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = NestedUNet(in_channels=1).to(device)  # Instantiate your model with required arguments

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train the model
    num_epochs = 3
    train_model(model, train_loader, optimizer, device, num_epochs)

    # Evaluate on validation set
    evaluate_model(model, val_loader, device)


Epoch [1/3]


100%|██████████| 1189/1189 [3:11:28<00:00,  9.66s/it]


Epoch [1/3], Loss: 0.4323, Dice Score: 0.6384
Epoch [2/3]


100%|██████████| 1189/1189 [2:56:03<00:00,  8.88s/it]


Epoch [2/3], Loss: 0.4149, Dice Score: 0.6384
Epoch [3/3]


 28%|██▊       | 338/1189 [50:27<1:58:57,  8.39s/it]