In [None]:
import os
import sys
sys.path.append('/home/kal/TF_models/bin/')
os.environ['CUDA_VISIBLE_DEVICES'] = '2' # Must be before importing keras!
import tf_memory_limit

from keras.models import Model
from keras.callbacks import ModelCheckpoint, TensorBoard, Callback, LearningRateScheduler
from keras.utils import plot_model
from keras.layers import Input, Lambda, Dense, Conv1D, Activation
from keras.optimizers import SGD, Adam, RMSprop
import keras.backend as K
import tensorflow as tf


import numpy as np
import matplotlib.pylab as plt 
from sklearn.metrics import precision_recall_curve
from scipy.integrate import trapz
from tqdm import tqdm
import ucscgenome
import pandas as pd
import pickle
import time

import sequence
import train_TFmodel

In [None]:
# set up directories
out_path = os.path.join('/home/kal/TF_models/seq_only/count_regression/zinb_CTCF')
os.makedirs(out_path)
weights_path = os.path.join(out_path, 'intermediate_weights')
os.makedirs(weights_path)
history_path = os.path.join(out_path, 'history')
os.makedirs(history_path)

#load in data
bed_path = '/home/kal/TF_models/data/count_regression/ctcf_regions_9_seqs.bed'
columns='chr start end name score nucs c1 c2 c3 c4 c5 c6 c7 c8 c9'
score_columns ='c1 c2 c3 c4 c5 c6 c7 c8 c9'.split()
peaks = pd.read_table(bed_path, header=None)
peaks.columns = columns.split()

# macro variables
prediction_window = 256
half_window = prediction_window // 2
num_training_examples = sum(peaks.chr != 'chr8')
batch_size = 32
drop_rate = 0.1
conv_string='32.3_32.32_32.3_16.3'
conv_list = [[int(x) for x in cell.split('.')] for cell in conv_string.split('_')]
num_outputs = len(score_columns)
num_epochs=10

In [None]:
def native_gen(mode='train', once=False):
    """Generate a positive seqeunce sample."""
    done = False
    if mode == 'test':
        indices = np.asarray(peaks[peaks.chr == 'chr8'].index.values)
        indices = [x for x in indices if x%2 == 0]
    elif mode =='val':
        indices = np.asarray(peaks[peaks.chr == 'chr8'].index.values)
        indices = [x for x in indices if x%2 == 1]
    else:
        indices = np.asarray(peaks[peaks.chr != 'chr8'].index.values)
    while not done:
        np.random.shuffle(indices)
        for idx in indices:
            if len(score_columns) == 1:
                yield peaks.get_value(idx, 'nucs'), peaks.get_value(idx, score_columns)
            else:
                scores=list()
                for c in score_columns:
                    scores.append(peaks.get_value(idx, c))
                yield peaks.get_value(idx, 'nucs'), np.asarray(scores)
            done = once
            
def scrambled_gen(scrambled, mode='train'):
        posgen = native_gen(mode=mode)
        if prediction_window % scrambled != 0:
            print(str(scrambled) + 'mers do not evenly divide the sequence.')
            scrambled = 1
        for p, q in posgen:
            p = np.asarray([base for base in p])
            p = p.reshape((-1,scrambled))
            np.random.shuffle(p)
            p = p.reshape([-1])
            yield ''.join(p) 
            
            
def pair_gen(mode='train', once=False, batch_size=32):
    """Generate batched of paired samples."""
    p = native_gen(mode=mode, once=once)
    n = scrambled_gen(2, mode=mode)
    while True:
        pos_seqs = list()
        neg_seqs = list()
        scores = list()
        for i in range(batch_size // 2):
            pos_seq, score = next(p)
            neg_seq = next(n)
            pos_seqs.append(sequence.encode_to_onehot(pos_seq))
            neg_seqs.append(sequence.encode_to_onehot(neg_seq))
            scores.append(score)
        labels = np.append(np.asarray(scores), np.zeros((32 // 2, len(scores[0]))), axis=0)
        yield np.asarray(pos_seqs + neg_seqs), labels

In [None]:
sys.path.append('/home/thouis/basenji_embeddings')
from zinb import ZINB

# layers
input = Input(batch_shape=(batch_size, prediction_window, 4))
add_RC_to_batch = Lambda(lambda x: K.concatenate([x, x[:, ::-1, ::-1]], axis=0), output_shape=lambda s: (2 * s[0], s[1], s[2]))
per_base_score = train_TFmodel.BasicConv(prediction_window, conv_list, final_activation=None, num_outputs=num_outputs, drop_rate=drop_rate)(add_RC_to_batch(input))
wide_scan = Conv1D(num_outputs, 50, use_bias=False, kernel_initializer='ones', trainable=False, name='wide_scan', padding='valid')
max_by_direction = Lambda(lambda x: K.maximum(K.max(x[:x.shape[0]//2, :, :], axis=1), K.max(x[x.shape[0]//2:, ::-1, :], axis=1)), name='stackmax', output_shape=lambda s: (s[0] // 2, num_outputs))
predictions = train_TFmodel.Bias(num_outputs, name='bias')(max_by_direction(wide_scan(per_base_score)))

# build the model
model = Model(inputs=[input], outputs=[predictions])

#zinb stuff
pi_layer = Dense(num_outputs, activation='sigmoid')

pi = max_by_direction(pi_layer(add_RC_to_batch(input)))       # not sure what layer to put here


zinb = ZINB(pi, theta_init=tf.zeros([1, num_outputs]))
model.layers[-1].trainable_weights.extend([zinb.theta_variable,
                                           *pi_layer.trainable_weights])

# save a graph of the model configuration
#plot_model(model, to_file=os.path.join(out_path, conv_string + '_model.png'), show_shapes=True)

In [None]:
# train the model
num_batches = len(peaks[peaks.chr != 'chr8']) // batch_size
print(str(num_batches) + ' batches')
val_steps = len(peaks[peaks.chr == 'chr8']) // (batch_size * 2)
verb=1

opt = Adam(beta_1=0.95, lr=1e-5, epsilon=.1)
opt = RMSprop(lr=1e-6)
model.compile(optimizer=opt,
              loss=zinb.loss)  # zinb loss function

checkpath = os.path.join(weights_path, '_'.join(['weights_{epoch:02d}_{val_loss:.2f}.hdf5']))
checkpoint = ModelCheckpoint(checkpath, verbose=verb, monitor='val_loss', mode='min')

#history = model.fit_generator(pair_gen(), steps_per_epoch=num_batches/(batch_size), epochs=200,
#                   callbacks=[checkpoint], validation_data=pair_gen(mode='val'), 
#                   validation_steps=val_steps, verbose=2)
        
history = model.fit_generator(pair_gen(), steps_per_epoch=2, epochs=200,
                   callbacks=[checkpoint], validation_data=pair_gen(mode='val'), 
                   validation_steps=2, verbose=2)