In [26]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

from datasets import *
from models import *
from util import *

## Load data

In [18]:
p_pheno = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/pheno_01-12-21.csv'
p_conn = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes_01-12-21.csv'

#data = ukbbSexDataset(p_pheno,p_conn)
data = caseControlDataset('SZ',p_pheno,p_conn)

  """


In [19]:
train_data, test_data = split_data(data)

trainloader = DataLoader(train_data, batch_size=16, shuffle=True)
testloader = DataLoader(test_data, batch_size=16, shuffle=True)

In [20]:
class ukbbSexEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # in_channels, out_channels
        self.conv = nn.Conv2d(1, 256, (40,1))
        self.batch0 = nn.BatchNorm2d(256)

        self.fc1 = nn.Linear(256*52, 64)
        self.batch1 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 64)
        self.batch2 = nn.BatchNorm1d(64)
        #self.fc3 = nn.Linear(64,2)

        self.dropout = nn.Dropout()
        #self.softmax = nn.Softmax(dim=1)
    
    def forward(self,x):
        x = self.conv(torch.unsqueeze(x,dim=1))
        x = self.batch0(x)
        x = x.view(x.size()[0],-1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.batch1(x)
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.batch2(x)
        #x = self.softmax(self.fc3(x))
        return x

class headModule(nn.Module):
    def __init__(self):
        super().__init__()
        # in_channels, out_channels
        #self.conv = nn.Conv2d(1, 256, (40,1))
        #self.batch0 = nn.BatchNorm2d(256)

        #self.fc1 = nn.Linear(256*52, 64)
        #self.batch1 = nn.BatchNorm1d(64)
        #self.fc2 = nn.Linear(64, 64)
        #self.batch2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64,2)

        self.dropout = nn.Dropout()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self,x):
        #x = self.conv(torch.unsqueeze(x,dim=1))
        # = self.batch0(x)
        #x = x.view(x.size()[0],-1)
        #x = self.dropout(F.relu(self.fc1(x)))
        #x = self.batch1(x)
        #x = self.dropout(F.relu(self.fc2(x)))
        #x = self.batch2(x)
        x = self.softmax(self.fc3(x))
        return x

In [21]:
class HPSModel(nn.Module):
    """ Multi-input HPS."""
    def __init__(self,encoder,decoders):
        super().__init__()
        self.encoder = encoder
        self.decoders = decoders
    
    def forward(self,x,task):
        x = self.encoder(x)
        x = self.decoders[task](x)
        return x

In [52]:
def trainMTL(dataloaders, model, loss_fn, optimizer,device='cpu'):
    tasks = list(dataloaders.keys())
    task = tasks[0]
    dataloader = dataloaders[task]

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
            
        # Compute prediction error
        pred = model(X,task)
        loss = loss_fn(pred, y)
            
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# From torch fundamentals course
def testMTL(dataloaders, model,loss_fn,device='cpu'):
    tasks = list(dataloaders.keys())
    task = tasks[0]
    dataloader = dataloaders[task]

    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X,task)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [53]:
model = HPSModel(ukbbSexEncoder().double(),{'SZ':headModule().double()})

In [54]:
trainloaders = {'SZ':trainloader}
testloaders = {'SZ':testloader}

In [55]:
#model = ukbbSex().double()

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

device = 'cpu'

In [57]:
epochs = 20
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    trainMTL(trainloaders, model, loss_fn, optimizer)
    testMTL(testloaders, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 0.651178  [    0/  510]
Test Error: 
 Accuracy: 56.2%, Avg loss: 0.043932 

Epoch 2
-------------------------------
loss: 0.554484  [    0/  510]
Test Error: 
 Accuracy: 71.1%, Avg loss: 0.036600 

Epoch 3
-------------------------------
loss: 0.625243  [    0/  510]
Test Error: 
 Accuracy: 72.7%, Avg loss: 0.035493 

Epoch 4
-------------------------------
loss: 0.425812  [    0/  510]
Test Error: 
 Accuracy: 71.9%, Avg loss: 0.035900 

Epoch 5
-------------------------------
loss: 0.352999  [    0/  510]
Test Error: 
 Accuracy: 71.9%, Avg loss: 0.034558 

Epoch 6
-------------------------------
loss: 0.421231  [    0/  510]
Test Error: 
 Accuracy: 71.9%, Avg loss: 0.035162 

Epoch 7
-------------------------------
loss: 0.363316  [    0/  510]
Test Error: 
 Accuracy: 73.4%, Avg loss: 0.034987 

Epoch 8
-------------------------------
loss: 0.319168  [    0/  510]
Test Error: 
 Accuracy: 71.9%, Avg loss: 0.036523 

Epoch 9
----------------