In [None]:
# Import necessary libraries
import os
from glob import glob
import pandas as pd
import cv2
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg16_bn
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Function to create dataframe
def create_dataframe(data_directory):
    images_paths = []
    masks_paths = glob(f'{data_directory}/*/*_mask*')

    for mask_path in masks_paths:
        images_paths.append(mask_path.replace('_mask', ''))

    dataframe = pd.DataFrame(data={'images_paths': images_paths, 'masks_paths': masks_paths})

    return dataframe

# Function to split dataframe into train, valid, test
def split_dataframe(dataframe):
    # Create train_dataframe
    train_dataframe, dummy_dataframe = train_test_split(dataframe, train_size=0.8)

    # Create valid_dataframe and test_dataframe
    valid_dataframe, test_dataframe = train_test_split(dummy_dataframe, train_size=0.5)

    return train_dataframe, valid_dataframe, test_dataframe

data_path = "/content/lgg-mri-segmentation/kaggle_3m"
df = create_dataframe(data_path)
train_df, val_df, test_df = split_dataframe(df)

# Display sample from the training data
print('Train\t', train_df.shape, '\nVal\t', val_df.shape, '\nTest\t', test_df.shape)

# Load and display sample image and mask
image = cv2.imread(train_df.iloc[0, 0]) / 255.0
mask = cv2.imread(train_df.iloc[0, 1]) / 255.0
mask = np.where(mask >= 0.5, 1., 0.)

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Image' + str(image.shape))
plt.subplot(1, 2, 2)
plt.imshow(mask)
plt.title('Mask' + str(mask.shape))
plt.show()

# Define the BrainMRI dataset class
class BrainMRIDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None, mask_transform=None):
        self.df = dataframe
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, 0]) / 255.0
        mask = cv2.imread(self.df.iloc[idx, 1]) / 255.0
        mask = np.where(mask >= 0.5, 1., 0.)

        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256))
])

# Create datasets
train_data = BrainMRIDataset(train_df, transform=transform, mask_transform=transform)
val_data = BrainMRIDataset(val_df, transform=transform, mask_transform=transform)
test_data = BrainMRIDataset(test_df, transform=transform, mask_transform=transform)

# Set batch size
batch_size = 64

# Create dataloaders
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# Display batch information
print('Training batches\t', len(train_dataloader))
print('Validation batches\t', len(val_dataloader))
print('Test batches\t\t', len(test_dataloader))

# Count the number of training and validation images
num_train_images = len(train_df)
num_val_images = len(val_df)

print(f'Number of Training Images: {num_train_images}')
print(f'Number of Validation Images: {num_val_images}')

# Evaluate the model on the test set and calculate accuracy
model.eval()
total_test_iou = 0.0

with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device).float(), labels.to(device).float()
        predictions = model(images)
        iou = IOU(labels, predictions)
        total_test_iou += iou.item()

    test_accuracy = total_test_iou / len(test_dataloader)
    print(f'Test Accuracy: {test_accuracy:.4f}')

# Display sample from dataloader
img_sample, msk_sample = next(iter(val_dataloader))
print(img_sample.shape, '\t', img_sample.dtype)
print(msk_sample.shape, '\t', msk_sample.dtype)

fig, axs = plt.subplots(2, 6, figsize=(20, 10))
for i in range(6):
    axs[0, i].imshow(img_sample[i].permute(1, 2, 0))
    axs[0, i].set_title("Image")

    axs[1, i].imshow(msk_sample[i].permute(1, 2, 0))
    axs[1, i].set_title("Mask")
fig.suptitle('Data Sample')
fig.tight_layout()
fig.show()
