In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Apply CLAHE
def apply_clahe(image):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(image)

# Custom Dataset Class for 3D Brain MRI
class Metastasis3DDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_filenames = []
        self.mask_filenames = []

        # Traverse through subdirectories and collect all .tif files
        for root, _, files in os.walk(data_dir):
            for f in files:
                # Skip __MACOSX and ._ files
                if "__MACOSX" in root or f.startswith("._"):
                    continue
                if f.endswith('.tif'):
                    if '_mask' in f:  # Check if the file is a mask
                        self.mask_filenames.append(os.path.join(root, f))
                    else:  # Otherwise, it's an image
                        self.image_filenames.append(os.path.join(root, f))

        # Keep only the matched pairs
        matched_images = set()
        matched_masks = set()

        for mask in self.mask_filenames:
            mask_name = os.path.basename(mask).replace('_mask.tif', '.tif')
            image_path = os.path.join(os.path.dirname(mask), mask_name)
            if image_path in self.image_filenames:
                matched_images.add(image_path)
                matched_masks.add(mask)

        self.image_filenames = list(matched_images)
        self.mask_filenames = list(matched_masks)

        # Ensure that the number of images and masks match
        if len(self.image_filenames) != len(self.mask_filenames):
            print(f"Warning: {len(self.image_filenames)} images and {len(self.mask_filenames)} masks found.")

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = self.image_filenames[idx]
        mask_path = self.mask_filenames[idx]  # Get the corresponding mask

        # Load and preprocess the image
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            raise FileNotFoundError(f"Image not found at {img_path}")

        image = apply_clahe(image)  # Apply CLAHE
        image = image / 255.0  # Normalize pixel values

        # Load the mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Mask not found at {mask_path}")

        mask = mask / 255.0  # Normalize mask pixel values

        # Convert to a PIL Image for transformation
        image = Image.fromarray((image * 255).astype(np.uint8))  # Convert to PIL Image
        mask = Image.fromarray((mask * 255).astype(np.uint8))  # Convert to PIL Image

        # Data Augmentation and Transformations
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Define Data Augmentation and Normalization
def get_transform(phase):
    if phase == 'train':
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
        ])

# Example usage
if __name__ == "__main__":
    data_dir = "/content/drive/MyDrive/Data"  # Update this to your actual data directory
    train_dataset = Metastasis3DDataset(data_dir=data_dir, transform=get_transform('train'))

    # Create DataLoader
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    # Test DataLoader
    for images, masks in train_loader:
        print(images.shape, masks.shape)  # Verify the shape of the images and masks
        break


torch.Size([1, 1, 256, 256]) torch.Size([1, 1, 256, 256])


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention Gate
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Conv2d(F_g, F_int, kernel_size=1, padding=0)
        self.W_x = nn.Conv2d(F_l, F_int, kernel_size=1, padding=0)
        self.psi = nn.Conv2d(F_int, 1, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        attention = torch.sigmoid(psi)
        return x * attention

# Basic Convolution Block
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

# Up-sampling block for decoding path
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat((x, skip), dim=1)  # Concatenate with the skip connection
        x = self.conv(x)
        return x

# Attention U-Net Architecture
class AttentionUNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=1):
        super(AttentionUNet, self).__init__()

        # Encoder Path
        self.conv1 = ConvBlock(in_channels, 64)   # Level 1
        self.conv2 = ConvBlock(64, 128)            # Level 2
        self.conv3 = ConvBlock(128, 256)           # Level 3
        self.conv4 = ConvBlock(256, 512)           # Level 4
        self.conv5 = ConvBlock(512, 1024)          # Level 5

        # Decoder Path
        self.up4 = UpBlock(1024, 512)              # Level 4
        self.up3 = UpBlock(512, 256)                # Level 3
        self.up2 = UpBlock(256, 128)                # Level 2
        self.up1 = UpBlock(128, 64)                 # Level 1

        # Attention Gates
        self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256)
        self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128)
        self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64)
        self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32)

        # Final Convolution
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.conv1(x)   # Level 1
        x2 = self.conv2(F.max_pool2d(x1, 2))  # Level 2
        x3 = self.conv3(F.max_pool2d(x2, 2))  # Level 3
        x4 = self.conv4(F.max_pool2d(x3, 2))  # Level 4
        x5 = self.conv5(F.max_pool2d(x4, 2))  # Level 5

        # Decoder with Attention Gates
        x4_up = self.up4(x5, x4)
        x4_att = self.att4(x4_up, x4)

        x3_up = self.up3(x4_att, x3)
        x3_att = self.att3(x3_up, x3)

        x2_up = self.up2(x3_att, x2)
        x2_att = self.att2(x2_up, x2)

        x1_up = self.up1(x2_att, x1)
        x1_att = self.att1(x1_up, x1)

        # Output
        output = self.final_conv(x1_att)

        return output

