# **2D Binary Classification Models for Alzheimer using middle section slicing**

# (1) Import the necessary packages

In [None]:
import os
import glob
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from sklearn.metrics import confusion_matrix, accuracy_score
import time
from scipy.ndimage import zoom
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.init as init
from torch import optim

import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
from torchvision import transforms
from glob import glob
import torch
from skimage.measure import label, regionprops
from scipy.ndimage.morphology import binary_fill_holes
import scipy
from sklearn.metrics import roc_curve, auc
import nibabel as nib
import warnings
warnings.filterwarnings('ignore')

In [None]:
#Check that CPU is available 
cuda = torch.cuda.is_available()
print("GPU available:", cuda)

In [None]:
#Set random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# (2) Data Exploration

### Load sample

In [None]:
# Load the NIfTI file
test_path = '/kaggle/input/preprocessed-data/ADNI_large_sample/validation/T1_affine/AD/ADNI_005_S_0929_MPR_m06_g_AD_dx_3_age_82.7_scan0_mri_brainmask_mni152brain_affine_tl.nii'
test_img = nib.load(test_path)

# Extract the data
data = test_img.get_fdata()

In [None]:
# Calculate the start and end indices for the middle two-thirds of the second dimension (height)
height = data.shape[1]
start_slice = height // 3
end_slice = height * 2 // 3

# Determine how many consecutive 3-slice sets can be made
# Here, we subtract 2 to ensure that we have enough slices to form the last 3-channel set
number_of_2d_images = (end_slice - start_slice - 2) // 3
number_of_2d_images = number_of_2d_images*(386+405)

print(f"Number of 2D images with 3 channels: {number_of_2d_images}")

### Visualize 3D and print dimensions

In [None]:
# Get dimensions
height, width, length = data.shape
print("Height:", height)
print("Width:", width)
print("Length:", length)

# Visualize the MRI data
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x, y, z = data.nonzero()
ax.scatter(x, y, z, zdir='z', c='red', marker='o', s=1, alpha=0.1)
plt.title('3D Visualization of MRI Data')
plt.show()

### Visualize 2D three axis and print dimensions

In [None]:
data.max(), data.min()
f, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(10,5))
# Coronal
ax1.imshow(data[:,109,:], cmap='gray')
# Axial
ax2.imshow(data[:,:,91], cmap='gray')
# Sagittal
ax3.imshow(data[91,:,:], cmap='gray')

# (3) Upload data and create data loader

### Define MRIDataset class

In [None]:
class MRIDataset(Dataset):
    def __init__(self, DataDir, mode, input_T1, transform=None, T1_normalization_method='max', downsample=2.5):
        self.DataDir = DataDir
        self.input_T1 = input_T1
        if input_T1:
            self.T1_img_files = sorted(glob(DataDir + 'T1_affine/*/*.nii'))
            print(f'T1 path: {DataDir}T1_affine')
            print(f'Load T1. Total T1 {mode} number is: ' + str(len(self.T1_img_files)))
        self.transform = transform
        self.T1_normalization_method = T1_normalization_method
        self.downsample = downsample

    def __len__(self):
        return len(self.T1_img_files)

    def resample_img(self, arr, val=2):
        return zoom(arr, (1/val, 1/val, 1/val))
    
    def __getitem__(self, idx):
        path = self.T1_img_files[idx]
        label = path.split('/')[-2]
        label = 1 if label == 'AD' else 0
        mri_data = nib.load(path).get_fdata().astype(np.float32)

        if self.T1_normalization_method == 'max' and mri_data.max() > 0:
            mri_data /= mri_data.max()
        
        if self.downsample:
            mri_data = self.resample_img(mri_data, self.downsample)

        # Select slices from the middle two-thirds of the second dimension
        start_slice = mri_data.shape[1] // 3
        end_slice = mri_data.shape[1] * 2 // 3
        relevant_slices = mri_data[:, start_slice:end_slice, :]
        
        # Create 3-channel 2D images from consecutive slices
        images = []
        for i in range(0, relevant_slices.shape[1] - 2, 3):
            img_stack = np.stack([
                relevant_slices[:, i, :],
                relevant_slices[:, i+1, :],
                relevant_slices[:, i+2, :]
            ], axis=-1)
            images.append(img_stack)
        
        # Convert to tensors and apply any additional transformations
        #images = [torch.tensor(img, dtype=torch.float32) for img in images]
        samples = [{'T1': img, 'label': label} for img in images]

        if self.transform:
            samples = [self.transform(sample) for sample in samples]

        return samples

