# A toy example  

In this notebook, we will use vanilla LSTM recurrent neural networks to learn our model.  

*Note: In this notebook, we will use the tensorflow probability library, which needs to be installed as it's not part of tensorflow*

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

import sys; sys.path.insert(0, '..')
from data.data_generator import *
from preprocess import *
from window_data import *

tfpl = tfp.layers
tfd = tfp.distributions

# Hide GPU from visible devices
tf.config.set_visible_devices([], 'GPU')

## Data  

We get data using the first model (also the simplest). We also only use `10` samples

In [None]:
N = 10

start = time.time()

total = generateData(model1,
        num_data = 10,
        init_sty = 'random',
        times = (0, 20),
        params = {'no. of prey': N, 
    'kappa for prey': 0.5, 
    'attraction of prey a': 1, 
    'repulsion of prey b_1': 1, 
    'repulsion of pred b_2': 0.07, 
    'attraction of pred c': 10, 
    'exponent of dist pred p': 1.2},
        steps = 1000,
        second_order = False,
        method = 'rk2',
        return_vel = False,
        cores = 8,
        flattened=False)
end = time.time()
print(f"Time taken: {end-start} seconds.")

A plot showing the data

In [None]:
multiPlot([total[0][1], 20/1000, 10], sample_points =[0,0.5,2,4,6,8,10],
            axis_lim = None, second_order = False, quiver=True)

The data has the shape `(batch, times, individuals, coordinates)`

In [None]:
data = np.array([total[i][1] for i in range(len(total))])
data.shape

We will only use one initial condition for this experiment, as our naive implementation only works well with one time series

In [None]:
data = data[0]
train_ds, valid_ds, test_ds = getDatasets(data, scaling = False, return_ndarray=True)

For an experiment, change to `input_width=900, label_width=100, shift=100` and change the data above (otherwise there is not enough data)

In [None]:
window1 = WindowData(input_width=300, label_width=10, shift=10,
                    train_ds=train_ds, val_ds=valid_ds, test_ds=test_ds)
print(window1)

In [None]:
train_ds = window1.make_train()
valid_ds = window1.make_val()
test_ds = window1.make_test()

print(train_ds.element_spec)
print(valid_ds.element_spec)
print(test_ds.element_spec)
print(window1.num_points)

## A naive model

The idea is that for data of shape `(batch, times, individuals)`, we pass to an LSTM layer after `embedding` it in some way (the idea is similar to one-hot encoding of integer/categorical values), where it has output shape `(batch, times, length of concatenated embeddings)`, we can then produce a prediction at a single future time step of the shape `(batch, 1, individuals)`

In [None]:
from rnn import *

model = tf.keras.Sequential([tf.keras.layers.Input(shape=(window1.input_width,21,2)),
                            embedder((window1.input_width,21,2), 64, batch_size=32)])
model.summary()

In [None]:
embedding_size = 64

lstm_model_1 = tf.keras.models.Sequential([
    tf.keras.layers.Input((window1.input_width,21,2)),
    embedder((window1.input_width,21,2), embedding_size, batch_size=32),
    tf.keras.layers.LSTM(2*embedding_size, return_sequences=False),
    # 5 outputs for each trajectory as it's bivariate normal
    # there are window1.num_points trajectories (21 in this case)
    # so we need 5*21 outputs at each time step
    # there are window1.label_width time steps (5 in this case)
    # so we have 5*21*5 outputs from Dense layer
    # first two 21 blocks are means, last 21*3 block 
    # form the lower-tril matrix (consecutive 3 for each coord)
    # final output shape should be (batch, 5, 21, 2)
    tf.keras.layers.Dense(window1.label_width * window1.num_points * 5, activation='linear'),
    tf.keras.layers.Reshape((window1.label_width, window1.num_points, 5)),
    # the loc should be (5, 21, 2)
    # the scale_tril should be (5, 21, 3)
    tfpl.DistributionLambda(lambda x: tfd.MultivariateNormalTriL(
                            loc=x[..., :2], 
                            scale_tril=tfp.math.fill_triangular(x[...,2:])
                            )
                           )
])

lstm_model_1.summary()

Before going on, we take a sample from our dataset and pass it through the model to verify the output shape

In [None]:
for x, y in train_ds.take(1):
    x_sample = x
    y_sample = y
    print("The log probability shape is:")
    print(lstm_model_1(x).log_prob(y).shape)
    print("The true value's shape is:")
    print(y.shape)
    print("After reduction the probability shape is:")
    print(tf.reduce_mean(lstm_model_1(x).log_prob(y), axis=1).shape)

We define also a custom loss function that computes the negative log likelihood over the time steps of prediction.  

Recall that we would like to sum over the prediction time steps, which is the second axis in this case   

$$
L^i = - \sum_{t=T_{obs}+1}^{T_{end}} \log (\Pr((x_i,y_i)^t \mid \mu_i^t, \sigma_i^t, \rho_i^t)
$$

In [None]:
def negLog(y_true, y_pred):
    return -tf.reduce_sum(y_pred.log_prob(y_true), axis=1)

In the original paper, the authors used various metrics to evaluate the model's performance. One of those is the average displacement error, which is the mean square error (MSE) over all estimated points of a trajectory and the true points.

In [None]:
def ADE(y_true, y_pred):
    return tf.reduce_sum((y_true-y_pred)**2)

Now we can use `RMSProp` as in the paper to train our model.

In [None]:
lstm_model_1.compile(loss=negLog, 
                    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.003),
                    metrics=[ADE]
                    )

In [None]:
lstm_model_1.fit(train_ds, epochs=50, validation_data=valid_ds)

We can visualize the model by sampling from the outputs

In [None]:
y_pred = lstm_model_1(x_sample)

In [None]:
true_trajec = y_sample.numpy()[0]
pred_trajec = y_pred.sample().numpy()[0]
print(pred_trajec.shape)

In [None]:
multiPlot([pred_trajec, 1, 10], sample_points =[0,1,2,3,4],
            axis_lim = None, second_order = False, quiver=True)

In [None]:
multiPlot([true_trajec, 1, 10], sample_points =[0,1,2,3,4],
            axis_lim = None, second_order = False, quiver=True)

In [None]:
print(f"The average displacement error is {ADE(y_sample, y_pred.sample()).numpy()}")