In [None]:
# import libraries
import numpy as np
import random
import scipy.signal as sgn
import scipy.misc
import scipy.stats as stats
import pandas as pd
import os
import shutil
import time
import matplotlib.pyplot as plt
import gc
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, average_precision_score 
import seaborn as sns
import sys



import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader, RandomSampler, ConcatDataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import WeightedRandomSampler

if torch.cuda.is_available:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    

In [None]:
# create connection (adjacency) matrix from network files
def connectionMatrix(fmat, nlist):
    n = np.shape(fmat)[1]
    cmat = np.zeros((n,n))
    for connection in nlist:
        nfrom = connection[0] - 1 # subtract 1 for python indexing
        nto = connection[1] - 1
        if connection[2] == 1:
            cmat[nfrom, nto] = 1
            cmat[nto, nfrom] = 1
    return(cmat)

In [None]:
# define dataset number to dataset name dict for easy modular loading
dataset_dict = {
    1: 'normal-1',
    2: 'normal-2',
    3: 'normal-3',
    4: 'normal-4',
    5: 'normal-3-highrate',
    6: 'normal-4-lownoise',
    7: 'lowcon',
    8: 'highcon',
    9: 'lowcc',
    10: 'highcc',
    11: 'small-1',
    12: 'small-2',
    13: 'small-3',
    14: 'small-4',
    15: 'small-5',
    16: 'small-6'
}

# define paths to data
data_parent_dir = "/scratch/dmc421/dlproject/data"

# dataset name : path
data_dir = {}
for v in dataset_dict.values():
    data_dir.update({v: str(os.path.join(data_parent_dir, v))})

# dataset filename : path
flr_dir = {}
net_dir = {}
pos_dir = {}

denoised_dir = {}
denoised_weighted_dir = {}

spike_dir = {}
spike_weighted_dir = {}
binspike_dir = {}

# update dictionaries with corresponding filepaths
for dataset in data_dir.keys():
    for f in os.listdir(data_dir[dataset]):
        filepath = os.path.join(data_dir[dataset], f)
        if "networkPositions" in f:
            pos_dir.update({dataset: filepath})
        elif "network" in f:
            net_dir.update({dataset: filepath})
        elif "denoised_weighted" in f:
            denoised_weighted_dir.update({dataset:filepath})
        elif "denoised" in f:
            denoised_dir.update({dataset:filepath})
        elif "spike_weighted" in f:
            spike_weighted_dir.update({dataset:filepath})
        elif "binspike" in f:
            binspike_dir.update({dataset:filepath})
        elif "spike" in f:
            spike_dir.update({dataset:filepath})
        elif "fluorescence" in f:
            flr_dir.update({dataset: filepath})

In [None]:
%%time

# import network files and denoised flr files
npath_1 = net_dir[dataset_dict[11]]
npath_2 = net_dir[dataset_dict[12]]
npath_3 = net_dir[dataset_dict[13]]
npath_4 = net_dir[dataset_dict[14]]
npath_5 = net_dir[dataset_dict[15]]
npath_6 = net_dir[dataset_dict[16]]

fpath_1 = denoised_dir[dataset_dict[11]]
fpath_2 = denoised_dir[dataset_dict[12]]
fpath_3 = denoised_dir[dataset_dict[13]]
fpath_4 = denoised_dir[dataset_dict[14]]
fpath_5 = denoised_dir[dataset_dict[15]]
fpath_6 = denoised_dir[dataset_dict[16]]

# import network mappings
network_1 = np.genfromtxt(
    npath_1,
    delimiter=',').astype(int)
network_2 = np.genfromtxt(
    npath_2,
    delimiter=',').astype(int)
network_3 = np.genfromtxt(
    npath_3,
    delimiter=',').astype(int)
network_4 = np.genfromtxt(
    npath_4,
    delimiter=',').astype(int)
network_5 = np.genfromtxt(
    npath_5,
    delimiter=',').astype(int)
network_6 = np.genfromtxt(
    npath_6,
    delimiter=',').astype(int)

print("networks loaded")

# import fluorescence data
flr_1 = np.genfromtxt(
    fpath_1,
    delimiter=',')
flr_2 = np.genfromtxt(
    fpath_2,
    delimiter=',')
flr_3 = np.genfromtxt(
    fpath_3,
    delimiter=',')
flr_4 = np.genfromtxt(
    fpath_4,
    delimiter=',')
flr_5 = np.genfromtxt(
    fpath_5,
    delimiter=',')
flr_6 = np.genfromtxt(
    fpath_6,
    delimiter=',')

print("flrs loaded")

