In [1]:
import os
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

from miniMTL.datasets import *
from miniMTL.models import *
from miniMTL.util import *
from miniMTL.training import *
from miniMTL.hps import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
SEED = 123

def get_connectome(sub_id,conn_path,format):
    conn = np.load(conn_path.format(sub_id))
    mask = np.tri(64,dtype=bool)

    if format == 0:
        return conn[mask]
    elif format == 1:
        np.random.seed(SEED)
        return conn[mask][np.random.permutation(2080)].reshape(40,52)
    elif format == 2:
        torch.unsqueeze(torch.from_numpy(conn),0)
    else:
        raise ValueError('Connectome format (int encoded) must be in [0,1,2].')

def get_concat(conf,sub_id,conn_path,format):
    conn = np.load(conn_path.format(sub_id))
    mask = np.tri(64,dtype=bool)
    concat = np.concatenate([conn[mask],conf])

    if format == 0:
        return conn[mask]
    elif format == 1:
        np.random.seed(SEED)
        return np.pad(concat[np.random.permutation(2080+58)],2).reshape(42,51)
    else:
        raise ValueError('Concatenated conn + conf format (int encoded) must be in [0,1].')

class caseControlDataset(Dataset):
    def __init__(self,case,pheno_path,id_path=None,conn_path=None,type='concat',strategy='balanced',format=0):
        assert type in ['concat','conn','conf']
        assert strategy in ['balanced','stratified']
        self.name = case
        self.type = type
        self.strategy = strategy
        self.format = format
        if conn_path:
            self.conn_path = os.path.join(conn_path,'connectome_{}_cambridge64.npy')
        pheno = pd.read_csv(pheno_path,index_col=0)

        # Select subjects
        if self.strategy == 'balanced':
            self.ids = pd.read_csv(os.path.join(id_path,f"{case}.csv"),index_col=0)
            self.idx = self.ids.index
        elif self.strategy == 'stratified':
            control = 'CON_IPC' if case in ['SZ','BIP','ASD'] else 'non_carriers'
            subject_mask = strat_mask(pheno,case,control)
            self.idx = pheno[subject_mask].index
        
        # Get confounds if needed
        if self.type != 'conn':
            confounds = ['AGE','SEX','SITE','mean_conn', 'FD_scrubbed']
            p = pd.get_dummies(pheno[confounds],['SEX','SITE'])
            cols = ['AGE','mean_conn', 'FD_scrubbed'] + [c for c in p.columns if 'SEX' in c ] + [c for c in p.columns if 'SITE' in c ]
            p = p[p.index.isin(self.idx)]
            self.X_conf = p[cols]
        
        # Get labels
        self.Y = pheno.loc[self.idx][case].values.astype(int)

        # Cleanup
        del pheno
        del p

    def __len__(self):
        return len(self.Y)
        
    def __getitem__(self,idx):
        if self.type == 'conn':
            conn = get_connectome(self.idx[idx], self.conn_path,self.format)
            return conn, {self.name:self.Y[idx]}
        elif self.type == 'conf':
            if self.format != 0:
                raise Warning('Confound format can only be 0 (vector).')
            return self.X_conf.iloc[idx].values, {self.name:self.Y[idx]}
        elif self.type == 'concat':
            concat = get_concat(self.X_conf.iloc[idx],self.idx[idx], self.conn_path,self.format)
            return concat, {self.name:self.Y[idx]}
    
    def split_data(self,random=True,fold=0,splits=(0.8,0.2),seed=None):
        if not random:
            if self.strategy != 'balanced':
                raise ValueError("Balanced CV folds only available for balanced dataset (set strategy to 'balanced').")
            rr = np.array(range(len(self.idx)))
            train_idx = rr[self.ids[f"fold_{fold}"] == 0]
            test_idx = rr[self.ids[f"fold_{fold}"] == 1]
        else:
            train_idx, test_idx, _, _ = train_test_split(range(len(self.idx)),
                                                    self.Y,
                                                    stratify=self.Y,
                                                    test_size=splits[1],
                                                    random_state=seed)
        return train_idx, test_idx

## Load data

In [22]:
p_pheno = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/pheno_01-12-21.csv'
p_ids = '/home/harveyaa/Documents/masters/neuropsych_mtl/datasets/cv_folds/hybrid'
p_conn = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes/'

cases = ['SZ',
        #'BIP',
        #'ASD',
        'DEL22q11_2',
        #'DEL16p11_2',
        #'DUP16p11_2',
        #'DUP22q11_2',
        #'DEL1q21_1',
        #'DUP1q21_1'
        ]

# MTL

In [23]:
# Create datasets
print('Creating datasets...')
data = []
for case in cases:
    print(case)
    #data.append(balancedCaseControlDataset(case,p_ids,p_conn,format=0))
    data.append(concatDataset(case,p_pheno,p_conn,format=0))
