# Run the following code if working in collab

In [None]:
# Clone the entire repo.
!git clone -l -s https://github.com/cimat-ris/TF-PathPred.git clonedrep

import os
os.chdir("./clonedrep")

!pip install pykalman

# Testing the Transformer

This notebook enables the testing of the prediction of the transformer model for the datasets of ETH and UCY. It also gives a visualization of such predictions, and a visualization of the attention weight matrices

In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import tensorflow as tf
from test_TF import test_model, print_sol
import matplotlib.pyplot as plt
import numpy as np
import os

from tools.trajectories import traj_to_real_coordinates

# Generate testing information

Run the cell bellow if you wish to generate the testing information. If you only want to visualize, go on to the next section.

In [2]:
path = "./"
#choose a dataset
test_dataset = "ETH-Univ"

test_path = path + f"generated_data/testing_data/{test_dataset}/"

try: os.mkdir(path+"generated_data/testing_data")
except: pass
try: os.mkdir(path+f"generated_data/testing_data/{test_dataset}")
except: pass

ade,fde,weights,trajs,transformer = test_model([test_dataset],path,5)

ade,fde = np.array(ade), np.array(fde)

np.save(test_path+"ade.npy", ade)
np.save(test_path+"fde.npy", fde)
np.save(test_path+"inp.npy", trajs[0])
np.save(test_path+"tar.npy", trajs[1])
np.save(test_path+"pred.npy", trajs[2])

Loading trajlets from:  ./generated_data/trajlets\ETH-Univ-trl.npy
Small trajectories: 92
(2321, 8, 2)
(2321, 12, 2)
./generated_data/checkpoints/train/ETH-Univ
Latest checkpoint restored!! ./generated_data/checkpoints/train/ETH-Univ\ckpt-110
Calculating predictions
ADE: 0.58671695 FDE: 1.3588676


In [14]:
trajs[1][12]

array([[ 0.37390298, -0.2288619 ],
       [ 0.        ,  0.        ],
       [-0.0871935 , -0.03051508],
       [-0.20659758, -0.38946742],
       [-0.45374957, -0.24956225],
       [-0.6715732 , -0.05334945],
       [-0.9829161 , -0.161539  ],
       [-1.3253201 , -0.18906768],
       [-1.7472138 , -0.28921288],
       [-2.2055075 , -0.44761813],
       [-2.64101   , -0.69115925],
       [-2.9735653 , -0.9468997 ],
       [-3.3288553 , -1.3065428 ],
       [-3.611939  , -1.7388265 ]], dtype=float32)

In [3]:
transformer.summary()