# create connection (adjacency) matrices
con_1 = connectionMatrix(flr_1, network_1)
con_2 = connectionMatrix(flr_2, network_2)
con_3 = connectionMatrix(flr_3, network_3)
con_4 = connectionMatrix(flr_4, network_4)
con_5 = connectionMatrix(flr_5, network_5)
con_6 = connectionMatrix(flr_6, network_6)
print("connection matrices constructed")

In [None]:
# examine the number of positive examples per dataset
print(np.sum(con_1))
print(np.sum(con_2))
print(np.sum(con_3))
print(np.sum(con_4))
print(np.sum(con_5))
print(np.sum(con_6))

In [None]:
# create a dataset in line with the filtering criteria from Romaszko / Dunn & Koo
# adapted from https://github.com/spoonsso/TFconnect/blob/master/conutils/utils.py
class RomaDataset(Dataset):
    def __init__(self, flrmat, conmat, seq_len=330, th=0.02):
        self.flrmat = flrmat
        self.conmat = conmat
        self.seq_len = seq_len
        self.th = th
        
        self.n_neurons = np.shape(self.flrmat)[1]
        self.n_pairs = np.shape(self.flrmat)[1]**2 #possible pairs including identical pairs
        self.n_timesamples = np.shape(self.flrmat)[0]
        self.sum_thresh = self.th * self.n_neurons
               
        # z-score input flr matrix
        self.flrz = stats.zscore(flrmat, axis=1)
        
        self.avg = np.mean(self.flrmat, axis=1)
        self.avgz = stats.zscore(self.avg)
        
        # take the diff matrix per Romaszko paper / spoonsso implementation
        fdiff = np.diff(flrmat, axis=0)
        totF = np.sum(fdiff, axis=1)
        totF = np.hstack([totF, totF[-1]]) # repeat last value for boolean logic size match
        self.ffilt = self.flrz[(totF>self.sum_thresh), :]
        self.avg = self.avg[(totF>self.sum_thresh)]
       
        self.n_timesamples_filt = np.shape(self.ffilt)[0]
        
    # length of dataset is the number of pairs    
    def __len__(self):
        return self.n_pairs
        
    def __getitem__(self, idx):
        
        self.p1 = int(np.floor(idx / self.n_neurons)) # from neuron     
        self.p2 = int(idx % self.n_neurons) # to neuron
        
        # pick a random time window of filtered flr matrix
        self.start = np.random.randint(0, self.n_timesamples_filt - self.seq_len - 1)
        self.end = self.start + self.seq_len

        
        ##### BUILD TRACKS #######
        # flr tracks
        self.flr_from_track = self.ffilt[self.start:self.end, self.p1]
        self.flr_to_track = self.ffilt[self.start:self.end, self.p2]
        self.avg_track = self.avgz[self.start:self.end]
                        
        # Assemble a single 3x330 track for training / evaluation
        self.track = torch.Tensor([
            self.flr_from_track,
            self.avg_track,
            self.flr_to_track
        ])
        
        
        # produce label for generated track
        self.contype = torch.Tensor([self.conmat[self.p1, self.p2]])
        
        # produce the time window used (for debugging purposes)
        self.samples = torch.Tensor([self.start, self.end])
            
        return(self.track, self.contype, self.samples)


   

In [None]:
# set the model sequence length and fluorescence threshold to generate datasets
seq_len = 330
th = 0.02

dataset_1 = RomaDataset(
    flr_1, 
    con_1,
    seq_len=seq_len,
    th=th)

dataset_2 = RomaDataset(
    flr_2, 
    con_2,
    seq_len=seq_len,
    th=th)

dataset_3 = RomaDataset(
    flr_3, 
    con_3,
    seq_len=seq_len,
    th=th)

dataset_4 = RomaDataset(
    flr_4, 
    con_4,
    seq_len=seq_len,
    th=th)

dataset_5 = RomaDataset(
    flr_5, 
    con_5,
    seq_len=seq_len,
    th=th)

dataset_6 = RomaDataset(
    flr_6, 
    con_6,
    seq_len=seq_len,
    th=th)


# train on datasets 1:4, validate on 5, test on 6
dataset_combined = ConcatDataset([dataset_1, dataset_2, dataset_3, dataset_4])
dataset_validation = dataset_5

transformed_dataset = {
    'train': dataset_combined,
    'validate': dataset_5
}


# ensure all datasets are long enough to sample sequence lengths
assert(dataset_1[0][0].size()[1] >= seq_len)
assert(dataset_2[0][0].size()[1] >= seq_len)
assert(dataset_3[0][0].size()[1] >= seq_len)
assert(dataset_4[0][0].size()[1] >= seq_len)
assert(dataset_5[0][0].size()[1] >= seq_len)
assert(dataset_6[0][0].size()[1] >= seq_len)

