In [21]:
import os
import torch
from torch.utils.data import Subset
from torch.utils.data import DataLoader

from miniMTL.datasets import caseControlDataset
from miniMTL.models import *
from miniMTL.training import Trainer
from miniMTL.mps import MPSModel

## Load data

In [28]:
p_pheno = '/Users/harveyaa/Documents/masters/data/pheno_26-01-22.csv'
p_ids = '/Users/harveyaa/Documents/masters/neuropsych_mtl/datasets/cv_folds/intrasite/'
p_conn = os.path.join('/Users/harveyaa/Documents/masters/data/','connectomes')

p_out = ''

In [29]:
# Create datasets
print('Creating datasets...')
cases = ['ASD','BIP','SZ']
data = []
for case in cases:
    print(case)
    data.append(caseControlDataset(case,p_pheno,id_path=p_ids,conn_path=p_conn,
                                    type='conn',strategy='balanced',format=0))
print('Done!\n')

Creating datasets...
ASD


  pheno = pd.read_csv(pheno_path,index_col=0)


BIP


  pheno = pd.read_csv(pheno_path,index_col=0)


SZ
Done!



  pheno = pd.read_csv(pheno_path,index_col=0)


In [30]:
# Split data & create loaders & loss fns
bs = 32

loss_fns = {}
trainloaders = {}
testloaders = {}
preencoders = {}
decoders = {}
for d, case in zip(data,cases):
    train_idx, test_idx = d.split_data(random=False,fold=0)
    train_d = Subset(d,train_idx)
    test_d = Subset(d,test_idx)
    trainloaders[case] = DataLoader(train_d, batch_size=bs, shuffle=True)
    testloaders[case] = DataLoader(test_d, batch_size=bs, shuffle=True)
    loss_fns[case] = nn.CrossEntropyLoss()
    preencoders[case] = eval(f'preencoder{3}().double()')
    decoders[case] = eval(f'head{3}().double()')

In [31]:
# Create model
model = MPSModel(preencoders,
                eval(f'encoder{33}().double()'),
                decoders,
                loss_fns)

Initialized HPSModel using: cpu.



In [32]:
# Create optimizer & trainer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
trainer = Trainer(optimizer,log_dir=p_out)

In [33]:
# Train model
trainer.fit(model,trainloaders,testloaders,num_epochs=50)

Epoch 0: 100%|██████████| 44/44 [00:03<00:00, 13.50it/s]
Epoch 1: 100%|██████████| 44/44 [00:02<00:00, 18.16it/s]
Epoch 2: 100%|██████████| 44/44 [00:02<00:00, 18.10it/s]
Epoch 3: 100%|██████████| 44/44 [00:02<00:00, 17.89it/s]
Epoch 4: 100%|██████████| 44/44 [00:02<00:00, 17.93it/s]
Epoch 5: 100%|██████████| 44/44 [00:02<00:00, 18.11it/s]
Epoch 6: 100%|██████████| 44/44 [00:02<00:00, 18.08it/s]
Epoch 7: 100%|██████████| 44/44 [00:02<00:00, 18.07it/s]
Epoch 8: 100%|██████████| 44/44 [00:02<00:00, 18.19it/s]
Epoch 9: 100%|██████████| 44/44 [00:02<00:00, 18.07it/s]
Epoch 10: 100%|██████████| 44/44 [00:02<00:00, 18.16it/s]
Epoch 11: 100%|██████████| 44/44 [00:02<00:00, 18.17it/s]
Epoch 12: 100%|██████████| 44/44 [00:02<00:00, 18.20it/s]
Epoch 13: 100%|██████████| 44/44 [00:02<00:00, 18.12it/s]
Epoch 14: 100%|██████████| 44/44 [00:02<00:00, 18.21it/s]
Epoch 15: 100%|██████████| 44/44 [00:02<00:00, 18.11it/s]
Epoch 16: 100%|██████████| 44/44 [00:02<00:00, 18.23it/s]
Epoch 17: 100%|█████████

In [34]:
# Evaluate at end
metrics = model.score(testloaders)
for key in metrics.keys():
    print()
    print(key)
    print('Accuracy: ', metrics[key]['accuracy'])
    print('Loss: ', metrics[key]['loss'])
print()


ASD
Accuracy:  57.14285714285714
Loss:  0.02342739164122072

BIP
Accuracy:  68.75
Loss:  0.019178310227088703

SZ
Accuracy:  58.59375
Loss:  0.026285603406605355

