In [None]:
%matplotlib inline

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

# sklearn functions
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, KFold, GroupShuffleSplit

# load functions from nitorch
from nitorch.data import load_nifti
from nitorch.transforms import  ToTensor, SagittalTranslate, SagittalFlip, \
                                AxialTranslate, normalization_factors, Normalize, \
                                IntensityRescale
from nitorch.callbacks import EarlyStopping, ModelCheckpoint
from nitorch.trainer import Trainer
from nitorch.initialization import weights_init
from nitorch.metrics import balanced_accuracy, sensitivity, specificity, auc_score
from nitorch.utils import count_parameters

In [None]:
####################
#### file paths ####
####################

## INPUT FILES
# training, validation and test h5 files (from 1_create_dataset_splits_stratified)
train_h5 = '/path/to/ADNI_3T_AD_CN_train.h5'
val_h5 = '/path/to/ADNI_3T_AD_CN_val.h5'
holdout_h5 = '/path/to/ADNI_3T_AD_CN_holdout.h5'
# male/female test h5 files for sex-specific evaluation (from 1_create_dataset_splits_stratified)
holdout_m_h5 = '/path/to/ADNI_3T_AD_CN_holdout_m.h5'
holdout_f_h5 = '/path/to/ADNI_3T_AD_CN_holdout_f.h5'


## OUTPUT FILE PATH
# path where the trained models will be saved
model_path = '/path/to/model'
# path for the training graphics (loss curve) INCLUDING file name prefix
# the file name should contain a format string (e.g. '{}') which will be replaced with the trial number
# example: the value '/path/to/trial_{}' will create
#   - /path/to/trial_0.png, /path/to/trial_1.png, etc., showing the loss curve
#   - /path/to/trial_0_balanced_accuracy.png, etc., showing the balanced accuracy curve
training_graphics_path = '/path/to/model/trial_{}'


In [None]:
print(torch.__version__)
print(torch.version.cuda)

In [None]:
gpu = 0
b = 4 # batch size
num_classes = 2

dtype = np.float64

In [None]:
train_h5_ = h5py.File(train_h5, 'r')
val_h5_ = h5py.File(val_h5, 'r')
holdout_h5_ = h5py.File(holdout_h5, 'r')

In [None]:
X_train, y_train = train_h5_['X'], train_h5_['y']
X_val, y_val = val_h5_['X'], val_h5_['y']
X_holdout, y_holdout = holdout_h5_['X'], holdout_h5_['y']

In [None]:
mean_std_normalization = False
min_max_normalization = True

In [None]:
# normalize min-max
X_train = np.array(X_train)
X_val = np.array(X_val)
X_holdout = np.array(X_holdout)

y_train = np.array(y_train)
y_val = np.array(y_val)
y_holdout = np.array(y_holdout)

if mean_std_normalization:
    mean = np.mean(X_train)
    std = np.std(X_train)
    X_train = (X_train - mean) / std
    X_val = (X_val - mean) / std
    X_holdout = (X_holdout - mean) / std
    
if min_max_normalization:
    for i in range(len(X_train)):
        X_train[i] -= np.min(X_train[i])
        X_train[i] /= np.max(X_train[i])

    for i in range(len(X_val)):
        X_val[i] -= np.min(X_val[i])
        X_val[i] /= np.max(X_val[i])

    for i in range(len(X_holdout)):
        X_holdout[i] -= np.min(X_holdout[i])
        X_holdout[i] /= np.max(X_holdout[i])

In [None]:
class ADNIDataset(Dataset):
    def __init__(self, X, y, transform=None, target_transform=None, mask=None, z_factor=None, dtype=np.float32, num_classes=2):
        self.X = X
        self.y = y
        self.transform = transform
        self.target_transform = target_transform
        self.mask = mask
        self.z_factor = z_factor
        self.dtype = dtype
        self.num_classes = num_classes
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        image = self.X[idx]
        label_tensor = np.zeros(shape=(self.num_classes,))
        label = self.y[idx] >= 0.5
        label = torch.LongTensor([label])
        
        if self.transform:
            image = self.transform(image)
            
        sample = {"image" : image,
                 "label" : label}
        return sample

