In [None]:
import numpy as np
#import pandas as pd
#import matplotlib.pyplot as plt
import requests

import os

from astropy.units import one
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score, cohen_kappa_score, \
                            precision_recall_curve, average_precision_score, roc_auc_score, roc_curve, auc, accuracy_score

import itertools
import random
import math

from IPython.display import clear_output

from tqdm import tqdm_notebook

import pandas as pd

import time
from datetime import datetime

from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
import matplotlib.ticker as ticker
#import matplotlib.pyplot as plt
import copy


In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.utils.data as data_utils
import torch.optim.lr_scheduler as lr_scheduler 


In [None]:
np.random.seed(0);
torch.manual_seed(0);

# Hyper-parameters

In [None]:
use_features = True

use_comp_features = True
use_bio_features = True

use_testset = True

do_horz_flip = False

do_cross_validation = False
do_testing_using_best_model = False
do_retraining = True
do_testing_only = False

use_padding_for_bucketing = False

class_weights = {0: 1., 1: 1.} 
class_weights = {0: 2.36/(2.36+1.), 1: 1./(2.36+1.)} 

num_classes=2

val_pct = 0.20


use_full_DS = True;

CLASS_BALANCE_RATIO = 1;

# Paths

In [None]:
project_dir = '../../../'

#ds_dir = f'{project_dir}data/seq/'
ds_dir = f'{project_dir}data/processed/'

snapshot_dir = f'{project_dir}snapshots/'

classes = ['lnc_e','lnc_p']
class_names = ['e-lncRNA','p-lncRNA']
class_suffix = '_noN'

train_val_test_dir = 'train_val_test__80_20'

# sequence files (train + val)
neg_filename = f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_tr_val.csv'
pos_filename = f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_tr_val.csv'

# sequence files (test)
test_neg_filename = f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_test.csv'
test_pos_filename = f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_test.csv'

