**1.** **Setup - Get Training Data**


*   Mount Drive w/ MRI Dataset
*   Perform Data Augmentation
*   Donwsample Images to create HR-LR Pairs

**NOTE:**
No need to create HR-LR pairs and augment data more than once





In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# check GPU status
! nvidia-smi

Mon Dec 11 13:25:12 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import os
import nibabel as nib
from PIL import Image, ImageEnhance
import numpy as np

# Define your directory paths
hr_training_dir = '/content/drive/MyDrive/BrainTumorDataset/BraTS18_T1CE_training'
hr_validation_dir = '/content/drive/MyDrive/BrainTumorDataset/BraTS19_T1CE_validation'
hr_testing_dir = '/content/drive/MyDrive/BrainTumorDataset/BraTS20_T1CE_testing'

hr_training_2Dslices = '/content/drive/MyDrive/BrainTumorDataset/hr_train_2D'
hr_validation_2Dslices = '/content/drive/MyDrive/BrainTumorDataset/hr_valid_2D'
hr_testing_2Dslices = '/content/drive/MyDrive/BrainTumorDataset/hr_test_2D'

lr_training_2Dslices = '/content/drive/MyDrive/BrainTumorDataset/lr_train_2D'
lr_validation_2Dslices = '/content/drive/MyDrive/BrainTumorDataset/lr_valid_2D'
lr_testing_2Dslices = '/content/drive/MyDrive/BrainTumorDataset/lr_test_2D'


In [None]:
def extract_middle_slice(nii_path, output_dir):
    # Load the NIfTI file
    nii_image = nib.load(nii_path)
    data = nii_image.get_fdata()

    # Calculate the middle slice index along the third axis (assuming axial slices)
    middle_index = data.shape[2] // 2

    # Extract the middle slice
    middle_slice = data[:, :, middle_index]

    # Normalize the slice for image representation
    normalized_slice = ((middle_slice - np.min(middle_slice)) / (np.max(middle_slice) - np.min(middle_slice))) * 255.0
    slice_image = Image.fromarray(normalized_slice.astype(np.uint8))

    # Save the slice image
    filename = os.path.basename(nii_path).replace('.nii', '.jpg')
    slice_image.save(os.path.join(output_dir, filename))

def process_nii_files(nii_dir, output_2d_dir):
    for nii_file in os.listdir(nii_dir):
        if nii_file.endswith('.nii') or nii_file.endswith('.nii.gz'):
            nii_path = os.path.join(nii_dir, nii_file)
            extract_middle_slice(nii_path, output_2d_dir)

In [None]:
# Call the function for each dataset
process_nii_files(hr_training_dir, hr_training_2Dslices)
process_nii_files(hr_validation_dir, hr_validation_2Dslices)
process_nii_files(hr_testing_dir, hr_testing_2Dslices)

FileNotFoundError: ignored

In [None]:
import cv2
import random

# Define Augmentation and Downsampling Functions

def random_rotation(image):
    # Randomly choose an angle for rotation
    angles = [90, 180, 270]
    angle = random.choice(angles)
    return image.rotate(angle)

def random_flip(image):
    # Randomly choose axis for flipping
    flips = [Image.FLIP_LEFT_RIGHT, Image.FLIP_TOP_BOTTOM]
    mode = random.choice(flips)
    return image.transpose(mode)

def random_scale(image, min_scale=0.9, max_scale=1.1):
    # Randomly choose a scale factor
    scale_factor = random.uniform(min_scale, max_scale)
    width, height = image.size
    scaled_width = int(width * scale_factor)
    scaled_height = int(height * scale_factor)
    return image.resize((scaled_width, scaled_height), Image.BICUBIC)

def random_brightness_contrast(image):
    # Randomly adjust brightness and contrast
    enhancer = ImageEnhance.Brightness(image)
    image = enhancer.enhance(random.uniform(0.8, 1.2))  # Adjust brightness
    enhancer = ImageEnhance.Contrast(image)
    return enhancer.enhance(random.uniform(0.8, 1.2))  # Adjust contrast

def downsample_image(image, scale_factor):
    # Calculate new dimensions based on scale factor
    width, height = image.size
    new_width = int(width / scale_factor)
    new_height = int(height / scale_factor)

    # Resize down and up using bicubic interpolation
    image_down = image.resize((new_width, new_height), Image.BICUBIC)
    image_up = image_down.resize((width, height), Image.BICUBIC)

    return np.array(image_up)