### Transform samples to Tensor objects

In [None]:
class ToTensor(object):
    def __call__(self, sample):
        torch_sample = {}
        for key, value in sample.items():
        	if key == 'label':
        		torch_sample[key] = torch.from_numpy(np.array(value))
        	else:
        		torch_sample[key] = torch.from_numpy(value)

        return torch_sample

### Create Datasets

In [None]:
TrainDataDir = '../input/preprocessed-data/ADNI_large_sample/train/'
ValidationDataDir = '../input/preprocessed-data/ADNI_large_sample/validation/'

train_dataset = MRIDataset(
    DataDir=TrainDataDir, 
    mode='train', 
    input_T1=True,
    transform=transforms.Compose([ToTensor()]), 
    T1_normalization_method='max', 
    downsample=1.5
)
validation_dataset = MRIDataset(
    DataDir=ValidationDataDir, 
    mode='train', 
    input_T1=True,
    transform=transforms.Compose([ToTensor()]), 
    T1_normalization_method='max', 
    downsample=1.5
)

### Visualize 3 channels of a test 2D Image

In [None]:
test = MRIDataset(
    DataDir=TrainDataDir, 
    mode='train', 
    input_T1=True,
    T1_normalization_method='max', 
    downsample=1.5
)
samples = test[0]
    
# Choose the first 3-channel 2D image from the returned samples
sample = samples[0]['T1']  # Assuming you want to visualize the first image from the batch

# Extract the three channels
channel1 = sample[:, :, 0]
channel2 = sample[:, :, 1]
channel3 = sample[:, :, 2]

# Plot each channel
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(channel1, cmap='gray')
axes[0].set_title('Channel 1')
axes[0].axis('off')

axes[1].imshow(channel2, cmap='gray')
axes[1].set_title('Channel 2')
axes[1].axis('off')

axes[2].imshow(channel3, cmap='gray')
axes[2].set_title('Channel 3')
axes[2].axis('off')

plt.show()

### Store Datasets (to speed up training)## Create dataset and save it (to speed up training)

In [None]:
def preprocess_and_save_dataset(dataset, save_path):
    """
    Preprocess a given dataset and save the processed data.
    """
    os.makedirs(save_path, exist_ok=True)  # Ensure the save directory exists
    
    preprocessed_samples = []
    for idx, sample in enumerate(dataset):
        # Assuming your sample is a dictionary with 'T1' and 'label' keys
        preprocessed_samples.append(sample)
        
        # Optionally, save individual samples to reduce memory usage
        #torch.save(sample, os.path.join(save_path, f'sample_{idx}.pt'))

    # Optionally, save the whole dataset as one file (if memory allows)
    torch.save(preprocessed_samples, os.path.join(save_path, 'dataset.pt'))
    
def load_preprocessed_data(preprocessed_path):
    """
    Load preprocessed data from the given path.
    """
    return torch.load(preprocessed_path)

In [None]:
preprocess_and_save_dataset(train_dataset, '../preprocessed_data/ADNI_large_sample/train_preprocessed/dataset.pt')
#preprocess_and_save_dataset(validation_dataset, '/kaggle/working/val/')

### Load datasets and create DataLoaders

In [None]:
batch_size = 1

preprocessed_train_data = load_preprocessed_data('../preprocessed_data/ADNI_large_sample/train_preprocessed/dataset.pt/dataset.pt')
preprocessed_validation_data = load_preprocessed_data('/kaggle/input/tensor-data/dataset.pt')

train_dataloader = DataLoader(preprocessed_train_data, batch_size=batch_size,
                              shuffle=True, num_workers=4)
val_dataloader = DataLoader(preprocessed_validation_data, batch_size=batch_size,
                            shuffle=False, num_workers=4)

# (4) Define model

### VGG19 with Batch Normalization

In [None]:
# import torchvision.models as models

# # Define VGG19 model pre-trained on ImageNet
# vgg19 = models.vgg19(weights="IMAGENET1K_V1") 

