In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from model import generate_data, RNNInput, RNNModelConfig, RNNModel, train_model, evaluate_model, baseline_mse

In [None]:
# Choose the hyperparameters
config = RNNModelConfig(learning_rate=0.0005, keep_probability=1, 
                        identity_init=True, max_grad_norm=100, 
                        state_size=75, batch_size=32)
config

In [None]:
# Parameters for generating data
train_size = 100000
test_size = 10000
value_low = -100
value_high = 100
min_length = 1
max_length = 10

In [None]:
# Generate data
np.random.seed(1)
train_df = generate_data(size=train_size, value_low=value_low, value_high=value_high, 
                         min_length=min_length, max_length=max_length)


In [None]:
# Examine the generated data
train_df

In [None]:
tf.reset_default_graph()
tf.set_random_seed(1)

# Create a model
train_input = RNNInput(train_df) 
with tf.name_scope("Train"):
    with tf.variable_scope("Model", reuse=None):
        m = RNNModel(config)


In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Train the model
num_epochs = 10
train_losses = train_model(sess, m, train_input, num_epochs)
plt.semilogy(train_losses)

In [None]:
# Create a test set
test_df = generate_data(test_size, value_low=value_low, value_high=value_high, 
                        min_length=min_length, max_length=max_length)


In [None]:
# Evaluate the model on the test set
test_input = RNNInput(test_df)
pred_loss, preds = evaluate_model(sess, m, test_input)

print "Test loss: {}".format(pred_loss)

# Compare with the baseline loss
baseline_loss = baseline_mse(test_df['l1norm'], value_low=value_low, value_high=value_high, 
                        min_length=min_length, max_length=max_length)
print "Baseline loss: {}".format(baseline_loss)

In [None]:
# Examine the predictions on the test set
pd.DataFrame.from_records(preds, columns=['prediction', 'ground_truth'])


In [None]:
sess.close()