Model: "transformer_cvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Encoder)            multiple                  545664    
_________________________________________________________________
cvae_attention (CVAE_attenti multiple                  52234     
_________________________________________________________________
decoder (Decoder)            multiple                  695040    
_________________________________________________________________
dense_47 (Dense)             multiple                  258       
Total params: 1,293,196
Trainable params: 1,293,196
Non-trainable params: 0
_________________________________________________________________


# Visualizing solution

Just run the cell bellow changing the name of the dataset you wish to visualize and change the value in the input box to visualize different trajectories. The red one is the observed trajectory, the blue one is the target, and the green ones are the possible outcomes predicted. If you wish to visualize a mode in particular just add a comma and the number of mode to visualize, with no spaces. e.g. "0,0"

In [23]:
Tobs = 8
Tpred = 12

from tools.opentraj_benchmark.all_datasets import get_trajlets
from tools.trajectories import obs_pred_trajectories,obs_pred_rotated_trajectories, convert_to_traj_with_rotations
test_name = ['ETH-univ']
trajectories = get_trajlets("./",test_name)[test_name[0]][:,:,:2]
Starts_train , Xm_test, Xp_test, dists, mtcs = obs_pred_rotated_trajectories(trajectories,Tobs,Tpred+Tobs)

Loading trajlets from:  ./generated_data/trajlets\ETH-univ-trl.npy
Small trajectories: 92


In [4]:
path = "./"
#choose a dataset
test_dataset = "ETH-univ"

test_path = path + f"generated_data/testing_data/{test_dataset}/"
reference_path = path + f"datasets/ETH/seq_eth/reference.png"
H_path = path + f"datasets/ETH/seq_eth/H.txt"

H = np.loadtxt(H_path)
H = np.linalg.inv(H)
img = plt.imread(reference_path)

ade = np.load(test_path+"ade.npy")
fde = np.load(test_path+"fde.npy")
inp = np.load(test_path+"inp.npy")
tar = np.load(test_path+"tar.npy")
pred = np.load(test_path+"pred.npy")

print("General: ADE:", np.mean(ade),"FDE:", np.mean(fde))

def f(x):
    x = x.split(',')
    if len(x) == 1:
        try:
            x = int(x[0])
        except:
            print("Error: Input was not an integer")
            return

        if x >= len(ade):
            print("There aren't that many trajectories")
            return
#         print_sol(inp[x],tar[x],pred[x],None)
        a = traj_to_real_coordinates(inp[x],H)
        b = traj_to_real_coordinates(tar[x],H)
        c = traj_to_real_coordinates(pred[x],H)
        print_sol(a,b,c,img)
        
        print(f"trajectory {x}:","ADE:", ade[x],"FDE:", fde[x])
        
    elif len(x) == 2:
        try:
            k = int(x[1])
            x = int(x[0])
        except:
            print("Error: Inputs were not integers")
            return

        if x >= len(ade):
            print("There aren't that many trajectories")
            return
        #print_sol(inp[x],tar[x],pred[x][k:(k+1)],img)
        a = inp[x]#traj_to_real_coordinates(inp[x],H)
        b = tar[x]#traj_to_real_coordinates(tar[x],H)
        c = pred[x][k]#traj_to_real_coordinates(pred[x][k],H)
        print_sol(a,b,None,None)#np.array([c]),None)
        print(f"trajectory {x}:","ADE:", ade[x],"FDE:", fde[x])
    else:
        print("the format is not valid")

interact(f, x="0");

General: ADE: 0.58671695 FDE: 1.3588676


interactive(children=(Text(value='0', description='x'), Output()), _dom_classes=('widget-interact',))

In [5]:
inp[0]

array([[ 1.0000000e+00, -1.6870156e-18],
       [ 7.6804924e-01, -3.6168404e-02],
       [ 6.8319803e-01, -5.2176796e-02],
       [ 5.3687590e-01, -4.3020580e-02],
       [ 3.5017157e-01, -5.6467921e-02],
       [ 2.4483772e-01,  6.4771124e-03],
       [ 1.2285026e-01,  8.2979165e-03]], dtype=float32)

In [7]:
tar[0]

array([[ 0.12285026,  0.00829792],
       [ 0.        ,  0.        ],
       [-0.13646209,  0.01131812],
       [-0.25367513,  0.02418254],
       [-0.3715892 ,  0.02686667],
       [-0.49097073,  0.02958424],
       [-0.60017306,  0.02257409],
       [-0.7002231 ,  0.02671189],
       [-0.79009384,  0.0315993 ],
       [-0.8933986 ,  0.0570505 ],
       [-0.98500806,  0.06212363],
       [-1.042929  ,  0.06932585],
       [-1.1030437 ,  0.10870075],
       [-1.1775117 ,  0.19050777]], dtype=float32)

In [6]:
pred[0]

array([[[ 0.12285026,  0.00829792],
        [ 0.        ,  0.        ],
        [-0.12279233,  0.01117741],
        [-0.26081228,  0.00693151],
        [-0.41017687,  0.00797362],
        [-0.53467   ,  0.00796123],
        [-0.64435357,  0.00701118],
        [-0.74307019,  0.00649073],
        [-0.86878723,  0.00361584],
        [-0.97647029,  0.00133931],
        [-1.08544576, -0.00273226],
        [-1.19096804, -0.00820347],
        [-1.27997339, -0.01881041],
        [-1.33447945, -0.02417786]]])