In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import time
import datetime
import csv
import yaml
import configparser
import copy
import math 
import random
import warnings
warnings.filterwarnings('ignore')

In [3]:
import pandas as pd
import numpy as np
import sklearn
import matplotlib
# mpl.use('Agg')
%matplotlib inline
# %matplotlib notebook
from matplotlib import pyplot as plt 

matplotlib.rcParams['lines.linewidth'] = 1
matplotlib.rcParams['lines.markersize'] = 5

In [14]:
from tqdm.notebook import trange, tqdm
from ptflops import get_model_complexity_info

In [11]:
import abcdfusion
from abcdfusion import get_abcd, metrics, models

In [15]:
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Subset
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

# 1. Data Loading

In [4]:
ROOT_PATH = '../'
try: 
    os.chdir(ROOT_PATH)
    sys.path.insert(0, ROOT_PATH)
    print("Current working directory: {}".format(os.getcwd()))
except Exception:
    print("Directory: {} is not valid".format(ROOT_PATH))
    sys.exit(1)

Current working directory: /home/geshi/ABCDFusion


In [5]:
# load and parse config 
config_file = './configs.yaml'
with open(config_file, 'r') as infile:
    try:
        configs = yaml.safe_load(infile)
    except yaml.YAMLError as exc:
        sys.exit(exc)

In [12]:
auxiliary = configs['Auxiliary']
DATA_PATH = auxiliary['DATA_PATH']

OTHER_DATA = auxiliary['OTHER_DATA'] 
DTI_DATA = auxiliary['DTI_DATA'] 
RS_DATA = auxiliary['RS_DATA']
OUTCOME = auxiliary['OUTCOME']

In [21]:
dti_file = os.path.join(DATA_PATH, DTI_DATA)
rs_file = os.path.join(DATA_PATH, RS_DATA)
other_file = os.path.join(DATA_PATH, OTHER_DATA)
label_file = os.path.join(DATA_PATH, OUTCOME)

# 2. Define Dataloader

In [32]:
valid_size=0.2
batch_size=128
num_splits=5
num_workers=0
num_epochs=10
step_size=10
lr=0.001
seed=42

In [33]:
def create_datasets(dataset, batch_size, num_workers=0, valid_size=0.2, shuffle=True):
    # obtain training indices that will be used for validation
    num_data = len(dataset)
    if shuffle:
        indices = list(range(num_data))
    np.random.shuffle(indices)
    split = int(np.floor(valid_size * num_data))
    valid_idx, train_idx = indices[:split], indices[split:]

    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    # load training data in batches
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              num_workers=num_workers)
    
    # load validation data in batches
    valid_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              sampler=valid_sampler,
                              num_workers=num_workers)
    
    return train_loader, valid_loader

In [34]:
abcd_dataset = get_abcd(dti_file, rs_file, other_file, label_file)
train_loader, valid_loader = create_datasets(abcd_dataset, batch_size, num_workers, valid_size)

In [35]:
dti, rs, other, y = next(iter(valid_loader))
print('dti: ', dti.shape, 'rs fmri: ', rs.shape, 'other data: ', other.shape, 'label: ', y.shape)

dti:  torch.Size([128, 31]) rs fmri:  torch.Size([128, 270]) other data:  torch.Size([128, 7]) label:  torch.Size([128, 1])


# 3. Define Train

In [None]:
def train_epoch(model, dataloader, num_classes, criterion, optimizer, scheduler, device):
    epoch_loss = 0.0
    epoch_acc = 0
    epoch_ce = 0.0
    epoch_iou = 0.0
    epoch_dice = 0.0
    count = 0

    piter = tqdm(dataloader, desc='Batch', unit='batch', position=1, leave=False)
    for inputs, seg_masks in piter:

        inputs = inputs.to(device)
            # transfer label to device
        targets = target.to(device)
        seg_masks = seg_masks.to(device)
        _, targets = torch.max(seg_masks, 1)

        batch_size = inputs.size(0)
        nxt_count = count+batch_size
        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, seg_masks)
        
                loss.backward()
        optimizer.step()

        # statistics
        epoch_loss = loss.item() * batch_size/nxt_count + epoch_loss * count/nxt_count
        epoch_acc = ((preds == targets).sum()/np.prod(preds.size())).item() * batch_size/nxt_count + epoch_acc * count/nxt_count

        count = nxt_count
        piter.set_postfix(accuracy=100. * epoch_acc)

    epoch_acc *= 100.
    scheduler.step()
    train_stats = {
        'train_loss': epoch_loss,
        'train_acc': epoch_acc,
    }
    
    return model, epoch_loss, epoch_acc, train_stats

# 4. DTI Model

In [37]:
x = torch.rand(1, 31)
dti_model = models.BinaryMLP(31, [32, 64, 32], p=0.2)
out = dti_model(x)
print(out.shape)

torch.Size([1, 1])


In [39]:
criterion = nn.BCELoss()
optimizer = optim.Adam(dti_model.parameters(), lr=lr)

In [41]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, _, _, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = dti_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

RuntimeError: all elements of input should be between 0 and 1

# 5. rs fMRI Model