print("individual datasets created")


In [None]:
# main training loop adapted from in class lab examples
def train_model(model, dataloader, optimizer, loss_function, scheduler, num_epochs = 10, verbose = False, print_metrics=False):
    acc_dict = {'train':[],'validate':[]}
    loss_dict = {'train':[],'validate':[]}
    recall_dict = {'train':[],'validate':[]}
    specificity_dict = {'train':[],'validate':[]}
    AUROC_dict = {'train':[],'validate':[]}
    AP_dict = {'train':[],'validate':[]}
    
    best_AUROC = 0
    phases = ['train', 'validate']
    since = time.time()
    for i in range(num_epochs):
        print('Epoch: {}/{}'.format(i, num_epochs-1))
        print('-'*10)
        batch = 0
        for p in phases:
            running_correct = 0
            running_loss = 0
            running_total = 0
            if p == 'train':
                model.train()
            else:
                model.eval()
            
            all_predictions = []
            all_labels = []
            
            
            for data in dataloader[p]:
                optimizer.zero_grad()
                seq = data[0].to(device)

                
                
                label = data[1].to(device)
                y_pred = model(seq)
                

                
                loss = loss_function(y_pred, label.long().squeeze())
                _, preds = torch.max(y_pred, dim = 1)
                num_seqs = seq.size()[0]

                
                pr = preds.clone().cpu().detach().numpy()
                lb = label.clone().cpu().detach().numpy()
                
                all_predictions = np.concatenate([all_predictions, pr], axis=None)
                all_labels = np.concatenate([all_labels, lb], axis=None)
                
                


                ########################
                ## PRINT BATCH METRICS#
                #####################
                if print_metrics:
                    try:
                        connection_calls = torch.sum((preds==1))
                        correct_calls = torch.sum(preds == label)
                        connection_call_accuracy = ((correct_calls.item()) / (connection_calls.item()))

                        print("connections present: {}".format(torch.sum(label.view(-1).long())))
                        print("connection calls: {}".format(connection_calls))
                        #print("connection call accuracy: {}".format(connection_call_accuracy))
                        #print("\n")
                

                
                        pr = preds.clone().cpu().detach().numpy()
                        lb = label.clone().cpu().detach().numpy()  
                        tn, fp, fn, tp = confusion_matrix(lb, pr).ravel()

                        specificity = tn / (tn + fp)
                        precision = tp / (tp + fp)
                        recall = tp / (tp + fn)
                        f1 = (2*tp /( 2*tp + fp + fn))
                        #print("precision: {}".format(precision))
                        print("recall: {}".format(recall))
                        print("specificity: {}".format(specificity))
                        print("f1: {}".format(f1))
                    except:
                        print("whoopsie")
                    print("loss: {}".format(loss.item()))
                    print("\n")
                
                running_correct += torch.sum(preds == label.view(-1).long()).item()
                #print("running correct: {}".format(running_correct))
                running_loss += loss.item()*num_seqs
                running_total += num_seqs
                #print("running total: {}".format(running_total))
                running_acc = running_loss / running_total
                #print("running_acc: {}".format(running_acc))
                
                
                if p == 'train':
                    loss.backward()
                    clipping_value = 0.5 #
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
                    optimizer.step()
                
                #print("\n")
            epoch_acc = float(running_correct/running_total)
            epoch_loss = float(running_loss/running_total)
            
            tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
            fpr, tpr, thresholds = roc_curve(all_labels, all_predictions, pos_label=1)
            epoch_AP = average_precision_score(all_labels, all_predictions, pos_label=1)
            epoch_AUROC = auc(fpr, tpr)
            
       
            epoch_specificity = tn / (tn + fp)
            epoch_precision = tp / (tp + fp)
            epoch_recall = tp / (tp + fn)
            epoch_f1 = (2*tp /( 2*tp + fp + fn))
            #print("precision: {}".format(precision))
            print("recall: {}".format(epoch_recall))
            print("specificity: {}".format(epoch_specificity))
            print("f1: {}".format(epoch_f1))
            print("AUROC: {}".format(epoch_AUROC))
            print("AP: {}".format(epoch_AP))
            
            
            if verbose or (i%10 == 0):
                print('Phase:{}, epoch loss: {:.4f} Acc: {:.4f}'.format(p, epoch_loss, epoch_acc))
                print('\n')
            
            # add metrics to epoch dict
            acc_dict[p].append(epoch_acc)
            loss_dict[p].append(epoch_loss)
            recall_dict[p].append(epoch_recall)
            specificity_dict[p].append(epoch_specificity)
            AUROC_dict[p].append(epoch_AUROC)
            AP_dict[p].append(epoch_AP)
            
            
            # choose to retain model based on AUROC
            if p == 'validate':
                if epoch_AUROC > best_AUROC:
                    best_AUROC = epoch_AUROC
                    best_model_wts = model.state_dict()
            else:
                if scheduler:
                    scheduler.step()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val AUROC: {:4f}'.format(best_AUROC))
    
    model.load_state_dict(best_model_wts)
    
    return model, acc_dict, loss_dict, recall_dict, specificity_dict, AUROC_dict, AP_dict