def process_images(hr_input_dir, lr_output_dir, scale_factor):
    for img_name in os.listdir(hr_input_dir):
        if img_name.endswith('.jpg'):  # Assuming the slices are saved as '.png'
            img_path = os.path.join(hr_input_dir, img_name)
            img = Image.open(img_path)

            # Convert image to RGB if it's not already in that mode
            if img.mode != 'RGB':
                img = img.convert('RGB')

            # Apply a series of augmentations
            augmented_images = [
                random_rotation(img),
                random_flip(img),
                random_scale(img),
                random_brightness_contrast(img)
            ]

            # Save the original and augmented images in LR form
            for idx, aug_img in enumerate(augmented_images):
                img_lr = downsample_image(aug_img, scale_factor)
                aug_img_name = f"LR_aug{idx}_{img_name}"
                img_lr = Image.fromarray(img_lr)
                img_lr.save(os.path.join(lr_output_dir, aug_img_name))


In [None]:
directories = [
    (hr_training_2Dslices, lr_training_2Dslices),
    (hr_validation_2Dslices, lr_validation_2Dslices),
    (hr_testing_2Dslices, lr_testing_2Dslices)
]

for hr_dir, lr_dir in directories:
    if not os.listdir(lr_dir):  # Checks if the directory is empty
        process_images(hr_dir, lr_dir, scale_factor=2)
    else:
        print(f"2D Slices already downsampled and saved in {lr_dir}!")

2D Slices already downsampled and saved in /content/drive/MyDrive/BrainTumorDataset/lr_train_2D!
2D Slices already downsampled and saved in /content/drive/MyDrive/BrainTumorDataset/lr_valid_2D!
2D Slices already downsampled and saved in /content/drive/MyDrive/BrainTumorDataset/lr_test_2D!


In [None]:
#im = cv2.imread('/content/drive/MyDrive/BrainTumorDataset/lr_train_2D/LR_aug0_BraTS2018_HGG_Brats18_2013_11_1_Brats18_2013_11_1_t1ce.jpg')
#print(im.shape)

(240, 240, 3)


**2.** **Prepare Data Loader**


*   Create Dataset Class to load and preprocess images

  * Resizing & Normalization

  * Return Images as PyTorch Tensors




In [None]:
import torch
from torchvision import transforms as T
from torch.utils.data import Dataset

class MRIImageDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform=None, num_augments=4):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.transform = transform
        self.num_augments = num_augments
        self.hr_images = [img for img in os.listdir(hr_dir) if img.endswith('.jpg') or img.endswith('.png')]
        self.image_pairs = [(hr_img, [f'LR_aug{idx}_{hr_img}' for idx in range(num_augments)]) for hr_img in self.hr_images]

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_img_name, lr_img_names = self.image_pairs[idx]
        selected_lr_img_name = random.choice(lr_img_names)  # Randomly select one LR image

        hr_img_path = os.path.join(self.hr_dir, hr_img_name)
        lr_img_path = os.path.join(self.lr_dir, selected_lr_img_name)

        hr_image = Image.open(hr_img_path).convert('RGB')
        lr_image = Image.open(lr_img_path).convert('RGB')

         # Transformations
        target_size = (240, 240)

        # Resize if larger
        if hr_image.size[0] > target_size[0] or hr_image.size[1] > target_size[1]:
            hr_image = T.functional.center_crop(hr_image, target_size)
            lr_image = T.functional.center_crop(lr_image, target_size)

        # Pad if smaller
        if hr_image.size[0] < target_size[0] or hr_image.size[1] < target_size[1]:
            padding = [0, 0, target_size[0] - hr_image.size[0], target_size[1] - hr_image.size[1]]  # left, top, right, bottom
            hr_image = T.functional.pad(hr_image, padding)
            lr_image = T.functional.pad(lr_image, padding)

        # Ensure final size is consistent
        hr_image = T.functional.resize(hr_image, target_size)
        lr_image = T.functional.resize(lr_image, target_size)

        if self.transform:
            hr_image = self.transform(hr_image)
            lr_image = self.transform(lr_image)

        return {'hr': hr_image, 'lr': lr_image}