# kmer feature files (train + val)
feature_filenames_class_0 = [f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_1mer_features_tr_val.csv',
                             f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_2mer_features_tr_val.csv',
                             f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_3mer_features_tr_val.csv']
feature_filenames_class_1 = [f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_1mer_features_tr_val.csv',
                             f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_2mer_features_tr_val.csv',
                             f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_3mer_features_tr_val.csv']

# kmer feature files (test)
test_feature_filenames_class_0 = [f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_1mer_features_test.csv',
                                  f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_2mer_features_test.csv',
                                  f'{ds_dir}{train_val_test_dir}/{classes[0]}{class_suffix}_3mer_features_test.csv']
test_feature_filenames_class_1 = [f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_1mer_features_test.csv',
                                  f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_2mer_features_test.csv',
                                  f'{ds_dir}{train_val_test_dir}/{classes[1]}{class_suffix}_3mer_features_test.csv']


# biological featurefiles (train + val)
bio_feature_filenames_class_0 = [f'{ds_dir}{train_val_test_dir}/{classes[0]}_nosim.fa.matrix_tr_val.csv']
bio_feature_filenames_class_1 = [f'{ds_dir}{train_val_test_dir}/{classes[1]}_nosim.fa.matrix_tr_val.csv']

# biological featurefiles (test)
test_bio_feature_filenames_class_0 = [f'{ds_dir}{train_val_test_dir}/{classes[0]}_nosim.fa.matrix_test.csv']
test_bio_feature_filenames_class_1 = [f'{ds_dir}{train_val_test_dir}/{classes[1]}_nosim.fa.matrix_test.csv']



# Load sequence data

In [None]:
def get_data(filename):
    with open(filename,'r') as nf:
        examples = [seq.strip() for seq in nf.readlines()]
        num_examples = len(examples) # number of negative examples
    return num_examples, examples

# read sequence data
num_neg, neg_examples = get_data(neg_filename)
print("num_neg:", num_neg)
num_pos, pos_examples = get_data(pos_filename)
print("num_pos:", num_pos)

if do_horz_flip:
    num_neg *=2
    neg_examples.extend(neg_examples)    

    num_pos *=2
    pos_examples.extend(pos_examples)    
    print("After augmentation")
    print("num_neg:", num_neg)
    print("num_pos:", num_pos)
    
# combine negative and positive examples (sequences)
sequences_DS = neg_examples + pos_examples 
labels_DS = np.array([0]*num_neg + [1]*num_pos)   

print("# of sequences:", len(sequences_DS), "\n# of labels:", len(labels_DS), "\nlength of sequences:", len(sequences_DS[0]))

idx_neg =  np.arange(num_neg)
idx_pos =  num_neg + np.arange(num_pos) # needed shift as negative examples' idx in concatenated array: [neg, pos]

print("index sizes:",len(idx_neg), len(idx_pos))


if (use_testset):
    print()

    # read test sequence data
    test_num_neg, test_neg_examples = get_data(test_neg_filename)
    print("test_num_neg:", test_num_neg)
    test_num_pos, test_pos_examples = get_data(test_pos_filename)
    print("test_num_pos:", test_num_pos)

    if do_horz_flip:
        test_num_neg *=2
        test_neg_examples.extend(test_neg_examples)    

        test_num_pos *=2
        test_pos_examples.extend(test_pos_examples)    
        print("After augmentation")
        print("num_neg:", num_neg)
        print("num_pos:", num_pos)
    
    # combine negative and positive test examples (sequences)
    test_sequences_DS = test_neg_examples + test_pos_examples 
    test_labels_DS = np.array([0]*test_num_neg + [1]*test_num_pos)   

    print("# of test sequences:", len(test_sequences_DS), "\n# of test labels:", len(test_labels_DS), "\nlength of test sequences:", len(test_sequences_DS[0]) )

    test_idx_neg =  np.arange(test_num_neg)
    test_idx_pos =  test_num_neg + np.arange(test_num_pos) # needed shift as negative examples' idx in concatenated array: [neg, pos]
    print("test index size",len(test_idx_neg), len(test_idx_pos))

    
# with open(neg_filename,'r') as nf:
#     neg_examples = [seq.strip() for seq in nf.readlines()]
#     num_neg = len(neg_examples) # number of negative examples
#     print(num_neg)

# with open(pos_filename,'r') as pf:
#     pos_examples = [seq.strip() for seq in pf.readlines()]
#     num_pos = len(pos_examples) # number of positive examples
#     print(num_pos)    

# sequences_DS = neg_examples + pos_examples 
# labels_DS = np.array([0]*num_neg + [1]*num_pos)   
# len(sequences_DS), len(labels_DS)



# Load features (if used)

In [None]:
if(use_features):
    
    if use_comp_features:
        # kmers (train + val)
        features_class0 = np.concatenate([pd.read_csv(feature_filename, header=None, sep="\t", index_col=False).to_numpy() for feature_filename in feature_filenames_class_0], axis = 1)
        features_class1 = np.concatenate([pd.read_csv(feature_filename, header=None, sep="\t", index_col=False).to_numpy() for feature_filename in feature_filenames_class_1], axis = 1)
        
        if do_horz_flip:
            features_class0 = np.concatenate([features_class0, features_class0], axis = 0)    
            features_class1 = np.concatenate([features_class1, features_class1], axis = 0)    
            
        neg_features = features_class0
        pos_features = features_class1

        # combine negative and positive examples into a single array
        features_DS = np.concatenate([features_class0, features_class1], axis = 0)
    
        print("neg features shape:",neg_features.shape, "\npos features shape:",pos_features.shape, "\ncombined features shape:",features_DS.shape)
        print()

    if use_bio_features:  
        # bio featurers (train + val)
        bio_features_class0 = np.concatenate([pd.read_csv(bio_feature_filename, header=None, index_col=False).to_numpy() for bio_feature_filename in bio_feature_filenames_class_0], axis = 1)
        bio_features_class1 = np.concatenate([pd.read_csv(bio_feature_filename, header=None, index_col=False).to_numpy() for bio_feature_filename in bio_feature_filenames_class_1], axis = 1)

        if do_horz_flip:
            bio_features_class0 = np.concatenate([bio_features_class0, bio_features_class0], axis = 0)    
            bio_features_class1 = np.concatenate([bio_features_class1, bio_features_class1], axis = 0)   
        
        neg_bio_features = bio_features_class0
        pos_bio_features = bio_features_class1

        # combine negative and positive examples into a single array
        bio_features_DS = np.concatenate([bio_features_class0, bio_features_class1], axis = 0)  

        print("neg bio features shape:",neg_bio_features.shape, "\npos bio features shape:",pos_bio_features.shape, "\ncombined bio features shape:",bio_features_DS.shape)
        print()

        
    if (use_testset):

        if use_comp_features:
            # kmers (test)
            test_features_class0 = np.concatenate([pd.read_csv(test_feature_filename, header=None, sep="\t", index_col=False).to_numpy() for test_feature_filename in test_feature_filenames_class_0], axis = 1)
            test_features_class1 = np.concatenate([pd.read_csv(test_feature_filename, header=None, sep="\t", index_col=False).to_numpy() for test_feature_filename in test_feature_filenames_class_1], axis = 1)

            if do_horz_flip:
                test_features_class0 = np.concatenate([test_features_class0, test_features_class0], axis = 0)    
                test_features_class1 = np.concatenate([test_features_class1, test_features_class1], axis = 0)                 

            test_neg_features = test_features_class0
            test_pos_features = test_features_class1    

            # combine negative and positive examples into a single array    
            test_features_DS = np.concatenate([test_features_class0, test_features_class1], axis = 0)    
            print("test neg features shape:",test_neg_features.shape, "\ntest pos features shape:",test_pos_features.shape, "\ncombined test features shape:",test_features_DS.shape)
            print()
            
        if use_bio_features:  
            # bio featurers (test)
            test_bio_features_class0 = np.concatenate([pd.read_csv(bio_feature_filename, header=None, index_col=False).to_numpy() for bio_feature_filename in test_bio_feature_filenames_class_0], axis = 1)
            test_bio_features_class1 = np.concatenate([pd.read_csv(bio_feature_filename, header=None, index_col=False).to_numpy() for bio_feature_filename in test_bio_feature_filenames_class_1], axis = 1)

            if do_horz_flip:
                test_bio_features_class0 = np.concatenate([test_bio_features_class0, test_bio_features_class0], axis = 0)    
                test_bio_features_class1 = np.concatenate([test_bio_features_class1, test_bio_features_class1], axis = 0)                 
   

            test_neg_bio_features = test_bio_features_class0
            test_pos_bio_features = test_bio_features_class1

            # combine negative and positive examples into a single array    
            test_bio_features_DS = np.concatenate([test_bio_features_class0, test_bio_features_class1], axis = 0)    
    
            print("test neg bio features shape:",test_neg_bio_features.shape, "\ntest pos bio features shape:",test_pos_bio_features.shape, "\ntest combined bio features shape:",test_bio_features_DS.shape)
            print()

In [None]:
# labels_DS[:10],labels_DS[18487-10:18487+10],labels_DS[-10:],len(labels_DS) 

In [None]:
# def compute_kmer_counts(seq, k, do_sliding = True):
#     alphabet = list(set(list(seq)))
#     alphabet.sort()

#     kmers_list = []
#     kmers_list.append(alphabet)
    
#     for k_idx in range(k-1):
#         kmers_list.append([a+b for a in kmers_list[k_idx] for b in alphabet ])

#     seq_lst = list(seq)
#     subs = ["".join(seq_lst[idx:idx+k]) for idx in range(len(seq)-k+1)]
#     kmers = kmers_list[k-1]
#     kmers_counts = [subs.count(kmer) for kmer in kmers]
    
#     return kmers_counts

# results = []
# for idx in range(len(sequences_DS)):#np.random.randint(0, len(sequences_DS), size= (20,1) ):
#     #idx = idx.item()
#     seq = sequences_DS[idx]
#     kmer_counts1 = compute_kmer_counts(seq[:1000],3)
#     kmer_counts2 = compute_kmer_counts(seq[1000:],3)
#     #print(kmer_counts1, "||", kmer_counts2, "||",features_DS[idx,:][40:])
#     result = np.all(np.array(kmer_counts1+kmer_counts2) == np.array(features_DS[idx,40:].squeeze())) 
#     if(not result):
#         print(idx)
#         print(kmer_counts1)
#         print(kmer_counts2)
#         print(seq)
#         print(features_DS[idx,40:72].squeeze())
#         print(features_DS[idx,72:].squeeze())
#         print()
#     results.append(result)

In [None]:
# np.all(results)

## subsample dataset (if needed)

In [None]:
# shuffle the positive and negative indices
idx_neg_shuf = np.random.permutation(idx_neg)
idx_pos_shuf = np.random.permutation(idx_pos)

# cap class size ratio to be used to the class size ratio of the dataset
DS_class_balance_ratio = len(idx_neg_shuf)/len(idx_pos_shuf) 
DS_class_balance_ratio = 1/DS_class_balance_ratio if DS_class_balance_ratio < 1 else DS_class_balance_ratio


if DS_class_balance_ratio < CLASS_BALANCE_RATIO:
    effective_class_balance_ratio = DS_class_balance_ratio 
else:
    effective_class_balance_ratio = CLASS_BALANCE_RATIO

print("class balance ratio in dataset:", DS_class_balance_ratio)
print("specified class balance ratio:", CLASS_BALANCE_RATIO)
print("effective class balance ratio being used :", effective_class_balance_ratio)    
print()

if use_full_DS:
    merged_index = np.concatenate([idx_neg_shuf, idx_pos_shuf], axis = 0) ;
    
    print("no subsampling, size of neg :", len(idx_neg_shuf))
    print("no subsampling, size of pos :", len(idx_pos_shuf))
    print( 'no subsampling, length of merged: ' , len(merged_index) )
    print()
else:
    if(num_neg <= num_pos):
        idx_neg_selected =  idx_neg_shuf
        idx_pos_selected =  idx_pos_shuf[:int(effective_class_balance_ratio*num_neg)]
    else:
        idx_neg_selected =  idx_neg_shuf[:int(effective_class_balance_ratio*num_pos)]
        idx_pos_selected =  idx_pos_shuf
    
    # shift the pos's indices as pos examples come after negative example in the combines data structure
    #idx_pos_selected += len(neg_examples)  
    #idx_pos_selected += len(idx_neg_selected)  
        
    merged_index = np.concatenate([idx_neg_selected, idx_pos_selected], axis = 0)
    
    print("after subsampling, size of neg :", len(idx_pos_selected))
    print("after subsampling, size of pos :", len(idx_neg_selected))
    print( 'len of merged : ' , len(merged_index) )
    print()


# perform subsampling
sequences, labels = zip(*[ (sequences_DS[i], labels_DS[i]) for i in merged_index ])
labels = np.array(labels)

print('#seq before subsampling:' , len(sequences_DS) )
print('#labels before subsampling:' , len(labels_DS) )
print()
print('#seq after subsampling:' , len(sequences) )
print('#labels after subsampling:' , len(labels) )
print()


if (use_testset):
    # no subsampling for test set
    test_sequences, test_labels = test_sequences_DS, test_labels_DS
    test_labels = np.array(test_labels)

if(use_features):

    if use_comp_features:
        features = features_DS[merged_index,:]
        print('#features before subsampling:' , len(features_DS), '#features after subsampling:' , len(features))    
    else:
        features = None
        
    if use_bio_features:    
        bio_features = bio_features_DS[merged_index,:]
        print('#bio features before subsampling:' , len(bio_features_DS) , '#bio features after subsampling:' , len(bio_features) )    
    else:
        bio_features = None
        
    if (use_testset):
        # no subsampling for test set
        if use_comp_features: 
            test_features = test_features_DS 
        else: 
            test_features = None 
        
        if use_bio_features:  
            test_bio_features = test_bio_features_DS
        else:
            test_bio_features = None

In [None]:
# np.all(np.sort(sequences_DS, axis = 0)==np.sort(sequences, axis = 0))

### NOT ACCURATE ANYMORE: sanity check: checking whether the selected positive examples are truly from the original positive examples loaded from the disk

In [None]:
# tmp_neg = [sequences[i] for i in range(len(labels)) if labels[i] == 0]
# tmp_pos = [sequences[i] for i in range(len(labels)) if labels[i] == 1]
# print(len(tmp_pos), len(tmp_neg))

# print(np.all([s in neg_examples for s in tmp_neg]))
# print(np.all([s in pos_examples for s in tmp_pos]))

In [None]:
# identify the alphabet of the whole DS
unique_DNAs= set()
for seq in sequences_DS:
    s = set (seq)
    unique_DNAs = unique_DNAs | s

unique_DNAs = list(unique_DNAs)
unique_DNAs.sort()
num_DNAs = len(unique_DNAs)
print("All nucleotides: ")
print(unique_DNAs)
print("# unique nucleaotides:",num_DNAs)

In [None]:
# lengths of the sequences
sequence_lens = np.array([len(x) for x in sequences])

print("# of different sample lengths:",len(np.unique(sequence_lens)))
print("number of samples:",len(sequence_lens))
print("length of shortest sample: ",min(sequence_lens))
print("length of longest sample: ",max(sequence_lens))
print()

if (use_testset):
    # lengths of the sequences
    test_sequence_lens = np.array([len(x) for x in test_sequences])

    print("# of different sample lengths:",len(np.unique(sequence_lens)))
    print("number of samples:",len(sequence_lens))
    print("length of shortest sample: ",min(sequence_lens))
    print("length of longest sample: ",max(sequence_lens))

In [None]:
# plot the lengths of the sequences
plt.plot(list(sorted(sequence_lens)))
plt.show()

if (use_testset):
    # plot the lengths of the test sequences
    plt.plot(list(sorted(test_sequence_lens)))

In [None]:
# code borrowed from lncNet repository for integer coding and bucketing

class CharacterTable(object): #make encoding table
    '''
    Given a set of characters:
    + Encode them to integer coding representation
    + Decode the one hot integer representation to their character output
    + Decode a vector of probabilities to their character output
    #chars : 0 (padding ) + other characters
    '''
    def __init__(self, chars):
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
    
    def encode(self, l):
        X = np.zeros((len(l)),dtype=np.int32)
        for i, c in enumerate(l):
            X[i]= self.char_indices[c]
    
        return X

    def decode(self, X, calc_argmax=True):
        if calc_argmax:
            X = X.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in X)

def create_index_buckets(sequence_lens, bucket_high_lims):
    
    # bucket of data indices
    buckets = [ [] for _ in bucket_high_lims]

    for i,len_x in enumerate(sequence_lens):
        for b_id, bucket_len in enumerate(bucket_high_lims):
            if len_x <= bucket_len:
                # append the index of the data instead of the actual data
                buckets[b_id].append(i)  
                break

    return buckets

def seq2intcoding(seqs, ctable):
    
    X=[]
    for seq in seqs:
        enc = ctable.encode(seq)
        X.append(enc)
    
    X = np.array(X)
    
    return X

# Do index bucketing

In [None]:
bucket_high_lims = [500*i for i in range(1,201)]

if (use_testset):
    test_bucket_high_lims = [500*i for i in range(1,201)]

# do not pad sequences if option not enabled (for datasets with fixed-length sequences where the length does not match any bucket's high limit)
if not use_padding_for_bucketing:
    assert max(sequence_lens) == min(sequence_lens), "Sequence length varies; please enable padding for bucketing"
    assert max(test_sequence_lens) == min(test_sequence_lens), "test set's sequence length varies; please enable padding for bucketing"
    bucket_high_lims = [max(sequence_lens)]   # fixed length sequences, so there will be only one bucket
    test_bucket_high_lims = [max(test_sequence_lens)]
    
chars= "0" + "".join(unique_DNAs)
ctable = CharacterTable(chars)


# integer coded sequences
seq_int_coded = seq2intcoding(sequences, ctable)
# create buckets containing data indices
idxs_in_buckets = create_index_buckets(sequence_lens, bucket_high_lims)   # returning array of arrays containing indices of data to put in each bucket
# remove buckets with no elements
used_bucket_idxs, idxs_in_buckets  = zip(*[(i,idxs_in_buckets[i]) for i,b in enumerate(idxs_in_buckets) if b])
# only keep the high limits for the non-empty buckets
bucket_high_lims = np.array(bucket_high_lims).reshape(1,-1)[used_bucket_idxs] 


# for test set
if (use_testset):
    # integer coded sequences
    test_seq_int_coded = seq2intcoding(test_sequences, ctable)
    # create buckets containing data indices
    test_idxs_in_buckets = create_index_buckets(test_sequence_lens, test_bucket_high_lims)   # returning array of arrays containing indices of data to put in each bucket
    # remove buckets with no elements
    test_used_bucket_idxs, test_idxs_in_buckets  = zip(*[(i,test_idxs_in_buckets[i]) for i,b in enumerate(test_idxs_in_buckets) if b])
    # only keep the high limits for the non-empty buckets
    test_bucket_high_lims = np.array(test_bucket_high_lims).reshape(1,-1)[test_used_bucket_idxs] 


# Do sequence and feature (if being used) bucketing along with one-hot encoding of the sequences

In [None]:
def create_data_buckets(X, Y, idxs_in_buckets, bucket_high_lims, sequence_lens, features = None, bio_features = None):    
    
    """
    params: 
    X: integer coded sequences
    Y: labels
    idx_buckets: array-like containing list of indices in each bucket
    bucket_high_lims: maximum length of sequences in each bucket
    sequence_lens: lengths of sequences in X

    """

    X_buckets = []
    X_buckets_padded = []
    Y_buckets = []
    X_buckets_packed = []
    
    
    if features is None and bio_features is None:  
        use_features = False
    else:
        use_features = True
        
    sequence_lens = np.array(sequence_lens)
    
    for b_id, idxs_in_bucket in enumerate(idxs_in_buckets):
        
        seq_lens_bucket = sequence_lens[idxs_in_bucket]
        
        # ignore empty buckets
        if len(idxs_in_bucket)==0:
            continue

        num_elements_in_bucket = len(idxs_in_bucket)
        
        # indices of the data points in this bucket
        x_idxs = np.array(idxs_in_bucket)
        
        #all_idxs = np.concatenate([all_idxs, x_idxs], axis = 0)
        
        # indices to the bucket contents (which is a set of indices starting from 0 ending in num_elements_in_bucket-1 )
        # for future compatibility - for example, selecting a fraction of the elements from each bucket, in which case, select a subset of this list of indices
        idx_in_bucket = np.arange(num_elements_in_bucket)
        
        X_bucket = X[ x_idxs[ idx_in_bucket ] ]   
        Y_bucket = Y[ x_idxs[ idx_in_bucket ] ]
        
         
        X_bucket_T = [torch.Tensor(x) for x in X_bucket]
        #print(len(X_bucket_T))
        
        # in pytorch, in order to pad a list of sequences to a specific length, we need a (possibly dummy) sequence of that length
        X_bucket_T.append(torch.zeros((bucket_high_lims[b_id],)))
        X_bucket_padded = pad_sequence(X_bucket_T, batch_first = True, padding_value = 0.0 )
        X_bucket_padded = X_bucket_padded.numpy()

        #print(X_bucket_padded.shape)
        
        # delete the last dummy row that served as element with padding length 
        X_bucket_padded = np.delete(X_bucket_padded, -1, axis = 0)        

        #print(X_bucket_padded.shape)
        
        # one-hot encoding
        X_bucket_padded = (np.arange(X_bucket_padded.max()+1) == X_bucket_padded[:,:,None]).astype(dtype='float32') 
        X_bucket_padded = np.delete(X_bucket_padded,0, axis=-1)
        
        #X_bucket = [X_bucket_padded[idx][:seq_lens_bucket[idx],:] for idx in range(X_bucket_padded.shape[0]) ]
        
        #X_buckets.append(X_bucket)
        if(use_features):
            
            if not features is None:
                features_expanded = np.expand_dims(features[x_idxs[idx_in_bucket],:],-1)  # add a last dimension - feature dimension (length =1)
            else:
                features_expanded = np.empty((0,0))
              
            if not bio_features is None:
                bio_features_expanded = np.expand_dims(bio_features[x_idxs[idx_in_bucket],:],-1)
            else:
                bio_features_expanded = np.empty((0,0))
            
            X_bucket_padded = ( X_bucket_padded, features_expanded, bio_features_expanded )  
            
            print(X_bucket_padded[0].shape, X_bucket_padded[1].shape, X_bucket_padded[2].shape)
        
        else:
            X_bucket_padded = ( X_bucket_padded, )
            
        X_buckets_padded.append(X_bucket_padded)        
        Y_buckets.append(Y_bucket)
    
        #print(torch.Tensor(X_bucket_padded).shape)
        #sorted_lens = list(sorted(sequence_lens[idx_buckets[b_id]]))
        #print(sorted_lens[:10])

        # returns a packed sequences with #batches equal to the size of the maximum seq length, where a batch size
        # corresponds to how many sequences have a valid input at that time step
        #packed_input = pack_padded_sequence(torch.Tensor(X_bucket_padded), seq_lens_buckets, batch_first=True, enforce_sorted=False)
        #X_buckets_packed.append(packed_input)
        
    return X_buckets_padded,Y_buckets#, X_buckets_packed

In [None]:
if(use_features):
    
    X_buckets, Y_buckets = create_data_buckets(seq_int_coded, labels, idxs_in_buckets, bucket_high_lims, sequence_lens, features, bio_features)
    
    if (use_testset):
        test_X_buckets, test_Y_buckets = create_data_buckets(test_seq_int_coded, test_labels, test_idxs_in_buckets, test_bucket_high_lims, test_sequence_lens, test_features, test_bio_features)
        
else:

    X_buckets, Y_buckets = create_data_buckets(seq_int_coded, labels, idxs_in_buckets, bucket_high_lims, sequence_lens)

    if (use_testset):
        test_X_buckets, test_Y_buckets = create_data_buckets(test_seq_int_coded, test_labels, test_idxs_in_buckets, test_bucket_high_lims, test_sequence_lens)
    
    

In [None]:
# for b_id in range(len(X_buckets)):
#     idxs_bucket = idx_buckets[b_id]
#     for rand_idx in np.random.randint(len(X_buckets[b_id]), size = 10):
#         print(X_buckets[b_id][rand_idx].shape, sequence_lens[idxs_bucket[rand_idx]])
        
#     print()

In [None]:
# len(X_buckets), len(idxs_in_buckets), torch.cuda.is_available()

In [None]:
# bucket_content = X_buckets[0]
# seq_elements = bucket_content[0]
# seq_element = seq_elements[0,:]

# len(bucket_content), seq_element.shape

## sanity check 
[converting one-hot encoded sequences to DNA seq and checking against the corresponding original from the dataset]


In [None]:
# # sanity check [converting one-hot encoded sequences to DNA seq and checking against the corresponding original from the dataset]

# def compute_kmer_counts(seq, k, alphabet, do_sliding = True):

#     kmers_list = []
#     kmers_list.append(alphabet)
    
#     for k_idx in range(k-1):
#         kmers_list.append([a+b for a in kmers_list[k_idx] for b in alphabet ])

#     seq_lst = list(seq)
#     subs = ["".join(seq_lst[idx:idx+k]) for idx in range(len(seq)-k+1)]
#     kmers = kmers_list[k-1]
#     kmers_counts = [subs.count(kmer) for kmer in kmers]
    
#     return kmers_counts

# data_check_arr = []
# feature_check_arr = []
# incorrect_feature_match_indices = []
# incorrect_feature_match_seqs = []

# for b_idx in range(len(X_buckets)):
#     x_bucket_padded = X_buckets[b_idx]

#     if(use_features):
#         features_bucket = x_bucket_padded[1]
#         x_bucket_padded = x_bucket_padded[0]
        
#     y_bucket = Y_buckets[b_idx]
#     idx_bucket = idxs_in_buckets[b_idx]
    
#     for seq_idx in range(x_bucket_padded.shape[0]):
#         seq_oh = x_bucket_padded[seq_idx,:,:]
        
#         seq_str = "".join([unique_DNAs[dna_idx] for dna_idx in np.argmax(seq_oh, axis = -1)]) # convert from one-hot to DNA
#         seq_str = seq_str[:sequence_lens[idx_bucket[seq_idx]]]  # get rid of the padding
        
#         seq = sequences[idx_bucket[seq_idx]]  # original seq retrieved according to saved indices of the one-hot data
        
#         if(use_features):
#             idx_range = [[0,11],[12,59],[60,251]]
#             feature_check_acc = []
#             #seq = seq[-500:] + seq[:-500]
#             for k in range(1,4):
#                 kmer_counts1 = compute_kmer_counts(seq_str[:200], k, unique_DNAs)    
#                 kmer_counts2 = compute_kmer_counts(seq_str[200:400], k, unique_DNAs)    
#                 kmer_counts3 = compute_kmer_counts(seq_str[400:], k, unique_DNAs)    
            
#                 features_computed= np.array(kmer_counts1+kmer_counts2+kmer_counts3)
#                 features_from_file = features_bucket[seq_idx,idx_range[k-1][0]:idx_range[k-1][1]+1].squeeze()
#                 feature_check = np.all( features_computed == features_from_file)
#                 #print(feature_check)
#                 #print(features_computed)
#                 #print(features_from_file)
#                 #print()
#                 #print(features_bucket[seq_idx,idx_range[k-1][0]:idx_range[k-1][1]+1].squeeze())
#                 #print()
#                 feature_check_acc.append(feature_check)
#             #print()    
#             feature_check_arr.append(np.all(feature_check_acc))
            
#             if not np.all(feature_check_acc):
#                 #print(seq_idx)
#                 incorrect_feature_match_indices.append(seq_idx)
#                 incorrect_feature_match_seqs.append(seq)
#                 for k in range(1,4):
#                     kmer_counts1 = compute_kmer_counts(seq_str[:200], k)    
#                     kmer_counts2 = compute_kmer_counts(seq_str[200:400], k)    
#                     kmer_counts3 = compute_kmer_counts(seq_str[400:], k)    

#                     features_computed= np.array(kmer_counts1+kmer_counts2+kmer_counts3)
#                     features_from_file = features_bucket[seq_idx,idx_range[k-1][0]:idx_range[k-1][1]+1].squeeze()
#                     feature_check = np.all( features_computed == features_from_file)
#                     feature_check_acc.append(feature_check)
#                     print(len(kmer_counts1), len(kmer_counts2), len(kmer_counts3),features_computed.shape, features_from_file.shape)
#                 print(seq)
#                 print(seq_oh.shape, len(seq))
#                 print()
#         data_check_arr.append(seq == seq_str)
        
        
# print("One-hot encoding and sequence bucketing is correct:", np.all(np.array(data_check_arr)))

# if(use_features):
#     print("feature computation and bucketing is correct:", np.all(np.array(feature_check_arr)))

In [None]:
# s = "CATCACTATCATCATCATCATCACACCGCCACCATCACCACCACCACCATCACACCACCACCACCACCACCACCGTCACCATCACTATCATCATCATCACACCACCACCACCACCACCATCATCACTATCATCATCATCATCACCACACCACCACCATCACCACCACCACCATCACACCACCACCACCACTGTCACCACCACCACCACCACCATCACACCACCACCACCGCCACCGTCGTCATCACTATCATCATCATCATCACACCACCACCATCACCACCACCATCACACCACCACCACCACTACTGTCATCACTATCATCATCATCACACCACCACCATCACCATCATCACACCACCACCACCATCACCATCACCATCATCACACCACCACCACCATCACACCACCACCACCATCATTATCATCACCATCATCACACCACCACCACCACCATCACCATCACCATCATCACACCACCACCACCATTACCATCACCATCATCACACCACCACCACCACCATTACCATCACCATCATCACACCACCACCATCACCATCATCACACCACCACCACCATCACCATCATCACACCACCACCAT"
# kcount = compute_kmer_counts(s[400:], 1)

# print(s[400:])
# kcount, s[400:].count('G'), len(s[400:]), len(s),[s[400:].count(c) for c in "ACGT"]
# incorrects = [sequences[idxs_buckets[0][i]] for i in incorrect_feature_match_indices]
# [print(incorrect,"\n\n") for incorrect in incorrects]

In [None]:
# len(incorrect_feature_match_indices), incorrect_feature_match_seqs

In [None]:
## Verify that in a packed seq, batch_size equals the number of sequences having a valid input at (index+1) of the index batch_size in batch_sizes

# sequence_lens_np = np.array(sequence_lens)
# batch_size_idx = random.choice(range(len(X_buckets_packed[0].batch_sizes)))
# batch_size = X_buckets_packed[0].batch_sizes[batch_size_idx]
# batch_size.numpy() == np.sum(sequence_lens_np[idx_buckets[0]]>=(batch_size_idx+1))

# Metrics

In [None]:
class Metrics:
    
    def __init__(self, benchmark_acc = None, benchmark_f1 = None, benchmark_spec = None, benchmark_sen = None, benchmark_auc = None, benchmark_mcc = None):
        self.b_accuracy = benchmark_acc
        self.b_f1 = benchmark_f1
        self.b_sensitivity = benchmark_sen
        self.b_specificity = benchmark_spec
        self.b_auc_roc = benchmark_auc
        self.b_mcc = benchmark_mcc
        
    def reset_history(self):
        self.accuracies = []
        self.f1s = []
        self.recalls = []
        self.precisions = []
        self.sensitivity = []
        self.specificity = []
        self.auc_roc = []        
        self.tp_tn_fp_fn = []
        self.mcc = []

    @classmethod
    def compute_mcc(cls, y_true, y_pred):
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        mcc = (tp*tn - fp*fn) / (np.sqrt(  (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)  ) + 1e-8) 
        return mcc
    
    @classmethod
    def compute_sensitivity(cls, y_true, y_pred):
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp/(tp+fn)
        return sensitivity
    
    @classmethod
    def compute_specificity(cls, y_true, y_pred):
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        specificity = tn/(fp+tn)
        return specificity
    
    @classmethod
    def compute_accuracy(cls, y_true, y_pred):
        accuracy = (y_true==y_pred).sum()/len(y_true)
        return accuracy
    
    def compute_metrics(self, y_true, y_pred, epoch, do_print = False, store_vals = False):
        accuracy = (y_true==y_pred).sum()/len(y_true)
    
        cm = confusion_matrix(y_true, y_pred)
        #cm = cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis]
        tn, fp, fn, tp = cm.ravel()

        specificity = tn/(fp+tn)
        sensitivity = tp/(tp+fn)
        f1 = f1_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        auc_roc = roc_auc_score( y_true, y_pred )  
        mcc = (tp*tn - fp*fn) / (np.sqrt(  (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)  ) + 1e-8) 
        
        results = {'accuracy': accuracy, 'specificity':specificity, 'sensitivity':sensitivity, 'f1_score':f1, 'auc_roc':auc_roc, 'recall':recall, 
                   'precision': precision, 'tp_fp_tn_fn':{'tp':tp, 'fp':fp, 'tn':tn, 'fn':fn}, 'mcc': mcc }
        
        if(store_vals):
            self.accuracies.append(accuracy)
            self.specificity.append(specificity)
            self.sensitivity.append(sensitivity)
            self.f1s.append(f1)
            self.auc_roc.append(auc_roc)        
            self.recalls.append(recall)
            self.precisions.append(precision)
            self.tp_tn_fp_fn.append((tp, tn, fp, fn)) 
            self.mcc.append(mcc) 
            
        if(do_print):
            is_beaten_acc = False
            is_beaten_f1 = False
            is_beaten_sens = False
            is_beaten_spec = False
            is_beaten_auc = False
            is_beaten_mcc = False
            
            if( not self.b_accuracy is None and accuracy > self.b_accuracy ):
                is_beaten_acc = True
            
            if( not self.b_f1 is None and f1 > self.b_f1 ):
                is_beaten_f1 = True
            
            if( not self.b_sensitivity is None and sensitivity > self.b_sensitivity ):
                is_beaten_sens = True

            if( not self.b_specificity is None and specificity > self.b_specificity ):
                is_beaten_spec = True

            if( not self.b_auc_roc is None and auc_roc > self.b_auc_roc ):
                is_beaten_auc = True

            if( not self.b_mcc is None and mcc > self.b_mcc ):
                is_beaten_mcc = True
                
            print(f'_________________________________________ METRICS for epoch {epoch} _______________________________________________________')
            print('accuracy '.ljust(16),f':{accuracy:.5f} ',f' bmark:({self.b_accuracy:.4f})',f' accuracy beaten:{is_beaten_acc}')
            #print('f1_score '.ljust(16),f':{f1:.5f} ',f' bmark:({self.b_f1:.4f})',f' f1 beaten:{is_beaten_f1}')
            print('sensitivity '.ljust(16),f':{sensitivity:.5f} ',f' bmark:({self.b_sensitivity:.4f})',f' sen.beaten:{is_beaten_sens}')
            print('specificity '.ljust(16),f':{specificity:.5f} ',f' bmark:({self.b_specificity:.4f})',f' spe.beaten:{is_beaten_spec}')
            #print('auc_roc '.ljust(16),f':{auc_roc:.5f} ',f' bmark:({self.b_auc_roc:.4f})',f' spe.beaten:{is_beaten_auc}')            
            print('mcc '.ljust(16),f':{mcc:.5f} ',f' bmark:({self.b_mcc:.4f})',f' mcc.beaten:{is_beaten_mcc}')
            
            #print('precision '.ljust(16),f':{precision} ')
            #print('recall '.ljust(16),f':{recall}')
            print( '---------------------') 
            print(f'| tp:{tp} '.ljust(9),f'| fp:{fp}'.ljust(9),'|')
            print(f'| fn:{fn} '.ljust(9),f'| tn:{tn}'.ljust(9),'|')
            print('---------------------') 

            #if(is_beaten_acc and is_beaten_f1 and is_beaten_sens and is_beaten_spec and is_beaten_auc and is_beaten_mcc):
            if(is_beaten_acc and is_beaten_sens and is_beaten_spec and is_beaten_mcc):
                print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
                print(f'##############################  BEATEN ALL at epoch:{epoch} ##############################################################')
                print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
        
        
        return results
    
    

# Learning rate schedular

In [None]:
# https://discuss.pytorch.org/t/how-to-implement-torch-optim-lr-scheduler-cosineannealinglr/28797

class OneCycleCosineAnnealing:
    
    def __init__(self, optim, steps_per_cycle, cycles_per_round, max_lr = 1e-3, start_lr = None, end_lr = None, half_cycle_len_pct = .3, decay_rate = 1.0):
        init_div_factor, last_div_factor = 25.0, 1e4
        self.optim = optim
        self.max_lr = max_lr
        self.decay_rate = decay_rate
        # one round is during a span for which the max lr is constant. It is decayed at the start of every round
        self.cycles_per_round = int(cycles_per_round) 
        self.start_lr = self.max_lr/init_div_factor if start_lr is None else start_lr
        self.end_lr = self.max_lr/(init_div_factor*last_div_factor) if end_lr is None else end_lr
        for param_group in self.optim.param_groups:
            param_group['lr'] = self.start_lr
        self.steps_per_cycle = int(steps_per_cycle)
        self.half_cycle_len_pct = half_cycle_len_pct
        self.peak_step_idx = int(self.steps_per_cycle*self.half_cycle_len_pct)
        print("self.start_lr",self.start_lr)
        print("self.max_lr",self.max_lr)
        print("self.end_lr",self.end_lr)
        
        self.coslr1 = lr_scheduler.CosineAnnealingLR(self.optim, T_max = self.peak_step_idx, eta_min = self.max_lr)        
        self.coslr2 = lr_scheduler.CosineAnnealingLR(self.optim, T_max = self.steps_per_cycle - self.peak_step_idx, eta_min = self.end_lr)
        self.step_idx = 0
        self.cycle_idx = 0
        self.lrs = []
        self.lrs.append(self.optim.param_groups[0]['lr'])
        
    def step(self):
        self.step_idx+=1
        if(self.step_idx < self.peak_step_idx):
            self.coslr1.step()  
        elif(self.step_idx <= self.steps_per_cycle):
            self.coslr2.step()
        else:
            self.cycle_idx +=1
            if(self.cycle_idx%self.cycles_per_round==0):
                self.max_lr*=self.decay_rate
            self.coslr1 = lr_scheduler.CosineAnnealingLR(self.optim, T_max = self.peak_step_idx, eta_min = self.max_lr)  
            self.coslr2 = lr_scheduler.CosineAnnealingLR(self.optim, T_max = self.steps_per_cycle - self.peak_step_idx, eta_min = self.end_lr)
            self.coslr1.step()           
            self.step_idx = 1
            
        self.lrs.append(self.optim.param_groups[0]['lr'])
        
            
model = torch.nn.Linear(1, 1)

max_lr = 1e-3
init_lr = max_lr/25.0
last_lr = max_lr/25e4


num_epochs = 50

optimizer = torch.optim.SGD(model.parameters(), lr = init_lr)

for param_group in optimizer.param_groups:
        param_group['lr'] = init_lr

scheduler = OneCycleCosineAnnealing(optimizer, num_epochs, 5, max_lr = 1e-3, decay_rate = .95)

lrs = []

lrs.append(optimizer.param_groups[0]['lr'])    

for _ in range(10*num_epochs):
    #print("-", l.last_epoch, optim.param_groups[0]['lr'])
    scheduler.step()    
    lrs.append(optimizer.param_groups[0]['lr'])
print("min:",min(lrs))        
plt.plot(lrs)
plt.show()

# Network definitions

In [None]:
# class ListDataset(torch.utils.data.Dataset):
    
#     def __init__(self, data, label) -> None:
#         super().__init__()
#         self.data = data
#         self.label = label
        
#         assert len(self.data) == len(self.label)
    
#     def __getitem__(self, index):
#         return (self.data[index], self.label[index])
    
#     def __len__(self):
#         return len(self.data)
    
class LSTM_CNN_Net(nn.Module):
    
    def __init__(self, num_features, seq_info):
        super(LSTM_CNN_Net, self).__init__()
        
        hidden_sz, seq_len = seq_info
        
        self.num_features = num_features
        self.hidden_sz = hidden_sz
        self.seq_len = seq_len
        self.num_layers = 2
        self.is_bidirectional = True
        
        #Apply LSTM on the input first, then apply 2-D convolution the final layer's output.        
        self.lstm = nn.LSTM(input_size = self.num_features, 
                            hidden_size = self.hidden_sz, 
                            batch_first = True, bidirectional = self.is_bidirectional, num_layers = self.num_layers)
        
        self.conv2D_1 = nn.Conv2d(in_channels = 1, out_channels= 32, kernel_size= 3)
        self.batchnorm2D_1 = nn.BatchNorm2d(num_features = self.conv2D_1.out_channels)
        
        self.maxpool2D_1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2D_2 = nn.Conv2d(in_channels = 32, out_channels= 64, kernel_size= 5)
        self.batchnorm2D_2 = nn.BatchNorm2d(num_features = self.conv2D_2.out_channels)
        
        #self.linear1_allstatesout = nn.Linear(in_features = self.seq_len*self.num_lstm_directions*self.hidden_sz, out_features= self.seq_len*self.hidden_sz//4)        
        self.linear1 = nn.Linear(in_features = 5*147*64, out_features= 48)
        self.linear2 = nn.Linear(in_features = 48, out_features= 2)
        
        self.dropout1 = nn.Dropout(0.8)
        self.dropout2 = nn.Dropout(0.5)
        
#         for name, param in self.lstm.named_parameters():
#             if 'bias' in name:
#                 nn.init.constant_(param, 0.0)
#             elif 'weight' in name:
#                 nn.init.xavier_normal_(param)
        
    def forward(self, x, seq_lengths):
#         x_packed = pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
#         packed_out, (hn, cn) = self.lstm(x_packed)        
#         out, input_sizes = pad_packed_sequence(packed_out, batch_first=True)


        #output: output of shape (batch, seq_len, num_directions * hidden_size) are output features (h_t) from the last layer of the LSTM, for each time step 
        #                        [if batch_first was True]
        #hn = output of last time step sz: (num_direction*num_layers, batch_sz, hidden_sz)
        #cn = cell state of last time step sz: (num_direction*num_layers, batch_sz, hidden_sz)
        #input is: (batch_sz, sequence_length, num_feature_in_each_time_step) since batch_first was True in LSTM initialization
        out, (hn, cn) = self.lstm(x)

        out = torch.unsqueeze(out, dim = 1)
        out = self.conv2D_1(out)
        out = self.batchnorm2D_1(out)
        out = F.relu(out)
        
        out = self.dropout2(out)
        out = self.maxpool2D_1(out)
        
        out = self.conv2D_2(out)
        out = self.batchnorm2D_2(out)
        out = F.relu(out)
        out = self.dropout2(out)
        out = self.maxpool2D_1(out)
        
        #print(out.shape)
        out = self.linear1(out.contiguous().view(-1, 5*147*64))
        out = F.relu(out)
        out = self.dropout1(out)
        
        out = self.linear2(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        #l2allstates = self.batchnrm(l2allstates)
        #net_out = self.linear2(l2allstates.view(-1, 2*self.hidden_sz))
        #net_out = self.linear2(last_state_cmbnd.view(-1, 2*self.hidden_sz))
        
        return out

    def getname(self):
        return "LSTM_CNN_Net"
    
    
# The network
class RNA_Net(nn.Module):
    
    def __init__(self, num_features, hidden_sz):
        super(RNA_Net, self).__init__()
        
        self.num_features = num_features
        self.hidden_sz = hidden_sz
        self.lstm = nn.LSTM(input_size = num_features, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        self.batchnrm = nn.BatchNorm1d(hidden_sz)
        #self.linear = nn.Linear(in_features = hidden_sz, out_features= 2)
        self.linear2 = nn.Linear(in_features = 2*hidden_sz, out_features= 2)
        
#         for name, param in self.lstm.named_parameters():
#             if 'bias' in name:
#                 nn.init.constant_(param, 0.0)
#             elif 'weight' in name:
#                 nn.init.xavier_normal_(param)
        
    def forward(self, x, seq_lengths):
#         x_packed = pack_padded_sequence(x, seq_lengths, batch_first=True, enforce_sorted=False)
#         packed_out, (hn, cn) = self.lstm(x_packed)        
#         out, input_sizes = pad_packed_sequence(packed_out, batch_first=True)

        out, (hn, cn) = self.lstm(x)        
        l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        #l2allstates = self.batchnrm(l2allstates)
        net_out = self.linear2(l2allstates.view(-1, 2*self.hidden_sz))
        #net_out = self.linear2(last_state_cmbnd.view(-1, 2*self.hidden_sz))
        
        return net_out

    
    
class CNPPNet(nn.Module):
    
    def __init__(self, num_features, hidden_sz):
        super(CNPPNet, self).__init__()
        #super().__init__()
        
        self.hidden_sz = hidden_sz
        # inception layer 1
        self.conv11 = nn.Conv1d(num_features,64, kernel_size= 3, padding=1)
        self.conv12 = nn.Conv1d(num_features,64, kernel_size= 5, padding=2)
        self.conv13 = nn.Conv1d(num_features,32, kernel_size= 7, padding=3)
        
        # inception layer 2
        self.conv21 = nn.Conv1d(160,64, kernel_size= 3, padding=1)
        self.conv22 = nn.Conv1d(160,64, kernel_size= 5, padding=2)
        self.conv23 = nn.Conv1d(160,32, kernel_size= 7, padding=3)

        
        # inception layer 3
        self.conv31 = nn.Conv1d(160,64, kernel_size= 3, padding=1)
        self.conv32 = nn.Conv1d(160,64, kernel_size= 5, padding=2)
        self.conv33 = nn.Conv1d(160,32, kernel_size= 7, padding=3)
        
        # inception layer 3
        self.conv41 = nn.Conv1d(160,64, kernel_size= 3, padding=1)
        self.conv42 = nn.Conv1d(160,64, kernel_size= 5, padding=2)
        self.conv43 = nn.Conv1d(160,32, kernel_size= 7, padding=3)
        
        self.maxpool1 = torch.nn.MaxPool1d(kernel_size=2)
        self.maxpool2 = torch.nn.MaxPool1d(kernel_size=4)

        self.conv3 = nn.Conv1d(160,32, kernel_size= 7)
        #self.lstm = nn.LSTM(input_size = 64, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        self.linear1 = nn.Linear(in_features = 32*15, out_features= 16)
        self.linear2 = nn.Linear(in_features = 16, out_features= 2)
 
        #self.dropout1 = nn.Dropout(0.4)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.6)
        self.dropout3 = nn.Dropout(0.7)

        
    def forward(self,x, features):
        x = x.permute(0,2,1)
        
        #print(x.shape)
        out1 = self.conv11(x)
        out1 = F.relu(out1)
        out2 = self.conv12(x)
        out2 = F.relu(out2)
        out3 = self.conv13(x)
        out3 = F.relu(out3)
        out = torch.cat((out1,out2,out3),1)
        #out = F.relu(out)
        out = self.dropout1(out)
        out = self.maxpool1(out)
        
        out1 = self.conv21(out)
        out1 = F.relu(out1)
        out2 = self.conv22(out)
        out2 = F.relu(out2)
        out3 = self.conv23(out)
        out3 = F.relu(out3)
        out = torch.cat((out1,out2,out3),1)
        #out = F.relu(out)
        out = self.dropout1(out)
        out = self.maxpool1(out)

        out1 = self.conv31(out)
        out1 = F.relu(out1)
        out2 = self.conv32(out)
        out2 = F.relu(out2)
        out3 = self.conv33(out)
        out3 = F.relu(out3)
        out = torch.cat((out1,out2,out3),1)
        #out = F.relu(out)
        out = self.dropout1(out)
        out = self.maxpool1(out)

        
        out1 = self.conv41(out)
        out1 = F.relu(out1)
        out2 = self.conv42(out)
        out2 = F.relu(out2)
        out3 = self.conv43(out)
        out3 = F.relu(out3)
        out = torch.cat((out1,out2,out3),1)
        #out = F.relu(out)
        out = self.dropout1(out)
        out = self.maxpool1(out)
        
        
        
        out = self.conv3(out)
        out = self.dropout1(out)
        out = self.maxpool1(out)
        #out = out.permute(0,2,1)
        
        #out, (hn, cn)  = self.lstm(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        
        #l2allstates = self.batchnrm(l2allstates)
        #print(out.shape)
        
        out = self.linear1(out.view(-1, 32*15))
        out = self.dropout2(out)
        out = self.linear2(out)

        return out

    def getname(self):
        return "CNPPNet"
    
    
    
class CNPPNet_Hybrid(nn.Module):
    
    def __init__(self, num_features, hidden_sz):
        super(CNPPNet_Hybrid, self).__init__()
        #super().__init__()
        
        self.hidden_sz = hidden_sz
        # inception layer 1
        self.conv11 = nn.Conv1d(num_features,64, kernel_size= 3, padding=1)
        self.conv12 = nn.Conv1d(num_features,64, kernel_size= 5, padding=2)
        self.conv13 = nn.Conv1d(num_features,32, kernel_size= 7, padding=3)
        
        # inception layer 2
        self.conv21 = nn.Conv1d(160,64, kernel_size= 3, padding=1)
        self.conv22 = nn.Conv1d(160,64, kernel_size= 5, padding=2)
        self.conv23 = nn.Conv1d(160,32, kernel_size= 7, padding=3)
        
        self.maxpool1d = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv3 = nn.Conv1d(160,32, kernel_size= 7)
        #self.lstm = nn.LSTM(input_size = 64, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        #self.linear1 = nn.Linear(in_features = 32*29 , out_features= 16)
        
        self.linear1 = nn.Linear(in_features = 32*29 + 64, out_features= 16)
        self.linear2 = nn.Linear(in_features = 16, out_features= 2)
 
        self.feature_linear1 = nn.Linear(in_features = 168, out_features= 128)
        self.feature_linear2 = nn.Linear(in_features = 128, out_features= 64)

        
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.6)
        self.dropout3 = nn.Dropout(0.7)
        
#         layers = [nn.BatchNorm1d(n_in)]
#         layers.append(nn.Dropout(p))
#         layers.append(nn.Linear(n_in, n_out))
#         layers.append(actn)
#         layers

        
    def forward(self,seq, features):
        x = seq.permute(0,2,1)
        features = features.permute(0,2,1)
#         print(x.type())
        
        out1 = self.conv11(x)
        out1 = F.relu(out1)
        out2 = self.conv12(x)
        out2 = F.relu(out2)
        out3 = self.conv13(x)
        out3 = F.relu(out3)
        out = torch.cat((out1,out2,out3),1)
        #out = F.relu(out)
        out = F.dropout(out, p = 0.4)
        out = self.maxpool1d(out)
        
        out1 = self.conv21(out)
        out1 = F.relu(out1)
        out2 = self.conv22(out)
        out2 = F.relu(out2)
        out3 = self.conv23(out)
        out3 = F.relu(out3)
        out = torch.cat((out1,out2,out3),1)
        #out = F.relu(out)
        out = F.dropout(out, p = 0.4)
        out = self.maxpool1d(out)
        out = self.conv3(out)
        out = F.dropout(out, p = 0.4)
        out = self.maxpool1d(out)
        #out = out.permute(0,2,1)
        
        #out, (hn, cn)  = self.lstm(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        
        #l2allstates = self.batchnrm(l2allstates)
        
        #print(features.shape)
        feature_out = self.feature_linear1(features.view(-1,168))
        feature_out = self.feature_linear2(feature_out)
        
        
        combined_out = torch.cat((out.view(-1, 32*29),feature_out),1)
        #combined_out = torch.cat(out.view(-1, 32*29),1)
        out = self.linear1(combined_out)
        
        #out = self.linear1(out.view(-1, 32*29))
        out = F.dropout(out, p = 0.5)
        out = self.linear2(out)

        return out    
    
    
class VGGNet(nn.Module):
    # https://peltarion.com/static/vgg_pa03.jpg
    
    def __init__(self, num_features, hidden_sz):
        super(VGGNet, self).__init__()
        #super().__init__()
        
        self.hidden_sz = hidden_sz
        self.conv1 = nn.Conv1d(in_channels = num_features, out_channels = 64, kernel_size= 3)        
        self.conv2 = nn.Conv1d(in_channels = self.conv1.out_channels, out_channels = 64, kernel_size= 3)
        self.maxpool1d_2 = torch.nn.MaxPool1d(kernel_size=2)
        
        self.conv3 = nn.Conv1d(in_channels = self.conv2.out_channels,out_channels = 128, kernel_size= 3)        
        self.conv4 = nn.Conv1d(in_channels = self.conv3.out_channels,out_channels = 128, kernel_size= 3)
        self.maxpool1d_2 = torch.nn.MaxPool1d(kernel_size=2)

        self.conv5 = nn.Conv1d(in_channels = self.conv4.out_channels,out_channels = 256, kernel_size= 3)        
        self.conv6 = nn.Conv1d(in_channels = self.conv5.out_channels,out_channels = 256, kernel_size= 3)
        self.conv7 = nn.Conv1d(in_channels = self.conv6.out_channels,out_channels = 256, kernel_size= 3)
        self.maxpool1d_4 = torch.nn.MaxPool1d(kernel_size=4)

        self.conv8 = nn.Conv1d(in_channels = self.conv7.out_channels,out_channels = 256, kernel_size= 3)        
        self.conv9 = nn.Conv1d(in_channels = self.conv8.out_channels,out_channels = 256, kernel_size= 3)
        self.conv10 = nn.Conv1d(in_channels = self.conv9.out_channels,out_channels = 256, kernel_size= 3)
        self.maxpool1d_4 = torch.nn.MaxPool1d(kernel_size=4)
        
        
        #self.lstm = nn.LSTM(input_size = 64, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        self.linear1 = nn.Linear(in_features = 256*7, out_features= 64)
        self.linear2 = nn.Linear(in_features = 64, out_features= 16)
        self.linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.dropout = nn.Dropout(p =0.5)
    
    def forward(self,x):
        x = x.permute(0,2,1)
        
        #print(x.shape)
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = F.relu(out)
        out = self.maxpool1d_2(out)

        out = self.conv3(out)
        out = F.relu(out)
        out = self.conv4(out)
        out = F.relu(out)
        out = self.maxpool1d_2(out)
        
        out = self.conv5(out)
        out = F.relu(out)
        out = self.conv6(out)
        out = F.relu(out)
        out = self.conv7(out)
        out = F.relu(out)
        out = self.maxpool1d_4(out)

        out = self.conv8(out)
        out = F.relu(out)
        out = self.conv9(out)
        out = F.relu(out)
        out = self.conv10(out)
        out = F.relu(out)
        out = self.maxpool1d_4(out)
        
        
        out = self.linear1(out.view(-1, 256*7))
        out = F.relu(out)
        out = self.dropout(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.linear3(out)

        #out = out.permute(0,2,1)
        
        #out, (hn, cn)  = self.lstm(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        
        #l2allstates = self.batchnrm(l2allstates)
        #print(out.shape)
        #print(out.shape)

        return out

    
    
    
class FantomNet_Bio_kmer_MLP(nn.Module):
    # https://peltarion.com/static/vgg_pa03.jpg
    
    def __init__(self, num_features, hidden_sz):
        super(FantomNet_Bio_kmer_MLP, self).__init__()
        #super().__init__()
        
        num_features_seq, num_features_bio = num_features
        
        self.hidden_sz = hidden_sz
        self.conv1 = nn.Conv1d(in_channels = num_features_seq, out_channels = 32, kernel_size= 7)        
        self.conv2 = nn.Conv1d(in_channels = self.conv1.out_channels, out_channels = 64, kernel_size= 5)
        self.maxpool1d_2 = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv3 = nn.Conv1d(in_channels = self.conv2.out_channels,out_channels = 128, kernel_size= 3)        
        self.maxpool1d_4 = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv4 = nn.Conv1d(in_channels = self.conv3.out_channels,out_channels = 256, kernel_size= 3)        

        
        #self.lstm = nn.LSTM(input_size = 64, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        self.linear1 = nn.Linear(in_features = 256*8, out_features= 64)
        self.linear2 = nn.Linear(in_features = 64, out_features= 16)
        self.linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.dropout = nn.Dropout(p =0.5)
    
        # Bio_MLP
        self.bio_linear1 = nn.Linear(in_features = num_features_bio, out_features= 64)
        self.bio_linear2 = nn.Linear(in_features = 64, out_features= 16)
        self.bio_linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.bio_dropout1 = nn.Dropout(p =0.7)
        self.bio_dropout2 = nn.Dropout(p =0.8)
    
        self.combined_linear1 = nn.Linear(self.bio_linear2.out_features + self.bio_linear2.out_features, 16)
        self.combined_linear2 = nn.Linear(16, 2)
        
    
    def forward(self, x, z):
        x = x.permute(0,2,1)
        z = z.permute(0,2,1)

        out_seq = self.conv1(x)
        out_seq = F.relu(out_seq)
        out_seq = self.conv2(out_seq)
        out_seq = F.relu(out_seq)
        out_seq = self.maxpool1d_2(out_seq)

        out_seq = self.conv3(out_seq)
        out_seq = self.maxpool1d_4(out_seq)        
        
        out_seq = self.conv4(out_seq)
        out_seq = self.maxpool1d_2(out_seq)        

        out_seq = self.linear1(out_seq.view(-1, 256*8))
        out_seq = F.relu(out_seq)
        out_seq = self.dropout(out_seq)
        out_seq = self.linear2(out_seq)
        out_seq = F.relu(out_seq)
        out_seq = self.dropout(out_seq)
        #out = self.linear3(out)

        out_bio = self.bio_linear1(z.contiguous().view(z.shape[0], -1))
        out_bio = F.relu(out_bio)
        out_bio = self.bio_dropout1(out_bio)
        out_bio = self.bio_linear2(out_bio)
        out_bio = F.relu(out_bio)
        out_bio = self.bio_dropout2(out_bio)
        #out_bio = self.linear3(out_bio)
        #print(out_bio.shape)
        #print(out_seq.shape)
        
        
        out = torch.cat([out_seq, out_bio], dim = 1)
        out = self.combined_linear1(out)
        out = F.relu(out)
        out = self.dropout(out)    
        out = self.combined_linear2(out)
        
        #print(out.shape)
        
        
        return out
    
    

    
class Bio_kmer_MLP(nn.Module):
    # https://peltarion.com/static/vgg_pa03.jpg
    
    def __init__(self, num_features, hidden_sz):
        super(Bio_kmer_MLP, self).__init__()
        #super().__init__()
        
        self.linear1 = nn.Linear(in_features = 544, out_features= 64)
        self.linear2 = nn.Linear(in_features = 64, out_features= 16)
        self.linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.dropout1 = nn.Dropout(p =0.7)
        self.dropout2 = nn.Dropout(p =0.8)
    
    def forward(self,x, z):
        
        z = z.permute(0,2,1)
        
        out = self.linear1(z.contiguous().view(z.shape[0], -1))
        out = F.relu(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.dropout2(out)
        out = self.linear3(out)
        
        #out = out.permute(0,2,1)
        
        #out, (hn, cn)  = self.lstm(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        
        #l2allstates = self.batchnrm(l2allstates)
        #print(out.shape)
        #print(out.shape)

        return out
    

class FantomNet(nn.Module):
    # https://peltarion.com/static/vgg_pa03.jpg
    
    def __init__(self, num_features, hidden_sz):
        super(FantomNet, self).__init__()
        #super().__init__()
        
        self.hidden_sz = hidden_sz
        self.conv1 = nn.Conv1d(in_channels = num_features, out_channels = 32, kernel_size= 7)        
        self.conv2 = nn.Conv1d(in_channels = self.conv1.out_channels, out_channels = 64, kernel_size= 5)
        self.maxpool1d_2 = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv3 = nn.Conv1d(in_channels = self.conv2.out_channels,out_channels = 128, kernel_size= 3)        
        self.maxpool1d_4 = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv4 = nn.Conv1d(in_channels = self.conv3.out_channels,out_channels = 256, kernel_size= 3)        

        #self.lstm = nn.LSTM(input_size = 64, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        self.linear1 = nn.Linear(in_features = 256*8, out_features= 64)
        self.linear2 = nn.Linear(in_features = 64, out_features= 16)
        self.linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.dropout1 = nn.Dropout(p =0.3)
        self.dropout2 = nn.Dropout(p =0.2)

        
    def forward(self, x, features):
        x = x.permute(0,2,1)
        
        #print(x.shape)
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = F.relu(out)
        out = self.maxpool1d_2(out)

        out = self.conv3(out)
        out = F.relu(out)
        out = self.maxpool1d_4(out)        
        
        out = self.conv4(out)
        out = F.relu(out)
        out = self.maxpool1d_2(out) 
        
        out = self.linear1(out.view(-1, 256*8))
        out = F.relu(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.dropout2(out)
        out = self.linear3(out)

        #out = out.permute(0,2,1)
        
        #out, (hn, cn)  = self.lstm(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        
        #l2allstates = self.batchnrm(l2allstates)
        #print(out.shape)
        #print(out.shape)

        return out
        
    def getname(self):
        return "FantomNet"    
    
class BioMLP(nn.Module):
    # https://peltarion.com/static/vgg_pa03.jpg
    
    def __init__(self, num_features, hidden_sz):
        super(BioMLP, self).__init__()
        #super().__init__()
        
#         self.linear1 = nn.Linear(in_features = 292, out_features= 64)
#         self.linear2 = nn.Linear(in_features = 64, out_features= 16)
#         self.linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.linear1 = nn.Linear(in_features = 292, out_features= 128)
        self.linear2 = nn.Linear(in_features = 128, out_features= 64)
        self.linear3 = nn.Linear(in_features = 64, out_features= 16)        
        self.linear4 = nn.Linear(in_features = 16, out_features= 2)        
     
    
#         self.dropout1 = nn.Dropout(p =0.7)
#         self.dropout2 = nn.Dropout(p =0.8)

        self.dropout1 = nn.Dropout(p =0.4)
        self.dropout2 = nn.Dropout(p =0.5)
        self.dropout3 = nn.Dropout(p =0.5)
        
    def forward(self,x, features):
        
        z = features[0]
        z = z.permute(0,2,1)
        
#         out = self.linear1(z.contiguous().view(z.shape[0], -1))
#         out = F.relu(out)
#         out = self.dropout1(out)
#         out = self.linear2(out)
#         out = F.relu(out)
#         out = self.dropout2(out)
#         out = self.linear3(out)

        out = self.linear1(z.contiguous().view(z.shape[0], -1))
        out = F.relu(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.dropout2(out)
        out = self.linear3(out)
        out = F.relu(out)
        out = self.dropout3(out)
        out = self.linear4(out)


        
        #out = out.permute(0,2,1)
        
        #out, (hn, cn)  = self.lstm(out)
        
        #l2allstates = torch.sqrt(torch.sum(out**2, dim = 1))
        #last_state_cmbnd = torch.cat([hn, cn], dim = -1)
        
        #l2allstates = self.batchnrm(l2allstates)
        #print(out.shape)
        #print(out.shape)

        return out    
    
    def getname(self):
        return "BioMLP"
    
class FantomNet_Bio_MLP(nn.Module):
    # https://peltarion.com/static/vgg_pa03.jpg
    
    def __init__(self, num_features, hidden_sz):
        super(FantomNet_Bio_MLP, self).__init__()
        #super().__init__()
        
        num_features_seq, num_features_bio = num_features
        
        self.hidden_sz = hidden_sz
        self.conv1 = nn.Conv1d(in_channels = num_features_seq, out_channels = 32, kernel_size= 7)        
        self.conv2 = nn.Conv1d(in_channels = self.conv1.out_channels, out_channels = 64, kernel_size= 5)
        self.maxpool1d_2 = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv3 = nn.Conv1d(in_channels = self.conv2.out_channels,out_channels = 128, kernel_size= 3)        
        self.maxpool1d_4 = torch.nn.MaxPool1d(kernel_size=4)
        
        self.conv4 = nn.Conv1d(in_channels = self.conv3.out_channels,out_channels = 256, kernel_size= 3)        

        #self.lstm = nn.LSTM(input_size = 64, hidden_size = hidden_sz, batch_first = True, bidirectional = True)
        self.linear1 = nn.Linear(in_features = 256*8, out_features= 64)
        self.linear2 = nn.Linear(in_features = 64, out_features= 16)
        self.linear3 = nn.Linear(in_features = 16, out_features= 2)        
 
        self.dropout1 = nn.Dropout(p =0.3)
        self.dropout2 = nn.Dropout(p =0.2)
        self.dropout3 = nn.Dropout(p =0.6)
        
        # Bio_MLP
#         self.bio_linear1 = nn.Linear(in_features = 292, out_features= 128)
#         self.bio_batchnorm_lin1 = nn.BatchNorm1d(self.bio_linear1.out_features)
#         self.bio_linear2 = nn.Linear(in_features = 128, out_features= 64)
#         self.bio_batchnorm_lin2 = nn.BatchNorm1d(self.bio_linear2.out_features)
#         self.bio_linear3 = nn.Linear(in_features = 64, out_features= 16)        
#         self.bio_batchnorm_lin3 = nn.BatchNorm1d(self.bio_linear3.out_features)
#         self.bio_linear4 = nn.Linear(in_features = 16, out_features= 2)        
     
    
#         self.dropout1 = nn.Dropout(p =0.7)
#         self.dropout2 = nn.Dropout(p =0.8)

        self.bio_linear1 = nn.Linear(in_features = num_features_bio, out_features= 12)

        #self.bio_dropout1 = nn.Dropout(p =0.4)
        self.bio_dropout2 = nn.Dropout(p =0.7)
        #self.bio_dropout3 = nn.Dropout(p =0.5)        
        #self.bio_dropout4 = nn.Dropout(p =0.75)
        #self.bio_linear1 = nn.Linear(in_features = num_features_bio, out_features= 64)
        #self.bio_batchnorm_lin1 = nn.BatchNorm1d(self.bio_linear1.out_features)
        #self.bio_linear1 = nn.Linear(in_features = num_features_bio, out_features= 64)
        #self.bio_batchnorm_lin1 = nn.BatchNorm1d(self.bio_linear1.out_features)
        #self.bio_linear2 = nn.Linear(in_features = 64, out_features= 16)
        #self.bio_batchnorm_lin2 = nn.BatchNorm1d(self.bio_linear1.out_features)
        
        #self.bio_dropout1 = nn.Dropout(p =0.6)
        #self.bio_dropout2 = nn.Dropout(p =0.8)
        #self.bio_dropout1 = nn.Dropout(p =0.8)
        #self.bio_dropout2 = nn.Dropout(p =0.8)
        
        #self.combined_linear1 = nn.Linear(self.bio_linear2.out_features + self.bio_linear2.out_features, 16)
        #self.combined_linear1 = nn.Linear(self.linear2.out_features + self.bio_linear1.out_features, 16)
        
        #self.combined_linear1 = nn.Linear(self.bio_linear2.out_features, 16)
        #self.combined_linear2 = nn.Linear(self.combined_linear1.out_features, 2)

        
        #self.combined_linear_L = nn.Linear(self.linear2.out_features + num_features_bio, 2)
        self.combined_linear_L = nn.Linear(self.linear2.out_features + self.bio_linear1.out_features, 2)
        
        #self.combined_dropout = nn.Dropout(p =0.8)
    
    def forward(self, x, features):
        z = features[0]
        x = x.permute(0,2,1)
        z = z.permute(0,2,1)

        out_seq = self.conv1(x)
        out_seq = F.relu(out_seq)
        out_seq = self.conv2(out_seq)
        out_seq = F.relu(out_seq)
        out_seq = self.maxpool1d_2(out_seq)

        out_seq = self.conv3(out_seq)
        out_seq = self.maxpool1d_4(out_seq)        
        
        out_seq = self.conv4(out_seq)
        out_seq = self.maxpool1d_2(out_seq) 
        
        out_seq = self.linear1(out_seq.view(-1, 256*8))
        out_seq = F.relu(out_seq)
        out_seq = self.dropout1(out_seq)
        out_seq = self.linear2(out_seq)
        #out_seq = F.relu(out_seq)
        #out_seq = self.dropout2(out_seq)
        #out_seq = self.linear3(out_seq)
        
        
        #out_bio = self.bio_linear1(z.contiguous().view(z.shape[0], -1))
        #out_bio = F.relu(out_bio)
        #out_bio = self.bio_dropout1(out_bio)
        #out_bio = self.bio_batchnorm_lin1(out_bio)
        
        #out_bio = self.bio_linear2(out_bio)
        #out_bio = F.relu(out_bio)
        #out_bio = self.bio_dropout2(out_bio)
        #out_bio = self.bio_batchnorm_lin2(out_bio)
        
        #out_bio = self.bio_linear3(out_bio)
        #print(out_bio.shape)
        #print(out_seq.shape)
        
        #out = self.combined_linear1(out)
        #print(out.shape)
        
        #out = self.combined_linear1(out)
        #out = F.relu(out)
        #out = self.combined_dropout(out)    
        
        
#         out_bio = self.bio_linear1(z.contiguous().view(z.shape[0], -1))
#         out_bio = self.bio_batchnorm_lin1(out_bio)
#         out_bio = F.relu(out_bio)
#         out_bio = self.bio_dropout1(out_bio)
#         out_bio = self.bio_linear2(out_bio)
#         out_bio = self.bio_batchnorm_lin2(out_bio)
#         out_bio = F.relu(out_bio)
#         out_bio = self.bio_dropout2(out_bio)
#         out_bio = self.bio_linear3(out_bio)
#         out_bio = self.bio_batchnorm_lin3(out_bio)
#         out_bio = F.relu(out_bio)
#         out_bio = self.bio_dropout3(out_bio)
#         out_bio = self.bio_linear4(out_bio)
        
        
#         out_seq = self.dropout3(out_seq)
#         out_bio = self.bio_dropout4(out_bio)
        
        out_bio = self.bio_linear1(z.contiguous().view(z.shape[0], -1))
        out_bio = self.bio_dropout2(out_bio)
        
        out = torch.cat([out_seq, out_bio], dim = 1)                
        #out = torch.cat([out_seq, z.contiguous().view(z.shape[0], -1)], dim = 1)                
        out = self.combined_linear_L(out)
        
        
        return out
    
    def getname(self):
        return "FantomNet_BioMLP"

In [None]:
# import torch
# from matplotlib import pyplot as plt

# print("torch version: ", torch.__version__)

        
# model = torch.nn.Linear(1, 1)

# max_lr = 1e-3
# init_lr = max_lr/25.0
# last_lr = max_lr/25e4
# print("init_lr",init_lr)
# print("max_lr",max_lr)
# print("last_lr",last_lr)

# num_epochs = 20

# optim = torch.optim.SGD(model.parameters(), lr = init_lr)

# for param_group in optim.param_groups:
#         param_group['lr'] = init_lr

        

# l1 = lr_scheduler.CosineAnnealingLR(optim, T_max = num_epochs//3, eta_min = max_lr)
# lr2 = lr_scheduler.CosineAnnealingLR(optim, T_max = 2*num_epochs//3, eta_min = last_lr)
# #lr2 = lr_scheduler.StepLR(optim, gamma = 0.9, step_size = 2)
# lrs = []


# lrs.append(optim.param_groups[0]['lr'])    

# for _ in range(num_epochs):
#     #print("-", l.last_epoch, optim.param_groups[0]['lr'])
#     if(l1.last_epoch<num_epochs//3):
#         l1.step()        
#     else:
#         lr2.step()
#     lrs.append(optim.param_groups[0]['lr'])
# plt.plot(lrs)
# plt.show()

# Cross Validation

In [None]:
params = {'legend.fontsize': 6,
          'legend.handlelength': 2}
plt.rcParams.update(params)


# different experiment configurations

configs = {
    
           'features': {
                         'network':BioMLP,
                         'use_comp_features': False,
                         'use_bio_features': True,               
                         'num_features': 292, 
                         'num_hidden': None, 
                         'max_lr': 2e-3, 
                         'wd': .06, 
                         'max_lr_decay': 0.8, 
                         'cycles_per_round': 1, 
                         'cycle_freq': 1,
                         'bmarks': { 
                                     'benchmark_acc': 1.0, 
                                     'benchmark_f1': 1.0, 
                                     'benchmark_spec': 0.7601, 
                                     'benchmark_sen': 0.8974, 
                                     'benchmark_auc': 1.0, 
                                     'benchmark_mcc': 1.0
                                   }               
                       },    
    
            'seq': {
                   'network':FantomNet, 
                   'use_comp_features': False,
                   'use_bio_features': False,                 
                   'num_features': len(unique_DNAs), 
                   'num_hidden': 16, 
                   'max_lr': 2e-3, 
                   'wd': .002,
                   'max_lr_decay': 0.9, 
                   'cycles_per_round': 1, 
                   'cycle_freq': 1, 
                   'bmarks': { 
                               'benchmark_acc': 1.0, 
                               'benchmark_f1': 1.0, 
                               'benchmark_spec': 0.7601, 
                               'benchmark_sen': 0.8974, 
                               'benchmark_auc': 1.0, 
                               'benchmark_mcc': 1.0
                             }
                  
                  },

 
            'seq_LSTM_CNN_Net': {
                   'network':LSTM_CNN_Net, 
                   'use_comp_features': False,
                   'use_bio_features': False,                 
                   'num_features': len(unique_DNAs), 
                   'num_hidden': (16, 600), 
                   'max_lr': 2e-3, 
                   'wd': .003,
                   'max_lr_decay': 0.9, 
                   'cycles_per_round': 1, 
                   'cycle_freq': 1, 
                   'bmarks': { 
                               'benchmark_acc': 1.0, 
                               'benchmark_f1': 1.0, 
                               'benchmark_spec': 0.7601, 
                               'benchmark_sen': 0.8974, 
                               'benchmark_auc': 1.0, 
                               'benchmark_mcc': 1.0
                             }
                  
                  },
    
        
    
            'seq-CNPP': {
               'network':CNPPNet, 
               'use_comp_features': False,
               'use_bio_features': False,                 
               'num_features': len(unique_DNAs), 
               'num_hidden': 16, 
               'max_lr': 1e-3, 
               'wd': .002,
               'max_lr_decay': 0.95, 
               'cycles_per_round': 2, 
               'cycle_freq': 1, 
               'bmarks': { 
                           'benchmark_acc': 1.0, 
                           'benchmark_f1': 1.0, 
                           'benchmark_spec': 0.7601, 
                           'benchmark_sen': 0.8974, 
                           'benchmark_auc': 1.0, 
                           'benchmark_mcc': 1.0
                         }

              },

#            #'features': {'network':Bio_kmer_MLP, 'num_features': 544, 'num_hidden': None, 'max_lr': 5e-3, 'wd': .065, 'max_lr_decay': 0.8, 'cycles_per_round': 1, 'cycle_freq': 7},

           'seq+features': {'network':FantomNet_Bio_MLP, 
                            'use_comp_features': False,
                            'use_bio_features': True,     
                            'num_features': (len(unique_DNAs), 292), 
                            'num_hidden': 16, 
                            'max_lr': 1e-3, 
                            'wd': .0045, 
                            'max_lr_decay': 0.8,
                            'cycles_per_round': 1,
                            'cycle_freq': 1,
                            'bmarks': {'benchmark_acc': 1.0, 
                                       'benchmark_f1': 1.0, 
                                       'benchmark_spec': 0.8819, 
                                       'benchmark_sen': 0.8895, 
                                       'benchmark_auc': 1.0, 
                                       'benchmark_mcc':.7447
                                      }
                           },
           #'seq+features': {'network':FantomNet_Bio_kmer_MLP, 'num_features': (len(unique_DNAs), 544), 'num_hidden': 16, 'max_lr': 5e-3, 'wd': .065, 'max_lr_decay': 0.8},
          }



xval_fold_count =  10 
num_epochs =  400
batch_sz = 2048

init_learning_rate = 0.01    

###### Setup K-fold X-validation
skf = StratifiedKFold( n_splits= xval_fold_count , random_state = 23, shuffle=True)


## packing variable-length data in pytorch using packing and padding:
## https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial

## https://gist.github.com/MikulasZelinka/9fce4ed47ae74fca454e88a39f8d911a

num_buckets = len(X_buckets)
train_val_idxs_each_bucket = []

# generating train and validation indices for all k folds, for each bucket, separately
for bucket_idx in range(num_buckets):
    bucket_input_features = X_buckets[bucket_idx] # tuple containing the sequence (and optionally computed features and bio features)
    
    if(use_features):
        bucket_input_features = bucket_input_features[0]
        
    bucket_labels = Y_buckets[bucket_idx]
    
    train_val_idxs = list(skf.split(np.zeros(len(bucket_labels)), bucket_labels))
    train_val_idxs_each_bucket.append(train_val_idxs)
    
config_idx = 4
config_name = list(configs.keys())[config_idx]

use_comp_features = configs[config_name]['use_comp_features']
use_bio_features = configs[config_name]['use_bio_features']


In [None]:

if do_cross_validation:


    #metrics_ht = Metrics(benchmark_acc = 0.8169, benchmark_f1 = 1.0, benchmark_spec = 0.8060, benchmark_sen = 0.8277, benchmark_auc = 1.0, benchmark_mcc = .6476)
    metrics_ht = Metrics(**configs[config_name]['bmarks'])


    all_folds_last_results = []
    all_folds_best_results_for_fold = []

    print("############################################################################################################################")
    print("############################################################################################################################")
    print(f"##################################             {config_name}                ###############################################")
    print("############################################################################################################################")
    print("############################################################################################################################")


    num_features_means = []
    num_features_sds = []

    bio_features_means = []
    bio_features_sds = []    

    for fold_idx in range(xval_fold_count):  # k-fold x-validation

        if fold_idx == 10:
            break

        best_model_mcc = -1

        print("Fold", fold_idx)

        metrics_ht.reset_history()

        train_datasets = []
        val_datasets = []

        bucket_sampling_idxs = np.random.permutation(list(range(len(X_buckets))))  # random ordering of buckets

        training_seq_label_buckets = []
        val_seq_label_buckets = []

        training_num_features_buckets = []
        val_num_features_buckets = []

        training_bio_features_buckets = []
        val_bio_features_buckets = []


        # iterate over the train-val buckets to create separate train lists for each bucket and validation lists for each bucket
        for bucket_idx in bucket_sampling_idxs:

            bucket_input_features = X_buckets[bucket_idx]
            bucket_labels = Y_buckets[bucket_idx]

            train_idx = train_val_idxs_each_bucket[bucket_idx][fold_idx][0]
            val_idx = train_val_idxs_each_bucket[bucket_idx][fold_idx][1]

            # original indices of data points in each bucket (assuming no subsampling took place while bucketing)
            #idxs_in_bucket = np.array(idxs_in_buckets[bucket_idx])

            train_labels = bucket_labels[train_idx]
            val_labels = bucket_labels[val_idx]

            #train_data_idxs = idxs_in_bucket[train_idx]        
            #val_data_idxs = idxs_in_bucket[val_idx]

            seq_features = bucket_input_features[0]
            train_seq_features = seq_features[train_idx,:]
            val_seq_features = seq_features[val_idx,:]

            training_seq_label_buckets.append( (train_seq_features, train_labels) )
            val_seq_label_buckets.append( (val_seq_features, val_labels) )

            if(use_features):

                #train_features = (bucket_input_features[0][train_idx,:], bucket_input_features[1][train_idx,:], bucket_input_features[2][train_idx,:])
                #val_features = (bucket_input_features[0][val_idx,:], bucket_input_features[1][val_idx,:], bucket_input_features[2][val_idx,:])

                #training_seq_label_buckets.append( (train_features[0], train_labels) )
                #val_seq_label_buckets.append( (val_features[0], val_labels) )

                comp_features = bucket_input_features[1]
                bio_features = bucket_input_features[2]

                if use_features and use_comp_features:
                    train_comp_features = comp_features[train_idx,:]
                    val_comp_features = comp_features[val_idx,:]

                    # kmers
                    #training_num_features_buckets.append(np.concatenate([train_features[1], train_features[2]], axis = 1) )
                    #val_num_features_buckets.append(np.concatenate([val_features[1], val_features[2]], axis = 1))
                    training_num_features_buckets.append(train_comp_features)
                    val_num_features_buckets.append(val_comp_features)

                if use_features and use_bio_features:
                    train_bio_features = bio_features[train_idx,:]
                    val_bio_features = bio_features[val_idx,:]

                    # bio only
                    #training_bio_features_buckets.append(train_features[2])
                    #val_bio_features_buckets.append(val_features[2])

                    training_bio_features_buckets.append(train_bio_features)
                    val_bio_features_buckets.append(val_bio_features)

                #train_datasets.append(data_utils.TensorDataset(torch.from_numpy(train_features[0]), torch.from_numpy(train_features[1]).float(), torch.from_numpy(train_labels)))
                #val_datasets.append(data_utils.TensorDataset(torch.from_numpy(val_features[0]), torch.from_numpy(val_features[1]).float(), torch.from_numpy(val_labels)))

            #else:
            #    train_features = bucket_input_features[train_idx,:]
            #    val_features = bucket_input_features[val_idx]

            #    train_datasets.append(data_utils.TensorDataset(torch.from_numpy(train_features), torch.from_numpy(train_labels)))
            #    val_datasets.append(data_utils.TensorDataset(torch.from_numpy(val_features), torch.from_numpy(val_labels)))

        # normalize numerical features combining data across buckets        
        if(use_features):
            if use_comp_features:
                training_num_features = np.concatenate(training_num_features_buckets, axis = 0)
                mu = np.mean(training_num_features, axis = 0)
                sd = np.std(training_num_features, axis = 0)
                training_num_features_buckets = [ (training_num_features_bucket - mu)/sd for training_num_features_bucket in training_num_features_buckets]

                #training_num_features = (training_num_features - mu)/sd
                #val_num_features = np.concatenate(val_num_features_buckets, axis = 0)
                #val_num_features = (val_num_features - mu)/sd
                #print(np.mean(training_num_features_buckets[0].squeeze(),0) )
                #print(np.std(training_num_features_buckets[0].squeeze(),0) )
                val_num_features_buckets = [ (val_num_features_bucket - mu)/sd for val_num_features_bucket in val_num_features_buckets]

                num_features_means.append(mu)
                num_features_sds.append(sd)
            else: # dummy data
                training_num_features_buckets = [ np.zeros( (training_seq_label_bucket[0].shape[0],252,1) ) for training_seq_label_bucket in training_seq_label_buckets ]
                val_num_features_buckets = [ np.zeros( (val_seq_label_bucket[0].shape[0],252,1) ) for val_seq_label_bucket in val_seq_label_buckets ]

            if use_bio_features:
                training_bio_features = np.concatenate(training_bio_features_buckets, axis = 0)
                bio_mu = np.mean(training_bio_features, axis = 0)
                bio_sd = np.std(training_bio_features, axis = 0)
                training_bio_features_buckets = [ (training_bio_features_bucket - bio_mu)/bio_sd for training_bio_features_bucket in training_bio_features_buckets]

                #training_num_features = (training_num_features - mu)/sd
                #val_num_features = np.concatenate(val_num_features_buckets, axis = 0)
                #val_num_features = (val_num_features - mu)/sd
                #print(np.mean(training_num_features_buckets[0].squeeze(),0) )
                #print(np.std(training_num_features_buckets[0].squeeze(),0) )
                val_bio_features_buckets = [ (val_bio_features_bucket - bio_mu)/bio_sd for val_bio_features_bucket in val_bio_features_buckets]

                bio_features_means.append(bio_mu)
                bio_features_sds.append(bio_sd)
            else:  # dummy data
                training_bio_features_buckets = [ np.zeros( (training_seq_label_bucket[0].shape[0],292,1) ) for training_seq_label_bucket in training_seq_label_buckets ]
                val_bio_features_buckets = [ np.zeros( (val_seq_label_bucket[0].shape[0],292,1) ) for val_seq_label_bucket in val_seq_label_buckets ]

        for bucket_idx in bucket_sampling_idxs:
            #print(training_seq_label_buckets[bucket_idx][0].shape)
            #print(training_num_features[bucket_idx].shape)
            #print(training_seq_label_buckets[bucket_idx][1].shape)

            train_seq_T = torch.from_numpy(training_seq_label_buckets[bucket_idx][0])
            train_label_T = torch.from_numpy(training_seq_label_buckets[bucket_idx][1])
            #train_dataset_components = [train_seq_T]
            #train_dataset_components.append(train_label_T)

            val_seq_T = torch.from_numpy(val_seq_label_buckets[bucket_idx][0])
            val_label_T = torch.from_numpy(val_seq_label_buckets[bucket_idx][1])
            #val_dataset_components = [val_seq_T]
            #val_dataset_components.append(val_label_T)

            #if use_comp_features: 
            train_num_T = torch.from_numpy(training_num_features_buckets[bucket_idx]).float()
            #train_dataset_components.append(train_num_T)

            val_num_T = torch.from_numpy(val_num_features_buckets[bucket_idx]).float()
            #val_dataset_components.append(val_num_T)

            #if use_bio_features: 
            train_bio_T = torch.from_numpy(training_bio_features_buckets[bucket_idx]).float()
            #train_dataset_components.append(train_bio_T)

            val_bio_T = torch.from_numpy(val_bio_features_buckets[bucket_idx]).float()
            #val_dataset_components.append(val_bio_T)

            train_datasets.append(data_utils.TensorDataset(train_seq_T, train_label_T, train_num_T, train_bio_T ) )
            val_datasets.append(data_utils.TensorDataset(val_seq_T, val_label_T, val_num_T, val_bio_T))


        train_dataloaders = [data_utils.DataLoader(train_dataset, batch_size = batch_sz, shuffle = True) for train_dataset in train_datasets]
        val_dataloaders = [data_utils.DataLoader(val_dataset, batch_size = batch_sz, shuffle = False) for val_dataset in val_datasets]

        num_iter_per_epoch = np.sum([len(loader) for loader in train_dataloaders])

        training_losses = []
        val_losses = []

        # initialize the model
        #model = RNA_Net(num_DNAs, 16).cuda()
        #model = CNPPNet(num_DNAs, 16).cuda()

    #         if(use_features):   
    #             model = CNPPNet_Hybrid(num_DNAs, 16).cuda()
    #         else:
    #             model = FantomNet(num_DNAs, 16).cuda()

        model = configs[config_name]['network'](configs[config_name]['num_features'], configs[config_name]['num_hidden']).cuda()

        print(model)

        # loss function, optimization algorithm, and final layer activation function
        loss_function = torch.nn.CrossEntropyLoss(weight=torch.Tensor(list(class_weights.values())).cuda() )
        activationFunc = torch.nn.LogSoftmax(dim = -1)
        optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3,  weight_decay = configs[config_name]['wd'])
        scheduler = OneCycleCosineAnnealing(optimizer, 
                                            num_iter_per_epoch*num_epochs/configs[config_name]['cycle_freq'], 
                                            cycles_per_round = configs[config_name]['cycles_per_round'], 
                                            max_lr = configs[config_name]['max_lr'], 
                                            decay_rate = configs[config_name]['max_lr_decay'])

        print(num_iter_per_epoch)



        # train the network 
        for epoch in range(num_epochs):
            print("Epoch",epoch,"started.")

            train_dataset_sizes = [len(train_dataloader.dataset) for train_dataloader in train_dataloaders]
            val_dataset_sizes = [len(val_dataloader.dataset) for val_dataloader in val_dataloaders]

            print("Training dataset size(s):",*train_dataset_sizes)
            print("Validation dataset size(s):",*val_dataset_sizes)

            training_loss = []
            train_ys = []
            train_preds = []

            model = model.train()

            # train the network                
            for train_dataloader in tqdm_notebook(train_dataloaders):

                #for x, y in train_dataloader:
                for x, y, p, q in train_dataloader:
                    model.zero_grad()
                    optimizer.zero_grad()

                    x = x.cuda()
                    y = y.cuda()

                    feature_data = []

                    if use_comp_features:
                        p = p.cuda()
                        feature_data.append(p)

                    if use_bio_features:
                        q = q.cuda()
                        feature_data.append(q)

                    #x_lens = torch.sum(torch.any(x !=0, dim = -1), dim = -1)

                    #torch.all(x !=0)

                    # compute the length of each sequence in the batch (to be used for packing)
                    # pad the sequences in the batch

                    #logits = model(x, seq_lengths = x_lens)

                    logits = model(x, feature_data)

                    #logits = model(x, z)
                    preds = torch.argmax(activationFunc(logits), dim = -1)

                    #print(logits.shape)
                    #print(y.shape)
                    loss = loss_function(logits, y)
                    training_loss.append(loss.item())

                    train_ys.append(y.cpu().numpy())
                    train_preds.append(preds.cpu().numpy())

                    loss.backward()
                    scheduler.step()
                    optimizer.step()

            training_losses.append(np.mean(training_loss))        
            train_ys = np.concatenate(train_ys) 
            train_preds = np.concatenate(train_preds)

            print("Epoch",epoch,"completed.")

            #print('________________________________________')
            #print('Training metrics')
            #results = metrics_ht.compute_metrics(train_ys, train_preds, epoch, do_print = False)
            #print('training_accuracy:',results['accuracy'])
            #print('________________________________________')

            model = model.eval()
            # compute validation loss, and accuracy metrics        
            with torch.no_grad():
                val_ys = []
                val_preds = []
                val_pred_probs = []
                val_loss = []

                for val_dataloader in val_dataloaders:
                    #for x, y in val_dataloader:
                    for x, y, p, q in val_dataloader:
                        x = x.cuda()
                        y = y.cuda()

                        feature_data = []

                        if use_comp_features:
                            p = p.cuda()
                            feature_data.append(p)

                        if use_bio_features:
                            q = q.cuda()
                            feature_data.append(q)

                        #x_lens = torch.sum(torch.any(x !=0, dim = -1), dim = -1)

                        #logits = model(x, seq_lengths = x_lens.cpu())                

                        logits = model(x, feature_data)                
                        #logits = model(x, z)                
                        log_pred_probs = activationFunc(logits)
                        preds = torch.argmax(activationFunc(logits), dim = -1)

                        loss = loss_function(logits, y)
                        val_loss.append(loss.item())

                        val_ys.append(y.cpu().numpy())
                        val_preds.append(preds.cpu().numpy())
                        val_pred_probs.append(np.exp(log_pred_probs.cpu().numpy()))
                val_losses.append(np.mean(val_loss))


            val_ys = np.concatenate(val_ys) 
            val_preds = np.concatenate(val_preds)
            val_pred_probs = np.concatenate(val_pred_probs)
            #val_accuracy = (val_ys==val_preds).sum()/len(val_preds) 
            #print("--- validation accuracy:", val_accuracy)   
            print('============================================================================================================')
            print('Validation metrics')
            val_results = metrics_ht.compute_metrics(val_ys, val_preds, epoch, do_print = True, store_vals = True)
            print('============================================================================================================')
            #print('Training metrics')
            #results = metrics_ht.compute_metrics(train_ys, train_preds, epoch, do_print = True, store_vals = False)
            #print('============================================================================================================')

            line_colors = plt.cm.tab20(np.linspace(0,1,20))

            tr_loss_color = line_colors[0]
            val_loss_color = line_colors[2]



            epoch_mcc = Metrics.compute_mcc(val_ys, val_preds)
            epoch_sensitivity = Metrics.compute_sensitivity(val_ys, val_preds)
            epoch_specificity = Metrics.compute_specificity(val_ys, val_preds)

            if epoch_mcc > best_model_mcc:
                network_name = model.getname()
                best_model_mcc = epoch_mcc 
                now = datetime.now()
                savefile_name = now.strftime("%H_%M_%S")+"_"+now.strftime("%m") + now.strftime("%d") # + now.strftime("%Y")
                savedir = f'{network_name}/fold_{fold_idx:02d}'
                if not os.path.exists(savedir):
                    os.makedirs(savedir)

                pth  = Path(savedir)
                for f in list(pth.glob('*.pkl')):
                    os.remove(f)

                for f in list(pth.glob('*.npy')):
                    os.remove(f)
                    
                prefix = savedir.replace('/','_')
                torch.save(model.state_dict(), f'{savedir}/mdl_{prefix}_ep_{epoch:03d}_mcc_{epoch_mcc:.05f}_sens_{epoch_sensitivity:.05f}_spec_{epoch_specificity:.05f}__{savefile_name}.pkl')    

                if use_comp_features:
                    np.save(f'{savedir}/num_features_means_fold_{fold_idx}', num_features_means[fold_idx])
                    np.save(f'{savedir}/num_features_sds_fold_{fold_idx}', num_features_sds[fold_idx])

                if use_bio_features:
                    np.save(f'{savedir}/bio_features_means_fold_{fold_idx}', bio_features_means[fold_idx])
                    np.save(f'{savedir}/bio_features_sds_fold_{fold_idx}', bio_features_sds[fold_idx])
                
                best_results = metrics_ht.compute_metrics(val_ys, val_preds, epoch, do_print = False, store_vals = False)
                # use to load and start inference:
                # model.load_state_dict(torch.load(filepath))
                # model.eval()

    #             print("Test results")
    #             for result in all_folds_last_results:
    #                 print("accuracy:",result['accuracy'])
    #                 print("sensitivity:",result['sensitivity'])
    #                 print("specificity:",result['specificity'])


            if epoch%5==0:   # only plot in even-numbered epoch to reduce stress on the webpage
                x_data = range(epoch+1)

                f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4,figsize=(12,2))
                ax1.plot(x_data, training_losses, label="train loss", color = tr_loss_color, alpha=1.0)
                ax1.plot(x_data, val_losses, label="val_loss", color = val_loss_color, alpha=1.0)
                ax1.legend()
                ax1.set_title("Fold: "+str(fold_idx)+" epoch: "+str(epoch))

                #ax = fig.subplot(122), alpha=.7
                tmp = np.array(metrics_ht.tp_tn_fp_fn)
                ax2.plot(x_data, tmp[:,2], label="False Positive", color = tr_loss_color, alpha=1.0)
                ax2.plot(x_data, tmp[:,3], label="False Negatives", color = val_loss_color, alpha=1.0)
                ax2.legend()
                ax2.set_title("FP & FN Fold: "+str(fold_idx)+" epoch: "+str(epoch))


                #ax3.plot(x_data, [metrics_ht.b_accuracy]*len(x_data), label="b_accuracy", color = 'r', alpha=.5)
                #ax3.plot(x_data, [metrics_ht.b_sensitivity]*len(x_data), label="b_sensitivity", color = 'g', alpha=.5)
                #ax3.plot(x_data, [metrics_ht.b_specificity]*len(x_data), label="b_specificity", color = 'b', alpha=.5)
                #ax3.plot(x_data, [metrics_ht.b_f1]*len(x_data), label="b_f1_score", color = 'c', alpha=.5)
                #ax3.plot(x_data, [metrics_ht.b_auc_roc]*len(x_data), label="b_auc_roc", color = 'k', alpha=.5)

                #ax3.plot(x_data, metrics_ht.accuracies, '--', label="accuracy", color = 'r')
                #ax3.plot(x_data, metrics_ht.sensitivity, '--',  label="sensitivity", color = 'g')
                #ax3.plot(x_data, metrics_ht.specificity, '--',  label="specificity", color = 'b')
                #ax3.plot(x_data, metrics_ht.f1s,  '--', label="b_f1_score", color = 'c')
                #ax3.plot(x_data, metrics_ht.auc_roc,  '--', label="b_auc_roc", color = 'k')

                #ax3.legend(loc='lower left')
                #ax3.set_title("Acc-Sensitivity-Specificity-F1-AUC for Fold: "+str(k)+" epoch: "+str(epoch))

                ax4.plot(range(len(scheduler.lrs)), scheduler.lrs, label="learning rates")
                ax4.legend()

                plt.show()
                plt.close()

        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        print('--------------------------------------------------------------------------------')
        all_folds_last_results.append(val_results)
        all_folds_best_results_for_fold.append(best_results)

    #     for result in all_folds_last_results:
    #         print(result)

    # for result in all_folds_best_results_for_fold:
    #     print(result)

    print("Best results from each fold")
    for fold_idx, result in enumerate(all_folds_best_results_for_fold):
        print(f"Fold {fold_idx:02d}: mcc:{result['mcc']:.05f} sensitivity:{result['sensitivity']:.05f} specificity:{result['specificity']:.05f}")

In [None]:
# print("Cross validation results (for last epochs results from each fold ):")
# accuracy =  [result['accuracy'] for result in all_folds_last_results]
# #print('mean accuracy:', np.mean(accuracy), np.std(accuracy))    
# sensitivity =  [result['sensitivity'] for result in all_folds_last_results]
# print('mean sensitivity:', np.mean(sensitivity), np.std(sensitivity))    
# specificity =  [result['specificity'] for result in all_folds_last_results]
# print('mean specificity:', np.mean(specificity), np.std(specificity))   
# mcc =  [result['mcc'] for result in all_folds_last_results]
# print('mean mcc:', np.mean(mcc), np.std(mcc))   

In [None]:
if do_cross_validation:
    print("Best results from each fold")
    for fold_idx, result in enumerate(all_folds_best_results_for_fold):
        print(f"Fold {fold_idx:02d}: mcc:{result['mcc']:.05f} sensitivity:{result['sensitivity']:.05f} specificity:{result['specificity']:.05f}")

    print("===========================================================")
    print("Cross validation results (for best results from each fold):")
    #accuracy =  [result['accuracy'] for result in all_folds_best_results_for_fold]
    #print('mean accuracy:', np.mean(accuracy), np.std(accuracy))    
    sensitivity =  [result['sensitivity'] for result in all_folds_best_results_for_fold]
    print('mean sensitivity:', np.mean(sensitivity), np.std(sensitivity))    
    specificity =  [result['specificity'] for result in all_folds_best_results_for_fold]
    print('mean specificity:', np.mean(specificity), np.std(specificity))   
    mcc =  [result['mcc'] for result in all_folds_best_results_for_fold]
    print('mean mcc:', np.mean(mcc), np.std(mcc))   

# Testing with best (manually selected) model

In [None]:
if do_testing_using_best_model:

    model = configs[config_name]['network'](configs[config_name]['num_features'], configs[config_name]['num_hidden']).cuda()

    best_fold_idx = 0

    best_model_dir = model.getname() + f'/fold_{best_fold_idx:02d}'
    #best_model_filename = 'mdl_FantomNet_BioMLP_fold_01_ep_159_mcc_0.64987_sens_0.85923_spec_0.81466__16_21_55_0814.pkl'

    best_model_path = str(list(Path(f'{best_model_dir}/').glob('*.pkl'))[0])
    #best_model_path = best_model_dir + '/' + str(best_model_filename)     
    
    test_model_dir = best_model_dir
    test_model_path = best_model_path

    
    #metrics_ht = Metrics(benchmark_acc = 0.8169, benchmark_f1 = 1.0, benchmark_spec = 0.8060, benchmark_sen = 0.8277, benchmark_auc = 1.0, benchmark_mcc = .6476)
    test_metrics_ht = Metrics(**configs[config_name]['bmarks'])
    test_metrics_ht.reset_history()
    test_bucket_sampling_idxs = np.random.permutation(list(range(len(test_X_buckets))))  # random ordering of buckets

    test_seq_label_buckets = []
    test_num_features_buckets = []
    test_bio_features_buckets  = []
    test_datasets = []

    for test_bucket_idx in test_bucket_sampling_idxs:

        test_bucket_input_features = test_X_buckets[test_bucket_idx]
        test_bucket_labels = test_Y_buckets[test_bucket_idx]

        test_seq_features = test_bucket_input_features[0]

        test_seq_label_buckets.append( (test_seq_features, test_bucket_labels) )

        if(use_features):

            test_comp_features = test_bucket_input_features[1]
            test_bio_features = test_bucket_input_features[2]

            if use_features and use_comp_features:
                test_num_features_buckets.append(test_comp_features)

            if use_features and use_bio_features:
                test_bio_features_buckets.append(test_bio_features)

    if(use_features):
        if use_comp_features:
            num_features_means = np.load(f'{best_model_dir}/num_features_means_fold_{best_fold_idx}.npy')
            num_features_sds = np.load(f'{best_model_dir}/num_features_sds_fold_{best_fold_idx}.npy')
            test_num_features = np.concatenate(test_num_features_buckets, axis = 0)
            mu = num_features_means[best_fold_idx]
            sd = num_features_sds[best_fold_idx]
            test_num_features_buckets = [ (test_num_features_bucket - mu)/sd for test_num_features_bucket in test_num_features_buckets]
        else: # dummy data
            test_num_features_buckets = [ np.zeros( (test_seq_label_bucket[0].shape[0],1,1) ) for test_seq_label_bucket in test_seq_label_buckets ]

        if use_bio_features:
            bio_features_means = np.load(f'{best_model_dir}/bio_features_means_fold_{best_fold_idx}.npy')
            bio_features_sds = np.load(f'{best_model_dir}/bio_features_sds_fold_{best_fold_idx}.npy')
            test_bio_features = np.concatenate(test_bio_features_buckets, axis = 0)
            bio_mu = bio_features_means[best_fold_idx]
            bio_sd = bio_features_sds[best_fold_idx]
            test_bio_features_buckets = [ (test_bio_features_bucket - bio_mu)/bio_sd for test_bio_features_bucket in test_bio_features_buckets]        

        else:  # dummy data
            test_bio_features_buckets = [ np.zeros( (test_seq_label_bucket[0].shape[0],2,1) ) for test_seq_label_bucket in test_seq_label_buckets ]

    for test_bucket_idx in test_bucket_sampling_idxs:
        test_seq_T = torch.from_numpy(test_seq_label_buckets[test_bucket_idx][0])
        test_label_T = torch.from_numpy(test_seq_label_buckets[test_bucket_idx][1])

        #if use_comp_features: 
        test_num_T = torch.from_numpy(test_num_features_buckets[test_bucket_idx]).float()

        #if use_bio_features: 
        test_bio_T = torch.from_numpy(test_bio_features_buckets[test_bucket_idx]).float()

        test_datasets.append(data_utils.TensorDataset(test_seq_T, test_label_T, test_num_T, test_bio_T ) )


    test_dataloaders = [data_utils.DataLoader(test_dataset, batch_size = batch_sz, shuffle = False) for test_dataset in test_datasets]

    num_iter_per_test_epoch = np.sum([len(loader) for loader in test_dataloaders])

    test_losses = []

    # loss function and final layer activation function
    loss_function = torch.nn.CrossEntropyLoss(weight=torch.Tensor(list(class_weights.values())).cuda() )
    activationFunc = torch.nn.LogSoftmax(dim = -1)


    

    

In [None]:
# (neg_gt_test_preds_probs[:len(neg_gt_test_preds_probs)//2,:] + neg_gt_test_preds_probs[len(neg_gt_test_preds_probs)//2:,:])/2

# Testing by retraining with train+val set

In [None]:
# best_fold_idx = 0
# best_model_filename = 'mdl_BioMLP_fold_00_ep_034_mcc_0.67554_sens_0.89331_spec_0.79017__03_03_48_0809.pkl'

# best_model_dir = model.getname() + f'/fold_{best_fold_idx:02d}'
# best_model_path = best_model_dir + '/' + best_model_filename 


#metrics_ht = Metrics(benchmark_acc = 0.8169, benchmark_f1 = 1.0, benchmark_spec = 0.8060, benchmark_sen = 0.8277, benchmark_auc = 1.0, benchmark_mcc = .6476)

if do_retraining or do_testing_only:

    retrain_metrics_ht = Metrics(**configs[config_name]['bmarks'])
    retrain_metrics_ht.reset_history()

    test_metrics_ht = Metrics(**configs[config_name]['bmarks'])
    test_metrics_ht.reset_history()

    retrain_bucket_sampling_idxs = np.random.permutation(list(range(len(X_buckets))))  # random ordering of buckets
    test_bucket_sampling_idxs = np.random.permutation(list(range(len(test_X_buckets))))  # random ordering of buckets

    test_seq_label_buckets = []
    test_num_features_buckets = []
    test_bio_features_buckets  = []
    test_datasets = []

    retrain_seq_label_buckets = []
    retrain_num_features_buckets = []
    retrain_bio_features_buckets  = []
    retrain_datasets = []


    for retrain_bucket_idx in retrain_bucket_sampling_idxs:

        retrain_bucket_input_features = X_buckets[retrain_bucket_idx]
        retrain_bucket_labels = Y_buckets[retrain_bucket_idx]

        retrain_seq_features = retrain_bucket_input_features[0]

        retrain_seq_label_buckets.append( (retrain_seq_features, retrain_bucket_labels) )

        if(use_features):

            retrain_comp_features = retrain_bucket_input_features[1]
            retrain_bio_features = retrain_bucket_input_features[2]

            if use_features and use_comp_features:
                retrain_num_features_buckets.append(retrain_comp_features)

            if use_features and use_bio_features:
                retrain_bio_features_buckets.append(retrain_bio_features)

    for test_bucket_idx in test_bucket_sampling_idxs:

        test_bucket_input_features = test_X_buckets[test_bucket_idx]
        test_bucket_labels = test_Y_buckets[test_bucket_idx]

        test_seq_features = test_bucket_input_features[0]

        test_seq_label_buckets.append( (test_seq_features, test_bucket_labels) )

        if(use_features):

            test_comp_features = test_bucket_input_features[1]
            test_bio_features = test_bucket_input_features[2]

            if use_features and use_comp_features:
                test_num_features_buckets.append(test_comp_features)

            if use_features and use_bio_features:
                test_bio_features_buckets.append(test_bio_features)

    if(use_features):
        if use_comp_features:

            retrain_num_features = np.concatenate(retrain_num_features_buckets, axis = 0)
            mu = np.mean(retrain_num_features, axis = 0)
            sd = np.std(retrain_num_features, axis = 0)
            retrain_num_features_buckets = [ (retrain_num_features_bucket - mu)/sd for retrain_num_features_bucket in retrain_num_features_buckets]

            test_num_features_buckets = [ (test_num_features_bucket - mu)/sd for test_num_features_bucket in test_num_features_buckets]
        else: # dummy data
            retrain_num_features_buckets = [ np.zeros( (retrain_seq_label_bucket[0].shape[0],252,1) ) for retrain_seq_label_bucket in retrain_seq_label_buckets ]        
            test_num_features_buckets = [ np.zeros( (test_seq_label_bucket[0].shape[0],252,1) ) for test_seq_label_bucket in test_seq_label_buckets ]

        if use_bio_features:
            retrain_bio_features = np.concatenate(retrain_bio_features_buckets, axis = 0)
            bio_mu = np.mean(retrain_bio_features, axis = 0)
            bio_sd = np.std(retrain_bio_features, axis = 0)
            retrain_bio_features_buckets = [ (retrain_bio_features_bucket - bio_mu)/bio_sd for retrain_bio_features_bucket in retrain_bio_features_buckets]        

            test_bio_features_buckets = [ (test_bio_features_bucket - bio_mu)/bio_sd for test_bio_features_bucket in test_bio_features_buckets]        

        else:  # dummy data
            retrain_bio_features_buckets = [ np.zeros( (retrain_seq_label_bucket[0].shape[0],292,1) ) for retrain_seq_label_bucket in retrain_seq_label_buckets ]
            test_bio_features_buckets = [ np.zeros( (test_seq_label_bucket[0].shape[0],292,1) ) for test_seq_label_bucket in test_seq_label_buckets ]

    for retrain_bucket_idx in retrain_bucket_sampling_idxs:
        retrain_seq_T = torch.from_numpy(retrain_seq_label_buckets[retrain_bucket_idx][0])
        retrain_label_T = torch.from_numpy(retrain_seq_label_buckets[retrain_bucket_idx][1])

        #if use_comp_features: 
        retrain_num_T = torch.from_numpy(retrain_num_features_buckets[retrain_bucket_idx]).float()

        #if use_bio_features: 
        retrain_bio_T = torch.from_numpy(retrain_bio_features_buckets[retrain_bucket_idx]).float()

        retrain_datasets.append(data_utils.TensorDataset(retrain_seq_T, retrain_label_T, retrain_num_T, retrain_bio_T ) )


    for test_bucket_idx in test_bucket_sampling_idxs:
        test_seq_T = torch.from_numpy(test_seq_label_buckets[test_bucket_idx][0])
        test_label_T = torch.from_numpy(test_seq_label_buckets[test_bucket_idx][1])

        #if use_comp_features: 
        test_num_T = torch.from_numpy(test_num_features_buckets[test_bucket_idx]).float()

        #if use_bio_features: 
        test_bio_T = torch.from_numpy(test_bio_features_buckets[test_bucket_idx]).float()

        test_datasets.append(data_utils.TensorDataset(test_seq_T, test_label_T, test_num_T, test_bio_T ) )

    retrain_dataloaders = [data_utils.DataLoader(retrain_dataset, batch_size = batch_sz, shuffle = True) for retrain_dataset in retrain_datasets]
    test_dataloaders = [data_utils.DataLoader(test_dataset, batch_size = batch_sz, shuffle = True) for test_dataset in test_datasets]

    num_iter_per_retrain_epoch = np.sum([len(loader) for loader in retrain_dataloaders])
    num_iter_per_test_epoch = np.sum([len(loader) for loader in test_dataloaders])

    model = configs[config_name]['network'](configs[config_name]['num_features'], configs[config_name]['num_hidden']).cuda()

    
    # loss function and final layer activation function
    loss_function = torch.nn.CrossEntropyLoss(weight=torch.Tensor(list(class_weights.values())).cuda() )
    activationFunc = torch.nn.LogSoftmax(dim = -1)
    
    if do_retraining:
        
        print(model)

        
        retrain_losses = []
        

        # use to load and start inference:
        #model.load_state_dict(torch.load(best_model_path))
        #model.eval()

        optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3,  weight_decay = configs[config_name]['wd'])
        scheduler = OneCycleCosineAnnealing(optimizer, 
                                            num_iter_per_retrain_epoch*num_epochs/configs[config_name]['cycle_freq'], 
                                            cycles_per_round = configs[config_name]['cycles_per_round'], 
                                            max_lr = configs[config_name]['max_lr'], 
                                            decay_rate = configs[config_name]['max_lr_decay'])

        print(num_iter_per_retrain_epoch)

        model = model.train()

        # train the network 
        for epoch in range(num_epochs):
            print("Epoch",epoch,"started.")

            retrain_dataset_sizes = [len(retrain_dataloader.dataset) for retrain_dataloader in retrain_dataloaders]
            print("Retrain dataset size(s):",*retrain_dataset_sizes)

            retraining_loss = []
            retrain_ys = []
            retrain_preds = []

            # train the network                
            for retrain_dataloader in tqdm_notebook(retrain_dataloaders):

                #for x, y in train_dataloader:
                for x, y, p, q in retrain_dataloader:
                    model.zero_grad()
                    optimizer.zero_grad()

                    x = x.cuda()
                    y = y.cuda()

                    retrain_feature_data = []

                    if use_comp_features:
                        p = p.cuda()
                        retrain_feature_data.append(p)

                    if use_bio_features:
                        q = q.cuda()
                        retrain_feature_data.append(q)

                    logits = model(x, retrain_feature_data)

                    preds = torch.argmax(activationFunc(logits), dim = -1)

                    loss = loss_function(logits, y)
                    retraining_loss.append(loss.item())

                    retrain_ys.append(y.cpu().numpy())
                    retrain_preds.append(preds.cpu().numpy())

                    loss.backward()
                    scheduler.step()
                    optimizer.step()

            retrain_losses.append(np.mean(retraining_loss))        
            retrain_ys = np.concatenate(retrain_ys) 
            retrain_preds = np.concatenate(retrain_preds)

            print("Epoch",epoch,"completed.")

            print('============================================================================================================')
            print('Retrain metrics')
            retrain_results = retrain_metrics_ht.compute_metrics(retrain_ys, retrain_preds, epoch, do_print = True, store_vals = True)
            print('============================================================================================================')    

            if epoch%5==0: 
                line_colors = plt.cm.tab20(np.linspace(0,1,20))

                retr_loss_color = line_colors[0]
                retr_loss_color2 = line_colors[1]

                x_data = range(epoch+1)

                f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4,figsize=(12,2))
                ax1.plot(x_data, retrain_losses, label="retrain loss", color = retr_loss_color, alpha=1.0)
                ax1.legend()
                ax1.set_title("Epoch: "+str(epoch))

                #ax = fig.subplot(122), alpha=.7
                tmp = np.array(retrain_metrics_ht.tp_tn_fp_fn)
                ax2.plot(x_data, tmp[:,2], label="False Positive", color = retr_loss_color, alpha=1.0)
                ax2.plot(x_data, tmp[:,3], label="False Negatives", color = retr_loss_color2, alpha=1.0)
                ax2.legend()
                ax2.set_title("FP & FN Epoch: "+str(epoch))

                ax4.plot(range(len(scheduler.lrs)), scheduler.lrs, label="learning rates")
                ax4.legend()

                plt.show()
                plt.close()


        network_name = model.getname()
        now = datetime.now()
        savefile_name = now.strftime("%H_%M_%S")+"_"+now.strftime("%m") + now.strftime("%d") # + now.strftime("%Y")
        savedir = f'{network_name}/retrained'
        if not os.path.exists(savedir):
            os.makedirs(savedir)

        pth  = Path(savedir)
        for f in list(pth.glob('*.pkl')):
            os.remove(f)

        prefix = savedir.replace('/','_')
        torch.save(model.state_dict(), f'{savedir}/mdl_{prefix}__{savefile_name}.pkl')    

        test_model_dir = savedir
        test_model_path = f'{test_model_dir}/mdl_{prefix}__{savefile_name}.pkl'

        

In [None]:
if do_retraining or do_testing_using_best_model or do_testing_only:

    if(do_testing_only and not do_retraining):
        
        network_name = model.getname()
        
        savedir = f'{network_name}/retrained'
        
        test_model_path = list(Path(savedir).glob('*.pkl'))[0]

        #test_model_dir = savedir
        #test_model_path = test_model_dir + '/' + test_model_filename 

# do_testing_using_best_model = True
# do_retraining = False
# do_testing_only = False    
    
    # use to load and start inference:
    model.load_state_dict(torch.load(test_model_path))
    model.eval()

    print(model)
    
    print(num_iter_per_test_epoch)

    test_dataset_sizes = [len(test_dataloader.dataset) for test_dataloader in test_dataloaders]

    print("Test dataset size(s):",*test_dataset_sizes)

    test_losses = []
    
    with torch.no_grad():
        test_ys = []
        test_preds = []
        test_preds_probs = []
        test_loss = []

        for test_dataloader in test_dataloaders:
            #for x, y in val_dataloader:
            for test_x, test_y, test_p, test_q in test_dataloader:
                test_x = test_x.cuda()
                test_y = test_y.cuda()

                test_feature_data = []

                if use_comp_features:
                    test_p = test_p.cuda()
                    test_feature_data.append(test_p)

                if use_bio_features:
                    test_q = test_q.cuda()
                    test_feature_data.append(test_q)
                #x_lens = torch.sum(torch.any(x !=0, dim = -1), dim = -1)

                #logits = model(x, seq_lengths = x_lens.cpu())                

                test_logits = model(test_x, test_feature_data)
                #logits = model(x, z)                

                preds = torch.argmax(activationFunc(test_logits), dim = -1)
                log_pred_probs = activationFunc(test_logits)

                loss = loss_function(test_logits, test_y)
                test_loss.append(loss.item())

                test_ys.append(test_y.cpu().numpy())
                test_preds.append(preds.cpu().numpy())
                test_preds_probs.append(np.exp(log_pred_probs.cpu().numpy()))

        test_losses.append(np.mean(test_loss))


    test_ys = np.concatenate(test_ys) 
    test_preds = np.concatenate(test_preds)
    test_preds_probs = np.concatenate(test_preds_probs)

    test_ys_processed = test_ys
    test_preds_processed = test_preds

    if do_horz_flip:    
        neg_test_ys = test_ys[:test_num_neg]
        pos_test_ys = test_ys[test_num_neg:]

        neg_gt_test_preds = test_preds[:test_num_neg]
        pos_gt_test_preds = test_preds[test_num_neg:]

        neg_gt_test_preds_probs = test_preds_probs[:test_num_neg,:]
        pos_gt_test_preds_probs = test_preds_probs[test_num_neg:,:]

        # fold in half
        neg_test_ys = neg_test_ys[::2]
        pos_test_ys = pos_test_ys[::2]

        neg_gt_test_preds_probs = (neg_gt_test_preds_probs[:len(neg_gt_test_preds_probs)//2,:] + neg_gt_test_preds_probs[len(neg_gt_test_preds_probs)//2:,:])/2
        pos_gt_test_preds_probs = (pos_gt_test_preds_probs[:len(pos_gt_test_preds_probs)//2,:] + pos_gt_test_preds_probs[len(pos_gt_test_preds_probs)//2:,:])/2

        neg_gt_preds = np.argmax(neg_gt_test_preds_probs, axis = -1)
        pos_gt_preds = np.argmax(pos_gt_test_preds_probs, axis = -1)

        test_ys_processed = np.concatenate([neg_test_ys, pos_test_ys], axis = 0)
        test_preds_processed = np.concatenate([neg_gt_preds, pos_gt_preds], axis = 0)

    print('============================================================================================================')
    print('Test metrics')
    test_results = test_metrics_ht.compute_metrics(test_ys_processed, test_preds_processed, 0, do_print = True, store_vals = False)
    print('============================================================================================================')

    test_mcc = Metrics.compute_mcc(test_ys, test_preds)
    test_sensitivity = Metrics.compute_sensitivity(test_ys, test_preds)
    test_specificity = Metrics.compute_specificity(test_ys, test_preds)


    print("Test results")
    print("sensitivity:",test_results['sensitivity'])
    print("specificity:",test_results['specificity'])
    print("mcc:",test_results['mcc'])
    print("accuracy:",test_results['accuracy'])


    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')
    print('--------------------------------------------------------------------------------')


In [None]:
class_weights

In [None]:
if do_testing_only or do_testing_using_best_model or do_retraining:
    print("Test results")
    print("sensitivity:",test_results['sensitivity'])
    print("specificity:",test_results['specificity'])
    print("mcc:",test_results['mcc'])
    print("accuracy:",test_results['accuracy'])

# Helper functions for Saliency

In [None]:
def get_kmers(alphabet, k):
    kmers_list = []
    kmers_list.append(alphabet)
    
    for k_idx in range(k-1):
        kmers_list.append([a+b for a in kmers_list[k_idx] for b in alphabet ])

    seq_lst = list(seq)
    subs = ["".join(seq_lst[idx:idx+k]) for idx in range(len(seq)-k+1)]
    kmers = kmers_list[k-1]

    return kmers

def sequence2encodings(alphabet, sequence_lst):

    label_encodings = np.array([alphabet.index(c)+1  for seq in sequence_lst for c in seq ]).reshape(len(sequence_lst),-1)

    onehot_encodings=[(np.arange(len(alphabet)+1) == label_encoding[:,None]).astype(dtype='float32') for label_encoding in label_encodings]#one_hot

    onehot_encodings = np.array([np.delete(onehot,0, axis=-1) for onehot in onehot_encodings]).reshape(len(onehot_encodings),-1, len(alphabet))
        
    return label_encodings, onehot_encodings

def onehots2sequences(alphabet, onehot_encodings):

    for onehot_encoding in onehot_encodings: 
        assert onehot_encoding.shape[-1] ==len(alphabet), "One hot sequence must be batch_sz x seq_length x alphabet_sz"
    
    DNA_sequences = ["".join(seq_lst) for seq_lst in [[alphabet[i] for i in np.argmax(one_hot, axis = 1)] for one_hot in onehot_encodings ]]

    label_encodings = np.array([alphabet.index(c)+1  for seq in DNA_sequences for c in seq ]).reshape(len(DNA_sequences),-1)

    return DNA_sequences, label_encodings


def mutate_seq_for_saliencymap(sequence_to_mutate, window_sz, mutation_probs, mutation_rate_pct, stride = None, start_idx = 0, alphabet = None, include_identicals = False ):
    
    if stride is None or stride <= 0:
        stride = window_sz
        
    if alphabet is None:
        alphabet = list(sorted(list(set(list(sequence_to_mutate)))))
    
    seq_len = len(sequence_to_mutate)
    region_start_idxs = np.array(list(range(start_idx, seq_len, stride)))
    region_end_idxs = region_start_idxs + window_sz 
    #print(region_start_idxs, region_end_idxs, region_end_idxs>(seq_len))
    
    region_end_idxs[region_end_idxs>(seq_len)]=seq_len
    #print(region_start_idxs, region_end_idxs, )
    
    mutation_segments = [sequence_to_mutate[start:start+window_sz] for start in region_start_idxs]

    mutation_counts = [int(len(mutation_segment)*mutation_rate_pct) for mutation_segment in mutation_segments]
    #mutation_locations = [np.random.randint(0,len(mutation_segment), size=(1, mutation_count)) for mutation_segment,mutation_count in zip(mutation_segments, mutation_counts)]
    mutation_locations_segments = [[np.random.randint(0,len(mutation_segment), size=(1, mutation_count))] for mutation_segment, mutation_count in zip(mutation_segments, mutation_counts)]
    
    #print(alphabet, list(mutation_probs.values()))
    replacements_segments = [np.random.choice(alphabet, size = (mutation_count, ), p = list(mutation_probs.values())) for mutation_count in mutation_counts ]
    #print("mutation_locations_segments:", mutation_locations_segments)
    #print("replacements_segments:", replacements_segments)
    #print()
    mutated_segments = []
    for mutation_segment, mutation_locations_segment, replacements_segment in zip(mutation_segments, mutation_locations_segments, replacements_segments):
        #print("mutation_locations_segment",mutation_locations_segment)
        #print("replacements_segment",replacements_segment)
        mutation_locations_segment = np.array(mutation_locations_segment).reshape(1,)
        
        mutated_segment = list(mutation_segment[:]) 
        #print(mutated_segment)
        #
        for i_mloc, mloc in enumerate(np.array(mutation_locations_segment)):
            mutated_segment[mloc]=replacements_segment[i_mloc] 
        #    print("    ", mloc, replacements_segment[i_mloc] )
        #print(mutated_segment)
        mutated_segments.append("".join(mutated_segment))
        #print()
    
    mutated_sequence = sequence_to_mutate[:] # copy the sequence since strings are immutable
    
    for mutated_segment, start, end in zip(mutated_segments, region_start_idxs, region_end_idxs):
        #print(start,sequence_to_mutate, mutated_sequence[:start+1], mutated_segment, mutated_sequence[end:])
        mutated_sequence = mutated_sequence[:start] + mutated_segment + mutated_sequence[end:]
        #print(sequence_to_mutate)
        #print(mutated_sequence)
        #print()
    
    if not include_identicals:
        mutated_idxs = [i for i in range(len(sequence_to_mutate)) if sequence_to_mutate[i]!=mutated_sequence[i] ]
    else:
        mutated_idxs = list(range(len(sequence_to_mutate)))
        
    all_one_pos_change_sequences = []
    for i in mutated_idxs:
        all_one_pos_change_sequences.append("".join(list(sequence_to_mutate[:i])+list(mutated_sequence[i])+list(sequence_to_mutate[i+1:])))
    
    #return mutation_segments, mutated_segments, sequence_to_mutate, mutated_sequence, all_one_pos_change_sequences
    return sequence_to_mutate, mutated_sequence, all_one_pos_change_sequences
    
    
    
alphabet = unique_DNAs
kmers = get_kmers(alphabet, 2)

mutation_probs = {'eq':{'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 }, 
                  'A': {'A': 1.,   'C': 0.,   'G': 0.,   'T': 0.,  },
                  'C': {'A': 0.,   'C': 1.,   'G': 0.,   'T': 0.,  },
                  'G': {'A': 0.,   'C': 0.,   'G': 1.,   'T': 0.,  },
                  'T': {'A': 0.,   'C': 0.,   'G': 0.,   'T': 1.,  },
                 }


# # # labelencodings, one_hots = sequence2encodings(alphabet, kmers)
# # # DNA_seqs, labelencodings_2 = onehots2sequences(alphabet, one_hots)

# # # DNA_seqs, kmers, labelencodings, labelencodings_2, one_hots
# # # "".join(list(map(chr,list(range(65,65+26)))))
# # #"".join(DNA_seqs)
# # #print(mut_segs),"".join(DNA_seqs), mutation_segments, mutated_segments
# # # print(mutation_segments)
# # # print(mutated_segments)
# # # print()

# test_seq = "AGCTATG"

# all_one_place_changes_for_all_DNA = []
# for DNA in unique_DNAs:
#     orig_seq, mut_seq, all_changed = mutate_seq_for_saliencymap(test_seq, 
#                                                                 window_sz = 1, 
#                                                                 mutation_probs = mutation_probs[DNA], 
#                                                                 mutation_rate_pct = 1., 
#                                                                 stride = 1, 
#                                                                 alphabet = unique_DNAs, 
#                                                                 include_identicals = True)
#     all_one_place_changes_for_all_DNA.extend(all_changed)
# len(all_one_place_changes_for_all_DNA), print(all_one_place_changes_for_all_DNA)


# # #[print(c_seq) for c_seq in all_one_place_changes_for_all_DNA]    
# # #print(len(all_one_place_changes_for_all_DNA))
# # all_one_place_changes_for_all_DNA = list(sorted(list(set(all_one_place_changes_for_all_DNA))))    
# # #print(len(all_one_place_changes_for_all_DNA))
# # [print(c_seq) for c_seq in all_one_place_changes_for_all_DNA]
# # print(orig_seq)
# # print(mut_seq)
# # print()
# # _, ohe = sequence2encodings(alphabet, mut_seq)
# # [print(c_seq) for c_seq in all_one_place_changes_for_all_DNA]


# # #print(ohe)

In [None]:
time

# Compute Saliency

In [None]:
# take the first bucket for now

bucket_idx = 0

test_X_onehots = test_datasets[bucket_idx][:][0]
test_Ys = test_datasets[bucket_idx][:][1]
test_num_features = test_datasets[bucket_idx][:][2]
test_bio_features = test_datasets[bucket_idx][:][3]

test_X_seqs, test_X_label_encodings = onehots2sequences(unique_DNAs, test_X_onehots)


test_X_onehots.shape, len(test_X_seqs), len(test_Ys), test_Ys[285:295], test_num_features.shape, test_bio_features.shape, type(test_X_onehots), type(test_bio_features) 


# test_num_features_buckets, test_bio_features_buckets

# for test_bucket_idx in test_bucket_sampling_idxs:
#     test_seq_T = torch.from_numpy(test_seq_label_buckets[test_bucket_idx][0])
#     test_label_T = torch.from_numpy(test_seq_label_buckets[test_bucket_idx][1])

#     #if use_comp_features: 
#     test_num_T = torch.from_numpy(test_num_features_buckets[test_bucket_idx]).float()
    
#     #if use_bio_features: 
#     test_bio_T = torch.from_numpy(test_bio_features_buckets[test_bucket_idx]).float()
    
#     test_datasets.append(data_utils.TensorDataset(test_seq_T, test_label_T, test_num_T, test_bio_T ) )


# print(len(X_seqs))

# verifying generated sequences and onehot conversion algorithm's correctness
# X_label_encodings, X_onehot_from_seq = sequence2encodings(unique_DNAs, X_seqs)
# torch.all(np.equal(X_onehot_from_seq[0], X_onehots[0])).item()

In [None]:
alphabet = unique_DNAs

# kmers = get_kmers(alphabet, 2)

mutation_probs = {'eq':{'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 }, 
                  'A': {'A': 1.,   'C': 0.,   'G': 0.,   'T': 0.,  },
                  'C': {'A': 0.,   'C': 1.,   'G': 0.,   'T': 0.,  },
                  'G': {'A': 0.,   'C': 0.,   'G': 1.,   'T': 0.,  },
                  'T': {'A': 0.,   'C': 0.,   'G': 0.,   'T': 1.,  },
                 }

saliency_maps = dict()

for sample_class_label_int in [0,1]:  # for both negative and positive classes
    
    correct_classification_idxs_for_class = np.where(np.array(test_preds == test_ys) & np.array(test_ys == sample_class_label_int))[0].squeeze()

    sort_idxs = np.argsort(test_preds_probs[correct_classification_idxs_for_class, sample_class_label_int])
    best_performing_examples = correct_classification_idxs_for_class[sort_idxs]  # one with highest probability performing is at -1
    
    print(correct_classification_idxs_for_class.shape)
    print(test_preds_probs[correct_classification_idxs_for_class, sample_class_label_int].shape)
    print(best_performing_examples.shape)
    
    
    saliency_maps_for_class = []
    
    then = time.time() 
    for i_sample_idx, sample_idx in enumerate(best_performing_examples):
        if i_sample_idx%(len(best_performing_examples)//20) == 0:
            print(f"{i_sample_idx}/{len(best_performing_examples)}")
        sequence_to_mutate_idx = sample_idx #best_performing_examples[sample_idx] # one with highest probability performing is at -1
        sequence_to_mutate = test_X_seqs[sequence_to_mutate_idx]
        label_of_sequence_to_mutate = test_Ys[sequence_to_mutate_idx]

        #orig_seq, mut_seq, all_one_step_changed = mutate_seq_for_saliencymap( sequence_to_mutate, window_sz = 3, mutation_probs = mutation_probs[1], mutation_rate_pct = 0.5, stride = 3)

        all_one_step_changed = []
        for DNA in unique_DNAs:
            orig_seq, mut_seq, all_changed = mutate_seq_for_saliencymap( sequence_to_mutate, window_sz = 1, mutation_probs = mutation_probs[DNA], mutation_rate_pct = 1., 
                                                                        stride = 1, alphabet = unique_DNAs, include_identicals=True)
            all_one_step_changed.extend(all_changed)


        _, X_onehot_mutated = sequence2encodings(unique_DNAs, [sequence_to_mutate]+all_one_step_changed)
        X_onehot_mutated.shape

        saliencyy_Xs = torch.from_numpy(X_onehot_mutated)
        saliency_Ys = torch.from_numpy(np.zeros(shape=(X_onehot_mutated.shape[0],), dtype=np.long)+label_of_sequence_to_mutate.item())
        saliency_data_idx = torch.from_numpy(np.arange(X_onehot_mutated.shape[0]))
        
        saliency_dummy_num_features = torch.from_numpy(np.zeros((saliencyy_Xs.shape[0],252,1) )).float()
        saliency_dummy_bio_features = torch.from_numpy(np.zeros((saliencyy_Xs.shape[0],292,1) )).float()
        
        
        saliency_dataset = data_utils.TensorDataset(saliencyy_Xs, saliency_Ys, saliency_dummy_num_features, saliency_dummy_bio_features, saliency_data_idx)

        len(saliency_dataset), type(saliencyy_Xs), type(saliency_Ys), saliency_Ys.shape, saliencyy_Xs.dtype, saliency_Ys.dtype

        saliency_loader_batch_sz = 512

        saliency_dataloader = data_utils.DataLoader(saliency_dataset, batch_size = saliency_loader_batch_sz, shuffle = False)


        model = model.eval()
        # compute validation loss, and accuracy metrics        
        with torch.no_grad():

            saliency_ys = []
            saliency_preds = []
            saliency_pred_probs = []
            saliency_loss = []
            saliency_losses = []

            saliency_x_idxs = []
            for saliency_x, saliency_y, dummy_num_features, dummy_bio_features, saliency_x_batch_idxs in saliency_dataloader:
            #for x, z, y in val_dataloader:
                saliency_x = saliency_x.cuda()
                #z = z.cuda()
                saliency_y = saliency_y.cuda()
                #x_lens = torch.sum(torch.any(x !=0, dim = -1), dim = -1)

                #logits = model(x, seq_lengths = x_lens.cpu())                

                saliency_feature_data = []

                if use_comp_features:
                    dummy_num_features = dummy_num_features.cuda()
                    saliency_feature_data.append(dummy_num_features)

                if use_bio_features:
                    dummy_bio_features = dummy_bio_features.cuda()
                    saliency_feature_data.append(dummy_bio_features)
                
                saliency_logits = model(saliency_x, saliency_feature_data)                
                #logits = model(x, z)                
                saliency_log_pred_probs_batch = activationFunc(saliency_logits)
                saliency_preds_batch = torch.argmax(saliency_log_pred_probs_batch, dim = -1)

                #print(saleincy_logits.shape, saleincy_y.shape )
                saliency_loss_batch = loss_function(saliency_logits, saliency_y)

                saliency_loss.append(saliency_loss_batch.item())
                saliency_ys.append(saliency_y.cpu().numpy())
                saliency_pred_probs.append(np.exp(saliency_log_pred_probs_batch.cpu().numpy()))
                saliency_preds.append(saliency_preds_batch.cpu().numpy())
                saliency_x_idxs.append(saliency_x_batch_idxs)

            saliency_losses.append(np.mean(saliency_loss))


        saliency_ys = np.concatenate(saliency_ys) 
        saliency_preds = np.concatenate(saliency_preds)
        saliency_pred_probs = np.concatenate(saliency_pred_probs)
        saliency_x_idxs = np.concatenate(saliency_x_idxs)

        #sample_class_label_int
        #saliency_pred_probs[:5,sample_class_label_int], saliency_preds[:5], np.unique(saliency_ys), label_of_sequence_to_mutate, val_pred_probs[best_performing_examples[-1],sample_class_label_int]
        score_diff = saliency_pred_probs[:, sample_class_label_int] - saliency_pred_probs[0, sample_class_label_int]

        saliency_map = score_diff[1:].reshape(1,-1).reshape(len(unique_DNAs), len(orig_seq))
        score_diff
        saliency_maps_for_class.append(saliency_map)

    now = time.time()
    print(f"Time taken for class{sample_class_label_int}:{now-then} seconds")
    # plot saliency matrix
    saliency_maps[sample_class_label_int] = saliency_maps_for_class
    

# Plot Saliency maps

In [None]:
for sample_class_label_int in [0,1]:    
    score_multipliers = [-1.,1.]
    saliency_map = np.mean(np.array(saliency_maps[sample_class_label_int]), axis = 0) 
    saliency_map *= score_multipliers[sample_class_label_int]
    #print("saliency_map shape:", saliency_map.shape)
    #print(f'For class {class_names[sample_class_label_int]}')
    sns.set()
    #sns.set_context("poster", font_scale = .5, rc={"grid.linewidth": 5})

    span_min = 350
    span_max = 450 + 1
    label_min = span_min - 400
    label_max = label_min + (span_max-span_min) + 1
    plt.figure(figsize=(30,3))
    # sns_cmap = sns.palplot(sns.diverging_palette(240, 10, n=9, as_cmap = True))
    ax= sns.heatmap(saliency_map[:,span_min:span_max], 
                cmap = sns.diverging_palette(220,20, n=250, as_cmap = True), 
                vmax = np.max(saliency_map[:,span_min:span_max]), 
                vmin = np.min(saliency_map[:,span_min:span_max]), 
                center = 0.0, 
                cbar=False)

    #ax.set_aspect("equal")

    plt.grid(True, color='r', linestyle='--', linewidth=2, which = 'both')
    loc = ticker.MultipleLocator(base=2)
    ax.xaxis.set_minor_locator(loc)
    ax.yaxis.set_minor_locator(loc)

    plt.xticks(np.arange(0,span_max-span_min,10)+0.5, np.arange(label_min,label_max,10), fontsize=20)
    plt.yticks(np.arange(0,saliency_map.shape[0])+0.5, unique_DNAs, fontsize=20)
    #ax.set_yticklabels(ax.get_yticklabels(), rotation=0)x sssssssssdem  n
    plt.yticks(rotation=0)
    plt.title("(" + chr(65+sample_class_label_int) + f") For {class_names[sample_class_label_int]} promoter", fontsize=20)
    # ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator())
    # ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator())
    # ax.grid(b=True, which='major', color='w', linewidth=10.0)
    # ax.grid(b=True, which='minor', color='w', linewidth=15)
    gridlinewidth = 1.6
    ax.hlines(np.arange(len(unique_DNAs)+1), *ax.get_xlim(), color='w', linewidth=gridlinewidth)
    ax.vlines(np.arange(0,span_max-span_min+1), *ax.get_ylim(), color='w', linewidth=gridlinewidth)
    plt.show()

# Plot saliency map

# PWM generation step zero for PWM Exp 1 and PWM Exp 2

In [None]:
alphabet = unique_DNAs

# kmers = get_kmers(alphabet, 2)

mutation_probs = {'eq':{'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 }, 
                  'A': {'A': 1.,   'C': 0.,   'G': 0.,   'T': 0.,  },
                  'C': {'A': 0.,   'C': 1.,   'G': 0.,   'T': 0.,  },
                  'G': {'A': 0.,   'C': 0.,   'G': 1.,   'T': 0.,  },
                  'T': {'A': 0.,   'C': 0.,   'G': 0.,   'T': 1.,  },
                 }


subsequences = []

for sample_class_label_int in [0,1]:  # for both negative and positive classes
    
    correct_classification_idxs_for_class = np.where(np.array(test_preds == test_ys) & np.array(test_ys == sample_class_label_int))[0].squeeze()

    sort_idxs = np.argsort(test_preds_probs[correct_classification_idxs_for_class, sample_class_label_int])
    best_performing_examples = correct_classification_idxs_for_class[sort_idxs]  # one with highest probability performing is at -1

#     print(correct_classification_idxs_for_class.shape)
#     print(val_pred_probs[correct_classification_idxs_for_class, sample_class_label_int].shape)
#     print(best_performing_examples.shape)
    
    
    subsequences_for_class = []
    
    then = time.time() 
    for i_sample_idx, sample_idx in enumerate(best_performing_examples):
        sequence_idx = sample_idx #best_performing_examples[sample_idx] # one with highest probability performing is at -1
        sequence = test_X_seqs[sequence_idx]
        label_of_sequence_to_mutate = test_Ys[sequence_idx]

        seq_length = 7
        
        sub_seqs = [sequence[i:i+seq_length] for i in range(len(sequence)-(seq_length-1))] 
        #orig_seq, mut_seq, all_one_step_changed = mutate_seq_for_saliencymap( sequence_to_mutate, window_sz = 3, mutation_probs = mutation_probs[1], mutation_rate_pct = 0.5, stride = 3)
        #rint(len(sequence), len(sub_seqs), sub_seqs[0], sub_seqs[-1],sequence[:10],sequence[-10:])
        subsequences_for_class.append(sub_seqs)

    subsequences.append(subsequences_for_class)

print('Subsequence generation completed')


all_subsequences_class_0  = []
[all_subsequences_class_0.extend(subs) for subs in subsequences[0] ];
#all_unique_subsequences_class_0  = list(set(all_subsequences_class_0) )

all_subsequences_class_1  = []
[all_subsequences_class_1.extend(subs) for subs in subsequences[1] ];
#all_unique_subsequences_class_1  = list(set(all_subsequences_class_1) )

print('Subsequence merging completed')

# PWM Exp 1 : PWM for only unique subsequences appearing in a class, computing score, and taking the median as cut-off

In [None]:
# len(all_unique_subsequences_class_0), len(all_subsequences_class_0), len(all_unique_subsequences_class_1), len(all_subsequences_class_1)

In [None]:
from itertools import count
from collections import defaultdict

c_0 = itertools.count(0)
indexer_0 = lambda: next(c_0)
class_0_unique_subseq_indxs = []
class_0_unique_subseqs_dict = defaultdict(indexer_0)
class_0_unique_subseq_indxs = [class_0_unique_subseqs_dict[subseq] for subseq in all_subsequences_class_0]   


c_1 = itertools.count(0)
indexer_1 = lambda: next(c_1)
class_1_unique_subseq_indxs = []
class_1_unique_subseqs_dict = defaultdict(indexer_1)
class_1_unique_subseq_indxs = [class_1_unique_subseqs_dict[subseq] for subseq in all_subsequences_class_1]   

unique_subseqs_class_0 = list(class_0_unique_subseqs_dict.keys())
unique_subseqs_class_1 = list(class_1_unique_subseqs_dict.keys())

d_class_0 = dict(zip(unique_subseqs_class_0,range(len(unique_subseqs_class_0))))
d_class_1 = dict(zip(unique_subseqs_class_1,range(len(unique_subseqs_class_1))))

t0 = time.time()
#all_subsequences_only_in_class_0 = [subseq for subseq in all_subsequences_class_0 if subseq not in unique_subseqs_class_1]
all_subsequences_only_in_class_0 = [subseq for subseq in all_subsequences_class_0 if d_class_1.get(subseq) is None]
t1 = time.time()
all_subsequences_only_in_class_1 = [subseq for subseq in all_subsequences_class_1 if d_class_0.get(subseq) is None]
# all_subsequences_only_in_class_1 = [subseq for subseq in all_subsequences_class_1 if subseq not in unique_subseqs_class_0]
t2 = time.time()
common_subsequences = [subseq for subseq in all_subsequences_class_1 if not d_class_0.get(subseq) is None]
# common_subsequences = [subseq for subseq in all_subsequences_class_1 if subseq in unique_subseqs_class_0]
t3 = time.time()

print(f"{t1-t0:0.2f}")
print(f"{t2-t1:0.2f}")
print(f"{t3-t2:0.2f}")

In [None]:
len(all_subsequences_class_0), len(all_subsequences_class_1), len(unique_subseqs_class_0), len(unique_subseqs_class_1)

In [None]:
len(all_subsequences_only_in_class_0), len(all_subsequences_only_in_class_1), len(common_subsequences)

In [None]:
len(list(set(all_subsequences_only_in_class_0))), len(list(set(all_subsequences_only_in_class_1))), len(list(set(common_subsequences)))

In [None]:
len(list(set(common_subsequences) | set(all_subsequences_only_in_class_0))),  len(list(set(common_subsequences))), len(list(set(all_subsequences_only_in_class_0)))

In [None]:
unique_subsequences_per_class__list = []
unique_subsequences_per_class__list.append(all_subsequences_only_in_class_0)
unique_subsequences_per_class__list.append(all_subsequences_only_in_class_1)


params = list(model.parameters())
print([param.shape for param in params])

# params[0][0]
print(unique_DNAs)


mean_activations_for_classes = []
PWMs_for_classes = []
layers_idxs_for_classes = []
filters_idxs_for_classes = []

for unique_subsequences in unique_subsequences_per_class__list:
    
    mean_activations = []
    PWMs = []
    layers_idxs = []
    filters_idxs = []

    for l in range(0,6,2): # first layer conv filters only
        param_layer = params[l]

        for i, filter_weights in enumerate(param_layer):

            if(filter_weights.shape[1] < 6):
                continue
            #print("Filter shape:",filter_weights.shape)
            filter_length = filter_weights.shape[1]
            filter_weights = filter_weights.data.permute(1,0)

            motif_sequences = unique_subsequences#get_kmers(unique_DNAs,filter_length)

            label_encodings = np.array([unique_DNAs.index(c)+1  for motif_seq in motif_sequences for c in motif_seq ]).reshape(len(motif_sequences),-1)

            ohs=[(np.arange(len(unique_DNAs)+1) == label_encoding[:,None]).astype(dtype='float32') for label_encoding in label_encodings]#one_hot

            ohs = torch.Tensor(np.array([np.delete(oh,0, axis=-1) for oh in ohs]).reshape(len(ohs),-1, len(unique_DNAs)))
            #print(ohs.shape)
            #print(filter_weights.shape)
            logits = (ohs*filter_weights.cpu())

            unique_subseq_activisions = torch.nn.functional.relu(torch.sum(torch.sum(logits,1),1)).detach().squeeze()
            #print(unique_subseq_activisions.shape)
            passed_sequences_idx = np.where(unique_subseq_activisions > max(unique_subseq_activisions)/2)[0]
            if(len(passed_sequences_idx) <=1):
                continue
            
            #print("activisions:",activisions[activisions > max(activisions)/2].shape)
            mean_activation = torch.mean(unique_subseq_activisions[unique_subseq_activisions > max(unique_subseq_activisions)/2].squeeze())
            mean_activations.append(mean_activation.item())

            
            passed_sequences_motif = [unique_subsequences[i] for i in passed_sequences_idx.squeeze() ] 
            passed_sequences_le = np.array([unique_DNAs.index(c)+1  for motif_seq in passed_sequences_motif for c in motif_seq ]).reshape(len(passed_sequences_motif),-1)
            
            #passed_sequences_motif = [motif_sequences[i] for i in passed_sequences_idx.squeeze()]
            #passed_sequences, print(passed_sequences_motif)

            arr = np.array(passed_sequences_le).squeeze().transpose(1,0)
            probs = [np.sum(arr==i, axis = 1, keepdims=True)/len(passed_sequences_le) for i in np.unique(label_encodings) ]

            PWMs.append(np.array(probs).squeeze())

            layers_idxs.append(l//2)
            filters_idxs.append(i)

    mean_activations_for_classes.append(mean_activations)
    PWMs_for_classes.append(PWMs)
    layers_idxs_for_classes.append(layers_idxs)
    filters_idxs_for_classes.append(filters_idxs)



#unique_subseqs_class_1_dict = dict(zip(unique_subseqs_class_1, list(range(len(unique_subseqs_class_1))) ) )

# len(unique_subseqs_class_0), len(unique_subseqs_class_1)
# recons = [unique_subseqs_class_1[i] for i in class_1_unique_subseq_indxs]
# len(recons), len(all_subsequences_class_1)
# comps = [recons[i]==all_subsequences_class_1[i] for i in range(len(all_subsequences_class_1))]
# np.all(comps)

# PWM Exp 2: PWM for all subsequences appearing in a class (including duplicates), computing score for all, and taking the median as cut-off

In [None]:
all_subsequences_for_classes = []
all_subsequences_for_classes.append(all_subsequences_class_0)
all_subsequences_for_classes.append(all_subsequences_class_1)


unique_subsequences_for_classes = []
unique_subsequences_for_classes.append(unique_subseqs_class_0)
unique_subsequences_for_classes.append(unique_subseqs_class_1)

unique_subsequences_idxs_for_classes = []
unique_subsequences_idxs_for_classes.append(class_0_unique_subseq_indxs)
unique_subsequences_idxs_for_classes.append(class_1_unique_subseq_indxs)


params = list(model.parameters())
print([param.shape for param in params])

# params[0][0]
print(unique_DNAs)


mean_activations_for_classes = []
PWMs_for_classes = []
layers_idxs_for_classes = []
filters_idxs_for_classes = []

for unique_subsequences, unique_subsequences_idx, all_subsequences in zip(unique_subsequences_for_classes,unique_subsequences_idxs_for_classes, all_subsequences_for_classes):
    
    mean_activations = []
    PWMs = []
    layers_idxs = []
    filters_idxs = []

    unique_subsequences_idx = np.array(unique_subsequences_idx)
    for l in range(0,6,2): # first layer conv filters only
        param_layer = params[l]

        for i, filter_weights in enumerate(param_layer):

            if(filter_weights.shape[1] < 6):
                continue
            #print("Filter shape:",filter_weights.shape)
            filter_length = filter_weights.shape[1]
            filter_weights = filter_weights.data.permute(1,0)

            motif_sequences = unique_subsequences#get_kmers(unique_DNAs,filter_length)

            label_encodings = np.array([unique_DNAs.index(c)+1  for motif_seq in motif_sequences for c in motif_seq ]).reshape(len(motif_sequences),-1)

            ohs=[(np.arange(len(unique_DNAs)+1) == label_encoding[:,None]).astype(dtype='float32') for label_encoding in label_encodings]#one_hot

            ohs = torch.Tensor(np.array([np.delete(oh,0, axis=-1) for oh in ohs]).reshape(len(ohs),-1, len(unique_DNAs)))
            #print(ohs.shape)
            #print(filter_weights.shape)
            logits = (ohs*filter_weights.cpu())

            unique_subseq_activisions = torch.nn.functional.relu(torch.sum(torch.sum(logits,1),1))
            
            all_subseq_activations = np.array([unique_subseq_activisions[i].detach() for i in unique_subsequences_idx])
            
            passed_sequences_idx = np.where(all_subseq_activations > max(all_subseq_activations)/2)[0].squeeze()
            if(len(passed_sequences_idx) <=1):
                continue
            
            #print("activisions:",activisions[activisions > max(activisions)/2].shape)
            mean_activation = np.mean(all_subseq_activations[all_subseq_activations > max(all_subseq_activations)/2].squeeze())
            mean_activations.append(mean_activation.item())

            
            passed_sequences_motif = [all_subsequences[i] for i in passed_sequences_idx.squeeze() ] 
            passed_sequences_le = np.array([unique_DNAs.index(c)+1  for motif_seq in passed_sequences_motif for c in motif_seq ]).reshape(len(passed_sequences_motif),-1)
            
            
            [label_encodings[i,:] for i in unique_subsequences_idx[passed_sequences_idx.squeeze() ] ]
            #passed_sequences_motif = [motif_sequences[i] for i in passed_sequences_idx.squeeze()]
            #passed_sequences, print(passed_sequences_motif)

            arr = np.array(passed_sequences_le).squeeze().transpose(1,0)
            probs = [np.sum(arr==i, axis = 1, keepdims=True)/len(passed_sequences_le) for i in np.unique(label_encodings) ]

            PWMs.append(np.array(probs).squeeze())

            layers_idxs.append(l//2)
            filters_idxs.append(i)

    mean_activations_for_classes.append(mean_activations)
    PWMs_for_classes.append(PWMs)
    layers_idxs_for_classes.append(layers_idxs)
    filters_idxs_for_classes.append(filters_idxs)


  

# PWM visualization for PWM Exp 1 and Exp 2

In [None]:
np.set_printoptions(precision=3)
for j in range(len(mean_activations_for_classes)):
    
    mean_activations = mean_activations_for_classes[j]
    PWMs = PWMs_for_classes[j]
    layers_idxs = layers_idxs_for_classes[j]
    filters_idxs = filters_idxs_for_classes[j]

    
    sorted_idxs = np.argsort(mean_activations)

    for i in sorted_idxs:
        mean_activation = mean_activations[i]
        PWM = PWMs[i]
        layer_idx = layers_idxs[i]
        filter_idx = filters_idxs[i]

        print(f"Mean activation:{mean_activation:.3}")
        print(f"Layer {layer_idx}, filter#{filter_idx}\n")
        print(f"{repr(PWM)}")
        print()
        plt.figure(figsize=(12,2))
        ax = plt.imshow(PWM, cmap = plt.get_cmap('Blues') )

        plt.grid(False)

        gridlinewidth = 1.6

        plt.title(str(j))
        plt.show()
        
        print("----------------------------------------------------------------------")
        
    print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print(sorted_idxs)

# PWM Exp 3: PWM for all subsequences

# Motif (PWM) generation from Conv1 layers n_dictxL filter weights*all L-length subsequences, keeping the ones that leads to activations greater than half of the maximum activation, and computing the probability # of each nucleotide at each of L position of the subsequence. Only consider motifs of length 6 or greater

In [None]:
params = list(model.parameters())
print([param.shape for param in params])

# params[0][0]
print(unique_DNAs)

mean_activations = []
PWMs = []
layers_idxs = []
filters_idxs = []

for l in range(0,6,2): # first layer conv filters only
    param_layer = params[l]

    for i, filter_weights in enumerate(param_layer):

        if(filter_weights.shape[1] < 6):
            continue
        #print("Filter shape:",filter_weights.shape)
        filter_length = filter_weights.shape[1]
        filter_weights = filter_weights.data.permute(1,0)

        motif_sequences = get_kmers(unique_DNAs,filter_length)

        label_encodings = np.array([unique_DNAs.index(c)+1  for motif_seq in motif_sequences for c in motif_seq ]).reshape(len(motif_sequences),-1)

        ohs=[(np.arange(len(unique_DNAs)+1) == label_encoding[:,None]).astype(dtype='float32') for label_encoding in label_encodings]#one_hot

        ohs = torch.Tensor(np.array([np.delete(oh,0, axis=-1) for oh in ohs]).reshape(len(ohs),-1, len(unique_DNAs)))
        #print(ohs.shape)
        #print(filter_weights.shape)
        vals = (ohs*filter_weights.cpu())

        activisions = torch.nn.functional.relu(torch.sum(torch.sum(vals,1),1))
        #activisions[np.where(activisions > max(activisions)/2)], max(activisions)/2
        
        #print("activisions:",activisions[activisions > max(activisions)/2].shape)
        mean_activation = torch.mean(activisions[activisions > max(activisions)/2].squeeze())
        
        passed_sequences_idx = np.where(activisions > max(activisions)/2)[0]
        if(len(passed_sequences_idx) <=1):
            continue
        mean_activations.append(mean_activation.item())
        
        passed_sequences = [label_encodings[i] for i in passed_sequences_idx.squeeze()]
        passed_sequences_motif = [motif_sequences[i] for i in passed_sequences_idx.squeeze()]

        #passed_sequences, print(passed_sequences_motif)

        arr = np.array(passed_sequences).squeeze().transpose(1,0)
        probs = [np.sum(arr==i, axis = 1, keepdims=True)/len(passed_sequences) for i in np.unique(label_encodings) ]
        
        PWMs.append(np.array(probs).squeeze())
        
        layers_idxs.append(l//2)
        filters_idxs.append(i)
        
#         print("Mean activation:",mean_activation.item())
#         print(f"Layer {l//2}, filter#{i}\n",np.array(probs).squeeze())
#         print()
#         ax = plt.imshow(np.array(probs).squeeze(), cmap = plt.get_cmap('Blues') )
        
#         plt.grid(False)
        
#         #plt.grid(True, color='w', linestyle='-', linewidth=2, which = 'both')
#         #loc = ticker.MultipleLocator(base=2)
#         #ax.xaxis.set_minor_locator(loc)
#         #ax.yaxis.set_minor_locator(loc)

#         #plt.xticks(np.arange(0,filter_weights.shape[0])+0.5, np.arange(1,filter_weights.shape[0] + 1), fontsize=20)
#         #plt.yticks(np.arange(0,filter_weights.shape[1])+0.5, unique_DNAs, fontsize=20)
#         #ax.set_yticklabels(ax.get_yticklabels(), rotation=0)x sssssssssdem  n
#         #plt.yticks(rotation=0)
#         gridlinewidth = 1.6
#         #ax.hlines(np.arange(len(unique_DNAs)+1), *ax.get_xlim(), color='r', linewidth=gridlinewidth)
#         #ax.vlines(np.arange(0,filter_weights.shape[0]+1), *ax.get_ylim(), color='w', linewidth=gridlinewidth)
#         plt.show()
#         print("----------------------------------------------------------------------")

In [None]:
sorted_idxs = np.argsort(mean_activations)

for i in sorted_idxs:
    mean_activation = mean_activations[i]
    PWM = PWMs[i]
    layer_idx = layers_idxs[i]
    filter_idx = filters_idxs[i]
    
    print("Mean activation:",mean_activation)
    print(f"Layer {layer_idx}, filter#{filter_idx}\n",PWM)
    print()
    ax = plt.imshow(PWM, cmap = plt.get_cmap('Blues') )

    plt.grid(False)

    #plt.grid(True, color='w', linestyle='-', linewidth=2, which = 'both')
    #loc = ticker.MultipleLocator(base=2)
    #ax.xaxis.set_minor_locator(loc)
    #ax.yaxis.set_minor_locator(loc)

    #plt.xticks(np.arange(0,filter_weights.shape[0])+0.5, np.arange(1,filter_weights.shape[0] + 1), fontsize=20)
    #plt.yticks(np.arange(0,filter_weights.shape[1])+0.5, unique_DNAs, fontsize=20)
    #ax.set_yticklabels(ax.get_yticklabels(), rotation=0)x sssssssssdem  n
    #plt.yticks(rotation=0)
    gridlinewidth = 1.6
    #ax.hlines(np.arange(len(unique_DNAs)+1), *ax.get_xlim(), color='r', linewidth=gridlinewidth)
    #ax.vlines(np.arange(0,filter_weights.shape[0]+1), *ax.get_ylim(), color='w', linewidth=gridlinewidth)
    plt.show()
    print("----------------------------------------------------------------------")



# params = list(model.parameters())
# print([param.shape for param in params])

# # params[0][0]
# print(unique_DNAs)



# for l in range(0,6,2): # first layer conv filters only
#     param_layer = params[l]

#     for i, filter_weights in enumerate(param_layer):

#         if(filter_weights.shape[1] < 6):
#             continue
#         print("Filter shape:",filter_weights.shape)
#         filter_length = filter_weights.shape[1]
#         filter_weights = filter_weights.data.permute(1,0)

#         motif_sequences = get_kmers(unique_DNAs,filter_length)

#         label_encodings = np.array([unique_DNAs.index(c)+1  for motif_seq in motif_sequences for c in motif_seq ]).reshape(len(motif_sequences),-1)

#         ohs=[(np.arange(len(unique_DNAs)+1) == label_encoding[:,None]).astype(dtype='float32') for label_encoding in label_encodings]#one_hot

#         ohs = torch.Tensor(np.array([np.delete(oh,0, axis=-1) for oh in ohs]).reshape(len(ohs),-1, len(unique_DNAs)))
#         #print(ohs.shape)
#         #print(filter_weights.shape)
#         vals = (ohs*filter_weights.cpu())


#         activisions = torch.nn.functional.relu(torch.sum(torch.sum(vals,1),1))

#         #activisions[np.where(activisions > max(activisions)/2)], max(activisions)/2
        
#         passed_sequences_idx = np.where(activisions > max(activisions)/2)[0]
#         if(len(passed_sequences_idx) ==0):
#             continue
#         passed_sequences = [label_encodings[i] for i in passed_sequences_idx.squeeze()]
#         passed_sequences_motif = [motif_sequences[i] for i in passed_sequences_idx.squeeze()]

#         #passed_sequences, print(passed_sequences_motif)

#         arr = np.array(passed_sequences).squeeze().transpose(1,0)
#         probs = [np.sum(arr==i, axis = 1, keepdims=True)/len(passed_sequences) for i in np.unique(label_encodings) ]
#         print(f"Layer {l//2}, filter#{i}\n",np.array(probs).squeeze())
#         print()
#         plt.imshow(np.array(probs).squeeze(), cmap = plt.get_cmap('jet') )
#         plt.grid('off')
        
#         plt.show()
        