In [1]:
###########################################################################
# Authors: Arthur Mateos and Chris Zhang
#
# Date: 27 July, 2016
###########################################################################

In [2]:
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt
import random

import tensorflow as tf
from tensorflow.contrib import rnn

The minimum supported version is 2.4.6



In [3]:
# Paths to files with containing fish data
datafiles = ['~/JolleFishData-5fish/CM1FRE_150324_1147_RP10_S04_G22_P.csv',
            '~/JolleFishData-5fish/CM1FRE_150324_1227_RP09_S05_G05_P.csv',
            '~/JolleFishData-5fish/CM1FRE_150324_1227_RP10_S05_G20_P.csv',
            '~/JolleFishData-5fish/CM1FRE_150324_1307_RP09_S06_G12_P.csv']

# Path to where the model is saved
model_path = 'ModelCheckpoints/five_fish_model.ckpt'

# Training parameters
learning_rate = 0.0001
dropout = 0.4
n_minibatches = 500
minibatch_size = 512
display_step = 10

# Network parameters
n_input = 10           # number of dimensions in input
n_embedded = 64        # number of dimensions in which to embed input
n_output = 10          # number of dimensions in output
n_per_hidden = 256     # number of nodes per hidden layer
    # Note: we require that all hidden layers have the same number of nodes
n_hidden_layers = 2    # number of hidden layers
l2_coefficient = 0.001 # coefficient for l2 loss normalization
window_length = 50     # length of lookback window to give the LSTM

# Process Input Data into the Right Format

In [4]:
def download_data(filename):
    """Download the csv file stored at 'filename'.
    
    Args:
        filename (str): The location of the file to read.
        
    Returns:
        DataFrame: A DataFrame containing the downloaded data.
    """
    
    data = pd.read_csv(filename, sep=",", header=0, index_col=0)
    # Drop unneccesary columns
    data = data.drop(['Color', 'col'], axis=1)
    return data

def flatten_data(data):
    """Flatten data from long to wide format.
    
    Args:
        data (DataFrame)
        
    Returns:
        DataFrame: Flattened DataFrame.
    """
    
    pivoted = data.pivot(columns="ID")  # collapse on ID column
    col_names = list(pivoted.columns.map('{0[0]}-{0[1]}'.format))  # extract column names from column index
    pivoted.columns = col_names
    return pivoted

def add_delta_pos(data):
    """Add change-in-position columns for each fish.
    
    Args:
        data (DataFrame)
        
    Returns:
        DataFrame: DataFrame with change-in-position columns added.
    """
    
    for col in data.columns:
        data['d' + col] = data[col].diff()
    data.dropna(inplace=True)
    return data

In [5]:
def normalize_windows(windows):
    """Normalize the data in a given window, so that all points except possibly the last lie in [0,1]"""
    normalized_windows = []
    for window in windows:
        mins = window[:-1,:].min(axis=0)   # leave out the last row when normalizing
        maxes = window[:-1,:].max(axis=0)  # leave out the last row when normalizing
        normalized_window = (window-mins)/(maxes-mins)
        normalized_windows.append(normalized_window)
    return normalized_windows

def partition_windows(windows, window_length, train_percent, valid_percent, test_percent):
    "Partition data into train, validation, and test."
    n_windows = len(windows)
    possible_overlap = 2*window_length
    n_windows -= possible_overlap
    
    n_train = n_windows*train_percent//100
    n_valid = n_windows*valid_percent//100
    n_test  = n_windows*test_percent//100

    train = windows[:n_train,:,:]
    valid = windows[n_train+window_length:n_train+n_valid+window_length,:,:]
    test  = windows[n_train+n_valid+2*window_length:,:,:]
    
    return train, valid, test


