In [None]:
import os
import numpy as np
from sklearn import metrics
from matplotlib import pyplot as plt
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Modality ['t1', 't2', 't1ce', 'flair']
modality = 'flair'

# Fold 
fold = 1

In [None]:
rad_folder = '/path/to/rad/features'
g_rad_folder = '/path/to/guided/rad/features/'
label_folder = '/path/to/labels/'
results_dir = '/path/to/results'

In [None]:
class GuidanceBoth_dataset(Dataset):
    def __init__(self, rad_folder, g_rad_folder, label_folder, modality, fold, datasplit):
        
        self.rad_array = np.load(os.path.join(rad_folder, modality, str(fold), datasplit+'_'+str(fold)+'_'+modality+'.npy'))
        self.g_rad_array = np.load(os.path.join(g_rad_folder, modality, modality+'_'+str(fold)+'_'+datasplit+'.npy'))
        
        label_csv = pd.read_csv(os.path.join(label_folder, 'split_'+str(fold)+'_'+datasplit+'.csv'))
        label_list = label_csv['class'].tolist()
        label_dict = {'G':0, 'O':1, 'A':2}
        self.labels = [label_dict[i] for i in label_list]
        
        
    def __len__(self):
        assert len(self.rad_array) == len(self.g_rad_array)
        assert len(self.rad_array) == len(self.labels)
        return len(self.rad_array)
    
    def __getitem__(self, idx):
        r = self.rad_array[idx]
        gr = self.g_rad_array[idx]
        c = np.concatenate([gr, r])
        label = self.labels[idx]
        data = {'input':c, 'label':label}
        return data

In [None]:
train_set = GuidanceBoth_dataset(rad_folder, g_rad_folder, label_folder, modality, fold, 'train')
val_set = GuidanceBoth_dataset(rad_folder, g_rad_folder, label_folder, modality, fold, 'val')
test_set = GuidanceBoth_dataset(rad_folder, g_rad_folder, label_folder, modality, fold, 'test')

In [None]:
# weighted random sampler
label_dict = {'G':0, 'O':1, 'A':2}
csv_file = os.path.join(label_folder, 'split_'+str(fold)+'_'+'train.csv')
df = pd.read_csv(csv_file, usecols=['class'])
y_train = df['class']
labels = [label_dict[t] for t in y_train]
labels = np.array(labels)
class_sample_count = np.array([len(np.where(labels==t)[0]) for t in np.unique(labels)]) # np.unique returns sorted unique values
class_sample_probabilities = 1./class_sample_count
sample_probabilities = np.array([class_sample_probabilities[t] for t in labels])
sample_probabilities = torch.from_numpy(sample_probabilities)
wrs = WeightedRandomSampler(weights = sample_probabilities.type('torch.DoubleTensor'), num_samples = len(sample_probabilities), replacement = True)

In [None]:
train_loader = DataLoader(train_set, batch_size=50, sampler=wrs)
val_loader = DataLoader(val_set, batch_size=30, shuffle=False)
test_loader = DataLoader(test_set, batch_size=30, shuffle=False)

In [None]:
class GuidanceBoth_architecture(nn.Module):
    def __init__(self):
        super(GuidanceBoth_architecture, self).__init__()
        
        self.l1 = nn.Linear(in_features=1536, out_features=50)
        self.l2 = nn.Linear(in_features=50, out_features=3)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.dropout(x)
        x = F.relu(self.l2(x))
        x = self.dropout(x)
        return x
    
model = GuidanceBoth_architecture()
model.to(device)

In [None]:
loss_weights = torch.tensor([1., 1.7, 1.62]).to(device)
criterion = nn.CrossEntropyLoss(weight=loss_weights)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-6, nesterov=True)

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=20, stop_epoch=50, verbose=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 20
            stop_epoch (int): Earliest epoch possible for stopping
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.patience = patience
        self.stop_epoch = stop_epoch
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, epoch, val_loss, model, ckpt_name = 'checkpoint.pt'):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
        elif score < self.best_score:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience and epoch > self.stop_epoch:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, ckpt_name):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), ckpt_name)
        self.val_loss_min = val_loss
        
# early_stopping = False
early_stopping = EarlyStopping(patience=200, stop_epoch=100, verbose=False)

In [None]:
# train loop

train_loss = []
val_loss = []
epochs = 500

for epoch in range(epochs):
    running_train_loss = 0.0
    model.train()
    for data in train_loader:
        c, label = data['input'], data['label']
        c, label = c.to(device), label.to(device)
        optimizer.zero_grad()
        pred = model(c)
        loss_train = criterion(pred, label)
        loss_train.backward()
        optimizer.step()
        running_train_loss += loss_train.item()

    loss_train_avg = running_train_loss/len(train_loader)
    train_loss.append(loss_train_avg)
    print('Epoch {} of {}, Train Loss: {:.3f}'.format(epoch+1, epochs, loss_train_avg))
    
    running_val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for data in val_loader:
            c, label = data['input'], data['label']
            c, label = c.to(device), label.to(device)
            pred = model(c)
            loss_val = criterion(pred, label)
            running_val_loss += loss_val.item()
            
        loss_val_avg = running_val_loss/len(val_loader)
        val_loss.append(loss_val_avg)
        
    if early_stopping:
        early_stopping(epoch, loss_val_avg, model, ckpt_name=os.path.join(results_dir, 'guided_'+modality+'_'+str(fold)+'.pt'))
        if early_stopping.early_stop:
            print("Early stopping")
            break

In [None]:
if early_stopping:
    model.load_state_dict(torch.load(os.path.join(results_dir, 'guided_'+modality+'_'+str(fold)+'.pt')))

In [None]:
plt.plot(train_loss, label='train_loss')
plt.plot(val_loss, label='val_loss')
plt.legend()

In [None]:
# Testing the network 

test_loss = []
running_test_loss = 0.0

model.eval()
with torch.no_grad():
    for data in test_loader:
        c, label = data['input'], data['label']
        c, label = c.to(device), label.to(device)
        pred = model(c)
        predict = torch.argmax(pred, dim=1)
        loss_test = criterion(pred, label)
        running_test_loss += loss_test.item()

    loss_test_avg = running_test_loss/len(test_loader)
    test_loss.append(loss_test_avg)

In [None]:
ba = metrics.balanced_accuracy_score(label.cpu().numpy(), predict.cpu().numpy())