# Run the following code if working in collab

In [None]:
# Clone the entire repo.
!git clone -l -s https://github.com/cimat-ris/TF-PathPred.git clonedrep

import os
os.chdir("./clonedrep")

!pip install pykalman

# Testing the Transformer

This notebook enables the testing of the prediction of the transformer model for the datasets of ETH and UCY. It also gives a visualization of such predictions, and a visualization of the attention weight matrices

In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import tensorflow as tf
from test_TF import test_model, print_sol
import matplotlib.pyplot as plt
import numpy as np
import os

# Generate testing information

Run the cell bellow if you wish to generate the testing information. If you only want to visualize, go on to the next section.

In [2]:
path = "./"
#choose a dataset
test_dataset = ["ETH-hotel"]

test_path = path + f"testing_data/{test_dataset[0]}/"

try: os.mkdir(path+"testing_data")
except: pass
try: os.mkdir(path+f"testing_data/{test_dataset[0]}")
except: pass

ade,fde,weights,trajs,transformer = test_model(test_dataset,path)

ade,fde = np.array(ade), np.array(fde)

np.save(test_path+"ade.npy", ade)
np.save(test_path+"fde.npy", fde)
np.save(test_path+"inp.npy", trajs[0])
np.save(test_path+"tar.npy", trajs[1])
np.save(test_path+"pred.npy", trajs[2])

loading trajlets from:  ./trajlets\ETH-hotel-trl.npy
Latest checkpoint restored!!
calculating predictions
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, ADE: 1.700885861878435 FDE: 1.987742307028067


# Visualizing solution

Just run the cell bellow changing the name of the dataset you wish to visualize and change the value in the input box to visualize different trajectories. The red one is the observed trajectory, the blue one is the target, and the green ones are the possible outcomes predicted. If you wish to visualize a mode in particular just add a comma and the number of mode to visualize, with no spaces. e.g. "0,0"

In [4]:
path = "./"
#choose a dataset
test_dataset = ["ETH-hotel"]

test_path = path + f"testing_data/{test_dataset[0]}/"

ade = np.load(test_path+"ade.npy")
fde = np.load(test_path+"fde.npy")
inp = np.load(test_path+"inp.npy")
tar = np.load(test_path+"tar.npy")
pred = np.load(test_path+"pred.npy")

print("General: ADE:", np.mean(ade),"FDE:", np.mean(fde))

def f(x):
    x = x.split(',')
    if len(x) == 1:
        try:
            x = int(x[0])
        except:
            print("Error: Input was not an integer")
            return

        if x >= len(ade):
            print("There aren't that many trajectories")
            return
        print_sol(inp[x],tar[x],pred[x])    
        print(f"trajectory {x}:","ADE:", ade[x],"FDE:", fde[x])
        
    elif len(x) == 2:
        try:
            k = int(x[1])
            x = int(x[0])
        except:
            print("Error: Inputs were not integers")
            return

        if x >= len(ade):
            print("There aren't that many trajectories")
            return
        for i in range(len(pred)):
            1#plt.plot(pred[i,:,0],pred[i,:,1], color = "white")
        print_sol(inp[x],tar[x],pred[x][k:(k+1)])
        print(f"trajectory {x}:","ADE:", ade[x],"FDE:", fde[x])
    else:
        print("the format is not valid")

interact(f, x="0");

General: ADE: 1.700885861878435 FDE: 1.987742307028067


interactive(children=(Text(value='0', description='x'), Output()), _dom_classes=('widget-interact',))