def download_and_preprocess_data(filenames=datafiles, window_length=window_length, normalize=False):
    """Download data from csv.
    Add columns for delta position.
    Break into windows, each of length 'window_length'.
    Partition into training, validation, and test sets"""
    
    # List to hold windows, each containing window_length consecutive timesteps of data
    windows = []
    
    # Iterate through all data files
    for file in filenames:
        data = download_data(file)
        data = flatten_data(data)
        data = add_delta_pos(data)

        # Create one window for each sequence of length window_length
        for index in range(len(data) - window_length):
            windows.append(np.array(data.iloc[index:index+window_length,:]))
    
    if normalize:
        windows = normalize_windows(windows)
    
    windows = np.array(windows)
    
    # 80% training, 10% test, 10% validation
    train, valid, test = partition_windows(windows, window_length, 80, 10, 10)

    # randomize the order
    np.random.shuffle(train)
    np.random.shuffle(valid)
    np.random.shuffle(test)
    
    # Select data of interest:
        # first 10 columns (x position, y position for each fish) as input
        # last 10 columns (delta-x, delta-y for each fish) as output
    # For each window, x contains all but last timestep, y contains only last timestep
    x_train = train[:, :-1, :n_input]
    y_train = train[:,  -1, -n_output:]
    x_valid = valid[:, :-1, :n_input]
    y_valid = valid[: , -1, -n_output:]
    x_test  =  test[: ,:-1, :n_input]
    y_test  =  test[: , -1, -n_output:]

    return x_train, y_train, x_valid, y_valid, x_test, y_test

In [6]:
def plot_fit(ground_truth, predicted):
    """Plot the accuracy of predictions, versus ground truth, for each of the predicted variables.
    
    TODO: add types here
    Args:
        ground truth (TYPE):
        predicted (TYPE):
        
    Returns:
        None.
    
    """
    assert(len(ground_truth) == len(predicted))

    n_points = len(ground_truth)
    
    # Plot each predicted variable separately
    for series_index in range(ground_truth.shape[1]):
        plt.title('Dataseries' + str(series_index) + ": predicted vs. ground truth")
            # TODO: make graph title more descriptive...
        plt.xlabel('Ground truth value')
        plt.ylabel('Predicted value')
        
        x_data = ground_truth[:,series_index].reshape(n_points)
        y_data = predicted[:,series_index].reshape(n_points)
        abline = [x for x in x_data]  # line of slope 1 and y-intercept 0
        
        # Plot predicted vs. ground truth
        plt.scatter(x_data, y_data, color='black')
        # Plot line of best fit
        plt.plot(np.unique(x_data), np.poly1d(
            np.polyfit(x_data, y_data, 1))(np.unique(x_data)), color='red')
        # Plot line of perfect fit (y = x)
        plt.plot(x_data, abline, color='blue')

        plt.show()

In [7]:
def get_minibatch(x_data, y_data, minibatch_size):
    """Generate a minibatch.
    
    Args:
        x_data (ndarray): x data to draw samples from
        y_data (ndarray): y data to draw samples from
        minibatch_size (int): Size of the minibatch.
        
    Returns:
        (ndarray, ndarray): input minibatch, output minibatch
    """
    assert(len(x_data) == len(y_data))
    assert(minibatch_size <= len(x_data))
    
    inputs = np.empty((minibatch_size, window_length-1, n_input))
    outputs = np.empty((minibatch_size, n_output))
    
    # Select minibatch_size random windows from the training set
    rand_indices = random.sample(range(len(x_data)), minibatch_size)
    for index in range(minibatch_size):
        inputs[index,:,:] = x_data[rand_indices[index],:,:]
        outputs[index,:] = y_data[rand_indices[index],:]
        
    return inputs, outputs

# Create and Run the Graph


In [8]:


graph = tf.Graph()

with graph.as_default():

    # Input data
    x_batch_placeholder = tf.placeholder(tf.float32,
                                      shape=(None, window_length-1, n_input))
        # None so that able to hold differently sized batches
    y_batch_placeholder = tf.placeholder(tf.float32, shape=(None, n_output))
        # None so that able to hold differently sized batches
    dropout_placeholder = tf.placeholder(tf.float32)

    # Variables to be trained
    embed_weights = tf.Variable(tf.truncated_normal([n_input, n_embedded]))
    embed_biases = tf.Variable(tf.zeros([n_embedded]))
    weights = tf.Variable(tf.truncated_normal([n_per_hidden, n_output]))
    biases = tf.Variable(tf.zeros([n_output]))
    
    # Build graph
    cells = []
    for _ in range(n_hidden_layers):
        cell = rnn.BasicLSTMCell(n_per_hidden)
        cell = rnn.DropoutWrapper(cell, output_keep_prob=1.0 - dropout_placeholder)
        cells.append(cell)
    cell = rnn.MultiRNNCell(cells)
    
    ### Define ops to run forward pass
    
    # Embedding the input laying and adding ReLU nonlinearity 
    stacks = [tf.nn.relu(tf.matmul(x_batch_placeholder[:,i,:], embed_weights)\
                       + embed_biases) for i in range(window_length-1)]

    embedd = tf.stack(stacks, axis=1)
    outputs, states = tf.nn.dynamic_rnn(cell, embedd, dtype=tf.float32)
    logits = tf.matmul(outputs[:,-1,:], weights) + biases
    
    # Define cost and optimizer
    l2_loss = tf.nn.l2_loss(weights)+tf.nn.l2_loss(biases)+tf.nn.l2_loss(embed_weights)
    cost = tf.sqrt(tf.reduce_mean(tf.squared_difference(logits, y_batch_placeholder))) +\
        l2_coefficient*l2_loss # cost function is rms
    optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)
    
    # Define op to initialize global variables
    init = tf.global_variables_initializer()
    
    # Define Saver op class to save and restore model
    saver = tf.train.Saver()