# # Determine which blocks to freeze - for example, freeze the first two blocks
# blocks_to_freeze = 2
# current_block = 0
# for module in vgg19.features.children():
#     if isinstance(module, nn.MaxPool2d):
#         current_block += 1
#     if current_block < blocks_to_freeze:
#         for param in module.parameters():
#             param.requires_grad = False

# # Modify the last fully connected layer to match the number of classes in your dataset
# num_classes = 2  # Example classes: AD (Alzheimer's Disease) and CN (Cognitively Normal)
# vgg19.classifier[6] = nn.Linear(4096, num_classes)

### ResNet18

In [None]:
# import torchvision.models as models
# import torch.nn as nn

# # Define ResNet18 model pre-trained on ImageNet
# resnet18 = models.resnet18(weights="IMAGENET1K_V1")

# # Freeze the first two blocks of the ResNet50 model
# blocks_to_freeze = 2
# current_block = 0
# for name, module in resnet18.named_children():
#     if "layer" in name:
#         current_block += 1
#         if current_block <= blocks_to_freeze:
#             for param in module.parameters():
#                 param.requires_grad = False

# # Modify the last fully connected layer to match the number of classes in your dataset
# num_classes = 2  # Example classes: AD (Alzheimer's Disease) and CN (Cognitively Normal)
# resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)

### Swin Transformer

In [None]:
import torch
import torch.nn as nn
from torchvision.models import swin_t

# Load pre-trained Swin Transformer
swin = swin_t(weights="IMAGENET1K_V1")

# Freeze layers - you will need to adapt this part based on the output of the print statement
# for example, it might be something like swin.patch_embed or layers within swin.layers
for name, param in swin.named_parameters():
    # Assuming the blocks to be frozen are in early layers, freeze if condition met
    if 'layer1' in name or 'layer2' in name:
        param.requires_grad = False

# Modify the classifier to match the number of classes
num_classes = 2  # Example classes: AD (Alzheimer's Disease) and CN (Cognitively Normal)
swin.head = nn.Linear(swin.head.in_features, num_classes)

# (5) Define helper functions and paramters for training and validation

### Helper Functions

In [None]:
#The performance on validation set has not improved for a while. Early stop. Training completed.
class EarlyStopChecker:
    '''
    Early stop checker. Credits to Chen "Raphael" Liu and Nanyan "Rosalie" Zhu.
    '''
    def __init__(self, search_window = 5, score_higher_better = True):
        self.search_window = search_window
        self.score_higher_better = score_higher_better
        self.score_history = []

    def __call__(self, current_score):
        self.score_history.append(current_score)
        if len(self.score_history) < 2 * self.search_window:
            return False
        else:
            if self.score_higher_better == True:
                if current_score < np.max(self.score_history) and np.mean(self.score_history[-self.search_window:]) < np.mean(self.score_history[-2*self.search_window:-self.search_window]):
                    return True
                else:
                    return False
            else:
                if current_score > np.max(self.score_history) and np.mean(self.score_history[-self.search_window:]) > np.mean(self.score_history[-2*self.search_window:-self.search_window]):
                    return True
                else:
                    return False
                
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

### Define loss function, early stop checker, and scheduler

In [None]:
#Defining the required variables
cuda_idx = '0'
momentum = 0.9
weight_decay = 5e-4
lr = 1e-5
device = torch.device(f"cuda:{cuda_idx}" if (torch.cuda.is_available()) else "cpu")

# Define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

# Initialize the checker for early stop
early_stop_checker = EarlyStopChecker()

optimizer = torch.optim.SGD(swin.parameters(), lr,
                            momentum=momentum,
                            weight_decay=weight_decay)
#Defining adaptive learning rate
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.8, patience = 2, verbose = 2)

### Move model to GPU

In [None]:
# vgg19.to(device)
# resnet18.to(device)
swin.to(device)

### Define train and validation functions

In [None]:
data_dropout = True
data_dropout_remaining_size = 2000
input_T1 = True

