In [1]:
# scan age or birth age
task = 'scan_age'

model_name = 'MLP.pt'

path = '/home/daniel/data/release/'

# hyperparameters
bs = 8
lr = 0.001
epochs = 200
hidden = 64
features = 'pos+norm+dha+x'

in_channels = 0
if 'pos' in features:
    in_channels += 3
if 'norm' in features:
    in_channels += 3
if 'dha' in features:
    in_channels += 3
if 'x' in features:
    in_channels += 4

import numpy as np
import pandas as pd
import nibabel as nib
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from MLP import MLP, actor_MLP
from GCN import GCN, actor_GCN
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
import matplotlib.pyplot as plt

In [2]:
log_dir=f'runs/invase/{task}/{model_name}/features={features}/bs={bs}_lr={lr}_epoch={epochs}_hidden={hidden}'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

train_ids = pd.read_csv(task + '_train.txt', header=None)
val_ids = pd.read_csv(task + '_val.txt', header=None)
test_ids = pd.read_csv(task + '_test.txt', header=None)

df = pd.read_csv("combined.tsv", sep='\t')

df.insert(0, "ID", "sub-" + df["participant_id"] + "_" + "ses-" + df["session_id"].apply(str))
df.drop("participant_id", axis=1, inplace=True)
df.drop("session_id", axis=1, inplace=True)

transform = T.Compose([T.NormalizeScale(), T.GenerateMeshNormals(), T.FaceToEdge()])

def get_data(path, task, ids):
    dataset = []
    for _id in ids[0]:
        try:
            surface = nib.load(os.path.join(path, 'surfaces', _id + '_left.wm.surf.gii'))
            pos, face = surface.agg_data()
            feature = nib.load(os.path.join(path, 'features', _id + '_left.shape.gii'))
            x = np.stack(feature.agg_data(), axis=1)
            y = np.array([[df.loc[df['ID'] == _id, task].item()]])
            data = Data()
            data.id = _id
            if 'x' in features:
                data.x = torch.from_numpy(x).to(torch.float32)
            data.pos = torch.from_numpy(pos).to(torch.float32)
            data.face = torch.from_numpy(face.T).to(torch.long)
            data.y = torch.from_numpy(y).to(torch.float32)
            if task == 'birth_age':
                confound = np.array([[df.loc[df['ID'] == _id, 'scan_age'].item()]])
                data.confound = torch.from_numpy(confound).to(torch.float32)
            data = transform(data)
            if 'norm' not in features:
                data.norm = None
            if 'dha' in features:
                data.dha = torch.from_numpy(np.load(os.path.join(path, 'preprocess/V_dihedral_angles', \
                                                                 _id + '_left.wm.surf_V_dihedralAngles.npy'))).to(torch.float32)
            # data.eig = torch.from_numpy(np.load(os.path.join(path, 'preprocess/aligned_eigen_vectors',
            #                                                     _id + '_left.wm.surf_eigen.npy'))).to(torch.float32)
            # data.curv = torch.from_numpy(np.load(os.path.join(path, 'preprocess/gaussian_curvatures',
            #                                                     _id + '_left.wm.surf_gaussian_curvature.npy'))).to(torch.float32).unsqueeze(1)
            # data.curv = (data.curv - data.curv.min()) / (data.curv.max() - data.curv.min())
            # data.hks = torch.from_numpy(np.load(os.path.join(path, 'preprocess/HKS',
            #                                                 _id + '_left.wm.surf_hks.npy'))).to(torch.float32)
            dataset.append(data)
        except Exception as error:
            print(error)
    return dataset

train_set = get_data(path, task, train_ids)
val_set = get_data(path, task, val_ids)
test_set = get_data(path, task, test_ids)

train_loader = DataLoader(train_set, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_set, batch_size=bs)
test_loader = DataLoader(test_set, batch_size=bs)

