In [1]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from datetime import datetime
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from sklearn.metrics import accuracy_score, f1_score, balanced_accuracy_score, recall_score, precision_score, classification_report

from train_model_multi_grouped_bal import train_model
from validate_multi import validate

torch.manual_seed(1337)
BATCH_SIZE = 1 #5
N_EPOCHS = 100
LEARNING_RATE = 0.01
DROPOUT_RATE = 0.2
LAYER_SIZE = 5
SPLIT = 0.45
NUM_CLASSES = 3

device = ("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
class COBRE_dataset_subj(torch.utils.data.Dataset):
    """This Dataset class loads pairs of COBRE images and labels into the computer memory. The entire subject file is loaded.

    Parameters:
    - data_path (str): Path to the folder containing images and labels
    Returns:
    - Tensor: Torch tensor with COBRE time-series
    - int: label (0 or 1)

   """
    def __init__(self, data_path):
        self.data_path = data_path
        self.num_subjects = len(os.listdir(self.data_path))

    def __len__(self):
        return self.num_subjects

    def __getitem__(self, index):
        subject = os.listdir(self.data_path)[index]
        img = self.data_path + '/' + subject + '/session_1/rest_1/nourest.nii'
        label = self.data_path + '/' + subject + '/session_1/' + subject + '_data.csv'

        img = nib.load(img).get_fdata()
        img = np.swapaxes(img, 0, 3)
        img = torch.from_numpy(img)
        img = img.to(torch.float)

        label = pd.read_csv(label)
        subject_type = label.iloc[0]['Subject Type']

        ### LABEL
        if subject_type == 'Patient':
            label = str(label.iloc[0]['Diagnosis'])
            if label == '295.1' or label ==  '295.2' or label ==  '295.3' or label ==  '295.6' or label ==  '295.9':
                label = 1
            else:
                label = 2
        elif subject_type == 'Control':
            label = 0
        else:
            print('Something wrong with data label:', end=' ')
            print(label)
        
        return img, label

In [3]:
control_path = 'C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/data/regrouped/control'
paranoid_path = 'C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/data/regrouped/paranoid'
rest_path = 'C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/data/regrouped/rest'

control_data = COBRE_dataset_subj(control_path)
paranoid_data = COBRE_dataset_subj(paranoid_path)
rest_data = COBRE_dataset_subj(rest_path)

seed = torch.manual_seed(1337)

train_control, test_val_control = random_split(control_data, [1-SPLIT, SPLIT], seed)
train_paranoid, test_val_paranoid = random_split(paranoid_data, [1-SPLIT, SPLIT], seed)
train_rest, test_val_rest = random_split(rest_data, [1-SPLIT, SPLIT], seed)

val_control, test_control = random_split(test_val_control, [0.55, 0.45], seed)
val_paranoid, test_paranoid = random_split(test_val_paranoid, [0.55, 0.45], seed)
val_rest, test_rest = random_split(test_val_rest, [0.55, 0.45], seed)

print('Control')
print(len(train_control), len(train_paranoid), len(train_rest))
print()
print('Val')
print(len(val_control), len(val_paranoid), len(val_rest))
print()
print('Test')
print(len(test_control), len(test_paranoid), len(test_rest))
print()

train_data = ConcatDataset([train_control, train_paranoid, train_rest])
val_data = ConcatDataset([val_control, val_paranoid, val_rest])
# test_data = ConcatDataset([test_control, test_paranoid, test_rest])

trainloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, generator=seed, num_workers=0)
valloader = DataLoader(val_data, batch_size=1, shuffle=True, generator=seed, num_workers=0)
# testloader = DataLoader(test_data, batch_size=1, shuffle=True, generator=seed, num_workers=0)

Control
40 35 4

Val
18 15 2

Test
14 12 1



