In [1]:
import tensorflow as tf
import os
#os.environ["TF_GPU_THREAD"]="gpu_private"
#physical_devices = tf.config.list_physical_devices('GPU') 
#tf.config.experimental.set_memory_growth(physical_devices[0], True)
from datetime import datetime
%load_ext tensorboard

print(tf.version.VERSION)


2.5.0


In [2]:
train_filepath_dataset = tf.data.TFRecordDataset.list_files("data/train/*.tfrecord", shuffle=True)
train_num_files = train_filepath_dataset.cardinality()
#num_files=30
train_dataset = train_filepath_dataset.interleave(
    lambda x: tf.data.TFRecordDataset(x),
    cycle_length=train_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


In [3]:
validate_filepath_dataset = tf.data.TFRecordDataset.list_files("data/valid/*.tfrecord", shuffle=True)
validate_num_files = validate_filepath_dataset.cardinality()
#num_files=30
validate_dataset = validate_filepath_dataset.interleave(
    lambda x: tf.data.TFRecordDataset(x),
    cycle_length=validate_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)

In [4]:
context_features = {"population_size" : tf.io.FixedLenFeature((), tf.int64),
                   "selection_coefficient" : tf.io.FixedLenFeature((), tf.float32)}
sequence_features = {"raw_trait_frequencies" : tf.io.RaggedFeature(tf.float32)}

pop_size_min = 20
pop_size_max = 500
sc_min = 0.
sc_max = 0.5
min_survival=12 #number of gens survived is min_survival - 2 (because of starting and final freq)
max_time_steps = 60 # to avoid OOM issues
num_trait_data = 2000

def preprocess_and_split_into_tuples(tfrecord):
    # parse sequence example and normalise
    example = tf.io.parse_sequence_example(tfrecord, context_features=context_features, sequence_features=sequence_features)
    trait_frequencies = example[1]["raw_trait_frequencies"]
    pop_size_norm = (tf.cast(example[0]["population_size"], dtype=tf.float32) - pop_size_min) / (pop_size_max - pop_size_min)
    sc_norm = (tf.cast(example[0]["selection_coefficient"], dtype=tf.float32) - sc_min) / (sc_max - sc_min)
    # filter dataset to only include num_trait_data trials in which the trait survives at least min_survival
    trait_frequencies = tf.gather(trait_frequencies, tf.where(trait_frequencies.row_lengths() > min_survival), axis=0)
    
    
    # I CAN PROBABLY IMPLEMENT A RANDOM SAMPLING HERE SO I GET DIFFERENT INDICES OF TRAIT_FREQUENCIES RETURNED?
    
    trait_frequencies = tf.gather(trait_frequencies, tf.random.uniform(shape=[num_trait_data], minval=0,
                                                                      maxval=trait_frequencies.nrows(),
                                                                      dtype=tf.int32))
    

    trait_frequencies = trait_frequencies[:, :, 0:max_time_steps]

    #trait_frequencies = trait_frequencies[0:num_trait_data, :, 0:max_time_steps]
    
    # cast labels into a list of tensors matching the features
    trait_frequencies = tf.squeeze(trait_frequencies, axis=1)
    label_tensors = tf.reshape(tf.stack((tf.repeat(pop_size_norm, num_trait_data), tf.repeat(sc_norm, num_trait_data)), axis=1), 
                               shape=(num_trait_data,2))
    # add final dimension (number features)
    label_tensors = tf.expand_dims(label_tensors, axis=2)
    trait_frequencies = tf.expand_dims(trait_frequencies, axis=2)

    return tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(trait_frequencies), tf.data.Dataset.from_tensor_slices(label_tensors)))


In [5]:
# test cycle length and block length for performance
train_dataset = train_dataset.interleave(
    preprocess_and_split_into_tuples,
    cycle_length=train_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [6]:
# test cycle length and block length for performance
validate_dataset = validate_dataset.interleave(
    preprocess_and_split_into_tuples,
    cycle_length=validate_num_files,
    block_length=1,
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    deterministic=False)


In [7]:
batch_size = 512
train_dataset = train_dataset.batch(batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)
validate_dataset = validate_dataset.batch(batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)


In [8]:
#model = tf.keras.models.Sequential()
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)
rnn = tf.keras.layers.LSTM(50, return_sequences=True)
#rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True))