# Define transformation for your dataset
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard normalization
])

# Datasets
train_dataset = MRIImageDataset(hr_dir=hr_training_2Dslices, lr_dir=lr_training_2Dslices, transform=transform)
validate_dataset = MRIImageDataset(hr_dir=hr_validation_2Dslices, lr_dir=lr_validation_2Dslices, transform=transform)
test_dataset = MRIImageDataset(hr_dir=hr_testing_2Dslices, lr_dir=lr_testing_2Dslices, transform=transform)

**3.** **Build Network Architecture**


*   Number of layers

*   Types of layers
  * Convolutional, Pooling, Upsampling, etc. and activation functions.



In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(DenseLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        out = torch.cat([x, out], 1)
        return out

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(DenseLayer(in_channels + i * growth_rate, growth_rate))
        self.dense_layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.dense_layers(x)

class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.pool = nn.AvgPool2d(2)

    def forward(self, x):
        out = self.conv(x)
        out = self.pool(out)
        return out

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out += residual
        return out

class SuperResolutionCNN(nn.Module):
    def __init__(self):
        super(SuperResolutionCNN, self).__init__()
        growth_rate = 32
        num_dense_layers = 4

        # Initial convolution layer
        self.initial_conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)

        # Dense Blocks and Transition Layers
        self.denseblock1 = DenseBlock(64, growth_rate, num_dense_layers)
        self.transition1 = TransitionLayer(64 + growth_rate * num_dense_layers, 64)
        self.denseblock2 = DenseBlock(64, growth_rate, num_dense_layers)
        self.transition2 = TransitionLayer(64 + growth_rate * num_dense_layers, 64)

        # Residual Blocks
        self.resblock1 = ResidualBlock(64)
        self.resblock2 = ResidualBlock(64)

        # Upsampling Layer
        self.upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)

        # Output convolution
        self.output_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.initial_conv(x)
        out = self.denseblock1(out)
        out = self.transition1(out)
        out = self.denseblock2(out)
        out = self.transition2(out)
        out = self.resblock1(out)
        out = self.resblock2(out)
        out = self.upsample(out)
        out = self.output_conv(out)
        return out

# Initialize the model

# Run on GPU
device = 'cuda'

model = SuperResolutionCNN().to(device)
print(model)