In [None]:
class UNET_Mari(nn.Module):
    def __init__(self, num_classes, dropout_rate, layer_size):
        super(UNET_Mari, self).__init__()
        
        # Convolutional layers
        self.conv1 = (nn.Conv3d(150, layer_size, kernel_size=(3, 3, 3), padding=1))
        self.conv1_2 = (nn.Conv3d(layer_size, 1, kernel_size=(3, 3, 3), padding=1))
        self.pool1 = (nn.MaxPool3d(kernel_size=(2,2,2), stride=2))
        self.dropout = (nn.Dropout(dropout_rate))
        self.conv2 = (nn.Conv2d(47, 1, kernel_size=(3, 3), padding=1))
        # also more layers here, perhaps
        self.pool2 = (nn.MaxPool2d(kernel_size=(2,2), stride=2))
        self.fc1 = (nn.Linear(19*19, 4*4)) # one functional layer or two?
        self.fc2 = (nn.Linear(4*4, num_classes))

    def forward(self, x): # should ReLU be used in every step here?
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv1_2(x))
        x = self.dropout(x)
        x = self.pool1(x)
        x = torch.squeeze(x, dim=1)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = self.pool2(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

unet = UNET_Mari(NUM_CLASSES, DROPOUT_RATE, LAYER_SIZE)

In [None]:
best_params, losses_val, losses_train, accuracies_val, accuracies_train, balanced_accuracies, balanced_accuracies_train, best_epoch = train_model(unet, device, trainloader, valloader, N_EPOCHS, LEARNING_RATE, DROPOUT_RATE, LAYER_SIZE)

In [None]:
best_model = UNET_Mari(NUM_CLASSES, DROPOUT_RATE, LAYER_SIZE)
best_model.load_state_dict(best_params)

# best_model = torch.load('C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/multilabel_regrouped/models/1503_0.01_0.2_5/best_model_057407_bal.pt')

best_model.to(device)

In [None]:
# for param in best_model.parameters():
#     if param.requires_grad:
#         param = param
#         break
# param = deepcopy(param)
# param = param.to('cpu').detach().numpy()
# plt.imshow(param)
# plt.show()

In [None]:
# my_scan = nib.load('C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/my_scan/fMRI_231207_ERC2VF_100/005_fMRI_default-pulsm_ling/nou005_fMRI_default-pulsm_ling.nii')
# my_scan = my_scan.get_fdata()
# my_scan = np.swapaxes(my_scan, 0, 3)
# my_scan = torch.from_numpy(my_scan)
# my_scan = my_scan.to(torch.float)
# my_scan = my_scan[0:150]
# my_scan = my_scan.to(device)

# with torch.no_grad():
#     pred = best_model(my_scan)
#     _, predicted = torch.max(pred.data, 1)
# print(pred)
# print(predicted)

In [None]:
print('Val')
y_true, y_pred = validate(best_model, device, valloader)
print(classification_report(y_true, y_pred))

In [None]:
print('Train')
y_true_train, y_pred_train = validate(best_model, device, trainloader)
print(classification_report(y_true_train, y_pred_train))

In [None]:
# ### METRICS
# print('Val')
# print('Accuracy:', accuracy_score(y_true, y_pred))
# print('F1-score:', f1_score(y_true, y_pred, average='weighted'))
# print('Balanced accuracy score:', balanced_accuracy_score(y_true, y_pred))
# print('Balanced accuracy score:', balanced_accuracies[best_epoch-1])
# print('Recall:', recall_score(y_true, y_pred, average='weighted'))
# print('Precision:', precision_score(y_true, y_pred, average='weighted'))
# print()
# print('Train')
# print('Accuracy:', accuracy_score(y_true_train, y_pred_train))
# print('F1-score:', f1_score(y_true_train, y_pred_train, average='weighted'))
# print('Balanced accuracy score:', balanced_accuracy_score(y_true_train, y_pred_train))
# print('Balanced accuracy score:', balanced_accuracies_train[best_epoch-1])
# print('Recall:', recall_score(y_true_train, y_pred_train, average='weighted'))
# print('Precision:', precision_score(y_true_train, y_pred_train, average='weighted'))

In [None]:
print('Val')
print('True:', y_true)
unique, count = np.unique(y_true, return_counts=True)
print(unique, count)
print('Pred:', y_pred)
unique, count = np.unique(y_pred, return_counts=True)
print(unique, count)
unique, count = np.unique([t==y for t,y in zip(y_true, y_pred)], return_counts=True)
print(unique, count)
print('Number of wrong predictions:', [count[0] if len(count) != 1 else 0])

In [None]:
print('Train')
print('True:', y_true_train)
unique, count = np.unique(y_true_train, return_counts=True)
print(unique, count)
print('Pred:', y_pred_train)
unique, count = np.unique(y_pred_train, return_counts=True)
print(unique, count)
unique, count = np.unique([t==y for t,y in zip(y_true_train, y_pred_train)], return_counts=True)
print(unique, count)
print('Number of wrong predictions:', [count[0] if len(count) != 1 else 0])

In [None]:
# ### VAL MAP
# map = [t==p for t,p in zip(y_true, y_pred)]
# data_path = 'C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/data/val'

# for i in range(len(val)):
#     if(not map[i]):
#         _, label = val[i]
#         # label = data_path + '/' + subject + '/' + subject + '_data.csv'
#         # label = pd.read_csv(label)
#         # label = label.iloc[0]
#         print(label)
#         print()

In [None]:
# ### TRAIN MAP
# map = [t==p for t,p in zip(y_true_train, y_pred_train)]
# data_path = 'C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/data/train'

# for i in range(len(train_data)):
#     if(not map[i]):
#         subject = os.listdir(data_path)[i]
#         label = data_path + '/' + subject + '/' + subject + '_data.csv'
#         label = pd.read_csv(label)
#         label = label.iloc[0]
#         print(label)
#         print()

In [None]:
date = datetime.now().strftime('%d%m')
# date = str(1503) # March 15th
data_dir = 'C:/Users/oscar/OneDrive - University of Bergen/Documents/Master/vsc/COBRE_learning/multilabel_regrouped/model_acc&loss/' + date + '_' + str(LEARNING_RATE) + '_' + str(DROPOUT_RATE) + '_' +  str(LAYER_SIZE)

validation_losses = torch.load(data_dir + '/losses_val')
train_losses = torch.load(data_dir + '/losses_train')
validation_accuracies = torch.load(data_dir + '/acc_val')
train_accuracies = torch.load(data_dir + '/acc_train')
validation_bal_acc = torch.load(data_dir + '/bal_acc_val')
train_bal_acc = torch.load(data_dir + '/bal_acc_train')
best_epoch = np.argmax(validation_bal_acc) + 1

In [None]:
xvalues = np.linspace(1, N_EPOCHS, len(validation_losses))
yvalues01 = validation_losses
yvalues02 = train_losses
name = 'Dropout=' + str(DROPOUT_RATE)

plt.plot(xvalues, yvalues01, label=name + " val loss", color='r')
plt.plot(xvalues, yvalues02, label=name + " train loss", color='b', linestyle="dashed")
max_value = 4.7 #np.max([np.max(validation_losses), np.max(train_losses)])
min_value = np.min([np.min(validation_losses), np.min(train_losses)])
plt.plot([best_epoch, best_epoch], [min_value,max_value], alpha=0.5, linewidth=3, label='Chosen model')

plt.title('Validation and training loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
plt.xticks((np.array([0.1,1,2,3,4,5,6,7,8,9,10])/10)*N_EPOCHS)
plt.ylim(0,5)
plt.show()

In [None]:
xvalues = np.linspace(1, N_EPOCHS, len(validation_accuracies))
yvalues01 = validation_accuracies
yvalues02 = train_accuracies
name = 'Dropout=' + str(DROPOUT_RATE)

plt.plot(xvalues, yvalues01, label=name + " val acc", color='r')
plt.plot(xvalues, yvalues02, label=name + " train acc", color='b', linestyle="dashed")
plt.plot([best_epoch, best_epoch], [0,1], alpha=0.5, linewidth=3, label='Chosen model')

plt.title('Validation and training accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()
plt.xticks((np.array([1,2,3,4,5,6,7,8,9,10])/10)*N_EPOCHS)
plt.yticks(np.array([0,1,2,3,4,5,6,7,8,9,10])/10)
plt.show()

In [None]:
xvalues = np.linspace(1, N_EPOCHS, len(validation_bal_acc))
yvalues01 = validation_bal_acc
yvalues02 = train_bal_acc
name = 'Dropout=' + str(DROPOUT_RATE)

plt.plot(xvalues, yvalues01, label=name + " val bal acc", color='r')
plt.plot(xvalues, yvalues02, label=name + " train bal acc", color='b', linestyle="dashed")
plt.plot([best_epoch, best_epoch], [0,1], alpha=0.5, linewidth=3, label='Chosen model')

plt.title('Validation and training balanced accuracy')
plt.xlabel('Epoch')
plt.ylabel('Balanced accuracy')
plt.grid(True)
plt.legend()
plt.xticks((np.array([1,2,3,4,5,6,7,8,9,10])/10)*N_EPOCHS)
plt.yticks(np.array([0,1,2,3,4,5,6,7,8,9,10])/10)
plt.show()