In [None]:
# START TO DEFINE MODEL MODULES

In [None]:
# optional recurrent "front end" to for sequence representation transformation
class seq2seq(nn.Module):
    def __init__(self, hidden_size):
        super(seq2seq, self).__init__()
        
        self.hidden_size = hidden_size
        self.recurrent_layer = nn.Sequential(
            nn.GRU(
                input_size = 3,
                hidden_size = self.hidden_size,
                num_layers = 3,
                batch_first = True,
                dropout = 0.1,
                bidirectional = True),
            
        )

        self.reduce_features1 = nn.Conv2d(
            in_channels = 24,
            out_channels = 12,
            kernel_size =  1)
        
        self.reduce_features2 = nn.Conv2d(
            in_channels=12,
            out_channels = 4,
            kernel_size = 1)
         
        self.relu  = nn.ReLU()
        
    def forward(self, x, verbose=False):
        self.verbose = verbose
        x = x.transpose(1,2)
        
         
        x, _ = self.recurrent_layer(x)
        x = x.transpose(1,2)
        
        x = x.unsqueeze(2)
        x = self.relu(x)
                
        x = x.squeeze()
        
        if self.verbose:
            print("seq2seq input shape: {}".format(x.size()))
            print("out size: {}".format(x.size()))
        return(x)

In [None]:
# first convolutional block module, Inception v2 style with optional residual connection
class connception_module1(nn.Module):
    def __init__(self, residual = False):
        super(connception_module1, self).__init__()
        
        self.residual = residual
        
        self.branch_1x = nn.Sequential(
            nn.Conv2d(
                in_channels = 32,
                out_channels = 8,
                kernel_size = 1)
        )
        
        self.branch_3x = nn.Sequential(            
            nn.Conv2d(
                in_channels = 32,
                out_channels = 4,
                kernel_size = 1),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = 4,
                out_channels = 8,
                kernel_size = [3,3],
                padding = [1,1]            
            )
        )
        
        self.branch_5x = nn.Sequential(            
            nn.Conv2d(
                in_channels = 32,
                out_channels = 4,
                kernel_size = 1),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = 4,
                out_channels = 8,
                kernel_size = [3,3],
                padding = [1,1]),
            nn.Conv2d(
                in_channels = 8,
                out_channels = 8,
                kernel_size = [3,3],
                padding = [1,1])  
        )
        
        self.branch_mp = nn.Sequential(
            nn.AdaptiveMaxPool2d(
                output_size = [3,330]),
            nn.Conv2d(
                in_channels = 32,
                out_channels = 8,
                kernel_size = 1)
        )

        self.batch_norm = nn.BatchNorm2d(32)

            
    def forward(self, x, residual=False, verbose=False):
            self.residual = residual
            
            branch_1x = self.branch_1x(x)
            if verbose:
                print("connception module 2 branch 1x size: {}".format(branch_1x.size()))
            branch_3x = self.branch_3x(x)
            if verbose:
                print("connception module 2 branch 3x size: {}".format(branch_3x.size()))
            branch_5x = self.branch_5x(x)
            if verbose:
                print("connception module 2 branch 5x size: {}".format(branch_5x.size()))
            branch_mp = self.branch_mp(x)
            if verbose:
                print("connception module 2 branch mp size: {}".format(branch_mp.size()))

            
            out = torch.cat(
                (branch_1x,
                 branch_3x,
                 branch_5x,
                 branch_mp),
                1)
            
            if self.residual:
                out += x
            
            out = self.batch_norm(out)
            
            
            #print("out size: {}".format(out.size()))
            return(out)

