In [1]:
import os
import tempfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

from miniMTL.datasets import *
#from miniMTL.models import *
from miniMTL.util import *
from miniMTL.training import *
from miniMTL.hps import *

In [2]:
class encoder3(nn.Module):
    def __init__(self,dim=58,width=10):
        super().__init__()
        # in_channels, out_channels
        self.fc1 = nn.Linear(dim, width)
        #self.batch1 = nn.BatchNorm1d(width)
        self.fc2 = nn.Linear(width, width)
        #self.batch2 = nn.BatchNorm1d(width)

        self.dropout = nn.Dropout()
    
    def forward(self,x):
        x = self.dropout(F.relu(self.fc1(x)))
        #x = self.batch1(x)
        x = self.dropout(F.relu(self.fc2(x)))
        #x = self.batch2(x)
        return x


class head3(nn.Module):
    def __init__(self,width=10):
        super().__init__()
        self.fc3 = nn.Linear(width,width)
        #self.batch3 = nn.BatchNorm1d(width)
        self.fc4 = nn.Linear(width,2)

        self.dropout = nn.Dropout()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self,x):
        x = self.dropout(F.relu(self.fc3(x)))
        #x = self.batch3(x)
        x = self.dropout(F.relu(self.fc4(x)))
        x = self.softmax(x)
        return x

## Load data

In [3]:
p_pheno = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/pheno_01-12-21.csv'
p_ids = '/home/harveyaa/Documents/masters/MTL/conf_balancing/hybrid'
p_conn = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes/'

cases = [#'SZ',
        #'BIP',
        #'ASD',
        'DEL22q11_2',
        #'DEL16p11_2',
        #'DUP16p11_2',
        #'DUP22q11_2',
        #'DEL1q21_1',
        #'DUP1q21_1'
        ]

In [4]:
# Create datasets
print('Creating datasets...')
data = []
for case in cases:
    print(case)
    data.append(balancedCaseControlDataset(case,p_ids,p_conn,format=0))
print('Done!\n')

Creating datasets...
DEL22q11_2
Done!



In [5]:
# BALANCED TEST SETS

#batch_size=16
#head=3
#encoder=3
#fold=0
#
#loss_fns = {}
#trainloaders = {}
#testloaders = {}
#decoders = {}
#for d, case in zip(data,cases):
#    train_idx, test_idx = d.split_data(fold)
#    train_d = Subset(d,train_idx)
#    test_d = Subset(d,test_idx)
#    trainloaders[case] = DataLoader(train_d, batch_size=batch_size, shuffle=True)
#    testloaders[case] = DataLoader(test_d, batch_size=batch_size, shuffle=True)
#    loss_fns[case] = nn.CrossEntropyLoss()
#    #decoders[case] = eval(f'head{head}().double()')
#    decoders[case] = head3(width=100).double()

In [24]:
# RANDOM TEST SETS

batch_size=16
head=3
encoder=3
fold=0

# Split data & create loaders & loss fns
loss_fns = {}
trainloaders = {}
testloaders = {}
decoders = {}
for d, case in zip(data,cases):
    train_d, test_d = split_data(d)
    trainloaders[case] = DataLoader(train_d, batch_size=batch_size, shuffle=True)
    testloaders[case] = DataLoader(test_d, batch_size=batch_size, shuffle=True)
    loss_fns[case] = nn.CrossEntropyLoss()
    #decoders[case] = eval(f'head{head}().double()')
    decoders[case] = head3(width=64).double()

In [25]:
# Create model
model = HPSModel(#eval(f'encoder{encoder}().double()'),
                encoder3(dim=2080,width=64).double(),
                decoders,
                loss_fns)

Initialized HPSModel using: cpu.



In [26]:
x,y = next(iter(trainloaders['DEL22q11_2']))
x

tensor([[ 0.1694,  0.7383,  0.2318,  ..., -0.0056,  0.1872,  0.1571],
        [ 0.3065,  0.7976,  0.3194,  ...,  0.7127,  0.6301,  0.2740],
        [ 0.2905,  0.9014,  0.1903,  ..., -0.1729,  0.5195,  0.2676],
        ...,
        [ 0.2226,  0.5564,  0.2904,  ...,  0.8331,  0.5270,  0.1661],
        [ 0.3304,  0.9278,  0.3499,  ...,  0.5418,  1.0939,  0.3128],
        [ 0.3647,  0.8371,  0.4023,  ...,  0.4919,  0.5511,  0.2485]],
       dtype=torch.float64)

