# **3D Binary Classification Models for Alzheimer**

# (1) Import the necessary packages

In [None]:
import os
import glob
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from sklearn.metrics import confusion_matrix, accuracy_score
import time
from scipy.ndimage import zoom
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.init as init
from torch import optim

import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
from torchvision import transforms
from glob import glob
import torch
from skimage.measure import label, regionprops
from scipy.ndimage.morphology import binary_fill_holes
import scipy
from sklearn.metrics import roc_curve, auc
import nibabel as nib
import warnings
warnings.filterwarnings('ignore')

In [None]:
#Check that CPU is available 
cuda = torch.cuda.is_available()
print("GPU available:", cuda)

In [None]:
#Set random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# (2) Data Exploration

### Count samples for each label for training and validation

In [None]:
#Checking the number of samples in each folder (train, validation and test)
!ls /kaggle/input/preprocessed-data/ADNI_large_sample/train/T1_affine/AD/ADNI_*.nii| wc -l
#!ls /kaggle/input/preprocessed-data/ADNI_large_sample/test/T1_affine/AD/ADNI_*.nii| wc -l
!ls /kaggle/input/preprocessed-data/ADNI_large_sample/validation/T1_affine/AD/ADNI_*.nii| wc -l

!ls /kaggle/input/preprocessed-data/ADNI_large_sample/train/T1_affine/CN/ADNI_*.nii| wc -l
#!ls /kaggle/input/preprocessed-data/ADNI_large_sample/test/T1_affine/CN/ADNI_*.nii| wc -l
!ls /kaggle/input/preprocessed-data/ADNI_large_sample/validation/T1_affine/CN/ADNI_*.nii| wc -l

### Load sample

In [None]:
test_path = '/kaggle/input/preprocessed-data/ADNI_large_sample/validation/T1_affine/AD/ADNI_005_S_0929_MPR_m06_g_AD_dx_3_age_82.7_scan0_mri_brainmask_mni152brain_affine_tl.nii'
test_img = nib.load(test_path)
print(test_img)
arr = test_img.get_fdata()
print(arr.shape, type(arr))
print(test_img.affine)

### Visualize three axis: coronal, axial, and sagittal

In [None]:
arr.max(), arr.min()
f, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(10,5))
ax1.imshow(arr[:,109,:], cmap='gray')
ax2.imshow(arr[:,:,91], cmap='gray')
ax3.imshow(arr[91,:,:], cmap='gray')

### Test and visualize downsampling factor

In [None]:
from scipy.ndimage import zoom

def resample_img(arr, val=2):
    new_arr = zoom(arr, (1/val, 1/val, 1/val))
    print(new_arr.shape, arr.shape)
    return new_arr
                         
out = resample_img(arr,2.5)


f, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(10,5))
ax1.imshow(out[:,out.shape[1]//2,:], cmap='gray')
ax2.imshow(out[:,:,out.shape[1]//2], cmap='gray')
ax3.imshow(out[out.shape[1]//2,:,:], cmap='gray')

# (3) Upload data and create data loader

### Define MRIDataset class

In [None]:
#Creating DataLoader for each set (test, train and validation) using batches
class MRIDataset(Dataset):
    def __init__(self, DataDir, mode, input_T1, transform=None,T1_normalization_method = 'max', downsample=2.5):
        print('***************')
        print('MRIDataset')
        self.DataDir = DataDir
        self.input_T1 = input_T1

        if input_T1:
            self.T1_img_files = sorted(glob(DataDir + 'T1_affine/*/*.nii'))
            print(f'T1 path: {DataDir}T1_affine')
            print(f'Load T1. Total T1 {mode} number is: ' + str(len(self.T1_img_files)))
        
        self.transform = transform
        self.T1_normalization_method = T1_normalization_method
        self.downsample = downsample

    def __len__(self):
        return len(self.T1_img_files)
    
    def resample_img(self,arr, val=2):
        new_arr = zoom(arr, (1/val, 1/val, 1/val))
        return new_arr  
    
    def __getitem__(self, idx):
        label = self.T1_img_files[idx].split('/')[-2]
        label = (1 if label == 'AD' else 0)

        current_T1 = None
        #T1_dimension = (182, 218, 182)

        if self.input_T1:
            current_T1 = nib.load(self.T1_img_files[idx]).get_fdata().astype(np.float32)

        # normalization (important !!!!)
        assert self.T1_normalization_method in ['NA', 'max', 'WBtop10PercentMean']

        if current_T1 is not None:
            if self.T1_normalization_method == 'max':
                current_T1 = current_T1 / current_T1.max()
            elif self.T1_normalization_method == 'WBtop10PercentMean':
                current_T1_BM = self.WB_to_brain_mask(current_T1)
                normalization_factor = \
                np.mean(current_T1[np.logical_and(current_T1 >= \
                    np.percentile(current_T1[current_T1_BM == 1], 90, interpolation = 'nearest'), \
                    current_T1_BM == 1)])
                assert normalization_factor > 0
                current_T1 = current_T1 / normalization_factor
                      
        
        if self.downsample is not None:
            current_T1_new = self.resample_img(current_T1,self.downsample)
            
        if self.input_T1:
            sample = {'T1': current_T1_new,
                      'label': label,
                      }
            
        if self.transform:
            sample = self.transform(sample)

        return sample

### Transform samples to Tensor objects

In [None]:
class ToTensor(object):
    def __call__(self, sample):
        torch_sample = {}
        for key, value in sample.items():
        	if key == 'label':
        		torch_sample[key] = torch.from_numpy(np.array(value))
        	else:
        		torch_sample[key] = torch.from_numpy(value)

        return torch_sample

### Create Datasets

In [None]:
# define batch size
batch_size = 1

# define dataset dir
TrainDataDir = '../input/preprocessed-data/ADNI_large_sample/train/'
ValidationDataDir = '../input/preprocessed-data/ADNI_large_sample/validation/'

normalization_method = 'max'

# Train
Train_MRIDataset = MRIDataset(DataDir = TrainDataDir, mode = 'train', input_T1 = True,transform=transforms.Compose([ToTensor()]),
                        T1_normalization_method = normalization_method, downsample = 1.5)

# Validation
Validation_MRIDataset = MRIDataset(DataDir=ValidationDataDir, mode = 'validation', input_T1 = True,transform=transforms.Compose([ToTensor()]),
                                T1_normalization_method = normalization_method,downsample = 1.5)

### Store Datsets (to speed up training)

In [None]:
def preprocess_and_save_dataset(dataset, save_path):
    """
    Preprocess a given dataset and save the processed data.
    """
    os.makedirs(save_path, exist_ok=True)  # Ensure the save directory exists
    
    preprocessed_samples = []
    for idx, sample in enumerate(dataset):
        # Assuming your sample is a dictionary with 'T1' and 'label' keys
        preprocessed_samples.append(sample)
        
        # Optionally, save individual samples to reduce memory usage
        #torch.save(sample, os.path.join(save_path, f'sample_{idx}.pt'))

    # Optionally, save the whole dataset as one file (if memory allows)
    torch.save(preprocessed_samples, os.path.join(save_path, 'dataset.pt'))
    
def load_preprocessed_data(preprocessed_path):
    """
    Load preprocessed data from the given path.
    """
    return torch.load(os.path.join(preprocessed_path, 'dataset.pt'))

In [None]:
# Preprocess and save the training dataset
preprocess_and_save_dataset(Train_MRIDataset, '/kaggle/working/train/')

# Preprocess and save the validation dataset
preprocess_and_save_dataset(Validation_MRIDataset, '/kaggle/working/val/')

In [None]:
preprocessed_train_data = load_preprocessed_data('/kaggle/working/train/')
preprocessed_validation_data = load_preprocessed_data('/kaggle/working/val/')

# Convert preprocessed data to DataLoader if needed
# Note: This step might not be necessary if you directly iterate over the preprocessed samples
Train_dataloader = DataLoader(preprocessed_train_data, batch_size = batch_size,
                       shuffle=True, num_workers=4)
Val_dataloader = DataLoader(preprocessed_validation_data, batch_size = batch_size,
                       shuffle=False, num_workers=4)

# (4) Define model: 3D VGG

In [None]:
class VGG(nn.Module):
    '''
    VGG model 
    '''
    def __init__(self, feature_extractor, input_T1):
        super(VGG, self).__init__()
        self.input_T1 = input_T1
        self.feature_extractor = feature_extractor
        self.T1_feature_dimension = 4608 #Changed from 1*128*5*6*5 to make it match x = x1.view(x1.size(0), -1)
        feature_dimension = self.T1_feature_dimension
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(feature_dimension, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, 2),
         )

         # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x1):
        x1 = self.feature_extractor(x1)
        #print(x1.shape)
        x = x1.view(x1.size(0), -1)
        #print(x.shape)
        
        x = self.classifier(x)
        return x
    
def make_layers(cfg, input_T1, batch_norm=False):
    layers = []
    in_channels = 1

    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool3d(kernel_size=2, stride=2)]
        else:
            conv3d = nn.Conv3d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv3d, nn.BatchNorm3d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv3d, nn.ReLU(inplace=True)]
            in_channels = v
    feature_extractor_T1 = nn.Sequential(*layers)
    return feature_extractor_T1

cfg = {
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [16, 16, 'M', 32, 32, 'M', 64, 64, 64, 64, 'M', 128, 128, 128, 128, 'M', 
          128, 128, 128, 128, 'M'],
}

def vgg16(input_T1):
    return VGG(make_layers(cfg['D'], input_T1), input_T1)

def vgg16_bn(input_T1): #with batch normalization
    return VGG(make_layers(cfg['D'], input_T1, batch_norm=True), input_T1)

def vgg19_bn(input_T1):
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    return VGG(make_layers(cfg['E'], input_T1,batch_norm=True),input_T1)

In [None]:
#Visualize model structure
# model = vgg16(input_T1=True)
# model = vgg19(input_T1=True)
model = vgg19_bn(input_T1=True)

# (5) Define helper functions and paramters for training and validation

### Helper Functions

In [None]:
#The performance on validation set has not improved for a while. Early stop. Training completed.
class EarlyStopChecker:
    '''
    Early stop checker. Credits to Chen "Raphael" Liu and Nanyan "Rosalie" Zhu.
    '''
    def __init__(self, search_window = 5, score_higher_better = True):
        self.search_window = search_window
        self.score_higher_better = score_higher_better
        self.score_history = []

    def __call__(self, current_score):
        self.score_history.append(current_score)
        if len(self.score_history) < 2 * self.search_window:
            return False
        else:
            if self.score_higher_better == True:
                if current_score < np.max(self.score_history) and np.mean(self.score_history[-self.search_window:]) < np.mean(self.score_history[-2*self.search_window:-self.search_window]):
                    return True
                else:
                    return False
            else:
                if current_score > np.max(self.score_history) and np.mean(self.score_history[-self.search_window:]) > np.mean(self.score_history[-2*self.search_window:-self.search_window]):
                    return True
                else:
                    return False
                
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

### Define loss function, early stop checker, and scheduler

In [None]:
#Defining the required variables
cuda_idx = '0'
momentum = 0.9
weight_decay = 5e-4
lr = 1e-5
device = torch.device(f"cuda:{cuda_idx}" if (torch.cuda.is_available()) else "cpu")

# Define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

# Initialize the checker for early stop
early_stop_checker = EarlyStopChecker()

optimizer = torch.optim.SGD(model.parameters(), lr,
                            momentum=momentum,
                            weight_decay=weight_decay)
#Defining adaptive learning rate
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.8, patience = 2, verbose = 2)

### Move model to GPU

In [None]:
model.to(device)

### Define train and validation functions

In [None]:
data_dropout = True
data_dropout_remaining_size = 80
input_T1 = True

def train(train_loader, model, criterion, optimizer, epoch, device):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    train_display_counter = 0
    model_output_list, AD_ground_truth_list = [], []
    metrics_to_save = []

    # switch to train mode
    model.train()

    end = time.time()

    # If we use data-dropout, we randomly sample a certain subset for training in each epoch.
    if data_dropout:
        print('The size of the training set is :', len(train_loader), '. We will randomly sample %s scans for this epoch.' % (data_dropout_remaining_size))
        valid_training_indices = sorted(np.random.choice(range(len(train_loader)), data_dropout_remaining_size, replace = False))

    for i, (data) in enumerate(train_loader):
        if data_dropout:
            if not i in valid_training_indices:
                continue

        # assign variables
        input_data_T1 = None
        if input_T1:
            input_data_T1 = data['T1'].unsqueeze(1)
        target = data['label']

        # measure data loading time
        data_time.update(time.time() - end)

        if input_T1:
            input_data_T1 = input_data_T1.to(device)
        target = target.to(device)

        # compute output
        output = model(input_data_T1)

        #print('output: ', output, 'target: ', target)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        acc1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), batch_size)
        top1.update(acc1.item(), batch_size)

        # measure sensitivity, specificity, AUC.
        model_output_list.append(output.data.cpu().detach().numpy().tolist())
        AD_ground_truth_list.append(target.cpu().detach().numpy().tolist())

        AD_prediction_list = [scipy.special.softmax(pred)[0][1] for pred in model_output_list]
        fpr, tpr, _ = roc_curve(AD_ground_truth_list, AD_prediction_list)
        operating_point_index = np.argmax(1 - fpr + tpr)
        sensitivity, specificity = tpr[operating_point_index], 1 - fpr[operating_point_index]
        AUC = auc(fpr, tpr)
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        torch.cuda.empty_cache()
    
    print('\nEpoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Accuracy@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Sensitivity ({sensitivity:.3f})\t'
                      'Specificity ({specificity:.3f})\t'
                      'AUC ({AUC:.3f})'.format(
                          epoch, i, len(train_loader), batch_time=batch_time,
                          data_time=data_time, loss=losses, top1=top1,
                          sensitivity=sensitivity, specificity=specificity, AUC=AUC))
    metrics_dict = {
                'Epoch': epoch,
                'Batch': i,
                'Time_Value': batch_time.val,
                'Time_Average': batch_time.avg,
                'Data_Value': data_time.val,
                'Data_Average': data_time.avg,
                'Loss_Value': losses.val,
                'Loss_Average': losses.avg,
                'Accuracy@1_Value': top1.val,
                'Accuracy@1_Average': top1.avg,
                'Sensitivity': sensitivity,
                'Specificity': specificity,
                'AUC': AUC
                }
    metrics_to_save.append(metrics_dict)
    return metrics_to_save

In [None]:
def validate(val_loader, model, criterion, device, scheduler = None):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    val_display_counter = 0
    model_output_list, AD_ground_truth_list = [], []
    metrics_to_save = []

    # switch to evaluate mode
    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (data) in enumerate(val_loader):
            
            # assign variables
            input_data_T1 = None
            if input_T1:
                input_data_T1 = data['T1'].unsqueeze(1)
            target = data['label']

            # measure data loading time
            batch_time.update(time.time() - end)

            if input_T1:
                input_data_T1 = input_data_T1.to(device)
                
            target = target.to(device)

            # compute output
            with torch.no_grad():
                output = model(input_data_T1)
                loss = criterion(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            acc1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), batch_size)
            top1.update(acc1.item(), batch_size)

            # measure sensitivity, specificity, AUC.
            model_output_list.append(output.data.cpu().detach().numpy().tolist())
            AD_ground_truth_list.append(target.cpu().detach().numpy().tolist())


            AD_prediction_list = [scipy.special.softmax(pred)[0][1] for pred in model_output_list]
            fpr, tpr, _ = roc_curve(AD_ground_truth_list, AD_prediction_list)
            operating_point_index = np.argmax(1 - fpr + tpr)
            sensitivity, specificity = tpr[operating_point_index], 1 - fpr[operating_point_index]
            AUC = auc(fpr, tpr)
            
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            torch.cuda.empty_cache()
            
        metrics_dict = {
                'Batch': i,
                'Time_Value': batch_time.val,
                'Time_Average': batch_time.avg,
                'Loss_Value': losses.val,
                'Loss_Average': losses.avg,
                'Accuracy@1_Value': top1.val,
                'Accuracy@1_Average': top1.avg,
                'Sensitivity': sensitivity,
                'Specificity': specificity,
                'AUC': AUC
                }
        metrics_to_save.append(metrics_dict)
        print('\n * Accuracy@1 {top1.avg:.3f}'.format(top1=top1))
        print(' * Sensitivity {sensitivity:.3f} Specificity {specificity:.3f} AUC {AUC:.3f}'
            .format(sensitivity=sensitivity, specificity=specificity, AUC=AUC))

        scheduler.step(loss)

        return top1.avg, AUC, metrics_to_save

# (6) Training Loop

In [None]:
warnings.filterwarnings('ignore')
start_epoch = 0
epochs = 100
best_acc = 0
best_AUC = 0

metrics_to_save_all_epochs_train = []
metrics_to_save_all_epochs_val = []
for epoch in range(start_epoch, epochs):
    # train for one epoch
    metrics_to_save_train = train(Train_dataloader, model, criterion, optimizer, epoch, device)
    metrics_to_save_all_epochs_train.extend(metrics_to_save_train)

    # evaluate on validation set
    current_acc, current_AUC, metrics_to_save_val = validate(Val_dataloader, model, criterion, device, scheduler)

    # remember best acc@1 and save checkpoint
    is_best = best_AUC < current_AUC
    best_acc = max(current_acc, best_acc)
    best_AUC = max(current_AUC, best_AUC)

    metrics_to_save_all_epochs_val.extend(metrics_to_save_val)

    if early_stop_checker(current_AUC):
        print('The performance on validation set has not improved for a while. Early stop. Training completed.')
        break

# Save training metrics 
keys = metrics_to_save_all_epochs_train[0].keys()  # Get keys for the CSV column names from the first dictionary
csv_file_path = '/kaggle/working/training_metrics_train_downsample2.csv'  # Specify the Kaggle output path
with open(csv_file_path, 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=keys)
    writer.writeheader()
    for data in metrics_to_save_all_epochs_train:
        writer.writerow(data)

# Save validation metrics
keys = metrics_to_save_all_epochs_val[0].keys()  # Get keys for the CSV column names from the first dictionary
csv_file_path = '/kaggle/working/training_metrics_val_downsample2.csv'  # Specify the Kaggle output path
with open(csv_file_path, 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=keys)
    writer.writeheader()
    for data in metrics_to_save_all_epochs_val:
        writer.writerow(data)

# (7) Testing the model# Test data

### Create test dataset

In [None]:
TestDataDir = '../input/test-data/test/'

test_dataset = MRIDataset(
    DataDir=TestDataDir, 
    mode='test', 
    input_T1=True,
    transform=transforms.Compose([ToTensor()]), 
    T1_normalization_method='max', 
    downsample=1.5
)

### Store dataset, load it, and create Dataloader for testing

In [None]:
#preprocess_and_save_dataset(test_dataset, '/kaggle/working/test/')
preprocessed_test_data = load_preprocessed_data('/kaggle/working/test/')
test_dataloader = DataLoader(preprocessed_test_data, batch_size=batch_size,
                            shuffle=False, num_workers=4)

### Define test function

In [None]:
def test_model(test_loader, model, device):
    """
    Run the model on the test dataset and evaluate the performance
    """
    model.eval()  # Set the model to evaluate mode
    total_predictions = []
    total_targets = []

    with torch.no_grad():  # Turn off gradients for testing, saves memory and computations
        for i, data in enumerate(test_loader):
            input_data_T1 = data['T1'].unsqueeze(1).to(device) if input_T1 else None
            labels = data['label'].to(device)

            outputs = model(input_data_T1)
            probabilities = torch.softmax(outputs, dim=1)
            predicted = probabilities[:, 1]  # Assuming the second column represents the positive class probabilities

            total_predictions.extend(predicted.cpu().numpy())
            total_targets.extend(labels.cpu().numpy())

    # Calculate performance metrics
    fpr, tpr, thresholds = roc_curve(total_targets, total_predictions, pos_label=1)
    AUC_score = auc(fpr, tpr)
    sensitivity = tpr[np.argmax(1 - fpr + tpr)]
    specificity = 1 - fpr[np.argmax(1 - fpr + tpr)]
    accuracy = np.mean(np.array(total_targets) == (np.array(total_predictions) > 0.5))

    print(f'Accuracy: {accuracy:.3f}\nSensitivity: {sensitivity:.3f}\nSpecificity: {specificity:.3f}\nAUC: {AUC_score:.3f}')

### Print test results

In [None]:
test_model(test_dataloader, model, device)