In [1]:
import os
import torch
from torch import nn
from torch.utils.data import Subset
from torch.utils.data import DataLoader

from miniMTL.datasets import caseControlDataset
from miniMTL.models import head0, encoder100
from miniMTL.training import Trainer

from miniMTL.hps import HPSModel

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [2]:
p_pheno = '/Users/harveyaa/Documents/masters/data/pheno_26-01-22.csv'
p_ids = '/Users/harveyaa/Documents/masters/neuropsych_mtl/datasets/cv_folds/intrasite/'
p_conn = os.path.join('/Users/harveyaa/Documents/masters/data/','connectomes')

p_out = ''

In [3]:
# Create datasets
print('Creating datasets...')
#cases = ['ASD','BIP','SZ','DEL22q11_2','DUP22q11_2','DEL1q21_1','DUP1q21_1','DEL16p11_2','DUP16p11_2']
cases = ['ASD','BIP','SZ']
#cases = ['DEL22q11_2']
#cases = ['DEL1q21_1','DUP1q21_1']
data = []
for case in cases:
    print(case)
    data.append(caseControlDataset(case,p_pheno,id_path=p_ids,conn_path=p_conn,
                                    type='conn',strategy='balanced',format=2))
print('Done!\n')

Creating datasets...
ASD


  pheno = pd.read_csv(pheno_path,index_col=0)


BIP


  pheno = pd.read_csv(pheno_path,index_col=0)


SZ
Done!



  pheno = pd.read_csv(pheno_path,index_col=0)


In [5]:
# Split data & create loaders & loss fns
bs = 16

loss_fns = {}
trainloaders = {}
testloaders = {}
decoders = {}
for d, case in zip(data,cases):
    train_idx, test_idx = d.split_data(random=False,fold=4)
    train_d = Subset(d,train_idx)
    test_d = Subset(d,test_idx)
    trainloaders[case] = DataLoader(train_d, batch_size=bs, shuffle=True)
    testloaders[case] = DataLoader(test_d, batch_size=bs, shuffle=True)
    loss_fns[case] = nn.CrossEntropyLoss()
    decoders[case] = eval(f'head{0}().double()')
    
hps = HPSModel(eval(f'encoder{100}().double()'),
                decoders,
                loss_fns)

Initialized HPSModel using: cpu.



In [6]:
# Create optimizer & trainer
optim_hps = torch.optim.Adam(hps.parameters(), lr=1e-3)
trainer_hps = Trainer(optim_hps,log_dir=p_out)

In [7]:
# Train model
trainer_hps.fit(hps,trainloaders,testloaders,num_epochs=5)

Epoch 0:   0%|          | 0/88 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 88/88 [02:23<00:00,  1.63s/it]
Epoch 1:  99%|█████████▉| 87/88 [02:20<00:01,  1.62s/it]


KeyboardInterrupt: 

In [8]:
# Evaluate at end
metrics_hps = hps.score(testloaders)
for key in metrics_hps.keys():
    print()
    print(key)
    print('Accuracy: ', metrics_hps[key]['accuracy'])
    print('Loss: ', metrics_hps[key]['loss'])
print()


ASD
Accuracy:  46.27659574468085
Loss:  0.04457285473529006

BIP
Accuracy:  67.74193548387096
Loss:  0.03442825679932411

SZ
Accuracy:  67.71653543307087
Loss:  0.03822824655208791



In [8]:
metrics_hps

{'ASD': {'accuracy': 57.97872340425532, 'loss': 0.1369718253028779},
 'BIP': {'accuracy': 77.41935483870968, 'loss': 0.06530408812576415},
 'SZ': {'accuracy': 75.59055118110236, 'loss': 0.09313200944603715},
 'DEL22q11_2': {'accuracy': 76.47058823529412, 'loss': 0.06881384382984729},
 'DUP22q11_2': {'accuracy': 50.0, 'loss': 0.22229325533492122},
 'DEL1q21_1': {'accuracy': 30.0, 'loss': 0.4331672511722159},
 'DUP1q21_1': {'accuracy': 42.857142857142854, 'loss': 0.44524299089875496},
 'DEL16p11_2': {'accuracy': 41.66666666666667, 'loss': 0.3415105960658535},
 'DUP16p11_2': {'accuracy': 50.0, 'loss': 0.3076967009626213}}