In [None]:
%matplotlib inline

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from collections import OrderedDict
import pickle

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# load functions from nitorch
from nitorch.transforms import  ToTensor
from nitorch.metrics import balanced_accuracy

In [None]:
####################
#### file paths ####
####################

## INPUT FILE PATHS
# this notebook assumes that both paths given below contain subfolders for the different data splits,
# with the folder names given in the list 'splits'
#   - data_base_path/[split] should contain the h5 files (only holdout needed here)
#   - models_base_path/[split] should contain the trained models (from 2_train_models_multiGPU)
data_base_path = '/path/to/data'
models_base_path = '/path/to/models'
splits = ['split_0', 'split_1', 'split_2']

## OUTPUT
# a file called "raw_pred.pkl" will be created in each models_base_path/[split] subfolder


In [None]:
print(torch.__version__)
print(torch.version.cuda)

In [None]:
gpu = 0
b = 4 # batch size
num_classes = 2

dtype = np.float64

In [None]:
# CLASSIFIER
class ClassificationModel3D(nn.Module):
    def __init__(self, dropout=0.4, dropout2=0.4):
        nn.Module.__init__(self)
        self.Conv_1 = nn.Conv3d(1, 8, 3)
        self.Conv_1_bn = nn.BatchNorm3d(8)
        self.Conv_1_mp = nn.MaxPool3d(2)
        self.Conv_2 = nn.Conv3d(8, 16, 3)
        self.Conv_2_bn = nn.BatchNorm3d(16)
        self.Conv_2_mp = nn.MaxPool3d(3)
        self.Conv_3 = nn.Conv3d(16, 32, 3)
        self.Conv_3_bn = nn.BatchNorm3d(32)
        self.Conv_3_mp = nn.MaxPool3d(2)
        self.Conv_4 = nn.Conv3d(32, 64, 3)
        self.Conv_4_bn = nn.BatchNorm3d(64)
        self.Conv_4_mp = nn.MaxPool3d(3)
        self.dense_1 = nn.Linear(2304, 128)
        self.dense_2 = nn.Linear(128, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout2)

    def forward(self, x):
        x = self.relu(self.Conv_1_bn(self.Conv_1(x)))
        x = self.Conv_1_mp(x)
        x = self.relu(self.Conv_2_bn(self.Conv_2(x)))
        x = self.Conv_2_mp(x)
        x = self.relu(self.Conv_3_bn(self.Conv_3(x)))
        x = self.Conv_3_mp(x)
        x = self.relu(self.Conv_4_bn(self.Conv_4(x)))
        x = self.Conv_4_mp(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.relu(self.dense_1(x))
        x = self.dropout2(x)
        x = self.dense_2(x)
        return x
    
# DATASET
class ADNIDataset(Dataset):
    def __init__(self, X, y, transform=None, target_transform=None, mask=None, z_factor=None, dtype=np.float32, num_classes=2):
        self.X = X
        self.y = y
        self.transform = transform
        self.target_transform = target_transform
        self.mask = mask
        self.z_factor = z_factor
        self.dtype = dtype
        self.num_classes = num_classes
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        image = self.X[idx]
        label_tensor = np.zeros(shape=(self.num_classes,))
        label = self.y[idx] >= 0.5
        label = torch.LongTensor([label])
        
        if self.transform:
            image = self.transform(image)
            
        sample = {"image" : image,
                 "label" : label}
        return sample

In [None]:
for split in splits:
    ################ LOAD DATA
    
    holdout_h5 = '{}/{}/ADNI_3T_AD_CN_holdout.h5'.format(data_base_path, split)
    holdout_h5_ = h5py.File(holdout_h5, 'r')
    
    X_holdout, y_holdout = holdout_h5_['X'], holdout_h5_['y']
    
    X_holdout = np.array(X_holdout)
    y_holdout = np.array(y_holdout)
     
    for i in range(len(X_holdout)):
        X_holdout[i] -= np.min(X_holdout[i])
        X_holdout[i] /= np.max(X_holdout[i])
    
    adni_data_test = ADNIDataset(X_holdout, y_holdout, transform=transforms.Compose([ToTensor()]), dtype=dtype)
    
    ############### LOAD MODELS
    
    model_path = '{}/{}'.format(models_base_path, split)
    models = []
    for i in range(5):
        filename = "/trial_{}_BEST_ITERATION.h5".format(i)
        net = ClassificationModel3D()
        
        state_dict = torch.load(model_path + filename)
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove "module." prefix (due to nn.DataParallel)
            new_state_dict[name] = v
        
        net.load_state_dict(new_state_dict)
        models.append(net)
        
    test_loader = DataLoader(adni_data_test, batch_size=1, num_workers=1, shuffle=False)

    ############## INFERENCE
    
    pred_correct_all = []
    balanced_accs = []
    raw_preds_all = []
    for trial, model in enumerate(models):
        all_preds = []
        all_labels = []
        pred_correct = []
        raw_preds = []
        
        net = model.cuda(gpu)
        net.eval()
        with torch.no_grad():
            cou = 0
            for sample in test_loader:
                img = sample["image"]
                label = sample["label"]
                
                img = img.to(torch.device("cuda:" + str(gpu)))
                output = net.forward(img)
                output_softmax = F.softmax(output, dim=1)
                raw_preds.append(output_softmax[0][1].cpu().numpy().item())
                pred = torch.argmax(output_softmax)
                all_preds.append(pred.cpu().numpy().item())
                all_labels.append(label.numpy().item())
                pred_correct.append(all_preds[-1] == all_labels[-1])
        
        raw_preds_all.append(raw_preds)
        balanced_accs.append(balanced_accuracy(all_labels, all_preds))
        pred_correct_all.append(pred_correct)
    
    with open('{}/{}/raw_pred.pkl'.format(models_base_path, split), 'wb') as f:
        pickle.dump(raw_preds_all, f)
    print("finished split", split)