In [1]:
import matplotlib.pyplot as plt
from mlp import MLP
from train_test import train_model, test_model
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import pandas as pd
import nibabel as nib
import random

In [2]:
class MyDataset(Dataset):
    def __init__(self, img_dir, train, parity, warp, device):
        if train:
            self.ids = np.load('train.npy', allow_pickle=True)[:, 0]
        else:
            self.ids = np.load('validation.npy', allow_pickle=True)[:, 0]
        self.df = pd.read_csv("dHCP_gDL_demographic_data.csv")
        self.df.insert(0, "ID", "sub-" + self.df["Subject ID "] + "_" + "ses-" + self.df["Session ID"].apply(str))
        self.df.drop("Subject ID ", axis=1, inplace=True)
        self.df.drop("Session ID", axis=1, inplace=True)
        self.mirror_index = np.load('mirror_index.npy') # mirrors the right hemisphere
        self.train = train
        self.img_dir = img_dir
        self.parity = parity
        self.warp = warp
        self.mean = np.load('means_template.npy') # native for the other dataset
        self.std = np.load('stds_template.npy') # native for the other dataset
        self.neigh_orders = np.load('neigh_orders.npy')
        self.device = device
    
    def __len__(self):
        return 2 * len(self.ids) if self.parity == 'both' else len(self.ids)

    def __getitem__(self, idx):
        _id = self.ids[idx // 2] if self.parity == 'both' else self.ids[idx]
        y = np.array([self.df.loc[self.df['ID'] == _id, 'GA at birth (weeks)'].item(),
                      self.df.loc[self.df['ID'] == _id, 'PMA at scan (weeks)'].item(),
                      self.df.loc[self.df['ID'] == _id, 'Birthweight (kg)'].item()])
        parity_string = '_'
        if self.parity == 'both':
            if idx % 2 == 0:
                parity_string += 'L'
            else:
                parity_string += 'R'
        elif self.parity == 'left':
            parity_string += 'L'
        elif self.parity == 'right':
            parity_string += 'R'
        if self.warp:
            parity_string += '_W%d' % (random.randint(1, 100))
        img = nib.load(self.img_dir + _id + parity_string + '.shape.gii')
        x = np.stack(img.agg_data(), axis=1)
        x = x[self.neigh_orders].reshape([x.shape[0], 28])
        if parity_string == '_R':
            x = x[self.mirror_index]
        x = (x - self.mean) / self.std
        return torch.from_numpy(x).to(torch.float32).to(self.device), torch.from_numpy(y).to(torch.float32).to(self.device)


In [None]:
path = '../data/regression_template_space_features_warped/' # native for the other dataset
parity = 'both' # hemisphere choice: 'both', 'left', 'right'
warp = False
batch_size = 32
breakpoint = 1000
patience = 100 # for early stopping
lr = 0.001
#weight_decay = 0.01

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

# seed = 123

# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)

# np.random.seed(seed)

train_dataset = MyDataset(path, train=True, parity=parity, warp=warp, device=device)
val_dataset = MyDataset(path, train=False, parity=parity, warp=False, device=device)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False, drop_last=False)

## Model definition

In [None]:
model = MLP(28, [28, 28, 28, 28], 3, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print(model)
print('Number of parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))

## Training the model

In [None]:
val_losses = []
tr_losses = []
# train with Adam
best_val_index = -1
p_counter = 0
for epoch in range(1000):
    train_loss = train_model(train_loader, model, optimizer)
    val_loss = test_model(val_loader, model)
    val_loss = val_loss.cpu().detach().numpy()
    new_min = " "
    if epoch > 0:
        if val_losses[best_val_index] > val_loss:
            new_min = "*"
            best_val_index = epoch
            torch.save(model.state_dict(), 'MLP_template.pt')
            p_counter=0
        else:
            p_counter+=1
        # max epoch is reached
        if epoch - best_val_index > breakpoint:
            print ("Max epoch is reached")
            break
        # early stopping is called
        if p_counter > patience:
            print ("Early stopping, best val loss and index:")
            print(val_losses[best_val_index], best_val_index)
            break
    tr_losses.append(train_loss.cpu().detach().numpy())
    val_losses.append(val_loss)
    print(new_min, "Epoch: %d, train loss: %1.3f, val loss: %1.3f" % (epoch, train_loss, val_loss))


In [None]:
plt.figure(figsize=(8, 5))
plt.plot(tr_losses, 'g')
plt.plot(val_losses, 'b')
plt.show()