In [None]:
# first convolutional block with residual architecture from Dunn & Koo (2017)
class connception_block1(nn.Module):
    def __init__(self, first_input = False, rfe = False):
        super(connception_block1, self).__init__()
        
        # condition input seq of lengths BATCH x 3 X 330
        self.rfe = rfe
        if self.rfe:
            self.conv1 = nn.Sequential(
                nn.Conv2d(
                           in_channels = 1,
                           out_channels = 32, 
                           kernel_size = [8,5], 
                           stride = [8,1],
                           padding = [0,2], 
                           dilation = 1, 
                           bias = True, 
                           padding_mode = 'zeros'
                        ),           
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(
                           in_channels = 1,
                           out_channels = 32, 
                           kernel_size = [1,5], 
                           stride = 1,
                           padding = [0,2], 
                           dilation = 1, 
                           bias = True, 
                           padding_mode = 'zeros'
                        ),           
            )

        self.first_input = first_input
        self.connception1 = connception_module1()       
        self.connception2 = connception_module1()        
        self.connception3 = connception_module1()
        
        self.relu = nn.ReLU()
        
    def forward(self, x, verbose = False):        
        self.verbose = verbose
        # if first input, add dimension for convolution
        if self.verbose:
            print("input sequence size: {}".format(x.size()))
        
        if self.first_input:
            input_seq = x.unsqueeze(1)
            x = self.conv1(input_seq)      
        if self.verbose:
            print("post 'input convolution' size: {}".format(x.size()))
        
        
        x = self.connception1(x, self.verbose)
        residual = x
        
        x = self.relu(self.connception2(x, residual=True))
        x = self.relu(self.connception3(x, residual=True))
        
        x+=residual
        
        if self.verbose:
            print("connception block 1 output size: {}".format(x.size()))
        
        return(x)

In [None]:
# second convolutional block module, Inception v2 style with optional residual connection

class connception_module2(nn.Module):
    def __init__(self, residual = False):
        super(connception_module2, self).__init__()
        
        self.residual = residual
        
        self.branch_1x = nn.Sequential(
            nn.Conv2d(
                in_channels = 64,
                out_channels = 16,
                kernel_size = 1)
        )
        
        self.branch_3x = nn.Sequential(            
            nn.Conv2d(
                in_channels = 64,
                out_channels = 8,
                kernel_size = [2,1],
                padding = [1,0]),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = 8,
                out_channels = 16,
                kernel_size = [2,3],
                padding = [0,1]            
            )
        )
        
        self.branch_5x = nn.Sequential(            
            nn.Conv2d(
                in_channels = 64,
                out_channels = 8,
                kernel_size = [2,1],
                padding = [1,0]),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = 8,
                out_channels = 16,
                kernel_size = [3,3],
                padding = [1,1]),
            nn.Conv2d(
                in_channels = 16,
                out_channels = 16,
                kernel_size = [2,3],
                padding = [0,1])  
        )
        
        self.branch_mp = nn.Sequential(
            nn.AdaptiveMaxPool2d(
                output_size = [2,110]),
            nn.Conv2d(
                in_channels = 64,
                out_channels = 16,
                kernel_size = 1)
        )

        self.batch_norm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
            
    def forward(self, x, residual=False, verbose=False):

            self.residual = residual
            self.verbose = verbose
            
            branch_1x = self.branch_1x(x)
            branch_3x = self.branch_3x(x)                
            branch_5x = self.branch_5x(x)                
            branch_mp = self.branch_mp(x)

            if verbose:
                print("connception module 2 branch 1x size: {}".format(branch_1x.size()))
                print("connception module 2 branch 3x size: {}".format(branch_3x.size()))
                print("connception module 2 branch 5x size: {}".format(branch_5x.size()))
                print("connception module 2 branch mp size: {}".format(branch_mp.size()))
            
            out = torch.cat(
                (
                branch_1x,
                branch_3x,
                branch_5x,
                branch_mp),
                1)
            
            
            if self.residual:
                out += x
            
            out = self.batch_norm(out)
            
            
            return(out)

In [None]:
# second convolutional block with residual architecture from Dunn & Koo (2017)

class connception_block2(nn.Module):
    def __init__(self):
        super(connception_block2, self).__init__()
        
        # condition input from connception block 1
        # size B x 128 x 3 x seqlen --> B x 256 x 2 x seqlen 
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels = 32,
                out_channels = 64, 
                kernel_size = [2,3],
                padding = [0,1],
                stride = [1,3], 
                dilation = 1, 
                bias = True)
        )
 
        self.connception1 = connception_module2()       
        self.connception2 = connception_module2()        
        self.connception3 = connception_module2()
        
        self.connception4 = connception_module2()       
        self.connception5 = connception_module2()        
        self.connception6 = connception_module2()
        
        self.relu = nn.ReLU()
        self.max_pool_out = nn.AdaptiveMaxPool2d(
            output_size = [2,110])
        
        
    def forward(self, x, verbose=False):   
        
        self.verbose = verbose
        
        # convolution to match dimensions
        x = self.conv1(x)
        if self.verbose:
            print("post connception block 2 input conv size: {}".format(x.size()))
        
        x = self.connception1(x)
        x = self.relu(x)
        residual = x
        
        x = self.connception2(x, verbose = self.verbose, residual=True)
        x = self.relu(x)
        x = self.connception3(x, verbose = self.verbose, residual=True)
        x = self.relu(x)
        residual = x
        
        
        if self.verbose:
            print("connception block 2 output size: {}".format(x.size()))
        
        return(x)

        

