In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
import os
from datetime import datetime
import numpy as np
print(tf.version.VERSION)


2.6.0-dev20210603


TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


In [2]:
train_filepath_dataset = tf.data.TFRecordDataset.list_files("data/train/*.tfrecord", shuffle=True)
train_num_files = train_filepath_dataset.cardinality()
train_dataset = train_filepath_dataset.interleave(
    lambda x: tf.data.TFRecordDataset(x),
    cycle_length=train_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


In [3]:
validate_filepath_dataset = tf.data.TFRecordDataset.list_files("data/valid/*.tfrecord", shuffle=True)
validate_num_files = validate_filepath_dataset.cardinality()
validate_dataset = validate_filepath_dataset.interleave(
    lambda x: tf.data.TFRecordDataset(x),
    cycle_length=validate_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


In [4]:
context_features = {"population_size" : tf.io.FixedLenFeature((), tf.int64),
                   "selection_coefficient" : tf.io.FixedLenFeature((), tf.float32)}
sequence_features = {"raw_trait_frequencies" : tf.io.RaggedFeature(tf.float32)}

pop_size_min = 50
pop_size_max = 500
sc_min = 0.0
sc_max = 1.0
min_survival = 0 # number of gens survived is min_survival - 2 (the starting and final freq are always recorded)
max_time_steps = 45 # truncate trait frequencies at max_time_steps to avoid OOM issues
num_trait_data = 1000

def preprocess_and_split_into_tuples(tfrecord):
    # parse sequence example and normalise
    example = tf.io.parse_sequence_example(tfrecord, context_features=context_features, sequence_features=sequence_features)
    trait_frequencies = example[1]["raw_trait_frequencies"]
    pop_size_norm = (tf.cast(example[0]["population_size"], dtype=tf.float32) - pop_size_min) / (pop_size_max - pop_size_min)
    sc_norm = (tf.cast(example[0]["selection_coefficient"], dtype=tf.float32) - sc_min) / (sc_max - sc_min)
    # filter dataset to only include num_trait_data trials in which the trait survives at least min_survival
    trait_frequencies = tf.gather(trait_frequencies, tf.where(trait_frequencies.row_lengths() > min_survival), axis=0)
        
    trait_frequencies = tf.gather(trait_frequencies, tf.random.uniform(shape=[num_trait_data], minval=0,
                                                                      maxval=trait_frequencies.nrows(),
                                                                      dtype=tf.int32))    
    trait_frequencies = trait_frequencies[:, :, 0:max_time_steps] 
    # cast labels into a list of tensors matching the features
    label_tensors = tf.reshape(tf.stack((tf.repeat(pop_size_norm, num_trait_data), tf.repeat(sc_norm, num_trait_data)), axis=1), 
                               shape=(num_trait_data, 2))   
    # convert trait_frequencies to dense tensor (with mask value of -1)
    trait_frequencies = tf.squeeze(trait_frequencies, axis=1)
    trait_frequencies = trait_frequencies.to_tensor(default_value = -1., 
                                                    shape=(num_trait_data, max_time_steps))    
    # add final dimension (keras requires an input shape of (batch_size, timesteps, features))
    trait_frequencies = tf.expand_dims(trait_frequencies, axis=2)

    return tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(trait_frequencies), tf.data.Dataset.from_tensor_slices(label_tensors)))


In [5]:
train_dataset = train_dataset.interleave(
    preprocess_and_split_into_tuples,
    cycle_length=train_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [6]:
validate_dataset = validate_dataset.interleave(
    preprocess_and_split_into_tuples,
    cycle_length=validate_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


In [28]:
batch_size = 1024
train_dataset = train_dataset.shuffle(buffer_size=num_trait_data*train_num_files, reshuffle_each_iteration=True).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
validate_dataset = validate_dataset.shuffle(buffer_size=num_trait_data*validate_num_files, reshuffle_each_iteration=True).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)


In [None]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Masking(mask_value = -1., input_shape=(None,1)))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True)))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=False)))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(30, activation='elu', kernel_initializer='he_normal'))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(2))

model.summary()

In [None]:
def scheduler(epoch, lr):
    if epoch < 3:
        return lr
    else:
        return lr * tf.math.exp(-0.2)

scheduler_cb = tf.keras.callbacks.LearningRateScheduler(scheduler)

earlystopping_cb = tf.keras.callbacks.EarlyStopping(patience=7, restore_best_weights=True)

checkpoints = "checkpoints/" + datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_filepath = checkpoints + "checkpoint_vanilla_{epoch:02d}"

model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=False,
    save_freq='epoch')

optimizer = tf.keras.optimizers.Adam(0.01)
model.compile(loss="mse", optimizer=optimizer, metrics=[tf.keras.metrics.MeanAbsoluteError()])

