# Prediction Evaluation and Visualization (NuScenes)
* The purpose of this notebook is to aid in evaluating the trained models on the NuScenes dataset.
* It runs the model to generate predictions, saving as a pickle file.
* It then generates a Pandas dataframe to aggregate metrics and select some interesting outlier/edge cases for visualization.

## Adjustable Parameters

In [None]:
checkpoint_file  = 'log/nuscenes_multipath_lstm/00030_epochs.h5' # location of the checkpoint relative to the repo.
model_type       = 'multipath'                                   # 'multipath' or 'regression'
nuscenes_datadir = '/media/data/nuscenes-data/'                  # global path to the nuscenes dataroot

make_predictions       = True  # make predictions on the nuscenes dataset and save results
viz_interesting_cases  = True  # identify interesting cases using a Pandas dataframe and visualize them.

## Code and Model Setup

In [None]:
import os
import sys
import glob
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import pickle

ROOTDIR = os.getcwd().split('scripts')[0]
sys.path.append(os.path.join(ROOTDIR, 'scripts'))

from models.regression import Regression
from models.multipath import MultiPath
from datasets.splits import NUSCENES_TRAIN, NUSCENES_VAL
from evaluation.gmm_prediction import GMMPrediction
from evaluation.pandas_df_utils import eval_prediction_dict
from evaluation.nuscenes_pred_visualization import NuscenesPredictionVisualizer

In [None]:
anchors = np.load(os.path.join(ROOTDIR, 'data/nuscenes_clusters_16.npy'))
weights = np.load(os.path.join(ROOTDIR, 'data/nuscenes_clusters_16_weights.npy'))

if model_type == 'multipath':
    model = MultiPath(num_timesteps=12, num_hist_timesteps=2, anchors=anchors, weights=weights)
elif model_type == 'regression':
    model = Regression(num_timesteps=12, num_hist_timesteps=2)
else:
    raise ValueError("{} not implemented".format(dataset_type))
    
model.load_weights(os.path.join(ROOTDIR, checkpoint_file))

model_name = [x for x in checkpoint_file.split('/') if 'nuscenes' in x]
if len(model_name) == 1:
    model_name = model_name[0]
else:
    raise ValueError("Expected the model name to contain the substring <nuscenes>.  Unable to detect this.")

## Make Predictions

In [None]:
if make_predictions:
    for dataset_name, dataset in zip(['train', 'val'], [NUSCENES_TRAIN, NUSCENES_VAL]):
        predict_dict = model.predict(dataset)
        savename = checkpoint_file.replace('.h5', '_{}_preds.pkl'.format(dataset_name))
        savename = os.path.join(ROOTDIR, savename)
        print('Saving predictions to: {}'.format(savename))
        pickle.dump( predict_dict, open(savename, "wb") )

## Determine Interesting Examples

In [None]:
if viz_interesting_cases:
    # Assuming that we care about visualizing the val dataset to explore the generalization error.
    loadname           = checkpoint_file.replace('.h5', '_val_preds.pkl')
    loadname           = os.path.join(ROOTDIR, loadname)
    predict_dict       = pickle.load(open(loadname, 'rb'))
    model_name         = checkpoint_file.split('/')[-2]
    predict_metrics_df = eval_prediction_dict(predict_dict, anchors, model_name)
    
    print('Aggregate statistics:')
    for key in predict_metrics_df:
        if '_' in key: # hack to pick only the numeric entries of the dataframe
            print(f"\t{key} : {predict_metrics_df[key].mean()}")
    
    print('Histograms and Means')
    colors = ['r', 'g', 'b']
    for key in ['traj_LL', 'min_ade', 'min_fde']: # skipping minmax distance.
        fig, ((ax1, ax2, ax3)) = plt.subplots(3, 1, sharey=True)    
        ax_list = [ax1, ax2, ax3]
        for ind_k, k in enumerate([1,3,5]):    
            ax = ax_list[ind_k]
            ax.hist(predict_metrics_df['%s_%d' % (key, k)], color=colors[ind_k], log=True)
            ax.axvline(np.mean(predict_metrics_df['%s_%d' % (key, k)]), lw=3, color='k', ls='--')
            ax.set_ylabel('%s_%d' % (key, k))            
        fig.tight_layout()   
        
    for key in ['length', 'curvature']:
        # These metrics do not have a _{1,3,5} suffix.  Invariant to number of modes considered.
        plt.figure()
        ax = plt.gca()
        ax.hist(predict_metrics_df[key], color=colors[ind_k], log=True)
        ax.axvline(np.mean(predict_metrics_df[key]), lw=3, color='k', ls='--')
        ax.set_ylabel(key) 

In [None]:
# Using the histograms, select some reasonable thresholds for outliers.
# Tune criteria for choosing interesting cases in this cell.
CURV_THRESH = 0.15
LL_THRESH   = -450.

if viz_interesting_cases:
    has_high_curvature = np.abs(predict_metrics_df['curvature']) >= CURV_THRESH
    has_low_likelihood = predict_metrics_df['traj_LL_5'] <= LL_THRESH
    interesting_cases  = np.logical_or(has_high_curvature, has_low_likelihood)
    
    print(f"Number High Curvature Cases: {np.sum(has_high_curvature)}")
    print(f"Number Low Likelihood Cases: {np.sum(has_low_likelihood)}")
    print(f"Number Interesting Cases: {np.sum(interesting_cases)}")

## Visualize Interesting Examples

In [None]:
if viz_interesting_cases:
    npv = NuscenesPredictionVisualizer(dataroot=nuscenes_datadir)

### How to interpret the plots:
* The left column shows the ground truth as red circles.
* The right column shows (by default) the top-3 modes of the GMM with 95\% covariance ellipses.
  - The highest probability mode is in yellow, and the least likely mode is in magenta.
  - The probabilities are also printed for these top-3 modes following the figure.

In [None]:
if viz_interesting_cases:
    for binary_thresh in ['has_high_curvature', 'has_low_likelihood']:
        print("-"*80)
        print(f"Thresh used: {binary_thresh}")
        interesting_cases_df = predict_metrics_df[eval(binary_thresh)]
        for index in range( np.sum(eval(binary_thresh)) ):
            example = interesting_cases_df.iloc[index]
            _, img_gt, img_pred, top_probs = npv.visualize_prediction(predict_dict, example['instance'])    
            plt.subplot(121); plt.imshow(img_gt);
            plt.subplot(122); plt.imshow(img_pred)
            plt.tight_layout()
            plt.title(example['instance'])
            plt.show()
            print(top_probs)
            print()
        print("-"*80)