In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

from datasets import *

## TODO
- Create torch datasets
- Create dataloaders
- Create simple CNN model (from paper)
    - Random reformat
    - Figure out dimensions of conv layers
    - Batchnorm (after conv layer only?)
    - 3 dense layers of w/ 64 units
    - Relu
    - Softmax
- Train model

## Load data

In [9]:
p_pheno = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/pheno_01-12-21.csv'
p_conn = '/home/harveyaa/Documents/fMRI/data/ukbb_9cohorts/connectomes_01-12-21.csv'

#data = ukbbSexDataset(p_pheno,p_conn)
data = caseControlDataset('SZ',p_pheno,p_conn)

  """


In [10]:
train, test = split_data(data)

trainloader = DataLoader(train, batch_size=16, shuffle=True)
testloader = DataLoader(test, batch_size=16, shuffle=True)

In [11]:
class ukbbSex(nn.Module):
    def __init__(self):
        super().__init__()
        # in_channels, out_channels
        self.conv = nn.Conv2d(1, 256, (40,1))
        self.batch0 = nn.BatchNorm2d(256)

        self.fc1 = nn.Linear(256*52, 64)
        self.batch1 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 64)
        self.batch2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64,2)

        self.dropout = nn.Dropout()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self,x):
        x = self.conv(torch.unsqueeze(x,dim=1))
        x = self.batch0(x)
        x = x.view(x.size()[0],-1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.batch1(x)
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.batch2(x)
        x = self.softmax(self.fc3(x))
        return x

In [12]:
model = ukbbSex().double()

loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

device = 'cpu'

In [13]:

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [14]:
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [15]:
epochs = 20
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(trainloader, model, loss_fn, optimizer)
    test(testloader, model)
print("Done!")

Epoch 1
-------------------------------
loss: 0.614644  [    0/  510]
Test Error: 
 Accuracy: 54.7%, Avg loss: 0.042169 

Epoch 2
-------------------------------
loss: 0.652532  [    0/  510]
Test Error: 
 Accuracy: 62.5%, Avg loss: 0.040722 

Epoch 3
-------------------------------
loss: 0.679898  [    0/  510]
Test Error: 
 Accuracy: 65.6%, Avg loss: 0.039045 

Epoch 4
-------------------------------
loss: 0.505100  [    0/  510]
Test Error: 
 Accuracy: 65.6%, Avg loss: 0.038314 

Epoch 5
-------------------------------
loss: 0.493570  [    0/  510]
Test Error: 
 Accuracy: 63.3%, Avg loss: 0.040883 

Epoch 6
-------------------------------
loss: 0.611237  [    0/  510]
Test Error: 
 Accuracy: 69.5%, Avg loss: 0.036259 

Epoch 7
-------------------------------
loss: 0.415378  [    0/  510]
Test Error: 
 Accuracy: 67.2%, Avg loss: 0.037392 

Epoch 8
-------------------------------
loss: 0.404512  [    0/  510]
Test Error: 
 Accuracy: 64.8%, Avg loss: 0.039477 

Epoch 9
----------------