In [1]:
%matplotlib inline

In [2]:
import glob
from natsort import natsorted
import time

import numpy as np
import pandas as pd
from tqdm import tnrange, tqdm_notebook

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

from basenji import blocks
from basenji import metrics

In [3]:
TFR_INPUT = 'sequence'
TFR_OUTPUT = 'target'

seq_length = 262144
target_length = 2048
num_targets = 1312

def file_to_records(filename):
    return tf.data.TFRecordDataset(filename, compression_type='ZLIB')

def generate_parser(raw=False):

    def parse_proto(example_protos):
        """Parse TFRecord protobuf."""

        features = {
            TFR_INPUT: tf.io.FixedLenFeature([], tf.string),
            TFR_OUTPUT: tf.io.FixedLenFeature([], tf.string)
        }
        parsed_features = tf.io.parse_single_example(example_protos, features=features)

        sequence = tf.io.decode_raw(parsed_features[TFR_INPUT], tf.uint8)
        if not raw:
            sequence = tf.reshape(sequence, [seq_length, 4])
            sequence = tf.cast(sequence, tf.float32)

        targets = tf.io.decode_raw(parsed_features[TFR_OUTPUT], tf.float16)
        if not raw:
            targets = tf.reshape(targets, [target_length, num_targets])
            targets = tf.cast(targets, tf.float32)

        return sequence, targets

    return parse_proto

def make_dataset(tfr_pattern, batch_size, mode):
    """Make Dataset w/ transformations."""

    # initialize dataset from TFRecords glob
    tfr_files = natsorted(glob.glob(tfr_pattern))
    dataset = tf.data.Dataset.list_files(tf.constant(tfr_files), shuffle=False)

    # train
    if mode == tf.estimator.ModeKeys.TRAIN:
        # repeat
        dataset = dataset.repeat()

    # flat mix files
    dataset = dataset.flat_map(file_to_records)
    
    # parse
    dataset = dataset.map(generate_parser())

    # batch
    dataset = dataset.batch(batch_size)

    return dataset

In [4]:
tfr_train_full = 'data/tfrecords/train-0.tfr'
train_data = make_dataset(tfr_train_full, 2, tf.estimator.ModeKeys.TRAIN) 

In [5]:
tfr_eval_full = 'data/tfrecords/valid-0.tfr'
eval_data = make_dataset(tfr_eval_full, 2, tf.estimator.ModeKeys.EVAL)

In [6]:
def make_model():
    sequence = tf.keras.Input(shape=(seq_length, 4), name='sequence')
    current = sequence

    current = blocks.conv_tower(current, filters_init=64, repeat=7, kernel_size=5,
                     activation='relu', pool_size=2,
                     batch_norm=True, bn_momentum=0.99)

    preds = blocks.dense(current, units=1312)

    return tf.keras.Model(inputs=sequence, outputs=preds)

In [7]:
num_epochs = 2
train_batches = 32
valid_batches = 32

In [8]:
# optimizer1 = tf.keras.optimizers.SGD(learning_rate=.005, momentum=0.99)

# model1 = make_model()
# model1.compile(loss='poisson', optimizer=optimizer1,
#                metrics=[metrics.PearsonR(num_targets)])

# model1.fit(train_data, epochs=num_epochs, steps_per_epoch=train_batches,
#            validation_data=eval_data, validation_steps=valid_batches)

In [9]:
# model1.evaluate(eval_data, steps=valid_batches)

In [11]:
optimizer2 = tf.keras.optimizers.SGD(learning_rate=.005, momentum=0.99)

model2 = make_model()
model2.compile(loss='poisson', optimizer=optimizer2,
               metrics=[metrics.PearsonR(num_targets)])
               
loss_fn = tf.keras.losses.Poisson()
train_loss = tf.keras.metrics.Poisson()
train_r = metrics.PearsonR(num_targets)
valid_r = metrics.PearsonR(num_targets)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        pred = model2(x, training=True)
        loss = loss_fn(y, pred)
    train_loss(y, pred)
    train_r(y, pred)
    gradients = tape.gradient(loss, model2.trainable_variables)
    optimizer2.apply_gradients(zip(gradients, model2.trainable_variables))

for ei in range(num_epochs):
    # train
    t0 = time.time()
    si = 0
    for x, y in train_data:
        train_step(x, y)
        si += 1
        if si >= train_batches:
            break

    train_loss_epoch = train_loss.result().numpy()
    train_r_epoch = train_r.result().numpy()
    print('Epoch %d - %ds - train_loss: %.4f - train_r: %.4f' % \
          (ei, (time.time()-t0), train_loss_epoch, train_r_epoch), end='')

    # valid
    valid_loss, valid_pr = model2.evaluate(eval_data, steps=valid_batches, verbose=0)
    print(' - valid_loss: %.4f - valid_r: %.4f' % (valid_loss, valid_pr), flush=True)

    # reset
    train_loss.reset_states()
    train_r.reset_states()

Epoch 0 - 96s - train_loss: 0.9996 - train_r: 0.0008 - valid_loss: 1.0204 - valid_r: 0.0014
Epoch 1 - 87s - train_loss: 0.9631 - train_r: 0.0021 - valid_loss: 0.8511 - valid_r: 0.0063