# Example of creating a model
model = AttentionUNet(in_channels=1, num_classes=1)


In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize the model and move it to GPU if available
model = AttentionUNet(in_channels=1, num_classes=1).to(device)

# Define loss function (Dice Loss is commonly used in segmentation)
def dice_loss(pred, target, smooth=1):
    pred = torch.sigmoid(pred)  # Sigmoid activation for binary segmentation
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return 1 - ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth))

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# DataLoader (assuming dataset is already defined and loaded)
data_dir = "/content/drive/MyDrive/Data"  # Update this with the correct path
train_dataset = Metastasis3DDataset(data_dir=data_dir, transform=get_transform('train'))
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

# Training settings
num_epochs = 4  # Adjust as needed
accumulation_steps = 4  # Gradient accumulation steps to simulate larger batches

# Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    print(f'Starting epoch {epoch + 1}/{num_epochs}...')

    for batch_idx, (images, masks) in enumerate(train_loader):
        # Move images and masks to GPU (if available)
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)

        # Calculate Dice loss
        loss = dice_loss(outputs, masks.unsqueeze(1))

        # Normalize loss for gradient accumulation
        loss = loss / accumulation_steps

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()

        # Update weights after accumulating gradients for several steps
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()

        # Accumulate the loss for monitoring
        running_loss += loss.item()

        # Print progress every 10 batches
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

    # Average loss for the epoch
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}] finished with average loss: {avg_loss:.4f}")

print("Training completed.")


Using device: cpu
Starting epoch 1/4...
Epoch [1/4], Batch [0/901], Loss: 0.2372
Epoch [1/4], Batch [10/901], Loss: 0.2500
Epoch [1/4], Batch [20/901], Loss: 0.2500
Epoch [1/4], Batch [30/901], Loss: 0.2390
Epoch [1/4], Batch [40/901], Loss: 0.2500
Epoch [1/4], Batch [50/901], Loss: 0.2256
Epoch [1/4], Batch [60/901], Loss: 0.2380
Epoch [1/4], Batch [70/901], Loss: 0.2101
Epoch [1/4], Batch [80/901], Loss: 0.2439
Epoch [1/4], Batch [90/901], Loss: 0.2016
Epoch [1/4], Batch [100/901], Loss: 0.2500
Epoch [1/4], Batch [110/901], Loss: 0.2500
Epoch [1/4], Batch [120/901], Loss: 0.2500
Epoch [1/4], Batch [130/901], Loss: 0.2500
Epoch [1/4], Batch [140/901], Loss: 0.2367
Epoch [1/4], Batch [150/901], Loss: 0.2500
Epoch [1/4], Batch [160/901], Loss: 0.2071
Epoch [1/4], Batch [170/901], Loss: 0.2500
Epoch [1/4], Batch [180/901], Loss: 0.2500
Epoch [1/4], Batch [190/901], Loss: 0.2500
Epoch [1/4], Batch [200/901], Loss: 0.2115
Epoch [1/4], Batch [210/901], Loss: 0.2446
Epoch [1/4], Batch [220/9

In [1]:
def validate_model(model, val_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    total_dice_score = 0
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)

            # Resize the mask to match the output
            masks = masks.squeeze(1)  # Remove the extra channel dimension

            # Calculate loss
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            # Calculate DICE score
            outputs = torch.sigmoid(outputs)  # Apply sigmoid to get probabilities
            preds = (outputs > 0.5).float()  # Binarize predictions

            # Dice score calculation
            intersection = (preds * masks).sum((1, 2, 3))
            union = preds.sum((1, 2, 3)) + masks.sum((1, 2, 3))
            dice_score = (2 * intersection + 1e-8) / (union + 1e-8)
            total_dice_score += dice_score.mean().item()

            num_batches += 1

    avg_dice_score = total_dice_score / num_batches
    avg_loss = total_loss / num_batches
    print(avg_dice_score, avg_loss)
    return avg_dice_score, avg_loss

