In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import os
import sys
sys.path.append('../bin')

from collections import OrderedDict

import util
import splintr
from util import RunBuilder, RunManager
from splintr import SplintrNet, SpliceEventDataset

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import GridSearchCV

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, ConcatDataset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.nn import functional as F

from torchsummary import summary

from tqdm import tqdm, tqdm_notebook

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

seed = 99
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads=16

# Load data

## Load and transform dataset

In [3]:
# Parameters
feature_prefix = '../data/features/SE'
feature_suffix = ['upstream.bed', 'cassette_5p.bed', 'cassette_3p.bed', 'downstream.bed']
feature_files = [f'{feature_prefix}_10class_dedup_{suffix}' for suffix in feature_suffix]

# Additional parameters for loading data
seq_length = 400
genome_fa = '../data/hg19.fa'
augmented_data = []
k = 1

# Sample from each splice event k times
# Necessary transforms: PadSequence, ToOneHotEncoding; Optional transforms: CropSequence, ReverseComplement
for i in range(k): 
    # Load and perform necessary transforms on dataset
    tf1 = [splintr.PadSequence(seq_length), splintr.CropSequence(seq_length)]
    augmented_data.append(SpliceEventDataset(feature_files=feature_files,
                                       genome_fa=genome_fa,
                                       transform=tf1))
    
    # Load and perform necessary and optional transforms on dataset
    tf2 = [splintr.PadSequence(seq_length), splintr.CropSequence(seq_length), splintr.ReverseComplement()]
    augmented_data.append(SpliceEventDataset(feature_files=feature_files,
                                           genome_fa=genome_fa,
                                           transform=tf2))
    
splice_dataset = torch.utils.data.ConcatDataset(augmented_data)

train_size = int(0.8*len(splice_dataset))
splice_size = int(0.2*len(splice_dataset))
train_dataset, valid_dataset = torch.utils.data.random_split(splice_dataset, [train_size, splice_size])

# Convert categorical labels to numerical
label_names = pd.read_csv(feature_files[0], sep='\t', header=None)[3]
label_names = np.repeat(label_names, k*2)
label_names = pd.factorize(label_names)
print(len(pd.Series(label_names).value_counts()))

num_classes = int(max(label_names[0]) + 1)

sample_weights = []
# Balance class sampling using weighted sampler
for dataset in [train_dataset, valid_dataset]:
    labels = [sample[1] for sample in dataset] # get label of each sample
    weights = 100. / pd.Series(labels).value_counts() # class weights
    weights = weights[labels] # 
    sample_weights.append(weights.values)

train_sampler = WeightedRandomSampler(weights=sample_weights[0],
                                      num_samples=len(sample_weights[0]))
valid_sampler = WeightedRandomSampler(weights=sample_weights[1],
                                      num_samples=len(sample_weights[1]))

2


In [4]:
params = OrderedDict(
    # model parameters
    num_classes = [num_classes],
    c1_in = [seq_length],
    c1_out = [100],
    c1_kernel_w = [20],
    c1_filter = [64, 256],
    c1_stride_w = [4],
    c2_out = [20],
    c2_kernel_w = [3],
    c2_filter = [128],
    c2_stride_w = [3],
    fc_out = [256, 512],
    
    # hyperparameters
    batch_size = [64],
    lr = [0.001]
)

manager = RunManager()
is_first_run = True
for run in RunBuilder.get_runs(params):    
    # Initialize model and dataset
    network = SplintrNet(num_classes=run.num_classes,
                      c1_in=run.c1_in,
                      c1_out=run.c1_out,
                      c1_kernel_w=run.c1_kernel_w,
                      c1_filter=run.c1_filter,
                      c1_stride_w=run.c1_stride_w,
                      c2_out=run.c2_out,
                      c2_kernel_w=run.c2_kernel_w,
                      c2_filter=run.c2_filter,
                      c2_stride_w=run.c2_stride_w,
                      fc_out=run.fc_out).cuda(device)
    
    train_loader = DataLoader(train_dataset, batch_size=run.batch_size, sampler=train_sampler)
    valid_loader = DataLoader(valid_dataset, batch_size=run.batch_size, sampler=valid_sampler)

    optimizer = torch.optim.Adam(network.parameters(), lr=run.lr)
    log_dir = '/home/ubuntu/tb/7-29-19-10class/'
    # Display brief summary of first model
#     if is_first_run:
#         is_first_run = False
#         summary(network.cuda(), input_size=(4, 4, seq_length), device='cuda')
#         util.show_sample(train_dataset[np.random.randint(len(train_dataset))], class_names=label_names)
    
    # Perform training
    manager.begin_run(run, network, train_loader, valid_loader, log_dir)
    network.cuda()
    for epoch in range(30):
        
        manager.begin_epoch()
        
        # Train on batch
        for batch in train_loader:
            seqs, labels = batch
            preds = network(seqs.cuda(device)) # pass batch
            loss = F.cross_entropy(preds, labels.cuda(device)) # calculate loss
            optimizer.zero_grad() # zero gradients
            loss.backward() # calculate gradients
            optimizer.step() # update weights

            manager.track_train_loss(loss)
            manager.track_train_num_correct(preds, labels.cuda(device))
        
        # Check validation set
        with torch.no_grad():
            for data in valid_loader:
                seqs, labels = data
                preds = network(seqs.cuda(device))
                loss = F.cross_entropy(preds, labels.cuda(device))
                
                manager.track_valid_loss(loss)
                manager.track_valid_num_correct(preds, labels.cuda(device))

        manager.end_epoch()
    manager.end_run(train_class_names=[sample[1] for sample in train_dataset],
                    valid_class_names=[sample[1] for sample in valid_dataset])
manager.save('../results')

Unnamed: 0,run,epoch,train_loss,valid_loss,train_accuracy,valid_accuracy,epoch_duration,run_duration,num_classes,c1_in,...,c1_kernel_w,c1_filter,c1_stride_w,c2_out,c2_kernel_w,c2_filter,c2_stride_w,fc_out,batch_size,lr
0,1,1,2.341366,2.327524,0.140982,0.194064,1.501570,1.633299,10,400,...,20,64,4,20,3,128,3,256,64,0.001
1,1,2,2.274475,2.314246,0.231735,0.157534,1.211799,3.127759,10,400,...,20,64,4,20,3,128,3,256,64,0.001
2,1,3,2.174091,2.293837,0.352740,0.187215,1.213675,4.619667,10,400,...,20,64,4,20,3,128,3,256,64,0.001
3,1,4,2.069532,2.266576,0.472603,0.235160,1.212054,6.148105,10,400,...,20,64,4,20,3,128,3,256,64,0.001
4,1,5,1.956765,2.258410,0.577055,0.235160,1.212661,7.680177,10,400,...,20,64,4,20,3,128,3,256,64,0.001
5,1,6,1.885698,2.311547,0.644977,0.168950,1.218637,9.218761,10,400,...,20,64,4,20,3,128,3,256,64,0.001
6,1,7,1.842381,2.280912,0.676370,0.207763,1.206686,10.745524,10,400,...,20,64,4,20,3,128,3,256,64,0.001
7,1,8,1.784826,2.298970,0.738584,0.194064,1.225987,12.254277,10,400,...,20,64,4,20,3,128,3,256,64,0.001
8,1,9,1.758735,2.292335,0.761986,0.198630,1.202674,13.740095,10,400,...,20,64,4,20,3,128,3,256,64,0.001
9,1,10,1.712317,2.248508,0.817922,0.251142,1.216353,15.241376,10,400,...,20,64,4,20,3,128,3,256,64,0.001