In [None]:
# final convolutional module, collapse to 1d sequence using Inception v2 style filter arrangement
class connception_module3(nn.Module):

    def __init__(self):
        super(connception_module3, self).__init__()
        
        self.branch_1x = nn.Sequential(
            nn.Conv2d(
                in_channels = 64,
                out_channels = 16,
                kernel_size = [2,1])
        )
        
        self.branch_3x = nn.Sequential(            
            nn.Conv2d(
                in_channels = 64,
                out_channels = 16,
                kernel_size = [2,1],
                padding = [0,0]),
            nn.Conv1d(
                in_channels = 16,
                out_channels = 16,
                kernel_size = [1,3],
                padding = [0,1]            
            )
        )
        
        self.branch_5x = nn.Sequential(            
            nn.Conv2d(
                in_channels = 64,
                out_channels = 16,
                kernel_size = [2,1],
                padding = [0,0]),
            nn.Conv2d(
                in_channels = 16,
                out_channels = 16,
                kernel_size = [1,3],
                padding = [0,1]),
            nn.Conv2d(
                in_channels = 16,
                out_channels = 16,
                kernel_size = [1,3],
                padding = [0,1])  
        )
        
        
        
        
        
        self.feature_reduce = nn.Conv2d(
            in_channels = 48,
            out_channels = 16,
            kernel_size = 1)

        self.relu = nn.ReLU()
        self.max_pool = nn.AdaptiveMaxPool1d(output_size = 64)
        
        
        
        
        
        
        
    def forward(self, x, verbose=False):
        self.verbose = verbose
        
        branch_1x = self.branch_1x(x)                
        branch_3x = self.branch_3x(x)                
        branch_5x = self.branch_5x(x)                

        if verbose:
            print("connception module 3 branch 1x size: {}".format(branch_1x.size()))
            print("connception module 3 branch 3x size: {}".format(branch_3x.size()))
            print("connception module 3 branch 5x size: {}".format(branch_5x.size()))
#             print("connception module 2 branch mp size: {}".format(branch_mp.size()))

        x = torch.cat(
            (
            branch_1x,
            branch_3x,
            branch_5x),
            1)
        
        
        x = self.feature_reduce(x)        
        x = self.relu(x)
        
        
        
        
        x = x.squeeze(2)
        x = self.max_pool(x)
            
        if self.verbose:
            print("connception block 3 output size: {}".format(x.size()))    
            
            
        return(x)
    
    
    

In [None]:
# recurrent portion of classifier output, gated recurrent unit 
class connception_recurrent_out(nn.Module):
    def __init__(self, hidden_size):
        super(connception_recurrent_out, self).__init__()
    
        self.hidden_size = hidden_size
        self.gru = nn.GRU(
            input_size = 16,
            hidden_size = self.hidden_size,
            
            num_layers = 2,
            bias = True,
            batch_first = True,
            dropout = 0,
            bidirectional = True
        )
        
    def forward(self, x, verbose=False):
        self.verbose = verbose
        #condition shape for GRU input
        x = x.squeeze(2).transpose(1,2)
        if self.verbose:
            print('post transpose shape: {}'.format(x.size()))
 
        x, h_n = self.gru(x)
        h_n = h_n.transpose(0,1).flatten(start_dim=1, end_dim=-1)
        
        out = h_n
        
        if self.verbose:
            print('recurrent out size: {}'.format(out.size()))
        
        
        return(out)


In [None]:
# classifier output fully connected module, takes hidden state from GRU and produces two output numbers
class connception_fc(nn.Module):
    def __init__(self):
        super(connception_fc, self).__init__()
        
        self.fc1 = nn.Linear(64, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 2)
       
        self.relu = nn.ReLU()
        
    def forward(self, x, verbose=False):
        self.verbose = verbose
        
        x = x.view(x.size(0), -1)
        if self.verbose:
            print('post transpose shape: {}'.format(x.size()))

        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
             
        
        return(x)