In [None]:
history = model.fit(train_dataset, epochs=50, validation_data=validate_dataset, 
                    callbacks=[model_checkpoint_cb, scheduler_cb, earlystopping_cb])

In [10]:
model.load_weights('checkpoints/20210610-192445checkpoint_layernorm_09')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f322017d588>

In [7]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Masking(mask_value = -1., input_shape=(None,1)))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.RNN(tfa.rnn.LayerNormLSTMCell(60, recurrent_dropout=0.25), return_sequences=True)))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.RNN(tfa.rnn.LayerNormLSTMCell(60, recurrent_dropout=0.25), return_sequences=False)))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(30, activation='elu', kernel_initializer='he_normal'))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(2))

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masking (Masking)            (None, None, 1)           0         
_________________________________________________________________
bidirectional (Bidirectional (None, None, 120)         31920     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 120)               89040     
_________________________________________________________________
dropout (Dropout)            (None, 120)               0         
_________________________________________________________________
dense (Dense)                (None, 30)                3630      
_________________________________________________________________
dropout_1 (Dropout)          (None, 30)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 6

In [8]:
def scheduler(epoch, lr):
    if epoch < 4:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

scheduler_cb = tf.keras.callbacks.LearningRateScheduler(scheduler)
earlystopping_cb = tf.keras.callbacks.EarlyStopping(patience=6, restore_best_weights=True)

checkpoints = "checkpoints/" + datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_filepath = checkpoints + "checkpoint_layernorm_{epoch:02d}"

model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=False,
    save_freq='epoch')

optimizer = tf.keras.optimizers.Adam(0.01)
model.compile(loss="mse", optimizer=optimizer, metrics=[tf.keras.metrics.MeanAbsoluteError()])



In [38]:
history = model.fit(train_dataset, epochs=50, validation_data=validate_dataset, 
                    callbacks=[model_checkpoint_cb, scheduler_cb, earlystopping_cb])


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50


In [34]:
#model = tf.keras.models.load_model("checkpoints/20210609-161957checkpoint_lr01_06")
#model.load_weights("checkpoints/20210609-171355checkpoint_lr01_05")
model.load_weights("checkpoints/20210610-174049checkpoint_layernorm_05")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd6327890f0>

In [13]:
# setup test dataset
batch_size = 1024
test_filepath_dataset = tf.data.TFRecordDataset.list_files("data/test/*.tfrecord", shuffle=True)

test_num_files = test_filepath_dataset.cardinality()

test_dataset = test_filepath_dataset.interleave(
    lambda x: tf.data.TFRecordDataset(x),
    cycle_length=test_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)

def test_preprocess_and_split_into_tuples(tfrecord):
    # parse sequence example and normalise
    example = tf.io.parse_sequence_example(tfrecord, context_features=context_features, sequence_features=sequence_features)
    trait_frequencies = example[1]["raw_trait_frequencies"]
    pop_size_norm = (tf.cast(example[0]["population_size"], dtype=tf.float32) - pop_size_min) / (pop_size_max - pop_size_min)
    sc_norm = (tf.cast(example[0]["selection_coefficient"], dtype=tf.float32) - sc_min) / (sc_max - sc_min)
    # filter dataset to only include trials in which the trait survives at least min_survival
    trait_frequencies = tf.gather(trait_frequencies, tf.where(trait_frequencies.row_lengths() > 12), axis=0)
    # note that, unlike in the train/valid cases where we randomly sample num_trait_data trajectories,
    # we analyse all trajectories (> min_survival) for the test set        
    trait_frequencies = trait_frequencies[:, :, 0:max_time_steps] 
    trait_frequencies = tf.squeeze(trait_frequencies, axis=1)
    n_trajectories = trait_frequencies.nrows()
    # cast labels into a list of tensors matching the features
    label_tensors = tf.reshape(tf.stack((tf.repeat(pop_size_norm, n_trajectories), 
                                         tf.repeat(sc_norm, n_trajectories)), axis=1), shape=(n_trajectories, 2))   
    # convert trait_frequencies to dense tensor (with mask value of -1)
    trait_frequencies = trait_frequencies.to_tensor(default_value = -1., 
                                                    shape=(n_trajectories, max_time_steps))
    
    # add final dimension (keras requires an input shape of (batch_size, timesteps, features))
    trait_frequencies = tf.expand_dims(trait_frequencies, axis=2)

    return tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(trait_frequencies), tf.data.Dataset.from_tensor_slices(label_tensors)))

test_dataset = test_dataset.interleave(
    test_preprocess_and_split_into_tuples,
    cycle_length=test_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)

# no need to shuffle test set
test_dataset = test_dataset.shuffle(10000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

In [14]:
model.evaluate(test_dataset.take(200))



[0.0030169931706041098, 0.03436065465211868]

In [15]:
model.save_weights('best_model_june10')