In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import time
import datetime
import csv
import yaml
import configparser
import copy
import math 
import random
import warnings
warnings.filterwarnings('ignore')

In [3]:
import pandas as pd
import numpy as np
import sklearn
import matplotlib
# mpl.use('Agg')
%matplotlib inline
# %matplotlib notebook
from matplotlib import pyplot as plt 

matplotlib.rcParams['lines.linewidth'] = 1
matplotlib.rcParams['lines.markersize'] = 5

In [4]:
from tqdm.notebook import trange, tqdm
from ptflops import get_model_complexity_info

# 1. Data Loading

In [5]:
ROOT_PATH = '../'
try: 
    os.chdir(ROOT_PATH)
    sys.path.insert(0, ROOT_PATH)
    print("Current working directory: {}".format(os.getcwd()))
except Exception:
    print("Directory: {} is not valid".format(ROOT_PATH))
    sys.exit(1)

Current working directory: /home/geshi/ABCDFusion


In [6]:
# load and parse config 
config_file = './configs.yaml'
with open(config_file, 'r') as infile:
    try:
        configs = yaml.safe_load(infile)
    except yaml.YAMLError as exc:
        sys.exit(exc)

In [7]:
auxiliary = configs['Auxiliary']
DATA_PATH = auxiliary['DATA_PATH']

OTHER_DATA = auxiliary['OTHER_DATA'] 
DTI_DATA = auxiliary['DTI_DATA'] 
RS_DATA = auxiliary['RS_DATA']
OUTCOME = auxiliary['OUTCOME']

In [8]:
dti_file = os.path.join(DATA_PATH, DTI_DATA)
rs_file = os.path.join(DATA_PATH, RS_DATA)
other_file = os.path.join(DATA_PATH, OTHER_DATA)
label_file = os.path.join(DATA_PATH, OUTCOME)

In [44]:
import abcdfusion
from abcdfusion import get_abcd, metrics, models, utils

In [45]:
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Subset
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

# 2. Define Dataloader

In [11]:
valid_size=0.2
batch_size=128
num_splits=5
num_workers=0
num_epochs=10
step_size=10
lr=0.001
seed=42

In [12]:
def create_datasets(dataset, batch_size, num_workers=0, valid_size=0.2, shuffle=True):
    # obtain training indices that will be used for validation
    num_data = len(dataset)
    if shuffle:
        indices = list(range(num_data))
    np.random.shuffle(indices)
    split = int(np.floor(valid_size * num_data))
    valid_idx, train_idx = indices[:split], indices[split:]

    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    # load training data in batches
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              num_workers=num_workers)
    
    # load validation data in batches
    valid_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              sampler=valid_sampler,
                              num_workers=num_workers)
    
    return train_loader, valid_loader

In [13]:
abcd_dataset = get_abcd(dti_file, rs_file, other_file, label_file)
train_loader, valid_loader = create_datasets(abcd_dataset, batch_size, num_workers, valid_size)

In [14]:
dti, rs, other, y = next(iter(valid_loader))
print('dti: ', dti.shape, 'rs fmri: ', rs.shape, 'other data: ', other.shape, 'label: ', y.shape)

dti:  torch.Size([128, 31]) rs fmri:  torch.Size([128, 270]) other data:  torch.Size([128, 7]) label:  torch.Size([128, 1])


# 3. Define Train / Test

In [None]:
def train_epoch(model, dataloader, num_classes, criterion, optimizer, scheduler, device):
    epoch_loss = 0.0
    epoch_acc = 0
    epoch_ce = 0.0
    epoch_iou = 0.0
    epoch_dice = 0.0
    count = 0

    piter = tqdm(dataloader, desc='Batch', unit='batch', position=1, leave=False)
    for inputs, seg_masks in piter:

        inputs = inputs.to(device)
            # transfer label to device
        targets = target.to(device)
        seg_masks = seg_masks.to(device)
        _, targets = torch.max(seg_masks, 1)

        batch_size = inputs.size(0)
        nxt_count = count+batch_size
        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, seg_masks)
        
        loss.backward()
        optimizer.step()

        # statistics
        epoch_loss = loss.item() * batch_size/nxt_count + epoch_loss * count/nxt_count
        epoch_acc = ((preds == targets).sum()/np.prod(preds.size())).item() * batch_size/nxt_count + epoch_acc * count/nxt_count

        count = nxt_count
        piter.set_postfix(accuracy=100. * epoch_acc)

    epoch_acc *= 100.
    scheduler.step()
    train_stats = {
        'train_loss': epoch_loss,
        'train_acc': epoch_acc,
    }
    
    return model, epoch_loss, epoch_acc, train_stats

