In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

from miniMTL.datasets import *
from miniMTL.models import *
from miniMTL.util import *
from miniMTL.training import *

## Load data

In [2]:
p_pheno = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/pheno_01-12-21.csv'
p_conn = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes/'

cases = ['SZ',
        'BIP',
        'ASD',
        # 'DEL22q11_2',
        # 'DEL16p11_2',
        # 'DUP16p11_2',
        # 'DUP22q11_2'
        ]
data = []
for case in cases:
    data.append(caseControlDataset(case,p_pheno,p_conn))

  


In [3]:
loss_fns = {}
trainloaders = {}
testloaders = {}
decoders = {}
for d, case in zip(data,cases):
    train_d, test_d = split_data(d)

    trainloaders[case] = DataLoader(train_d, batch_size=16, shuffle=True)
    testloaders[case] = DataLoader(test_d, batch_size=16, shuffle=True)
    loss_fns[case] = nn.CrossEntropyLoss()
    decoders[case] = head().double()

In [4]:
model = HPSModel(encoder().double(),
                decoders,
                loss_fns)

Initialized HPSModel using: cpu.


In [5]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [6]:
trainer = Trainer(optimizer,100)

trainer.fit(model,trainloaders,testloaders)

Epoch 0: 100%|██████████| 88/88 [00:09<00:00,  9.10it/s]
Epoch 1: 100%|██████████| 88/88 [00:09<00:00,  9.30it/s]
Epoch 2: 100%|██████████| 88/88 [00:09<00:00,  9.67it/s]
Epoch 3: 100%|██████████| 88/88 [00:10<00:00,  8.11it/s]
Epoch 4: 100%|██████████| 88/88 [00:10<00:00,  8.14it/s]
Epoch 5: 100%|██████████| 88/88 [00:10<00:00,  8.07it/s]
Epoch 6: 100%|██████████| 88/88 [00:10<00:00,  8.10it/s]
Epoch 7: 100%|██████████| 88/88 [00:10<00:00,  8.14it/s]
Epoch 8: 100%|██████████| 88/88 [00:11<00:00,  7.96it/s]
Epoch 9: 100%|██████████| 88/88 [00:10<00:00,  8.17it/s]
Epoch 10: 100%|██████████| 88/88 [00:10<00:00,  8.20it/s]
Epoch 11: 100%|██████████| 88/88 [00:10<00:00,  8.10it/s]
Epoch 12: 100%|██████████| 88/88 [00:10<00:00,  8.15it/s]
Epoch 13: 100%|██████████| 88/88 [00:10<00:00,  8.28it/s]
Epoch 14: 100%|██████████| 88/88 [00:10<00:00,  8.37it/s]
Epoch 15: 100%|██████████| 88/88 [00:10<00:00,  8.35it/s]
Epoch 16: 100%|██████████| 88/88 [00:10<00:00,  8.56it/s]
Epoch 17: 100%|█████████

In [11]:
metrics = model.score(testloaders)
for key in metrics.keys():
    print(key)
    print('Accuracy: ', metrics[key]['accuracy'])
    print('Test loss: ', metrics[key]['test_loss'])
    print()

SZ
Accuracy:  75.0
Test loss:  0.03542333878302963

BIP
Accuracy:  75.0
Test loss:  0.03516487538279056

ASD
Accuracy:  56.08465608465608
Test loss:  0.047184154040325936