class Invase():
    def __init__(self):
        if model_name == 'MLP.pt':
            self.critic = MLP(in_channels=in_channels, hidden_channels=hidden, out_channels=1)
            self.baseline = MLP(in_channels=in_channels, hidden_channels=hidden, out_channels=1)
            self.actor = actor_MLP(in_channels=in_channels, hidden_channels=hidden, out_channels=in_channels)
        elif model_name == 'GCN.pt':
            self.critic = GCN(in_channels=in_channels, hidden_channels=hidden, out_channels=1)
            self.baseline = GCN(in_channels=in_channels, hidden_channels=hidden, out_channels=1)
            self.actor = actor_GCN(in_channels=in_channels, hidden_channels=hidden, out_channels=in_channels)
        self.critic = self.critic.to(device)
        self.critic.optimizer = torch.optim.AdamW(self.critic.parameters(), lr=lr)
        self.critic.criterion = nn.MSELoss()
        self.baseline = self.baseline.to(device)
        self.baseline.optimizer = torch.optim.AdamW(self.baseline.parameters(), lr=lr)
        self.baseline.criterion = nn.MSELoss()
        self.actor = self.actor.to(device)
        self.actor.optimizer = torch.optim.AdamW(self.actor.parameters(), lr=lr)
        self.actor.criterion = self.actor_loss
        self.lambda_ = 1.0

    def actor_loss(self, actor_pred, actor_out, critic_out, baseline_out, y_true):
        critic_loss = F.mse_loss(critic_out, y_true)
        baseline_loss = F.mse_loss(baseline_out, y_true)
        reward = -(critic_loss - baseline_loss)
        # reward * BCE(actor_pred, actor_out) - lambda * ||actor_pred||
        custom_actor_loss = reward * torch.sum(actor_out * torch.log(actor_pred + 1e-8) + \
                                               (1.0 - actor_out)* torch.log(1.0 - actor_pred + 1e-8), dim=1) - \
                                                self.lambda_ * torch.mean(actor_pred, dim=1)
        custom_actor_loss = torch.mean(-custom_actor_loss)
        return custom_actor_loss

invase = Invase()

No such file or no access: '/home/daniel/data/release/features/sub-CC00061XX04_ses-13300_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00084XX11_ses-31201_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00143AN12_ses-47501_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00170XX06_ses-56100_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00217XX11_ses-73700_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00221XX07_ses-75000_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00291XX12_ses-93100_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00307XX10_ses-98800_left.shape.gii'
No such file or no access: '/home/daniel/data/release/features/sub-CC00341XX12_ses-108000_left.shape.gii'
No such file or no access: '/home/daniel/data/release/

In [3]:
@torch.no_grad()
def test(loader, invase):
    actor_losses = []
    critic_losses = []
    critic_accs = []
    baseline_losses = []
    baseline_accs = []
    invase.baseline.eval()
    invase.critic.eval()
    invase.actor.eval()
    for data in loader:
        data = data.to(device)
        # baseline testing
        baseline_out = invase.baseline(data)
        baseline_loss = invase.baseline.criterion(baseline_out, data.y)
        baseline_losses.append(baseline_loss.item())
        baseline_acc = F.l1_loss(baseline_out, data.y)
        baseline_accs.append(baseline_acc.item())
        # critic testing
        selection_probability = invase.actor(data)
        selection = torch.bernoulli(selection_probability)
        critic_out = invase.critic(data, selection)
        critic_loss = invase.critic.criterion(critic_out, data.y)
        critic_losses.append(critic_loss.item())
        critic_acc = F.l1_loss(critic_out, data.y)
        critic_accs.append(critic_acc.item())
        # actor testing
        actor_loss = invase.actor.criterion(selection_probability, selection, critic_out, baseline_out, data.y)
        actor_losses.append(actor_loss.item())
        return sum(actor_losses) / len(actor_losses), \
               sum(critic_losses) / len(critic_losses), sum(critic_accs) / len(critic_accs), \
               sum(baseline_losses) / len(baseline_losses), sum(baseline_accs) / len(baseline_accs)

