In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import os
import sys
sys.path.append('../bin')

from collections import OrderedDict

import util
from util import ConvNet, SpliceSeqDataset, RunBuilder, RunManager

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision.transforms import Compose
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torch.nn import functional as F

from torchsummary import summary

from tqdm import tqdm, tqdm_notebook

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

seed = 99
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads=16

# Load data

## Generate training and validation datasets

In [6]:
# Parameters
feature_prefix = '../data/features/SE'
feature_suffix = ['upstream.bed', 'cassette_5p.bed', 'cassette_3p.bed', 'downstream.bed']
feature_files = [f'{feature_prefix}_10class_dedup_{suffix}' for suffix in feature_suffix]

# Convert categorical labels to numerical
label_names = pd.read_csv(feature_files[0], sep='\t', header=None)[3]

label_names = pd.factorize(label_names)
labels = label_names[0]
print(len(pd.Series(label_names).value_counts()))

# Additional parameters for loading data
seq_length = 400
num_classes = int(max(labels) + 1)
genome_fa = '../data/hg19.fa'
augmented_data = []
k = 1

# Sample from each splice event k times
# Necessary transforms: PadSequence, ToOneHotEncoding; Optional transforms: CropSequence, ReverseComplement
for i in range(k): 
    tf1 = [util.PadSequence(seq_length), util.CropSequence(seq_length), util.ToOneHotEncoding()]
    augmented_data.append(SpliceSeqDataset(feature_files=feature_files,
                                       genome_fa=genome_fa,
                                       transform=Compose(tf1)))
    
    tf2 = [util.PadSequence(seq_length), util.CropSequence(seq_length), util.ReverseComplement(), util.ToOneHotEncoding()]
    augmented_data.append(SpliceSeqDataset(feature_files=feature_files,
                                           genome_fa=genome_fa,
                                           transform=Compose(tf2)))
    
splice_dataset = torch.utils.data.ConcatDataset(augmented_data)

# Randomly split into training and validation datasets at 0.8:0.2 ratio
train_dataset, valid_dataset = torch.utils.data.random_split(splice_dataset,
                                                             [round(len(splice_dataset)*0.8),
                                                              round(len(splice_dataset)*0.2)])

# Balance class sampling using weighted sampler
class_sample_weights = 100. / pd.Series(label_names[0]).value_counts()
class_sample_weights = class_sample_weights[label_names[0]].values
train_sampler = WeightedRandomSampler(weights=[class_sample_weights[splice_dataset[i][1]] for i in train_dataset.indices],
                                      num_samples=len(train_dataset))

2


In [4]:
util.calc_conv_pad(50, 20, 3, 3)

5

In [7]:
params = OrderedDict(
    # model parameters
    num_classes = [num_classes],
    c1_in = [seq_length],
    c1_out = [100],
    c1_kernel_w = [20],
    c1_filter = [64, 256],
    c1_stride_w = [4],
    c2_out = [20],
    c2_kernel_w = [3],
    c2_filter = [64, 128],
    c2_stride_w = [3],
    fc_out = [128, 256, 512],
    
    # hyperparameters
    batch_size = [10, 32, 64],
    lr = [0.001, 0.0001, 0.00001]
)

manager = RunManager()
is_first_run = True
for run in RunBuilder.get_runs(params):    
    # Initialize model and dataset
    network = ConvNet(num_classes=run.num_classes,
                      c1_in=run.c1_in,
                      c1_out=run.c1_out,
                      c1_kernel_w=run.c1_kernel_w,
                      c1_filter=run.c1_filter,
                      c1_stride_w=run.c1_stride_w,
                      c2_out=run.c2_out,
                      c2_kernel_w=run.c2_kernel_w,
                      c2_filter=run.c2_filter,
                      c2_stride_w=run.c2_stride_w,
                      fc_out=run.fc_out).cuda(device)
    
    loader = DataLoader(train_dataset, batch_size=run.batch_size, sampler=train_sampler)
    optimizer = torch.optim.Adam(network.parameters(), lr=run.lr)
    log_dir = '/home/ubuntu/tb/7-26-19-10class/'
    # Display brief summary of first model
#     if is_first_run:
#         is_first_run = False
#         summary(network.cuda(), input_size=(4, 4, seq_length), device='cuda')
#         util.show_sample(train_dataset[np.random.randint(len(train_dataset))], class_names=label_names)
    
    # Perform training
    manager.begin_run(run, network, loader, log_dir)
    for epoch in range(50):
        
        manager.begin_epoch()
        # Process batch
        for batch in loader:
            seqs, labels = batch
            network.cuda()
            preds = network(seqs.cuda(device)) # pass batch
            loss = F.cross_entropy(preds, labels.cuda(device)) # calculate loss
            optimizer.zero_grad() # zero gradients
            loss.backward() # calculate gradients
            optimizer.step() # update weights

            manager.track_loss(loss)
            manager.track_num_correct(preds, labels.cuda(device))
            
        manager.end_epoch()
    manager.end_run(class_names=label_names[1])
# manager.save('../results')

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,num_classes,c1_in,c1_out,c1_kernel_w,c1_filter,c1_stride_w,c2_out,c2_kernel_w,c2_filter,c2_stride_w,fc_out,batch_size,lr
0,1,1,2.256841,0.216895,1.398219,1.427850,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
1,1,2,2.269192,0.203196,1.390165,2.935245,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
2,1,3,2.244078,0.228311,1.384100,4.436312,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
3,1,4,2.261772,0.210616,1.377707,5.911416,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
4,1,5,2.268051,0.204338,1.392865,7.422530,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
5,1,6,2.264055,0.208333,1.378861,8.920123,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
6,1,7,2.261201,0.211187,1.378823,10.415973,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
7,1,8,2.272617,0.199772,1.383698,11.918501,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
8,1,9,2.261772,0.208333,1.377993,13.424963,10,400,100,20,64,4,20,3,64,3,128,10,0.00100
9,1,10,2.277754,0.194635,1.384902,14.911545,10,400,100,20,64,4,20,3,64,3,128,10,0.00100


KeyboardInterrupt: 