In [27]:
x_0 = model.encoder(x)
model.decoders['DEL22q11_2'](x_0)

tensor([[0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5301, 0.4699],
        [0.5458, 0.4542],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5284, 0.4716],
        [0.4694, 0.5306],
        [0.5291, 0.4709],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5190, 0.4810],
        [0.5670, 0.4330],
        [0.5000, 0.5000]], dtype=torch.float64, grad_fn=<SoftmaxBackward>)

In [28]:
model(x,['DEL22q11_2'])

{'DEL22q11_2': tensor([[0.5055, 0.4945],
         [0.5024, 0.4976],
         [0.5000, 0.5000],
         [0.5614, 0.4386],
         [0.5000, 0.5000],
         [0.5000, 0.5000],
         [0.5377, 0.4623],
         [0.5000, 0.5000],
         [0.5294, 0.4706],
         [0.5006, 0.4994],
         [0.5000, 0.5000],
         [0.5121, 0.4879],
         [0.5000, 0.5000],
         [0.5256, 0.4744],
         [0.5000, 0.5000],
         [0.5743, 0.4257]], dtype=torch.float64, grad_fn=<SoftmaxBackward>)}

In [29]:
log_dir = tempfile.mkdtemp()
print(log_dir)

/tmp/tmpdp65c9ai


In [30]:
num_epochs=50
lr = 0.001

# Create optimizer & trainer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
trainer = Trainer(optimizer,num_epochs=num_epochs,log_dir=log_dir)

In [31]:
# Train model
trainer.fit(model,trainloaders,testloaders)

Epoch 0: 100%|██████████| 5/5 [00:00<00:00, 60.58it/s]
Epoch 1: 100%|██████████| 5/5 [00:00<00:00, 68.76it/s]
Epoch 2: 100%|██████████| 5/5 [00:00<00:00, 68.38it/s]
Epoch 3: 100%|██████████| 5/5 [00:00<00:00, 72.72it/s]
Epoch 4: 100%|██████████| 5/5 [00:00<00:00, 66.74it/s]
Epoch 5: 100%|██████████| 5/5 [00:00<00:00, 67.98it/s]
Epoch 6: 100%|██████████| 5/5 [00:00<00:00, 71.26it/s]
Epoch 7: 100%|██████████| 5/5 [00:00<00:00, 69.90it/s]
Epoch 8: 100%|██████████| 5/5 [00:00<00:00, 68.36it/s]
Epoch 9: 100%|██████████| 5/5 [00:00<00:00, 67.03it/s]
Epoch 10: 100%|██████████| 5/5 [00:00<00:00, 69.50it/s]
Epoch 11: 100%|██████████| 5/5 [00:00<00:00, 70.22it/s]
Epoch 12: 100%|██████████| 5/5 [00:00<00:00, 74.46it/s]
Epoch 13: 100%|██████████| 5/5 [00:00<00:00, 67.55it/s]
Epoch 14: 100%|██████████| 5/5 [00:00<00:00, 66.27it/s]
Epoch 15: 100%|██████████| 5/5 [00:00<00:00, 68.56it/s]
Epoch 16: 100%|██████████| 5/5 [00:00<00:00, 69.07it/s]
Epoch 17: 100%|██████████| 5/5 [00:00<00:00, 69.38it/s]
Ep

In [32]:
# BALANCED
# SZ 51.59
# BIP 50.0
# ASD 47.3

# RANDOM
# SZ 58.59
# BIP 71.875
# ASD 49.74

# Evaluate at end
metrics = model.score(testloaders)
for key in metrics.keys():
    print()
    print(key)
    print('Accuracy: ', metrics[key]['accuracy'])
    print('Loss: ', metrics[key]['loss'])
print()


DEL22q11_2
Accuracy:  50.0
Loss:  0.07701635339554948



In [33]:
import shutil

shutil.rmtree(log_dir)