In [None]:
augmentations = [SagittalFlip(), SagittalTranslate(dist=(-2, 3))]

In [None]:
adni_data_train = ADNIDataset(X_train, y_train, transform=transforms.Compose(augmentations + [ToTensor()]), dtype=dtype)
adni_data_val = ADNIDataset(X_val, y_val, transform=transforms.Compose([ToTensor()]), dtype=dtype)
adni_data_test = ADNIDataset(X_holdout, y_holdout, transform=transforms.Compose([ToTensor()]), dtype=dtype)

In [None]:
sample = adni_data_train[400]
img = sample["image"]
img.shape

In [None]:
plt.imshow(img[0][:,:,80], cmap='gray')

# Define the classifier

In [None]:
class ClassificationModel3D(nn.Module):
    def __init__(self, dropout=0.4, dropout2=0.4):
        nn.Module.__init__(self)
        self.Conv_1 = nn.Conv3d(1, 8, 3)
        self.Conv_1_bn = nn.BatchNorm3d(8)
        self.Conv_1_mp = nn.MaxPool3d(2)
        self.Conv_2 = nn.Conv3d(8, 16, 3)
        self.Conv_2_bn = nn.BatchNorm3d(16)
        self.Conv_2_mp = nn.MaxPool3d(3)
        self.Conv_3 = nn.Conv3d(16, 32, 3)
        self.Conv_3_bn = nn.BatchNorm3d(32)
        self.Conv_3_mp = nn.MaxPool3d(2)
        self.Conv_4 = nn.Conv3d(32, 64, 3)
        self.Conv_4_bn = nn.BatchNorm3d(64)
        self.Conv_4_mp = nn.MaxPool3d(3)
        self.dense_1 = nn.Linear(2304, 128)
        self.dense_2 = nn.Linear(128, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout2)

    def forward(self, x):
        x = self.relu(self.Conv_1_bn(self.Conv_1(x)))
        x = self.Conv_1_mp(x)
        x = self.relu(self.Conv_2_bn(self.Conv_2(x)))
        x = self.Conv_2_mp(x)
        x = self.relu(self.Conv_3_bn(self.Conv_3(x)))
        x = self.Conv_3_mp(x)
        x = self.relu(self.Conv_4_bn(self.Conv_4(x)))
        x = self.Conv_4_mp(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.relu(self.dense_1(x))
        x = self.dropout2(x)
        x = self.dense_2(x)
        return x

In [None]:
net = ClassificationModel3D()

gpu_ids = [0,1,2,3]
b = b*4 # 16
net = nn.DataParallel(net, device_ids=gpu_ids)

net = net.cuda(gpu)

In [None]:
print("Trainable model parameters: {}".format(count_parameters(net)))

# Training

In [None]:
def run(
    net,
    data,
    shape,
    callbacks=[],
    augmentations=[],
    masked=False,
    metrics=[],
    k_folds=None,
    b=4,
    num_epochs=35,
    retain_metric=None
):      
   
    fold_metric = []
    models = []
    fold = 0
    initial_prepend = None
    
    for trial in range(5):
        print("Starting trial {}".format(trial))

        # add current trial number to model checkpoint path
        if callbacks is not None:
            for idx, callback in enumerate(callbacks):
                if isinstance(callback, ModelCheckpoint):
                    if initial_prepend is None:
                        initial_prepend = callbacks[idx].prepend
                    callbacks[idx].prepend = initial_prepend + "trial_{}_".format(fold)
        fold += 1

        # restart model
        del net
        net = ClassificationModel3D()
        net = nn.DataParallel(net, device_ids=gpu_ids)
        net.cuda(gpu)
        
        # reset hyperparameters
        lr = 1e-4
        wd = 1e-4
        criterion = nn.CrossEntropyLoss().cuda(gpu)
        optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=wd)

        train_loader = DataLoader(
            adni_data_train, batch_size=b, num_workers=4, shuffle=True
        )
        val_loader = DataLoader(
            adni_data_val, batch_size=1, num_workers=1, shuffle=True
        )

        sample = next(iter(train_loader))
        img = sample["image"][0]
        lbl = sample["label"][0]
        plt.imshow(img.squeeze()[:,:,70], cmap='gray')
        plt.title(lbl.item())
        plt.show()
        trainer = Trainer(
            net,
            criterion,
            optimizer,
            metrics=metrics,
            callbacks=callbacks,
            device=gpu,
            prediction_type="classification"
        )
        # train model and store results
        net, report = trainer.train_model(
            train_loader,
            val_loader,
            num_epochs=num_epochs,
            show_train_steps=10,
            show_validation_epochs=1,
        )
        # append validation score of the retain metric
        if isinstance(retain_metric, str):
            fold_metric.append(report["val_metrics"][retain_metric][-1])
        else:
            fold_metric.append(report["val_metrics"][retain_metric.__name__][-1])

        models.append(net)
        print("Finished trial.")

        # visualize result
        trainer.visualize_training(report, metrics, training_graphics_path.format(trial))
        trainer.evaluate_model(val_loader, gpu)

    print("################################")
    print("################################")
    print("All accuracies: {}".format(fold_metric))
    return fold_metric, models


