# Inspect model performance on test data

In [None]:
import dataclasses

import numpy as np
import matplotlib.pyplot as plt
import torch

import trajectory
import run

In [None]:
# Directory where model is loaded from
RUN_DIRPATH = 'models/20220510_01'

# Number of trajectories to sample
N_TRIALS = 500

## Load model

In [None]:
run_params, model = run.load_run(RUN_DIRPATH)
run_params.print()

## Simulate trajectories

In [None]:
# Use trajectory generator with different random seed for test data
traj_params = dataclasses.replace(run_params.traj)
traj_params.rng_seed = run_params.traj.rng_seed + 1

# Sample batch
tgen = trajectory.TrajectoryGenerator(traj_params)
vel, pos = tgen.smp_batch(N_TRIALS)

## Run model on all trials

In [None]:
# Convert velocity array to Tensor in order to run model
vel_t = torch.Tensor(vel)

# Predict estimated position
pos_est_t, u_vals_t = model(vel_t)

# Convert estimated position back to Numpy array
pos_est = pos_est_t.detach().numpy()
u_vals = u_vals_t.detach().numpy()

## Plot results

In [None]:
boundary = trajectory.SquareBoundary(run_params.traj.boundary_height)

In [None]:
trial_plt = 100
t_start = 0
t_stop = 450

pos_plt = pos[trial_plt][t_start:t_stop]
pos_est_plt = pos_est[trial_plt][t_start:t_stop]

fig, ax = plt.subplots(1, 1)
ax.set_aspect('equal')
ax.set_title(f'Trial: {trial_plt}')
trajectory.plot_position_estimate(boundary, pos_plt, pos_est_plt, ax)