In [None]:
# define model with Recurrent Front End (RFE version)
class tmodel_rfe(nn.Module):
    def __init__(self):
        super(tmodel_rfe, self).__init__()
        
        self.s2s_hidden_size = 12
        self.recurrent_out_size = 16
        
        self.seq2seq = seq2seq(hidden_size = self.s2s_hidden_size)
        
        self.conv_block1 = connception_block1(first_input=True, rfe=True)
        self.conv_block2 = connception_block2()
        self.conv_block3 = connception_module3()

        self.connception_recurrent_out = connception_recurrent_out(self.recurrent_out_size)    
        self.connception_fc = connception_fc()
        
        
        
        self.batch_norm = nn.BatchNorm1d(16)
        self.relu = nn.ReLU()
        
        
    def forward(self, x, verbose=False):
        self.verbose = verbose
        
        if self.verbose:
            print("input size: {}".format(x.size()))
        
        x = self.seq2seq(x, verbose=self.verbose)
        
        x = self.conv_block1(x, verbose=self.verbose)
        x = self.relu(x)
        
        x = self.conv_block2(x, verbose=self.verbose)
        x = self.relu(x)
        
        x = self.conv_block3(x, verbose=self.verbose) 
        x = self.relu(x)
        x = self.batch_norm(x)
        
        x = self.connception_recurrent_out(x, verbose = self.verbose)
    
    
        x = self.connception_fc(x, verbose=self.verbose)
        
        
        return(x)
           

In [None]:
# define model with no rfe
class tmodel(nn.Module):
    def __init__(self):
        super(tmodel, self).__init__()
        
        self.s2s_hidden_size = 12
        self.recurrent_out_size = 16
        
        self.seq2seq = seq2seq(hidden_size = self.s2s_hidden_size)
        
        self.conv_block1 = connception_block1(first_input=True, rfe=False)
        self.conv_block2 = connception_block2()
        self.conv_block3 = connception_module3()

        self.connception_recurrent_out = connception_recurrent_out(self.recurrent_out_size)    
        self.connception_fc = connception_fc()
        
        
        
        self.batch_norm = nn.BatchNorm1d(16)
        self.relu = nn.ReLU()
        
        
    def forward(self, x, verbose=False):
        self.verbose = verbose
        
        if self.verbose:
            print("input size: {}".format(x.size()))
        
        #x = self.seq2seq(x, verbose=self.verbose)
        
        x = self.conv_block1(x, verbose=self.verbose)
        x = self.relu(x)
        
        x = self.conv_block2(x, verbose=self.verbose)
        x = self.relu(x)
        
        x = self.conv_block3(x, verbose=self.verbose) 
        x = self.relu(x)
        x = self.batch_norm(x)
        
        x = self.connception_recurrent_out(x, verbose = self.verbose)
    
    
        x = self.connception_fc(x, verbose=self.verbose)
        
        
        return(x)
           

In [None]:
# clear CUDA memory
try:
    del model
    torch.cuda.empty_cache()
except:
    pass


# set training parameters and train

batch_size = 100


dataloader = {x: DataLoader(transformed_dataset[x], 
                            batch_size=batch_size,
                            shuffle=True, 
                            num_workers=0) for x in ['train', 'validate']}
data_sizes = {x: len(transformed_dataset[x]) for x in ['train', 'validate']}


model = tmodel()


model = model.to(device)

loss = nn.CrossEntropyLoss(weight=torch.Tensor([0.15, 0.85]).cuda())

optimizer = optim.Adam(model.parameters(), lr = 1e-4)
lambda_func = lambda epoch: 0.9 ** epoch
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
model, acc_dict, loss_dict, recall_dict, specificity_dict, AUROC_dict, AP_dict = train_model(model, dataloader, optimizer, loss, scheduler=scheduler, num_epochs=50, print_metrics=False, verbose = True)


In [None]:
## PLOT METRICS

In [None]:
# model accuracy

