In [None]:
from mlp import MLP
from train_test import train_model, test_model
import torch
import pandas as pd
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

In [None]:
data_path = '../data/release/native/' # '../data/release/template/' for template space prediction
task = 'scan_age' # 'birth_age' for birth age prediction

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

batch_size = 32
epochs = 2000
patience = 200 # for early stopping
lr = 0.001


In [None]:
train_ids = np.loadtxt(task + '_train.txt', dtype='str')
val_ids = np.loadtxt(task + '_validation.txt', dtype='str')
test_ids = np.loadtxt(task + '_test.txt', dtype='str')

mirror_index = np.load('mirror_index.npy') # mirrors right hemispheres to match with left hemispheres

df = pd.read_csv("combined.csv")

df.insert(0, "ID", "sub-" + df["participant_id"] + "_" + "ses-" + df["session_id"].apply(str))
df.drop("participant_id", axis=1, inplace=True)
df.drop("session_id", axis=1, inplace=True)

df

In [None]:
def get_data(data_path, task, ids):
    xs = []
    ys = []
    for _id in ids:
        try:
            img_L = nib.load(data_path + _id + '_left.shape.gii')
            x_L = np.stack(img_L.agg_data(), axis=1)
            for i in range(4):
                # replaces the zeros of the medial wall cut area with mean values
                x_L[:, i][x_L[:, i] == 0] = np.mean(x_L[:, i][x_L[:, i] != 0])
            xs.append(x_L.astype(np.float32))
            y = np.array([df.loc[df['ID'] == _id, task].item()])
            ys.append(y.astype(np.float32))
            img_R = nib.load(data_path + _id + '_right.shape.gii')
            x_R = np.stack(img_R.agg_data(), axis=1)[mirror_index] # mirroring
            for i in range(4):
                # replaces the zeros of the medial wall cut area with mean values
                x_R[:, i][x_R[:, i] == 0] = np.mean(x_R[:, i][x_R[:, i] != 0])
            xs.append(x_R.astype(np.float32))
            y = np.array([df.loc[df['ID'] == _id, task].item()])
            ys.append(y.astype(np.float32))
        except:
            print('train set element %s does not exist' % _id)
    return xs, ys

train_xs, train_ys = get_data(data_path, task, train_ids)
val_xs, val_ys = get_data(data_path, task, val_ids)
test_xs, test_ys = get_data(data_path, task, test_ids)


In [None]:
# data standardization

train_xs = np.transpose(train_xs, axes=[1, 2, 0])
means  = np.mean(np.mean(train_xs, axis=2), axis=0) # means of the 4 channels in the train set
stds  = np.std(np.std(train_xs, axis=2), axis=0) # stds of the 4 channels in the train set
train_xs = (train_xs - means.reshape(1, means.shape[0], 1)) / stds.reshape(1, means.shape[0], 1)
train_xs = np.transpose(train_xs, axes=[2, 0, 1])

val_xs = np.transpose(val_xs, axes=[1, 2, 0])
val_xs = (val_xs - means.reshape(1, means.shape[0], 1)) / stds.reshape(1, means.shape[0], 1)
val_xs = np.transpose(val_xs, axes=[2, 0, 1])

test_xs = np.transpose(test_xs, axes=[1, 2, 0])
test_xs = (test_xs - means.reshape(1, means.shape[0], 1)) / stds.reshape(1, means.shape[0], 1)
test_xs = np.transpose(test_xs, axes=[2, 0, 1])

In [None]:
train_subset = [(torch.from_numpy(x), torch.from_numpy(y)) for x, y in zip(train_xs, train_ys)]

val_subset = [(torch.from_numpy(x), torch.from_numpy(y)) for x, y in zip(val_xs, val_ys)]

test_subset = [(torch.from_numpy(x), torch.from_numpy(y)) for x, y in zip(test_xs, test_ys)]


In [None]:
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_subset, batch_size=len(val_subset), shuffle=False)
test_loader = torch.utils.data.DataLoader(test_subset, batch_size=len(test_subset), shuffle=False)

model = MLP(4, [16, 16, 16, 16], 1, device=device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(model)
print('Number of parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))


In [None]:
train_losses = []
val_losses = []
test_losses = []

best_val_index = -1

for epoch in range(epochs):
    train_loss = train_model(train_loader, model, optimizer).cpu().detach().numpy()
    val_loss = test_model(val_loader, model).cpu().detach().numpy()
    test_loss = test_model(test_loader, model).cpu().detach().numpy()
    new_min = " "
    if epoch > 0:
        if val_losses[best_val_index] > val_loss:
            new_min = "*"
            best_val_index = epoch
            torch.save(model.state_dict(), 'MLP_paper.pt')
        # early stopping is called
        if len(val_losses) - best_val_index > patience:
            print ("Early stopping, best val loss and index:")
            print(val_losses[best_val_index], best_val_index)
            break
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    test_losses.append(test_loss)
    print(new_min, "Epoch: %d, train loss: %1.3f, val loss: %1.3f, test loss: %1.3f" % (epoch, train_loss, val_loss, test_loss))


In [None]:
plt.plot(train_losses)
plt.plot(val_losses)
plt.plot(test_losses)
plt.show()