In [None]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from datetime import datetime
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from sklearn.metrics import accuracy_score, f1_score, balanced_accuracy_score, recall_score, precision_score, classification_report

from train_model_multi_grouped_bal import train_model
from validate_multi import validate

# setting seeds for reproducibility
torch.manual_seed(1337)
BATCH_SIZE = _
N_EPOCHS = _
LEARNING_RATE = _
DROPOUT_RATE = _
LAYER_SIZE = _
SPLIT = 0.45
MODE = 'bal' # 'acc'
BINARY = True # False
if BINARY:
    NUM_CLASSES = 2
else:
    NUM_CLASSES = 3
print(NUM_CLASSES)

device = "cpu" #("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class ERC_I_dataset_subj(torch.utils.data.Dataset):
    """This Dataset class loads pairs of ERC-I images and labels.

    Parameters:
    - data_path (str): Path to the folder containing images and labels
    Returns:
    - Tensor: Torch tensor with ERC_I time-series
    - int: label (0, 1 or 2)

   """
    def __init__(self, data_path):
        self.data_path = data_path
        self.num_subjects = len(os.listdir(self.data_path))

    def __len__(self):
        return self.num_subjects

    def __getitem__(self, index):
        subject = os.listdir(self.data_path)[index]
        img = self.data_path + '/' + subject + '/functional.nii'
        label = self.data_path + '/' + subject + '/pheno.csv'

        img = nib.load(img).get_fdata()
        img = np.swapaxes(img, 0, 3)
        img = img[0:150]
        if img.shape[0] != 150:
            print(subject)
        img = torch.from_numpy(img)
        img = img.to(torch.float)

        label = pd.read_csv(label)
        diagnosis = label.iloc[0]['Diagnosis']

        ### LABEL
        if diagnosis == 'Ingen':
            return img, 0
        else:
            if BINARY:
                return img, 1
            else:
                diagnosis = label.iloc[0]['Group']
                return img, diagnosis

In [None]:
class ERC_II_dataset_subj(torch.utils.data.Dataset):
    """This Dataset class loads pairs of ERC-II images and labels.

    Parameters:
    - data_path (str): Path to the folder containing images and labels
    Returns:
    - Tensor: Torch tensor with ERC_II time-series
    - int: label (0, 1 or 2)

   """
    def __init__(self, data_path):
        self.data_path = data_path
        self.num_subjects = len(os.listdir(self.data_path))

    def __len__(self):
        return self.num_subjects

    def __getitem__(self, index):
        subject = os.listdir(self.data_path)[index]
        img = self.data_path + '/' + subject + '/functional.nii'
        label = self.data_path + '/' + subject + '/pheno.csv'

        img = nib.load(img).get_fdata()
        img = np.swapaxes(img, 0, 3)
        img = img[0:150]
        if img.shape[0] != 150:
            print(subject)
        img = torch.from_numpy(img)
        img = img.to(torch.float)

        label = pd.read_csv(label)
        diagnosis = label.iloc[0]['Diagnosis']

        ### LABEL
        if diagnosis == 0:
            return img, 0
        else:
            if BINARY:
                return img, 1
            else:
                diagnosis = label.iloc[0]['Include']
                return img, diagnosis

In [None]:
class COBRE_dataset_subj(torch.utils.data.Dataset):
    """This Dataset class loads pairs of COBRE images and labels.

    Parameters:
    - data_path (str): Path to the folder containing images and labels
    Returns:
    - Tensor: Torch tensor with COBRE time-series
    - int: label (0, 1 or 2)

   """
    def __init__(self, data_path):
        self.data_path = data_path
        self.num_subjects = len(os.listdir(self.data_path))

    def __len__(self):
        return self.num_subjects

    def __getitem__(self, index):
        subject = os.listdir(self.data_path)[index]
        img = self.data_path + '/' + subject + '/functional.nii'
        label = self.data_path + '/' + subject + '/pheno.csv'

        img = nib.load(img).get_fdata()
        img = np.swapaxes(img, 0, 3)
        img = img[0:150]
        if img.shape[0] != 150:
            print(subject)
        img = torch.from_numpy(img)
        img = img.to(torch.float)

        label = pd.read_csv(label)
        diagnosis = label.iloc[0]['Subject Type']

        ### LABEL
        if diagnosis == 'Control':
            return img, 0
        else:
            if BINARY:
                return img, 1
            diagnosis = label.iloc[0]['Diagnosis']
            if diagnosis == 295.1 or diagnosis == 295.2 or diagnosis == 295.3 or diagnosis == 295.6 or diagnosis == 295.9:
                return img, 1
            else:
                return img, 2