outputs= rnn(inputs.to_tensor(), mask=tf.sequence_mask(inputs.row_lengths()))

outputs = tf.keras.layers.LSTM(50, return_sequences=True)(outputs)
outputs = tf.keras.layers.LSTM(100)(outputs)
outputs = tf.keras.layers.Dropout(0.5)(outputs)
outputs = tf.keras.layers.Dense(100, activation='elu')(outputs)
outputs = tf.keras.layers.Dense(2, activation='linear')(outputs)



#mask_layer = mask_layer(inputs.to_tensor(), mask=tf.sequence_mask(inputs.row_lengths()))
#outputs = tf.keras.layers.Dense(2, activation='linear')(mask_layer)

model = tf.keras.Model(inputs, outputs)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
input.row_lengths (InstanceMeth (None,)              0           input_1[0][0]                    
__________________________________________________________________________________________________
input.to_tensor (InstanceMethod (None, None, 1)      0           input_1[0][0]                    
__________________________________________________________________________________________________
tf.sequence_mask (TFOpLambda)   (None, None)         0           input.row_lengths[0][0]          
______________________________________________________________________________________________

In [None]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(None, 1), ragged=True))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, activation='tanh', return_sequences=True)))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, activation='tanh')))
model.add(tf.keras.layers.Dense(100, activation='relu', kernel_initializer=tf.keras.initializers.HeNormal()))
model.add(tf.keras.layers.Dense(2, activation='linear'))
model.summary()

In [9]:
def scheduler(epoch, lr):
    if epoch < 3:
        return lr
    else:
        return lr * tf.math.exp(-0.02)

def abs_dist(y_true, y_pred):
    abs_difference = tf.math.abs(y_true - y_pred)
    return tf.reduce_mean(abs_difference, axis=-1)

optimizer = tf.keras.optimizers.Adam(0.002)
model.compile(loss="mse", optimizer=optimizer, metrics=[abs_dist, tf.keras.metrics.RootMeanSquaredError()])
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
logs = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs,
                                                 histogram_freq = 1)

checkpoint_filepath = './checkpoint_{epoch:02d}'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=False,
    save_freq='epoch')

In [10]:
history = model.fit(train_dataset, epochs=50, validation_data=validate_dataset, callbacks=[callback, tensorboard_callback,
                                                                                          model_checkpoint_callback])


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
 40/762 [>.............................] - ETA: 34s - loss: 0.0044 - abs_dist: 0.0407 - root_mean_squared_error: 0.0660

KeyboardInterrupt: 

In [None]:
%tensorboard --logdir=logs

In [None]:




def scheduler(epoch, lr):
    if epoch < 3:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

def abs_dist(y_true, y_pred):
    abs_difference = tf.math.abs(y_true[0] - y_pred[0])
    return tf.reduce_mean(abs_difference, axis=-1)

optimizer = tf.keras.optimizers.Adam(0.002)
model.compile(loss="mse", optimizer=optimizer, metrics=[abs_dist, tf.keras.metrics.RootMeanSquaredError()])
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

history = model.fit(dataset, epochs=10, callbacks=[callback])


In [None]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(None, 1), ragged=True))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, activation='tanh', return_sequences=True)))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, activation='tanh')))
model.add(tf.keras.layers.Dense(100, activation='relu', kernel_initializer=tf.keras.initializers.HeNormal()))
model.add(tf.keras.layers.Dense(2, activation='linear'))
model.summary()
def scheduler(epoch, lr):
    if epoch < 3:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

def abs_dist(y_true, y_pred):
    abs_difference = tf.math.abs(y_true - y_pred)
    return tf.reduce_mean(abs_difference, axis=-1)

optimizer = tf.keras.optimizers.Adam(0.0015)
model.compile(loss="mse", optimizer=optimizer, metrics=[abs_dist, tf.keras.metrics.RootMeanSquaredError()])
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

history = model.fit(dataset, epochs=20, callbacks=[callback])