In [9]:

def train_model(
    n_minibatches=n_minibatches,
    display_step=display_step,
    learning_rate=learning_rate, # I would recommend decreasing the learning rate as you train
    minibatch_size=minibatch_size,
    graph=graph,
    restore_from_save=True, # Change this to false when you train a new model
    restore_from_latest=False,
    restore_path=model_path,
    save_when_finished=True,
    save_path=model_path):
    
    # Launch the graph
    with tf.Session(graph=graph) as sess:
        if restore_from_save:
            if restore_from_latest:
                restore_path = tf.train.latest_checkpoint('./ModelCheckpoints/')
    #           restore_path = "./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy.ckpt-270"
            try:
                saver.restore(sess, restore_path)
                print("Model successfully restored from %s.\nResuming training." % restore_path)
            except tf.errors.NotFoundError:
                print("Save file not found.\nInitializing graph from scratch instead.")
                sess.run(init)
                print("Global variables initialized.\nCommencing training.")
        else:
            sess.run(init)
            print("Global variables initialized.\nCommencing training.")

        # Keep training until reach max iterations
        for minibatch_idx in range(n_minibatches):
            _x_batch, _y_batch = get_minibatch(x_train, y_train, minibatch_size)

            # Run optimization op (backprop)
            feed_dict = {x_batch_placeholder: _x_batch, y_batch_placeholder: _y_batch, dropout_placeholder: dropout}
            _train_cost, _ = sess.run([cost, optimizer], feed_dict=feed_dict)

            if minibatch_idx % display_step == 0:
                _valid_cost = sess.run(
                    cost, feed_dict={x_batch_placeholder: x_valid, y_batch_placeholder: y_valid, dropout_placeholder: dropout})
                print("Minibatch " + str(minibatch_idx) + ", Minibatch cost = " + \
                      "{:.6f}".format(_train_cost))
                print("Minibatch " + str(minibatch_idx) + ", Validation set cost = " + \
                      "{:.6f}".format(_valid_cost))
                _save_path = saver.save(sess, save_path, global_step=minibatch_idx)
                print("Model saved in file: %s" % _save_path)

        if save_when_finished:
            # Save model weights to disk
            _save_path = saver.save(sess, save_path)
            print("Model saved in file: %s" % _save_path)

        # Plot fit on validation data
        print("\nCurrent validation performance:")
        plot_fit(y_valid, logits.eval(feed_dict={x_batch_placeholder: x_valid, dropout_placeholder: 0}))

# Generate Predictions

In [11]:
def get_new_row(seed, prediction):  # TODO: adapt this to new data format
    old_last_row = seed[-1,:]
    return old_last_row + prediction
    
def generate_next_point(seed, sess):
    feed_dict = {x_batch_placeholder: seed, dropout_placeholder: 0}
    _logits = logits.eval(session=sess, feed_dict=feed_dict)
    return _logits
    
def shift_seed(old_seed, new_row):
    return np.vstack([old_seed[1:,:], new_row])
    
