In [1]:
import os
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

from miniMTL.datasets import *
#from miniMTL.models import *
from miniMTL.util import *
from miniMTL.training import *
from miniMTL.hps import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class encoder3(nn.Module):
    """ Simple MLP for connectome 2080 vec."""
    def __init__(self):
        super().__init__()
        # in_channels, out_channels
        self.fc1 = nn.Linear(2080,64)
        self.batch1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 16)
        self.batch2 = nn.BatchNorm1d(16)

        self.dropout = nn.Dropout()
        self.leaky = nn.LeakyReLU()
    
    def forward(self,x):
        #x = self.dropout(self.leaky(self.fc1(x)))
        x = self.fc1(x)
        #x = self.batch1(x)
        #x = self.dropout(self.leaky(self.fc2(x)))
        #x = self.batch2(x)
        return x


class head3(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc3 = nn.Linear(64,2)
        #self.batch3 = nn.BatchNorm1d(8)
        #self.fc4 = nn.Linear(8,2)

        self.dropout = nn.Dropout()
        self.leaky = nn.LeakyReLU()
    
    def forward(self,x):
        #x = self.dropout(self.leaky(self.fc3(x)))
        x = self.fc3(x)
        #x = self.batch3(x)
        #x = self.dropout(self.leaky(self.fc4(x)))
        return x

## Load data

In [3]:
p_pheno = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/pheno_01-12-21.csv'
p_ids = '/home/harveyaa/Documents/masters/neuropsych_mtl/datasets/cv_folds/hybrid'
p_conn = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes/'

cases = [#'SZ',
        #'BIP',
        #'ASD',
        'DEL22q11_2',
        #'DEL16p11_2',
        #'DUP16p11_2',
        #'DUP22q11_2',
        #'DEL1q21_1',
        #'DUP1q21_1'
        ]

# Investigate 22q

In [68]:
conn = pd.read_csv('/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes_01-12-21.csv',index_col=0)

In [73]:
df_22q = pd.read_csv(os.path.join(p_ids,'DEL22q11_2.csv'),index_col=0)

In [75]:
conn = conn[conn.index.isin(df_22q.index)]

In [99]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

clf = SVC(C=100)
fold = 4

train_idx = df_22q[df_22q[f'fold_{fold}'] == 0].index
test_idx = df_22q[df_22q[f'fold_{fold}'] == 1].index

X_train = conn[conn.index.isin(train_idx)].values
X_test = conn[conn.index.isin(test_idx)].values
y_train = df_22q[df_22q.index.isin(train_idx)]['DEL22q11_2'].values.reshape(-1,1)
y_test = df_22q[df_22q.index.isin(test_idx)]['DEL22q11_2'].values.reshape(-1,1)

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

clf.fit(X_train,y_train)
pred = clf.predict(X_test)
accuracy_score(y_test,pred)

(58, 2080)
(28, 2080)
(58, 1)
(28, 1)


  return f(*args, **kwargs)


0.6071428571428571

In [100]:
pheno = pd.read_csv(p_pheno,index_col=0)

  interactivity=interactivity, compiler=compiler, result=result)


In [102]:
conf = ['AGE',
            'SEX',
            'SITE',
            'mean_conn',
            'FD_scrubbed']
case = 'DEL22q11_2'

# PLOT TEST SET
#fig, ax = plt.subplots(len(conf),5,figsize=(15,12))
#for i,c in enumerate(conf):
#        for fold in range(5):
#                #ids = pd.read_csv(os.path.join(temp_dir.name,f"{case}_test_set_{fold}.txt"),header=None)
#                ids = df_22q[df_22q[f'fold_{fold}']==1].index
#                
#                sns.histplot(x=c,data=pheno[pheno.index.isin(ids)],hue=case,bins=25,ax=ax[i,fold])
#                if i == 0:
#                        ax[i,fold].set_title(f'fold {fold}')
#                if fold == 0:
#                        ax[i,fold].set_xlabel('')
#                        ax[i,fold].set_ylabel(c)
#                else:
#                        ax[i,fold].set_xlabel('')
#                        ax[i,fold].set_ylabel('')
#                        ax[i,fold].set_yticklabels([])
#plt.tight_layout()
#plt.subplots_adjust(wspace=0.1,hspace=0.2)
#plt.savefig(os.path.join(args.p_out,f"{case}_test.png"),dpi=300

# PLOT TRAIN SET
#fig, ax = plt.subplots(len(conf),5,figsize=(15,12))
#for i,c in enumerate(conf):
#        for fold in range(5):
#                ids_train = ids = df_22q[df_22q[f'fold_{fold}']==0].index
#                
#                sns.histplot(x=c,data=pheno[pheno.index.isin(ids_train)],hue=case,bins=25,ax=ax[i,fold])
#                if i == 0:
#                        ax[i,fold].set_title(f'fold {fold}')
#                if fold == 0:
#                        ax[i,fold].set_xlabel('')
#                        ax[i,fold].set_ylabel(c)
#                else:
#                        ax[i,fold].set_xlabel('')
#                        ax[i,fold].set_ylabel('')
#                        ax[i,fold].set_yticklabels([])
#plt.tight_layout()
#plt.subplots_adjust(wspace=0.1,hspace=0.2)

