In [2]:
%load_ext autoreload
%autoreload 2
%load_ext autotime

import os
import sys
sys.path.append('..')

from collections import OrderedDict

import splintr as sp
from splintr.splice import rmats_subset_top_events
sp.verbose = True

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GridSearchCV

import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset, WeightedRandomSampler
import torch.nn.functional as F
from torchsummary import summary

from tqdm import tqdm, tqdm_notebook

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

seed = 99
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads=16



In [3]:
torch.get_num_threads()

2

time: 3.58 ms


# Load and transform dataset

In [4]:
# Parameters
data_dir = '../data/features'
feature_file = f'{data_dir}/SE.txt'
feature_df = rmats_subset_top_events(feature_file, 5)
feature_df = feature_df.loc[feature_df.IncLevelDifference > 0] # upregulated AS events

# Additional parameters for loading data
seq_length = 250
genome_fa = '../data/hg19.fa'
augmented_data = []
k = 10

# Sample from each splice event k times
# Necessary transforms: PadSequence, ToOneHotEncoding; Optional transforms: CropSequence, ReverseComplement
for i in range(k): 
    # Load and perform necessary transforms on dataset
    tf1 = [sp.PadSequence(seq_length), sp.CropSequence(seq_length)]
    augmented_data.append(sp.SpliceEventDataset(feature_file=feature_df,
                                                genome_fa=genome_fa,
                                                transform=tf1))
    
    # Load and perform necessary and optional transforms on dataset
    tf2 = [sp.PadSequence(seq_length), sp.CropSequence(seq_length), sp.ReverseComplement()]
    augmented_data.append(sp.SpliceEventDataset(feature_file=feature_df,
                                                genome_fa=genome_fa,
                                                transform=tf2))
    
splice_dataset = torch.utils.data.ConcatDataset(augmented_data)

HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Pandas Apply', max=905, style=ProgressStyle(description_width…




HBox(children=(IntProgress(value=0, max=905), HTML(value='')))

time: 9.29 s


In [5]:
train_size = int(0.8*len(splice_dataset))
splice_size = int(len(splice_dataset) - train_size)
train_dataset, valid_dataset = torch.utils.data.random_split(splice_dataset, [train_size, splice_size])

# Convert categorical labels to numerical
print(feature_df['sample'].value_counts())
label_names = pd.factorize(feature_df['sample'])

num_classes = int(max(label_names[0]) + 1)
print(f'Classes: {num_classes}')

sample_weights = []
# Balance class sampling using weighted sampler
for dataset in [train_dataset, valid_dataset]:
    labels = [sample[1] for sample in dataset] # get label of each sample
    weights = 100. / pd.Series(labels).value_counts() # class weights
    weights = weights[labels] # 
    sample_weights.append(weights.values)

train_sampler = WeightedRandomSampler(weights=sample_weights[0],
                                      num_samples=len(sample_weights[0]))
valid_sampler = WeightedRandomSampler(weights=sample_weights[1],
                                      num_samples=len(sample_weights[1]))

AQR       234
HNRNPC    220
bg        178
U2AF2     135
RBM15      83
U2AF1      55
Name: sample, dtype: int64
Classes: 6
time: 6.3 s


In [18]:
sp.learning._calc_conv_pad(250, 50, 40, 5)

-7

time: 2.41 ms


In [16]:
# Run(num_classes=6, c1_in=250, c1_out=50, c1_kernel_w=10, c1_filter=64, c1_stride_w=5, c2_out=6, c2_kernel_w=4, c2_filter=8, c2_stride_w=4, fc_out=8, batch_size=128, lr=0.01, weight_decay=0, dropout=0)
params = OrderedDict(
    # model parameters
    num_classes = [num_classes],
    c1_in = [seq_length],
    c1_out = [50],
    c1_kernel_w = [20],
    c1_filter = [64],
    c1_stride_w = [5],
    c2_out = [6],
    c2_kernel_w = [4],
    c2_filter = [8],
    c2_stride_w = [4],
    fc_out = [8],
    
    # hyperparameters
    batch_size = [32],
    lr = [0.0001],
    weight_decay = [0],
    dropout = [0]
)