In [None]:
'''
Loads all datasets and ensures class ratio is equal in train, val and test
'''

# control_path_0 = '/data/working/oscar/data/COBRE/groups/control'
# paranoid_path_0 = '/data/working/oscar/data/COBRE/groups/paranoid'
# rest_path_0 = '/data/working/oscar/data/COBRE/groups/rest'

control_path_1 = '/data/working/oscar/data/ERC2-I/groups/control'
paranoid_path_1 = '/data/working/oscar/data/ERC2-I/groups/paranoid'
rest_path_1 = '/data/working/oscar/data/ERC2-I/groups/rest'

control_path_2 = '/data/working/oscar/data/ERC2-II/groups/control'
paranoid_path_2 = '/data/working/oscar/data/ERC2-II/groups/paranoid'
rest_path_2 = '/data/working/oscar/data/ERC2-II/groups/rest'

# control_data_0 = COBRE_dataset_subj(control_path_0)
# paranoid_data_0 = COBRE_dataset_subj(paranoid_path_0)
# rest_data_0 = COBRE_dataset_subj(rest_path_0)

control_data_1 = ERC_I_dataset_subj(control_path_1)
paranoid_data_1 = ERC_I_dataset_subj(paranoid_path_1)
rest_data_1 = ERC_I_dataset_subj(rest_path_1)

control_data_2 = ERC_II_dataset_subj(control_path_2)
paranoid_data_2 = ERC_II_dataset_subj(paranoid_path_2)
rest_data_2 = ERC_II_dataset_subj(rest_path_2)

# control_data = ConcatDataset([control_data_0, control_data_1, control_data_2])
# paranoid_data = ConcatDataset([paranoid_data_0, paranoid_data_1, paranoid_data_2])
# rest_data = ConcatDataset([rest_data_0, rest_data_1, rest_data_2])

control_data = ConcatDataset([control_data_1, control_data_2])
paranoid_data = ConcatDataset([paranoid_data_1, paranoid_data_2])
rest_data = ConcatDataset([rest_data_1, rest_data_2])

seed = torch.manual_seed(1337)

train_control, test_val_control = random_split(control_data, [1-SPLIT, SPLIT], seed)
train_paranoid, test_val_paranoid = random_split(paranoid_data, [1-SPLIT, SPLIT], seed)
train_rest, test_val_rest = random_split(rest_data, [1-SPLIT, SPLIT], seed)

val_control, test_control = random_split(test_val_control, [0.55, 0.45], seed)
val_paranoid, test_paranoid = random_split(test_val_paranoid, [0.55, 0.45], seed)
val_rest, test_rest = random_split(test_val_rest, [0.55, 0.45], seed)

train_data = ConcatDataset([train_control, train_paranoid, train_rest])
val_data = ConcatDataset([val_control, val_paranoid, val_rest])
test_data = ConcatDataset([test_control, test_paranoid, test_rest])


print('Train')
print(len(train_control), '+', len(train_paranoid), '+', len(train_rest), '=', len(train_data))
print()
print('Val')
print(len(val_control), '+', len(val_paranoid), '+', len(val_rest), '=', len(val_data))
print()
print('Test')
print(len(test_control), '+', len(test_paranoid), '+', len(test_rest), '=', len(test_data))
print()
print('Total')
print(len(train_control)+len(val_control)+(len(test_control)), '+', len(train_paranoid)+len(val_paranoid)+(len(test_paranoid)), '+', len(train_rest)+len(val_rest)+(len(test_rest)), '=', len(train_data)+len(val_data)+len(test_data))

trainloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, generator=seed, num_workers=0)
valloader = DataLoader(val_data, batch_size=1, shuffle=True, generator=seed, num_workers=0)
# testloader = DataLoader(test_data, batch_size=1, shuffle=True, generator=seed, num_workers=0)

# calculate class weights for the loss fuction
if BINARY:
    class_weights = [len(train_data)/(NUM_CLASSES*len(train_control)), len(train_data)/(NUM_CLASSES*(len(train_paranoid)+len(train_rest)))]
else:
    class_weights = [len(train_data)/(NUM_CLASSES*len(train_control)), len(train_data)/(NUM_CLASSES*len(train_paranoid)), len(train_data)/(NUM_CLASSES*len(train_rest))]
print(class_weights)

In [None]:
class dim_net(nn.Module):
    def __init__(self, num_classes, dropout_rate, layer_size):
        super(dim_net, self).__init__()
        # needs params num_classes, dropout rate and layer size

        self.conv1 = (nn.Conv3d(150, layer_size, kernel_size=(3, 3, 3), padding=1))
        self.conv1_2 = (nn.Conv3d(layer_size, 1, kernel_size=(3, 3, 3), padding=1))
        self.pool1 = (nn.MaxPool3d(kernel_size=(2,2,2), stride=2))
        self.dropout = (nn.Dropout(dropout_rate))
        self.conv2 = (nn.Conv2d(47, 1, kernel_size=(3, 3), padding=1))
        self.pool2 = (nn.MaxPool2d(kernel_size=(2,2), stride=2))
        self.fc1 = (nn.Linear(19*19, 4*4))
        self.fc2 = (nn.Linear(4*4, num_classes))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv1_2(x))
        x = self.dropout(x)
        x = self.pool1(x)
        x = torch.squeeze(x, dim=1)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = self.pool2(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = dim_net(NUM_CLASSES, DROPOUT_RATE, LAYER_SIZE)

In [None]:
# train function
best_params, losses_val, losses_train, accuracies_val, accuracies_train, balanced_accuracies, balanced_accuracies_train, best_epoch = train_model(net, device, trainloader, valloader, N_EPOCHS, LEARNING_RATE, DROPOUT_RATE, LAYER_SIZE, MODE, class_weights)

In [None]:
best_model = dim_net(NUM_CLASSES, DROPOUT_RATE, LAYER_SIZE)
best_model.load_state_dict(best_params)

# filename = _
# best_model = torch.load(filename, map_location=torch.device('cpu'))

best_model.to(device)

In [None]:
print('Val')
y_true, y_pred = validate(best_model, device, valloader, class_weights)
print(classification_report(y_true, y_pred))

In [None]:
print('Train')
y_true_train, y_pred_train = validate(best_model, device, trainloader, class_weights)
print(classification_report(y_true_train, y_pred_train))

In [None]:
# print('Test')
# y_true_test, y_pred_test = validate(best_model, device, testloader, class_weights)
# print(classification_report(y_true_test, y_pred_test))

In [None]:
print('Val')
print('True:', y_true)
unique, count = np.unique(y_true, return_counts=True)
print(unique, count)
print('Pred:', y_pred)
unique, count = np.unique(y_pred, return_counts=True)
print(unique, count)
unique, count = np.unique([t==y for t,y in zip(y_true, y_pred)], return_counts=True)
print(unique, count)
print('Number of wrong predictions:', [count[0] if len(count) != 1 else 0])

In [None]:
print('Train')
print('True:', y_true_train)
unique, count = np.unique(y_true_train, return_counts=True)
print(unique, count)
print('Pred:', y_pred_train)
unique, count = np.unique(y_pred_train, return_counts=True)
print(unique, count)
unique, count = np.unique([t==y for t,y in zip(y_true_train, y_pred_train)], return_counts=True)
print(unique, count)
print('Number of wrong predictions:', [count[0] if len(count) != 1 else 0])

In [None]:
# print('Test')
# print('True:', y_true_test)
# unique, count = np.unique(y_true_test, return_counts=True)
# print(unique, count)
# print('Pred:', y_pred_test)
# unique, count = np.unique(y_pred_test, return_counts=True)
# print(unique, count)
# unique, count = np.unique([t==y for t,y in zip(y_true_test, y_pred_test)], return_counts=True)
# print(unique, count)
# print('Number of wrong predictions:', [count[0] if len(count) != 1 else 0])

In [None]:
# date = datetime.now().strftime('%d%m')
# # date = str(1503) # March 15th
# data_dir = '/model_acc&loss/' + date + '_' + str(LEARNING_RATE) + '_' + str(DROPOUT_RATE) + '_' +  str(LAYER_SIZE)

# validation_losses = torch.load(data_dir + '/losses_val')
# train_losses = torch.load(data_dir + '/losses_train')
# validation_accuracies = torch.load(data_dir + '/acc_val')
# train_accuracies = torch.load(data_dir + '/acc_train')
# validation_bal_acc = torch.load(data_dir + '/bal_acc_val')
# train_bal_acc = torch.load(data_dir + '/bal_acc_train')
# best_epoch = np.argmax(validation_bal_acc) + 1

In [None]:
# xvalues = np.linspace(1, N_EPOCHS, len(validation_losses))
# yvalues01 = validation_losses
# yvalues02 = train_losses
# name = 'Dropout=' + str(DROPOUT_RATE)

# plt.plot(xvalues, yvalues01, label=name + " val loss", color='r')
# plt.plot(xvalues, yvalues02, label=name + " train loss", color='b', linestyle="dashed")
# max_value = 4.7 #np.max([np.max(validation_losses), np.max(train_losses)])
# min_value = np.min([np.min(validation_losses), np.min(train_losses)])
# plt.plot([best_epoch, best_epoch], [min_value,max_value], alpha=0.5, linewidth=3, label='Chosen model')

# plt.title('Validation and training loss')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.grid(True)
# plt.legend()
# plt.xticks((np.array([0.1,1,2,3,4,5,6,7,8,9,10])/10)*N_EPOCHS)
# plt.ylim(0,5)
# plt.show()

In [None]:
# xvalues = np.linspace(1, N_EPOCHS, len(validation_accuracies))
# yvalues01 = validation_accuracies
# yvalues02 = train_accuracies
# name = 'Dropout=' + str(DROPOUT_RATE)

# plt.plot(xvalues, yvalues01, label=name + " val acc", color='r')
# plt.plot(xvalues, yvalues02, label=name + " train acc", color='b', linestyle="dashed")
# plt.plot([best_epoch, best_epoch], [0,1], alpha=0.5, linewidth=3, label='Chosen model')

# plt.title('Validation and training accuracy')
# plt.xlabel('Epoch')
# plt.ylabel('Accuracy')
# plt.grid(True)
# plt.legend()
# plt.xticks((np.array([1,2,3,4,5,6,7,8,9,10])/10)*N_EPOCHS)
# plt.yticks(np.array([0,1,2,3,4,5,6,7,8,9,10])/10)
# plt.show()

In [None]:
# xvalues = np.linspace(1, N_EPOCHS, len(validation_bal_acc))
# yvalues01 = validation_bal_acc
# yvalues02 = train_bal_acc
# name = 'Dropout=' + str(DROPOUT_RATE)

# plt.plot(xvalues, yvalues01, label=name + " val bal acc", color='r')
# plt.plot(xvalues, yvalues02, label=name + " train bal acc", color='b', linestyle="dashed")
# plt.plot([best_epoch, best_epoch], [0,1], alpha=0.5, linewidth=3, label='Chosen model')

# plt.title('Validation and training balanced accuracy')
# plt.xlabel('Epoch')
# plt.ylabel('Balanced accuracy')
# plt.grid(True)
# plt.legend()
# plt.xticks((np.array([1,2,3,4,5,6,7,8,9,10])/10)*N_EPOCHS)
# plt.yticks(np.array([0,1,2,3,4,5,6,7,8,9,10])/10)
# plt.show()