# MTL

In [4]:
# Create datasets
print('Creating datasets...')
data = []
for case in cases:
    print(case)
    data.append(balancedCaseControlDataset(case,p_ids,p_conn,format=0))
print('Done!\n')

Creating datasets...
DEL22q11_2
Done!



In [5]:
# BALANCED TEST SETS

batch_size=1
head=3
encoder=3
fold=4

loss_fns = {}
trainloaders = {}
testloaders = {}
decoders = {}
for d, case in zip(data,cases):
    train_idx, test_idx = d.split_data(fold)
    train_d = Subset(d,train_idx)
    test_d = Subset(d,test_idx)
    trainloaders[case] = DataLoader(train_d, batch_size=batch_size, shuffle=True)
    testloaders[case] = DataLoader(test_d, batch_size=batch_size, shuffle=True)
    loss_fns[case] = nn.CrossEntropyLoss()
    decoders[case] = eval(f'head{head}().double()')

In [6]:
# RANDOM TEST SETS

#batch_size=4
#head=0
#encoder=0
#
## Split data & create loaders & loss fns
#loss_fns = {}
#trainloaders = {}
#testloaders = {}
#decoders = {}
#for d, case in zip(data,cases):
#    train_d, test_d = split_data(d,seed=888)
#    trainloaders[case] = DataLoader(train_d, batch_size=batch_size, shuffle=True)
#    testloaders[case] = DataLoader(test_d, batch_size=batch_size, shuffle=True)
#    loss_fns[case] = nn.CrossEntropyLoss()
#    decoders[case] = eval(f'head{head}().double()')
#    #decoders[case] = head3().double()

In [7]:
# Create model
model = HPSModel(eval(f'encoder{encoder}().double()'),
                decoders,
                loss_fns)

Initialized HPSModel using: cpu.



In [8]:
log_dir = '/home/harveyaa/Documents/masters/neuropsych_mtl/tmp'
print(log_dir)

/home/harveyaa/Documents/masters/neuropsych_mtl/tmp


In [9]:
num_epochs=100
lr = 0.001

# Create optimizer & trainer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.05)
#trainer = Trainer(optimizer,lr_scheduler=scheduler,num_epochs=num_epochs,log_dir=log_dir)
trainer = Trainer(optimizer,num_epochs=num_epochs,log_dir=log_dir)

In [10]:
# Train model
trainer.fit(model,trainloaders,testloaders)

Epoch 0: 100%|██████████| 58/58 [00:01<00:00, 53.83it/s]
Epoch 1: 100%|██████████| 58/58 [00:01<00:00, 55.90it/s]
Epoch 2: 100%|██████████| 58/58 [00:01<00:00, 55.07it/s]
Epoch 3: 100%|██████████| 58/58 [00:01<00:00, 54.94it/s]
Epoch 4: 100%|██████████| 58/58 [00:01<00:00, 51.07it/s]
Epoch 5: 100%|██████████| 58/58 [00:01<00:00, 54.85it/s]
Epoch 6: 100%|██████████| 58/58 [00:01<00:00, 52.60it/s]
Epoch 7: 100%|██████████| 58/58 [00:01<00:00, 51.51it/s]
Epoch 8: 100%|██████████| 58/58 [00:01<00:00, 52.97it/s]
Epoch 9: 100%|██████████| 58/58 [00:01<00:00, 52.60it/s]
Epoch 10: 100%|██████████| 58/58 [00:01<00:00, 53.59it/s]
Epoch 11: 100%|██████████| 58/58 [00:01<00:00, 56.12it/s]
Epoch 12: 100%|██████████| 58/58 [00:01<00:00, 45.42it/s]
Epoch 13: 100%|██████████| 58/58 [00:01<00:00, 41.06it/s]
Epoch 14: 100%|██████████| 58/58 [00:01<00:00, 42.42it/s]
Epoch 15: 100%|██████████| 58/58 [00:01<00:00, 42.56it/s]
Epoch 16: 100%|██████████| 58/58 [00:01<00:00, 43.40it/s]
Epoch 17: 100%|█████████

In [20]:
# BALANCED
# SZ 51.59
# BIP 50.0
# ASD 47.3

# RANDOM
# SZ 58.59
# BIP 71.875
# ASD 49.74

# Evaluate at end
metrics = model.score(testloaders)
for key in metrics.keys():
    print()
    print(key)
    print('Accuracy: ', metrics[key]['accuracy'])
    print('Loss: ', metrics[key]['loss'])
print()


DEL22q11_2
Accuracy:  42.857142857142854
Loss:  0.10395262985841776