def generate_prediction(seed, prediction_length, restore_path=model_path, progress_counter=20):
    """Starting from an unnormalized seed sequence and generate a new sequence of positions"
    Params:
        seed: ndarray, shape (1, window_length-1, n_input)
        prediction_length: integer, number of desired timesteps to generate
        restore_path: string, location from which to load saved graph state
        progress_counter: integer, indicates number of intervals to print progress in generating sequence
    Returns:
        array of predicted locations, of shape shape (prediction_length, n_output)
    """
    with tf.Session(graph=graph) as sess:
        # load the variables
        try:
            saver.restore(sess, restore_path)
            print("Model successfully restored from %s.\nDisplaying current fit." % restore_path)
        except tf.errors.NotFoundError:  # TODO: add type of error
            print("Save file not found.\nExiting.")
            return
        
        predictions = []
        show_interval = np.maximum(1,prediction_length//progress_counter)
        
        for index in range(prediction_length):
            if index % show_interval == 0:
                print("Generated", index, "of", prediction_length, "data points.")
            batch_seed = seed.reshape(1, window_length-1, n_input)  # cast to batch format (batch of size 1)
            prediction = generate_next_point(batch_seed, sess)
            pred_coords = get_new_row(seed, prediction)
            predictions.append(np.reshape(pred_coords, pred_coords.size))  # must reshape to remove extra dimension
            seed = shift_seed(seed, pred_coords)
        print("Done!")
        
        return np.array(predictions)

# Now Run the Code

In [12]:
# Create training, validation, and test sets
x_train, y_train, x_valid, y_valid, x_test, y_test = download_and_preprocess_data(
    filenames=datafiles, window_length=window_length, normalize=False)

In [1]:
#
# Train as much as you need to. Often a few thousands minibatches were used. 
# Decrease the learning rate as you go.
#

In [None]:
train_model(n_minibatches=100,display_step=20,learning_rate=0.0001, restore_from_save=False, restore_from_latest=True)

Global variables initialized.
Commencing training.
Minibatch 0, Minibatch cost = 5.837825
Minibatch 0, Validation set cost = 5.733052
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA2.ckpt-0
Minibatch 20, Minibatch cost = 5.149290
Minibatch 20, Validation set cost = 5.231023
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA2.ckpt-20


In [None]:
train_model(n_minibatches=1000,display_step=20,learning_rate=0.0001, restore_from_save=True, restore_from_latest=True)

INFO:tensorflow:Restoring parameters from ./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-60
Model successfully restored from ./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-60.
Resuming training.
Epoch 0, Minibatch cost = 4.581092
Epoch 0, Validation set cost = 4.662611
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-0
Epoch 20, Minibatch cost = 4.374648
Epoch 20, Validation set cost = 4.410938
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-20
Epoch 40, Minibatch cost = 4.086440
Epoch 40, Validation set cost = 4.168250
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-40
Epoch 60, Minibatch cost = 3.828678
Epoch 60, Validation set cost = 3.942400
Model saved in file: ModelCheckpo

In [None]:
train_model(n_minibatches=1000,display_step=20,learning_rate=0.0001, restore_from_save=True, restore_from_latest=True)

INFO:tensorflow:Restoring parameters from ./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-200
Model successfully restored from ./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-200.
Resuming training.
Epoch 0, Minibatch cost = 3.085219
Epoch 0, Validation set cost = 3.377854
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-0
Epoch 20, Minibatch cost = 3.028424
Epoch 20, Validation set cost = 3.361949
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-20
Epoch 40, Minibatch cost = 3.076978
Epoch 40, Validation set cost = 3.326528
Model saved in file: ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-40
Epoch 60, Minibatch cost = 3.081317
Epoch 60, Validation set cost = 3.309201
Model saved in file: ModelCheck

In [12]:
### Generate a trajectory

savefile = './ModelOutputs/generated_trajectory.csv'
seed = x_valid[0]  # an initial window to feed into the generative model
pred = generate_prediction(seed, prediction_length=5000, restore_path=model_path, progress_counter=40)
np.savetxt(savefile, pred, delimiter=',')

INFO:tensorflow:Restoring parameters from ./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-80
Model successfully restored from ./ModelCheckpoints/five_fish_xy-embed_64-relu_2x256-l2_norm-pred_five_fish_delta_xy-MORE_DATA.ckpt-80.
Displaying current fit.
Generated 0 of 5000 data points.
Generated 125 of 5000 data points.
Generated 250 of 5000 data points.
Generated 375 of 5000 data points.
Generated 500 of 5000 data points.
Generated 625 of 5000 data points.
Generated 750 of 5000 data points.
Generated 875 of 5000 data points.
Generated 1000 of 5000 data points.
Generated 1125 of 5000 data points.
Generated 1250 of 5000 data points.
Generated 1375 of 5000 data points.
Generated 1500 of 5000 data points.
Generated 1625 of 5000 data points.
Generated 1750 of 5000 data points.
Generated 1875 of 5000 data points.
Generated 2000 of 5000 data points.
Generated 2125 of 5000 data points.
Generated 2250 of 5000 data points.
Generated 2375 of 5000