SuperResolutionCNN(
  (initial_conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (denseblock1): DenseBlock(
    (dense_layers): Sequential(
      (0): DenseLayer(
        (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(inplace=True)
      )
      (1): DenseLayer(
        (conv): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(inplace=True)
      )
      (2): DenseLayer(
        (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(inplace=True)
      )
      (3): DenseLayer(
        (conv): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(inplace=True)
      )
    )
  )
  (transition1): TransitionLayer(
    (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (denseblock2): DenseBlock(
    (dense_layers): Sequential(
      (0): D

In [None]:
pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1


In [None]:
pip install lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m41.0/53.8 kB[0m [31m959.4 kB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m997.3 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: lpips
Successfully installed lpips-0.1.4


In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_tensor
from torchvision.utils import save_image
import torchmetrics
import lpips
lpips_fn = lpips.LPIPS(net='alex').to(device)  # Using AlexNet

# Initialize Data Loaders and Model Parameters
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_loader = DataLoader(validate_dataset, batch_size=16, shuffle=True)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def lr_psnr_metric(lr_images, hr_images):
    lr_psnr_metric = torchmetrics.PeakSignalNoiseRatio().to(device)
    lr_psnr_metric.update(lr_images, hr_images)
    return lr_psnr_metric.compute()

def calculate_pixel_std_dev(images):
    images = images.to(device)  # Ensure images are on the GPU
    std_devs = [torch.std(image).item() for image in images]
    return sum(std_devs) / len(std_devs)


def lpips_metric(output, target):
    output = output.to(device)
    target = target.to(device)
    return lpips_fn(output, target).mean()


# Function to compute evaluation metrics
def compute_metrics(model, loader):
    model.eval()
    psnr_metric = torchmetrics.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.StructuralSimilarityIndexMeasure().to(device)
    lpips_total, lr_psnr_total, pixel_std_dev_total = 0.0, 0.0, 0.0  # Initialize all accumulators

    # Add more metrics as needed

    with torch.no_grad():
        for batch in loader:
            hr_images = batch['hr'].to(device)
            lr_images = batch['lr'].to(device)
            outputs = model(lr_images)

            psnr_metric.update(outputs, hr_images)
            ssim_metric.update(outputs, hr_images)

            # LPIPS calculation
            lpips_value = lpips_fn(outputs, hr_images).mean()
            lpips_total += lpips_value

            # LR_PSNR calculation
            lr_psnr_value = lr_psnr_metric(lr_images, hr_images)
            lr_psnr_total += lr_psnr_value

            # Pixel Standard Deviation calculation
            pixel_std_dev_value = calculate_pixel_std_dev(outputs)
            pixel_std_dev_total += pixel_std_dev_value


    num_batches = len(loader)

    return psnr_metric.compute(), ssim_metric.compute(), lpips_total / num_batches, lr_psnr_total / num_batches, pixel_std_dev_total / num_batches

# Function to validate model
def validate_model(model, loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            hr_images = batch['hr'].to(device)
            lr_images = batch['lr'].to(device)
            outputs = model(lr_images)
            loss = criterion(outputs, hr_images)
            total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    return avg_loss

def save_hr_images(images, directory, epoch, batch_idx):
    os.makedirs(directory, exist_ok=True)  # Create directory if it doesn't exist
    for i, image in enumerate(images):
        save_path = os.path.join(directory, f"epoch{epoch}_batch{batch_idx}_img{i}.jpg")
        save_image(image, save_path)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:05<00:00, 41.2MB/s]


Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


In [None]:
#Training Loop
best_loss = float('inf')
patience = 5
trigger_times = 0
num_epochs = 30
save_interval = 10
generated_hr_dir = '/content/drive/MyDrive/BrainTumorDataset/generated_hr'

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch_idx, batch in enumerate(train_loader):
        hr_images = batch['hr'].to(device)
        lr_images = batch['lr'].to(device)

        outputs = model(lr_images)

        # Save HR images conditionally
        if epoch % save_interval == 0:  # 'save_interval' can be defined as per your requirement
            save_hr_images(outputs.cpu(), generated_hr_dir, epoch, batch_idx)

        loss = criterion(outputs, hr_images)
        total_loss += loss.item()
        num_batches += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_training_loss = total_loss / num_batches
    val_loss = validate_model(model, validation_loader)
    psnr, ssim, lpips_value, lr_psnr, pixel_std_dev = compute_metrics(model, validation_loader)

    print(f'Epoch [{epoch+1}/{num_epochs}], Avg Training Loss: {avg_training_loss}, Validation Loss: {val_loss}, PSNR: {psnr}, SSIM: {ssim},lpips valu: {lpips_value}, LRPSNR: {lr_psnr} PXL STD: {pixel_std_dev}')

    # # Early stopping logic
    # if val_loss < best_loss:
    #     best_loss = val_loss
    #     trigger_times = 0
    # else:
    #     trigger_times += 1
    #     if trigger_times >= patience:
    #         print("Early stopping!")
    #         break





Epoch [1/30], Avg Training Loss: 1.371220207048787, Validation Loss: 0.4206377693584987, PSNR: 17.302001953125, SSIM: 0.7063705325126648,lpips valu: 0.5252938866615295, LRPSNR: 22.128170013427734 PXL STD: 0.4652189782624241
Epoch [2/30], Avg Training Loss: 0.25427332603269154, Validation Loss: 0.17722030409744807, PSNR: 21.24116325378418, SSIM: 0.7500584125518799,lpips valu: 0.41490232944488525, LRPSNR: 22.469463348388672 PXL STD: 0.6245036619480978
Epoch [3/30], Avg Training Loss: 0.1672774205605189, Validation Loss: 0.1500229211080642, PSNR: 21.74248504638672, SSIM: 0.7599294781684875,lpips valu: 0.40570881962776184, LRPSNR: 22.426902770996094 PXL STD: 0.6986367851875874
Epoch [4/30], Avg Training Loss: 0.16741596617632443, Validation Loss: 0.14681137956324078, PSNR: 21.62818145751953, SSIM: 0.7564318776130676,lpips valu: 0.3929511606693268, LRPSNR: 22.037668228149414 PXL STD: 0.6445175437066628
Epoch [5/30], Avg Training Loss: 0.16260434314608574, Validation Loss: 0.1492322615924335