In [None]:
def test(model, dataloader, num_classes, device):
    since = time.time()
    model.eval()   # Set model to evaluate mode
    
    corrects = 0
    count = 0

    # Iterate over data.
    with torch.no_grad():
        piter = tqdm(dataloader, unit='batch')
        for inputs, seg_masks in piter:
            piter.set_description(f"Test ")

            inputs = inputs.to(device)
            _, targets = torch.max(seg_masks, 1)
            targets = targets.to(device)
            
            batch_size = inputs.size(0)
            count += batch_size

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # statistics
            corrects += torch.sum(preds == targets.data)
            pos_corrects += torch.sum(preds[] == targets.data)

            acc = corrects.double().item() / count
            piter.set_postfix(accuracy=100. * acc)


    acc = corrects.double().item() / count

    time_elapsed = time.time() - since
    print(f'Testing complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s, Test Acc: {100. * acc}, Test Iou: {mean_IOU}')
    
    test_stats = {
        "test_acc": 100. * acc,
    }

    return cl_wise_iou, test_stats

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device: {}'.format(device))

Using device: cuda


# 4. DTI Model

In [135]:
x = torch.rand(1, 31)
dti_model = models.BinaryMLP(31, [32, 64, 32], p=0.2)
out = dti_model(x)
print(out.shape)
dti_model = dti_model.to(device)

torch.Size([1, 2])


In [137]:
weights = torch.FloatTensor([0.05, 0.95])
criterion = metrics.WeightedBCELoss(weights, reduction='mean')# metrics.FocalLoss(gamma=1, weights=weights) # nn.BCELoss() # nn.BCEWithLogitsLoss()
# criterion = metrics.FocalLoss(gamma=1, weights=weights)
optimizer = optim.Adam(dti_model.parameters(), lr=lr)