In [None]:
num_epochs = 200
min_iters = 3
ignore_epochs = 15
normalize = False
retain_metric = balanced_accuracy
metrics = [balanced_accuracy]

In [None]:
check = ModelCheckpoint(path=model_path,
                             store_best=True,
                             ignore_before=ignore_epochs,
                             retain_metric=retain_metric)
callbacks = [check, EarlyStopping(patience=8, ignore_before=ignore_epochs, retain_metric="loss", mode='min')]

In [None]:
fold_metric, models = run(net=net, data=adni_data_train,
                  k_folds=-1,
                  callbacks=callbacks,
                  shape=-1,
                  masked=False,
                  metrics=metrics,
                  num_epochs=num_epochs,
                  retain_metric=retain_metric,
                  b=b,
                 )

print(np.mean(fold_metric))
print(np.std(fold_metric))

# Start inference

In [None]:
from collections import OrderedDict

# load models
models = []
for i in range(5):
    filename = "/trial_{}_BEST_ITERATION.h5".format(i)
    net = ClassificationModel3D()
    
    state_dict = torch.load(model_path + filename)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove "module." prefix (due to nn.DataParallel)
        new_state_dict[name] = v
    
    net.load_state_dict(new_state_dict)
    models.append(net)

In [None]:
test_loader = DataLoader(
            adni_data_test, batch_size=1, num_workers=1, shuffle=False
)

In [None]:
metrics = []
    
for trial, model in enumerate(models):
    print("Trial {}".format(trial))

    all_preds = []
    all_labels = []
    
    net = model.cuda(gpu)
    net.eval()
    with torch.no_grad():
        cou = 0
        for sample in test_loader:
            img = sample["image"]
            label = sample["label"]
            
            img = img.to(torch.device("cuda:" + str(gpu)))
            output = net.forward(img)
            pred = torch.argmax(F.softmax(output, dim=1))
            all_preds.append(pred.cpu().numpy().item())
            all_labels.append(label.numpy().item())
    
    balanced_acc = balanced_accuracy(all_labels, all_preds)
    sens = sensitivity(all_labels, all_preds)
    spec = specificity(all_labels, all_preds)
    auc = auc_score(all_labels, all_preds)
    print(balanced_acc)
    print()

    metrics.append((balanced_acc, sens, spec, auc))
    