@torch.no_grad()
def plot_reg(loader, invase):
    invase.baseline.eval()
    invase.critic.eval()
    invase.actor.eval()
    baseline_outs = []
    critic_outs = []
    ys = []
    for data in loader:
        data = data.to(device)
        baseline_out = invase.baseline(data)
        baseline_outs.append(baseline_out.cpu().numpy())
        selection_probability = invase.actor(data)
        selection = torch.bernoulli(selection_probability)
        critic_out = invase.critic(data, selection)
        critic_outs.append(critic_out.cpu().numpy())
        ys.append(data.y.cpu().numpy())
    plt.scatter(np.concatenate(ys), np.concatenate(baseline_outs))
    plt.xlabel('y')
    plt.ylabel('baseline_out')
    plt.savefig(os.path.join(log_dir, 'baseline_regression.png'))
    plt.close()
    plt.scatter(np.concatenate(ys), np.concatenate(critic_outs))
    plt.xlabel('y')
    plt.ylabel('critic_out')
    plt.savefig(os.path.join(log_dir, 'critic_regression.png'))
    plt.close()

In [4]:
invase.baseline.load_state_dict(torch.load(os.path.join(log_dir, model_name + '_baseline')))
invase.critic.load_state_dict(torch.load(os.path.join(log_dir, model_name + '_critic')))
invase.actor.load_state_dict(torch.load(os.path.join(log_dir, model_name + '_actor')))

<All keys matched successfully>

In [5]:
plot_reg(test_loader, invase)

In [38]:
test(test_loader, invase)

(1.2609210014343262, 0.23795820772647858, 0.3807697296142578, 0.4252564013004303, 0.5857276916503906)


In [41]:
with torch.no_grad():
    for data in test_set:
        # save surface
        surface = nib.load(os.path.join(path, 'surfaces', data.id + '_left.wm.surf.gii'))
        nib.save(surface, os.path.join('selection', data.id + '_left.wm.surf.gii'))
        # evaluate model
        invase.baseline.eval()
        invase.critic.eval()
        invase.actor.eval()
        data = data.to(device)
        baseline_out = invase.baseline(data)
        selection_probability = invase.actor(data)
        selection = torch.bernoulli(selection_probability)
        critic_out = invase.critic(data, selection)
        # save selection
        feature = nib.load(os.path.join(path, 'features', data.id + '_left.shape.gii'))
        for i in range(3, -1, -1):
            feature.remove_gifti_data_array(i)
        for darray in selection.cpu().numpy().T:
            gifti_data_array = nib.gifti.gifti.GiftiDataArray(data=darray, intent='NIFTI_INTENT_LABEL')
            feature.add_gifti_data_array(gifti_data_array)
            print(sum(darray) / len(darray))
        nib.save(feature, os.path.join('selection', data.id + '_left.shape.gii'))


0.3159050710412253
0.20948900853272173
0.42358452059320884
0.6075433924857626
0.6763396761841824
0.20141304981632296
0.1989543042624735
0.6563490058115804
0.23016968259830123
0.19404653151664755
0.6469319131567183
0.19336624618554296
0.2192753989387549
0.32662910115423066
0.21629555575752069
0.41898573116422794
0.5927019903662637
0.6687539761883122
0.20852494774152505
0.20318095064982278
0.6483140961555939
0.2368444969553758
0.20145414886849042
0.6326001999454695
0.19931836771789513
0.22564755066799963
0.35042054518870247
0.24155652387914936
0.40092939762873314
0.5562053939806303
0.6499124165786007
0.23211778160603386
0.22920798529177586
0.625519347974
0.2617948086917498
0.22330153307178946
0.5952777335437266
0.2222592179741448
0.2523271132214775
0.3410951379305725
0.22281916960302256
0.41121933573600933
0.5759857447179089
0.663044788917322
0.21397660740363617
0.21156499953107624
0.6368654456785326
0.24625195943139647
0.20776001822103726
0.6126421843808196
0.20718391189592572
0.2333498