In [138]:
dti_model.train()
pbar = trange(200, desc='Epoch', unit='epoch', initial=0, position=0)
for epoch in pbar:  # loop over the dataset multiple times

    running_loss = 0.0
    for inputs, _, _, labels in piter:
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.to(device)
        labels = labels.view(-1).long().to(device)
        labels = F.one_hot(labels, num_classes=2).float()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = dti_model(inputs)
        # probs = F.softmax(outputs)
        probs = F.sigmoid(outputs)
        
        loss = criterion(probs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    pbar.set_postfix(loss = running_loss)

print('Finished Training')

Epoch:   0%|          | 0/200 [00:00<?, ?epoch/s]

Finished Training


In [139]:
dti_model.eval()
with torch.no_grad():
    corrects = 0
    pos_corrects = 0
    neg_corrects = 0
    
    count = 0
    pos_count = 0
    neg_count = 0
    
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, _, _, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = dti_model(inputs)
        _, preds = torch.max(outputs, 1, keepdim=True)

        # statistics
        corrects += torch.sum(preds == labels.data)
        pos_corrects += torch.sum(preds[labels==1] == labels[labels==1].data)
        neg_corrects += torch.sum(preds[labels==0] == labels[labels==0].data)
        
        count += batch_size
        pos_count += (labels.data==1).sum()
        neg_count += (labels.data==0).sum()

    acc = corrects.double().item() / count
    pos_acc = pos_corrects.double().item() / pos_count
    neg_acc = neg_corrects.double().item() / neg_count

print('Finished Training Validation')
print(f'accuracy: {acc*100. : .2f}, pos accuracy: {pos_acc*100. : .2f}, neg accuracy: {neg_acc*100. : .2f}')

Finished Training Validation
accuracy:  98.62, pos accuracy:  99.64, neg accuracy:  100.00


In [140]:
dti_model.eval()
with torch.no_grad():
    corrects = 0
    pos_corrects = 0
    neg_corrects = 0
    
    count = 0
    pos_count = 0
    neg_count = 0
    
    for i, data in enumerate(valid_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, _, _, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = dti_model(inputs)
        _, preds = torch.max(outputs, 1, keepdim=True)

        # statistics
        corrects += torch.sum(preds == labels.data)
        pos_corrects += torch.sum(preds[labels==1] == labels[labels==1].data)
        neg_corrects += torch.sum(preds[labels==0] == labels[labels==0].data)
        
        count += batch_size
        pos_count += (labels.data==1).sum()
        neg_count += (labels.data==0).sum()

    acc = corrects.double().item() / count
    pos_acc = pos_corrects.double().item() / pos_count
    neg_acc = neg_corrects.double().item() / neg_count

print('Finished Training Validation')
print(f'accuracy: {acc*100. : .2f}, pos accuracy: {pos_acc*100. : .2f}, neg accuracy: {neg_acc*100. : .2f}')

Finished Training Validation
accuracy:  75.00, pos accuracy:  7.81, neg accuracy:  92.00


In [None]:
torch.cuda.empty_cache()

# 5. rs fMRI Model

In [157]:
x = torch.rand(1, 270)
rs_model = models.BinaryMLP(270, [270, 270, 40], p=0.2, hidden_dim=300)
out = rs_model(x)
print(out.shape)
rs_model = rs_model.to(device)

torch.Size([1, 2])


In [158]:
weights = torch.FloatTensor([0.05, 0.95])
criterion = metrics.WeightedBCELoss(weights, reduction='mean')# metrics.FocalLoss(gamma=1, weights=weights) # nn.BCELoss() # nn.BCEWithLogitsLoss()
# criterion = metrics.FocalLoss(gamma=1, weights=weights)
optimizer = optim.Adam(rs_model.parameters(), lr=lr)

In [159]:
rs_model.train()
pbar = trange(200, desc='Epoch', unit='epoch', initial=0, position=0)
for epoch in pbar:  # loop over the dataset multiple times

    running_loss = 0.0
    for _, inputs, _, labels in piter:
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.to(device)
        labels = labels.view(-1).long().to(device)
        labels = F.one_hot(labels, num_classes=2).float()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = rs_model(inputs)
        # probs = F.softmax(outputs)
        probs = F.sigmoid(outputs)
        
        loss = criterion(probs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    pbar.set_postfix(loss = running_loss)

print('Finished Training')

Epoch:   0%|          | 0/200 [00:00<?, ?epoch/s]

Finished Training


In [160]:
rs_model.eval()
with torch.no_grad():
    corrects = 0
    pos_corrects = 0
    neg_corrects = 0
    
    count = 0
    pos_count = 0
    neg_count = 0
    
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        _, inputs, _, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = rs_model(inputs)
        _, preds = torch.max(outputs, 1, keepdim=True)

        # statistics
        corrects += torch.sum(preds == labels.data)
        pos_corrects += torch.sum(preds[labels==1] == labels[labels==1].data)
        neg_corrects += torch.sum(preds[labels==0] == labels[labels==0].data)
        
        count += batch_size
        pos_count += (labels.data==1).sum()
        neg_count += (labels.data==0).sum()

    acc = corrects.double().item() / count
    pos_acc = pos_corrects.double().item() / pos_count
    neg_acc = neg_corrects.double().item() / neg_count

print('Finished Training Validation')
print(f'accuracy: {acc*100. : .2f}, pos accuracy: {pos_acc*100. : .2f}, neg accuracy: {neg_acc*100. : .2f}')

Finished Training Validation
accuracy:  98.66, pos accuracy:  100.00, neg accuracy:  100.00


In [161]:
rs_model.eval()
with torch.no_grad():
    corrects = 0
    pos_corrects = 0
    neg_corrects = 0
    pos_preds = 0
    
    count = 0
    pos_count = 0
    neg_count = 0
    
    for i, data in enumerate(valid_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        _, inputs, _, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = rs_model(inputs)
        _, preds = torch.max(outputs, 1, keepdim=True)

        # statistics
        corrects += torch.sum(preds == labels.data)
        pos_corrects += torch.sum(preds[labels==1] == labels[labels==1].data)
        neg_corrects += torch.sum(preds[labels==0] == labels[labels==0].data)
        
        count += batch_size
        pos_count += (labels.data==1).sum()
        neg_count += (labels.data==0).sum()
        pos_preds += (preds.data==1).sum()

    acc = corrects.double().item() / count
    pos_acc = pos_corrects.double().item() / pos_count
    neg_acc = neg_corrects.double().item() / neg_count

print('Finished Training Validation')
print(f'accuracy: {acc*100. : .2f}, pos accuracy: {pos_acc*100. : .2f}, neg accuracy: {neg_acc*100. : .2f}, positive predictions: {pos_preds}')

Finished Training Validation
accuracy:  78.79, pos accuracy:  0.00, neg accuracy:  97.38, positive predictions: 19