print("######## Final results ########")
metrics_df = pd.DataFrame(metrics)
print(metrics_df)
print("Balanced accuracy mean {:.2f} %".format(np.mean(metrics_df[0])*100))


### male/female

In [None]:
holdout_h5_m = h5py.File(holdout_m_h5, 'r')
holdout_h5_f = h5py.File(holdout_f_h5, 'r')

X_holdout_m, y_holdout_m = holdout_h5_m['X'], holdout_h5_m['y']
X_holdout_f, y_holdout_f = holdout_h5_f['X'], holdout_h5_f['y']

X_holdout_m = np.array(X_holdout_m)
y_holdout_m = np.array(y_holdout_m)
X_holdout_f = np.array(X_holdout_f)
y_holdout_f = np.array(y_holdout_f)

for i in range(len(X_holdout_m)):
    X_holdout_m[i] -= np.min(X_holdout_m[i])
    X_holdout_m[i] /= np.max(X_holdout_m[i])
for i in range(len(X_holdout_f)):
    X_holdout_f[i] -= np.min(X_holdout_f[i])
    X_holdout_f[i] /= np.max(X_holdout_f[i])
    
adni_data_test_m = ADNIDataset(X_holdout_m, y_holdout_m, transform=transforms.Compose([ToTensor()]), dtype=dtype)
adni_data_test_f = ADNIDataset(X_holdout_f, y_holdout_f, transform=transforms.Compose([ToTensor()]), dtype=dtype)

In [None]:
test_m_loader = DataLoader(
            adni_data_test_m, batch_size=1, num_workers=1, shuffle=False
)
test_f_loader = DataLoader(
            adni_data_test_f, batch_size=1, num_workers=1, shuffle=False
)

In [None]:
metrics = []

print("male patients")
for trial, model in enumerate(models):
    print("Trial {}".format(trial))

    all_preds = []
    all_labels = []
    
    net = model.cuda(gpu)
    net.eval()
    with torch.no_grad():
        for sample in test_m_loader:
            img = sample["image"]
            label = sample["label"]

            img = img.to(torch.device("cuda:" + str(gpu)))
            output = net.forward(img)
            pred = torch.argmax(F.softmax(output, dim=1))
            all_preds.append(pred.cpu().numpy().item())
            all_labels.append(label.numpy().item())
    
    balanced_acc = balanced_accuracy(all_labels, all_preds)
    sens = sensitivity(all_labels, all_preds)
    spec = specificity(all_labels, all_preds)
    auc = auc_score(all_labels, all_preds)
    print(balanced_acc)
    print()

    metrics.append((balanced_acc, sens, spec, auc))
    
print("######## Final results ########")
metrics_df = pd.DataFrame(metrics)
print(metrics_df)
print("Balanced accuracy mean {:.2f} %".format(np.mean(metrics_df[0])*100))

In [None]:
metrics = []

print("female patients")
for trial, model in enumerate(models):
    print("Trial {}".format(trial))

    all_preds = []
    all_labels = []
    
    net = model.cuda(gpu)
    net.eval()
    with torch.no_grad():
        for sample in test_f_loader:
            img = sample["image"]
            label = sample["label"]

            img = img.to(torch.device("cuda:" + str(gpu)))
            output = net.forward(img)
            pred = torch.argmax(F.softmax(output, dim=1))
            all_preds.append(pred.cpu().numpy().item())
            all_labels.append(label.numpy().item())
    
    balanced_acc = balanced_accuracy(all_labels, all_preds)
    sens = sensitivity(all_labels, all_preds)
    spec = specificity(all_labels, all_preds)
    auc = auc_score(all_labels, all_preds)
    print(balanced_acc)
    print()
    
    metrics.append((balanced_acc, sens, spec, auc))
    
print("######## Final results ########")
metrics_df = pd.DataFrame(metrics)
print(metrics_df)
print("Balanced accuracy mean {:.2f} %".format(np.mean(metrics_df[0])*100))