def train(train_loader, model, criterion, optimizer, epoch, device):
    """
    Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()
    model_output_list, AD_ground_truth_list = [], []

    end = time.time()

    total_samples = sum(len(batch) for batch in train_loader)
    print('Total samples in the training loader:', total_samples)

    # If we use data dropout, sample a subset of indices from the total samples
    if data_dropout:
        sampled_indices = set(np.random.choice(range(total_samples), data_dropout_remaining_size, replace=False))
        print(f'We will randomly sample {data_dropout_remaining_size} scans for this epoch from {total_samples} total samples.')

    current_index = 0
    for i, batch in enumerate(train_loader):
        # Process each sample within the batch
        for data in batch:
            if data_dropout and current_index in sampled_indices:
                input_data_T1 = data['T1'].to(device) if input_T1 else None
                input_data_T1 = input_data_T1.permute(0, 3, 1, 2)
                target = data['label'].to(device)

                data_time.update(time.time() - end)

                # compute output
                output = model(input_data_T1)

                loss = criterion(output, target)

                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update metrics
                output = output.float()
                loss = loss.float()
                acc1 = accuracy(output.data, target)[0]
                losses.update(loss.item(), input_data_T1.size(0))
                top1.update(acc1.item(), input_data_T1.size(0))

                
                model_output_list.append(output.data.cpu().detach().numpy().tolist())
                AD_ground_truth_list.append(target.cpu().detach().numpy().tolist())
                AD_prediction_list = [scipy.special.softmax(pred)[0][1] for pred in model_output_list]# measure sensitivity, specificity, AUC.

            
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
            current_index += 1

    
    fpr, tpr, _ = roc_curve(AD_ground_truth_list, AD_prediction_list)
    operating_point_index = np.argmax(1 - fpr + tpr)
    sensitivity, specificity = tpr[operating_point_index], 1 - fpr[operating_point_index]
    AUC = auc(fpr, tpr)            
    print('\nEpoch: [{0}][{1}/{2}]\t'
          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
          'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
          'Accuracy@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
              epoch, i, len(train_loader), batch_time=batch_time,
              data_time=data_time, loss=losses, top1=top1))
    return [{'epoch': epoch,'loss': losses.avg, 'accuracy': top1.avg,'sensitivity': sensitivity,'specificity': specificity,'AUC': AUC}]

In [None]:
def validate(val_loader, model, criterion, device, scheduler=None):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    val_display_counter = 0
    model_output_list, AD_ground_truth_list = [], []

    # switch to evaluate mode
    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            
            # Process each sub-batch
            for data in batch:
                input_data_T1 = data['T1']
                if input_T1:
                    input_data_T1 = input_data_T1.to(device)
                    # Ensure the input tensor is correctly shaped [batch_size, channels, height, width]
                    input_data_T1 = input_data_T1.permute(0, 3, 1, 2)

                target = data['label'].to(device)

                # measure data loading time
                batch_time.update(time.time() - end)

                # compute output
                output = model(input_data_T1)
                loss = criterion(output, target)

                output = output.float()
                loss = loss.float()

                # measure accuracy and record loss
                acc1 = accuracy(output.data, target)[0]
                losses.update(loss.item(), input_data_T1.size(0))
                top1.update(acc1.item(), input_data_T1.size(0))

                # measure sensitivity, specificity, AUC.
                model_output_list.append(output.data.cpu().detach().numpy().tolist())
                AD_ground_truth_list.append(target.cpu().detach().numpy().tolist())
                AD_prediction_list = [scipy.special.softmax(pred)[0][1] for pred in model_output_list]

                
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                

    
    fpr, tpr, _ = roc_curve(AD_ground_truth_list, AD_prediction_list)
    operating_point_index = np.argmax(1 - fpr + tpr)
    sensitivity, specificity = tpr[operating_point_index], 1 - fpr[operating_point_index]
    AUC = auc(fpr, tpr)
    print('\n * Accuracy@1 {top1.avg:.3f}'.format(top1=top1))
    print(' * Sensitivity {sensitivity:.3f} Specificity {specificity:.3f} AUC {AUC:.3f}'
          .format(sensitivity=sensitivity, specificity=specificity, AUC=AUC))

    # Optionally adjust learning rate based on validation loss
    if scheduler:
        scheduler.step(loss)

    metrics_dict = {
        'Time_Average': batch_time.avg,
        'Loss_Average': losses.avg,
        'Accuracy@1_Average': top1.avg,
        'Sensitivity': sensitivity,
        'Specificity': specificity,
        'AUC': AUC
    }
    torch.cuda.empty_cache()

    return top1.avg, AUC, [metrics_dict]

# (6) Training Loop

In [None]:
warnings.filterwarnings('ignore')
start_epoch = 0
epochs = 25
best_acc = 0
best_AUC = 0

metrics_to_save_all_epochs_train = []
metrics_to_save_all_epochs_val = []
for epoch in range(start_epoch, epochs):
    # train for one epoch
#     metrics_to_save_train = train(train_dataloader, resnet18, criterion, optimizer, epoch, device)
    metrics_to_save_train = train(train_dataloader, swin, criterion, optimizer, epoch, device)
    metrics_to_save_all_epochs_train.extend(metrics_to_save_train)

    # evaluate on validation set
#     current_acc, current_AUC, metrics_to_save_val = validate(val_dataloader, resnet18, criterion, device)
    current_acc, current_AUC, metrics_to_save_val = validate(val_dataloader, swin, criterion, device)

    # remember best acc@1 and save checkpoint
    is_best = best_AUC < current_AUC
    best_acc = max(current_acc, best_acc)
    best_AUC = max(current_AUC, best_AUC)

    metrics_to_save_all_epochs_val.extend(metrics_to_save_val)

    if early_stop_checker(current_AUC):
        print('The performance on validation set has not improved for a while. Early stop. Training completed.')
        break

keys = metrics_to_save_all_epochs_train[0].keys()  # Get keys for the CSV column names from the first dictionary
csv_file_path = '/kaggle/working/training_metrics_train_downsample2.csv'  # Specify the Kaggle output path
with open(csv_file_path, 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=keys)
    writer.writeheader()
    for data in metrics_to_save_all_epochs_train:
        writer.writerow(data)
keys = metrics_to_save_all_epochs_val[0].keys()  # Get keys for the CSV column names from the first dictionary
csv_file_path = '/kaggle/working/training_metrics_val_downsample2.csv'  # Specify the Kaggle output path
with open(csv_file_path, 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=keys)
    writer.writeheader()
    for data in metrics_to_save_all_epochs_val:
        writer.writerow(data)

# (7) Testing the model

### Create test dataset

In [None]:
TestDataDir = '../input/test-data/test/'

test_dataset = MRIDataset(
    DataDir=TestDataDir, 
    mode='test', 
    input_T1=True,
    transform=transforms.Compose([ToTensor()]), 
    T1_normalization_method='max', 
    downsample=1.5
)

### Store dataset, load it, and create Dataloader for testing

In [None]:
#preprocess_and_save_dataset(test_dataset, '/kaggle/working/test/')
preprocessed_test_data = load_preprocessed_data('/kaggle/input/tensor-data/dataset-2.pt')
test_dataloader = DataLoader(preprocessed_test_data, batch_size=batch_size,
                            shuffle=False, num_workers=4)

### Define test function

In [None]:
def test_model(test_loader, model, device):
    model.eval()  # Set the model to evaluate mode
    total_predictions = []
    total_targets = []

    with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
        for i, batch in enumerate(test_loader):
            for data in batch:
                input_data_T1 = data['T1'].to(device)
                input_data_T1 = input_data_T1.permute(0, 3, 1, 2)  # Adjust dimensions as necessary
                labels = data['label'].to(device)

                outputs = model(input_data_T1)
                _, predicted = torch.max(outputs.data, 1)

                total_predictions.extend(predicted.cpu().numpy())
                total_targets.extend(labels.cpu().numpy())

    # Calculate performance metrics
    fpr, tpr, thresholds = roc_curve(total_targets, total_predictions, pos_label=1)
    AUC_score = auc(fpr, tpr)
    sensitivity = tpr[np.argmax(1 - fpr + tpr)]
    specificity = 1 - fpr[np.argmax(1 - fpr + tpr)]
    accuracy = np.mean(np.array(total_targets) == np.array(total_predictions))

    print(f'Accuracy: {accuracy}\nSensitivity: {sensitivity}\nSpecificity: {specificity}\nAUC: {AUC_score}')

### Print test results

In [None]:
test_model(test_dataloader, swin, device)