print('Done!\n')

Creating datasets...
SZ


  if self.run_code(code, result):


DEL22q11_2
Done!



  if self.run_code(code, result):


In [24]:
X,y_dict = data[0].__getitem__(0)
X.shape

(2138,)

In [25]:
# BALANCED TEST SETS

#batch_size=1
#head=3
#encoder=3
#fold=4
#
#loss_fns = {}
#trainloaders = {}
#testloaders = {}
#decoders = {}
#for d, case in zip(data,cases):
#    train_idx, test_idx = d.split_data(fold)
#    train_d = Subset(d,train_idx)
#    test_d = Subset(d,test_idx)
#    trainloaders[case] = DataLoader(train_d, batch_size=batch_size, shuffle=True)
#    testloaders[case] = DataLoader(test_d, batch_size=batch_size, shuffle=True)
#    loss_fns[case] = nn.CrossEntropyLoss()
#    decoders[case] = eval(f'head{head}().double()')

In [26]:
# RANDOM TEST SETS

batch_size=16
head=5
encoder=5

# Split data & create loaders & loss fns
loss_fns = {}
trainloaders = {}
testloaders = {}
decoders = {}
for d, case in zip(data,cases):
    train_d, test_d = split_data(d)
    trainloaders[case] = DataLoader(train_d, batch_size=batch_size, shuffle=True)
    testloaders[case] = DataLoader(test_d, batch_size=batch_size, shuffle=True)
    loss_fns[case] = nn.CrossEntropyLoss()
    decoders[case] = eval(f'head{head}().double()')
    #decoders[case] = head3().double()

In [27]:
# Create model
model = HPSModel(eval(f'encoder{encoder}().double()'),
                decoders,
                loss_fns)

Initialized HPSModel using: cpu.



In [28]:
log_dir = '/home/harveyaa/Documents/masters/neuropsych_mtl/tmp'
print(log_dir)

/home/harveyaa/Documents/masters/neuropsych_mtl/tmp


In [29]:
num_epochs=100
lr = 0.001

# Create optimizer & trainer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.05)

#trainer = Trainer(optimizer,lr_scheduler=scheduler,num_epochs=num_epochs,log_dir=log_dir)
trainer = Trainer(optimizer,num_epochs=num_epochs,log_dir=log_dir)

In [30]:
# Train model
trainer.fit(model,trainloaders,testloaders)

Epoch 0: 100%|██████████| 34/34 [00:03<00:00, 10.16it/s]
Epoch 1: 100%|██████████| 34/34 [00:03<00:00, 10.58it/s]
Epoch 2: 100%|██████████| 34/34 [00:03<00:00, 10.43it/s]
Epoch 3: 100%|██████████| 34/34 [00:03<00:00,  8.95it/s]
Epoch 4: 100%|██████████| 34/34 [00:04<00:00,  7.93it/s]
Epoch 5: 100%|██████████| 34/34 [00:04<00:00,  7.54it/s]
Epoch 6: 100%|██████████| 34/34 [00:04<00:00,  8.28it/s]
Epoch 7: 100%|██████████| 34/34 [00:04<00:00,  8.16it/s]
Epoch 8: 100%|██████████| 34/34 [00:04<00:00,  7.65it/s]
Epoch 9: 100%|██████████| 34/34 [00:04<00:00,  8.44it/s]
Epoch 10: 100%|██████████| 34/34 [00:04<00:00,  8.11it/s]
Epoch 11: 100%|██████████| 34/34 [00:04<00:00,  7.87it/s]
Epoch 12: 100%|██████████| 34/34 [00:04<00:00,  7.90it/s]
Epoch 13: 100%|██████████| 34/34 [00:04<00:00,  8.01it/s]
Epoch 14: 100%|██████████| 34/34 [00:04<00:00,  7.52it/s]
Epoch 15: 100%|██████████| 34/34 [00:04<00:00,  8.06it/s]
Epoch 16: 100%|██████████| 34/34 [00:04<00:00,  7.95it/s]
Epoch 17: 100%|█████████

In [11]:
# BALANCED
# SZ 51.59
# BIP 50.0
# ASD 47.3

# RANDOM
# SZ 58.59
# BIP 71.875
# ASD 49.74

# Evaluate at end
metrics = model.score(testloaders)
for key in metrics.keys():
    print()
    print(key)
    print('Accuracy: ', metrics[key]['accuracy'])
    print('Loss: ', metrics[key]['loss'])
print()


SZ
Accuracy:  53.98230088495575
Loss:  0.05245458040815248



In [23]:
np.pad(np.random.randn((2080+58)),2).shape

(2142,)