manager = sp.RunManager()
is_first_run = True
for run in sp.RunBuilder.get_runs(params):    
    # Initialize model and dataset
    network = sp.SplintrNet(num_classes=run.num_classes,
                      c1_in=run.c1_in,
                      c1_out=run.c1_out,
                      c1_kernel_w=run.c1_kernel_w,
                      c1_filter=run.c1_filter,
                      c1_stride_w=run.c1_stride_w,
                      c2_out=run.c2_out,
                      c2_kernel_w=run.c2_kernel_w,
                      c2_filter=run.c2_filter,
                      c2_stride_w=run.c2_stride_w,
                      dropout=run.dropout,
                      fc_out=run.fc_out).cuda(device)
    
    train_loader = DataLoader(train_dataset, batch_size=run.batch_size, sampler=train_sampler)
    valid_loader = DataLoader(valid_dataset, batch_size=run.batch_size, sampler=valid_sampler)

    optimizer = torch.optim.Adam(network.parameters(), lr=run.lr, weight_decay=run.weight_decay)
    log_dir = '/home/ubuntu/tb/8-05-19-6class/'
    # Display brief summary of first model
    if is_first_run:
        is_first_run = False
        summary(network.cuda(), input_size=(4, 4, seq_length), device='cuda')
#         util.show_sample(train_dataset[np.random.randint(len(train_dataset))], class_names=label_names)
    
    # Perform training
    manager.begin_run(run, network, train_loader, valid_loader, log_dir)
    network.cuda()
    for epoch in range(30):
        
        manager.begin_epoch()
        
        # Train on batch
        for batch in train_loader:
            seqs, labels = batch
            preds = network(seqs.cuda(device)) # pass batch
            loss = F.cross_entropy(preds, labels.cuda(device)) # calculate loss
            optimizer.zero_grad() # zero gradients
            loss.backward() # calculate gradients
            optimizer.step() # update weights

            manager.track_train_loss(loss)
            manager.track_train_num_correct(preds, labels.cuda(device))
        
        # Check validation set
        with torch.no_grad():
            for data in valid_loader:
                seqs, labels = data
                preds = network(seqs.cuda(device))
                loss = F.cross_entropy(preds, labels.cuda(device))
                
                manager.track_valid_loss(loss)
                manager.track_valid_num_correct(preds, labels.cuda(device))

        manager.end_epoch()
    manager.end_run(train_class_names=label_names[1],
                    valid_class_names=label_names[1])
manager.save('../results')

Unnamed: 0,run,epoch,train_loss,valid_loss,train_accuracy,valid_accuracy,epoch_duration,run_duration,num_classes,c1_in,...,c1_stride_w,c2_out,c2_kernel_w,c2_filter,c2_stride_w,fc_out,batch_size,lr,weight_decay,dropout
0,1,1,1.79407,1.806097,0.16982,0.175138,7.818209,7.861254,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
1,1,2,1.793703,1.805805,0.176519,0.169613,7.807177,15.777744,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
2,1,3,1.793577,1.804833,0.176865,0.188398,7.800598,23.688742,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
3,1,4,1.792821,1.803963,0.187707,0.179282,7.831136,31.624688,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
4,1,5,1.792118,1.803683,0.184323,0.183978,7.821724,39.562947,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
5,1,6,1.789671,1.798516,0.188052,0.196685,7.800498,47.477336,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
6,1,7,1.787544,1.795165,0.195166,0.209669,7.814536,55.406424,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
7,1,8,1.784245,1.795464,0.204282,0.215193,7.804457,63.328367,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
8,1,9,1.780774,1.788628,0.212086,0.212983,7.97244,71.4191,6,250,...,5,6,4,8,4,8,32,0.0001,0,0
9,1,10,1.778056,1.789997,0.223964,0.231215,7.812672,79.347114,6,250,...,5,6,4,8,4,8,32,0.0001,0,0


time: 4min 3s