plt.plot(acc_dict['train'], label='train')
plt.plot(acc_dict['validate'], label='validation')
plt.title("Model Accuracy over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()








In [None]:

# model loss

plt.plot(loss_dict['train'], label='train')
plt.plot(loss_dict['validate'], label= 'validation')
plt.title("Training / Validation Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Weighted Cross Entropy Loss \n [0.15, 0.85]")
plt.legend()


In [None]:
#model recall

plt.plot(recall_dict['train'], label='train')
plt.plot(recall_dict['validate'])
plt.title("Model Recall over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Recall")
plt.legend()

In [None]:
#model specificity

plt.plot(specificity_dict['train'], label='train')
plt.plot(specificity_dict['validate'], label= 'validation')
plt.title("Model Specificity over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Specificity")
plt.legend()

In [None]:
#model AUROC

plt.plot(AUROC_dict['train'], label='train')
plt.plot(AUROC_dict['validate'], label= 'validation')
plt.title("AUROC over Epochs")
plt.xlabel("Epoch")
plt.ylabel("AUROC")
plt.legend()

In [None]:
#model AP

plt.plot(AP_dict['train'], label='train')
plt.plot(AP_dict['validate'], label= 'validation')
plt.title("AP over Epochs")
plt.xlabel("Epoch")
plt.ylabel("AP")
plt.legend()

In [None]:
test_loader = DataLoader(dataset_6, 
                            batch_size=100,
                            shuffle=True, 
                            num_workers=0)

In [None]:
## OPTIONAL CELL TO LOAD PREVIOUSLY SAVED MODELS FOR TESTINGG
# model = tmodel().to(device)
# state_dict = torch.load("/scratch/dmc421/dlproject/model_state")
# model.load_state_dict(state_dict)

In [None]:
# main testing
# run 5 times and record results
def test_model(model, dataloader, verbose = True, print_metrics=False):
    acc_dict = {'test':[]}
    loss_dict = {'test':[]}
    recall_dict = {'test':[]}
    specificity_dict = {'test':[]}
    AUROC_dict = {'test':[]}
    AP_dict = {'test':[]}
    
    best_AP = 0
    phases = ['test']
    since = time.time()
    for i in range(1):
        batch = 0
        for p in phases:
            running_correct = 0
            running_loss = 0
            running_total = 0
            model.eval()
            
            all_predictions = []
            all_labels = []
            
            
            for data in dataloader:

                seq = data[0].to(device)

                
                
                label = data[1].to(device)
                y_pred = model(seq)
                

                
#                 loss = loss_function(y_pred, label.long().squeeze())
                _, preds = torch.max(y_pred, dim = 1)
                num_seqs = seq.size()[0]

                
                pr = preds.clone().cpu().detach().numpy()
                lb = label.clone().cpu().detach().numpy()
                
                all_predictions = np.concatenate([all_predictions, pr], axis=None)
                all_labels = np.concatenate([all_labels, lb], axis=None)
                
                


                ########################
                ## PRINT BATCH METRICS#
                #####################
                if print_metrics:
                    try:
                        connection_calls = torch.sum((preds==1))
                        correct_calls = torch.sum(preds == label)
                        connection_call_accuracy = ((correct_calls.item()) / (connection_calls.item()))

                        print("connections present: {}".format(torch.sum(label.view(-1).long())))
                        print("connection calls: {}".format(connection_calls))
                        #print("connection call accuracy: {}".format(connection_call_accuracy))
                        #print("\n")
                

                
                        pr = preds.clone().cpu().detach().numpy()
                        lb = label.clone().cpu().detach().numpy()  
                        tn, fp, fn, tp = confusion_matrix(lb, pr).ravel()

                        specificity = tn / (tn + fp)
                        precision = tp / (tp + fp)
                        recall = tp / (tp + fn)
                        f1 = (2*tp /( 2*tp + fp + fn))
                        #print("precision: {}".format(precision))
                        print("recall: {}".format(recall))
                        print("specificity: {}".format(specificity))
                        print("f1: {}".format(f1))
                    except:
                        print("whoopsie")
                    print("loss: {}".format(loss.item()))
                    print("\n")
                
                running_correct += torch.sum(preds == label.view(-1).long()).item()
                #print("running correct: {}".format(running_correct))

                running_total += num_seqs
                #print("running total: {}".format(running_total))
                running_acc = running_loss / running_total
                #print("running_acc: {}".format(running_acc))
                
                
                #print("\n")
            epoch_acc = float(running_correct/running_total)

            
            tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
            fpr, tpr, thresholds = roc_curve(all_labels, all_predictions, pos_label=1)
            epoch_AP = average_precision_score(all_labels, all_predictions, pos_label=1)
            epoch_AUROC = auc(fpr, tpr)
            
            
            
            
            epoch_specificity = tn / (tn + fp)
            epoch_precision = tp / (tp + fp)
            epoch_recall = tp / (tp + fn)
            epoch_f1 = (2*tp /( 2*tp + fp + fn))
            #print("precision: {}".format(precision))
            print("recall: {}".format(epoch_recall))
            print("specificity: {}".format(epoch_specificity))
            print("f1: {}".format(epoch_f1))
            print("AUROC: {}".format(epoch_AUROC))
            print("AP: {}".format(epoch_AP))

    
 

test_model(model, test_loader)

In [None]:
## SAVE MODEL

In [None]:
# torch.save(model.state_dict(), "/scratch/dmc421/dlproject/model_state.pt")

In [None]:
# torch.save(model, "/scratch/dmc421/dlproject/model.pt")

In [None]:
# list model parameters
model.parameters