# Checkpoint Evaluation
* The purpose of this notebook is to check prediction performance as a function of training epochs (checkpoints).
* Includes train/val prediction errors and visualization of the latter image features.

## Adjustable Parameters

In [None]:
checkpoint_dir = 'log/l5kit_multipath_lstm'
model_type     = 'multipath'
dataset_type  = 'l5kit'

run_prediction_evaluation = True
run_layer_visualization   = True

## Code Setup

In [None]:
import os
import sys
import glob
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import cv2

ROOTDIR = os.getcwd().split('scripts')[0]
sys.path.append(os.path.join(ROOTDIR, 'scripts'))

from models.regression import Regression
from models.multipath import MultiPath
from datasets.splits import NUSCENES_TRAIN, NUSCENES_VAL, L5KIT_TRAIN, L5KIT_VAL
from datasets.tfrecord_utils import _parse_function
from evaluation.gmm_prediction import GMMPrediction

In [None]:
checkpoints = glob.glob(os.path.join(ROOTDIR, checkpoint_dir, '*.h5') )
checkpoints.sort()
if len(checkpoints) == 0:
    raise ValueError("No records found!")
else:
    print("{} records found.".format(len(checkpoints)))

if dataset_type == 'nuscenes':
    n_t, n_h_t = 12, 2
    anchors = np.load(os.path.join(ROOTDIR, 'data/nuscenes_clusters_16.npy'))
    weights = np.load(os.path.join(ROOTDIR, 'data/nuscenes_clusters_16_weights.npy'))
    train_records, val_records = NUSCENES_TRAIN, NUSCENES_VAL
elif dataset_type == 'l5kit':
    n_t, n_h_t = 25, 5
    anchors = np.load(os.path.join(ROOTDIR, 'data/l5kit_clusters_16.npy'))
    weights = np.load(os.path.join(ROOTDIR, 'data/l5kit_clusters_16_weights.npy'))
    train_records, val_records = L5KIT_TRAIN, L5KIT_VAL
else:
    raise ValueError("{} not implemented".format(dataset_type))
        
if model_type == 'multipath':
    model = MultiPath(num_timesteps=n_t, num_hist_timesteps=n_h_t, anchors=anchors, weights=weights)
elif model_type == 'regression':
    model = Regression(num_timesteps=n_t, num_hist_timesteps=n_h_t)
else:
    raise ValueError("{} not implemented".format(dataset_type))

## Prediction Evaluation

In [None]:
def eval_prediction_dict(predict_dict, model_name):
    data_list = []
    ks_eval = [1, 3, 5]

    columns   = ["sample", "instance", "model"]
    columns.extend([f"traj_LL_{k}" for k in ks_eval])
    columns.extend([f"class_top_{k}" for k in ks_eval])
    columns.extend([f"min_ade_{k}" for k in ks_eval])
    columns.extend([f"min_fde_{k}" for k in ks_eval])
    columns.extend([f"minmax_dist_{k}" for k in ks_eval])

    for key in tqdm(predict_dict.keys()):
        future_traj_gt = predict_dict[key]['future_traj']
        future_xy_gt = future_traj_gt[:, 1:3]
        gmm_pred       = predict_dict[key]['gmm_pred']

        n_modes     = len(gmm_pred.keys())
        n_timesteps = future_traj_gt.shape[0]
        mode_probabilities = np.array( [gmm_pred[mode]['mode_probability'] for mode in range(n_modes)] )
        mus                = np.array( [gmm_pred[mode]['mus'] for mode in range(n_modes)] )
        sigmas             = np.array( [gmm_pred[mode]['sigmas'] for mode in range(n_modes)] )

        gmm_pred     = GMMPrediction(n_modes, n_timesteps, mode_probabilities, mus, sigmas)

        sample_token   = '_'.join( key.split('_')[:-2] )
        instance_token = '_'.join( key.split('_')[-2:] )

        data_list_entry = [sample_token, instance_token, model_name]
        if n_modes == 1:
            num_ks = len(ks_eval)
            data_list_entry.extend([gmm_pred.compute_trajectory_log_likelihood(future_xy_gt)]*num_ks)
            data_list_entry.extend([1]*num_ks) # unimodal
            data_list_entry.extend([gmm_pred.compute_min_ADE(future_xy_gt)]*num_ks)
            data_list_entry.extend([gmm_pred.compute_min_FDE(future_xy_gt)]*num_ks)
            data_list_entry.extend([gmm_pred.compute_minmax_d(future_xy_gt)]*num_ks)
        else:
            gmm_pred_ks  = [gmm_pred.get_top_k_GMM(k) for k in ks_eval]

            data_list_entry.extend([gmm_pred_k.compute_trajectory_log_likelihood(future_xy_gt) \
                                    for gmm_pred_k in gmm_pred_ks])
            data_list_entry.extend(gmm_pred.get_class_top_k_scores(future_xy_gt, anchors, ks_eval))
            data_list_entry.extend([gmm_pred_k.compute_min_ADE(future_xy_gt) \
                                    for gmm_pred_k in gmm_pred_ks])
            data_list_entry.extend([gmm_pred_k.compute_min_FDE(future_xy_gt) \
                                    for gmm_pred_k in gmm_pred_ks])
            data_list_entry.extend([gmm_pred_k.compute_minmax_d(future_xy_gt) \
                                    for gmm_pred_k in gmm_pred_ks])

        data_list.append(data_list_entry)
    metrics_df = pd.DataFrame(data_list, columns=columns)
    return metrics_df

def average_selected_metrics(df, relevant_keys):
    avg_dict = {}
    for key in relevant_keys:
        avg_dict[key] = np.mean(df[key])
    return avg_dict

def select_relevant_keys(df):
    relevant_keys = []
    
    for key in df.columns:
        if 'class' in key:        # mode classification
            pass
        elif 'ade' in key:        # min ADE
            pass
        elif 'fde' in key:        # min FDE
            pass
        elif 'traj_LL' in key:    # log likelihood
            pass
        else:
            continue
        relevant_keys.append(key)
        
    return relevant_keys
    
if run_prediction_evaluation:    
    checkpoint_train_dict = {'model': []}
    checkpoint_val_dict   = {'model': []}
    
    for checkpoint in checkpoints:
        model.load_weights(checkpoint)
        model_name = model.model.name + '_' + checkpoint.split('/')[-1].split('_')[0] # bit hacky, may need fixing
        
        predictions_train_dict = model.predict(train_records)
        train_df = eval_prediction_dict(predictions_train_dict, model_name)
        avg_train_df = average_selected_metrics(train_df, select_relevant_keys(train_df))
        for key in avg_train_df.keys():
            if key in checkpoint_train_dict.keys():
                checkpoint_train_dict[key].append(avg_train_df[key])
            else:
                checkpoint_train_dict[key] = [avg_train_df[key]]
        checkpoint_train_dict['model'].append(model_name) 
        
        predictions_val_dict = model.predict(val_records)
        val_df = eval_prediction_dict(predictions_val_dict, model_name)
        avg_val_df = average_selected_metrics(val_df, select_relevant_keys(val_df))
        for key in avg_val_df.keys():
            if key in checkpoint_val_dict.keys():
                checkpoint_val_dict[key].append(avg_val_df[key])
            else:
                checkpoint_val_dict[key] = [avg_val_df[key]]
        checkpoint_val_dict['model'].append(model_name) 
    
    # Plot Results Across Checkpoints.  Hard-coding the keys for now, future can make this automated.
    epochs = [int(name.split('_')[-1]) for name in checkpoint_train_dict['model']]    
    min_epoch, max_epoch = np.amin(epochs), np.amax(epochs)
    epoch_delta = 10
    epoch_ticks = np.arange(min_epoch, max_epoch + epoch_delta, epoch_delta).astype(np.int)
    print('Train\tVal')
    
    key_prefixes = ['traj_LL', 'class_top', 'min_ade', 'min_fde']
    for key_prefix in key_prefixes:
        fig, ((ax1), (ax2)) = plt.subplots(1, 2, sharex=True, sharey=True)    
        for k in [1,3,5]:    
            ax1.plot(epochs, checkpoint_train_dict['%s_%d' % (key_prefix, k)], label='%d' % k)
            ax2.plot(epochs, checkpoint_val_dict['%s_%d' % (key_prefix, k)], label='%d' % k)
        ax1.set_xticks(epoch_ticks)
        ax2.set_xticks(epoch_ticks)
        ax1.set_xlabel('Epoch')
        ax2.set_xlabel('Epoch')   
        ax1.set_ylabel(key_prefix)
        ax1.grid()                
        ax2.grid()
        ax2.legend()    
        fig.tight_layout()
    
    plt.show()

## Layer Visualization

In [None]:
#[print(layer.name) for layer in model.model.layers] # To locate which layer to visualize.

if run_layer_visualization:

    target_layer = 'batch_normalization'

    entry_to_viz = 15 # which dataset example to view
    dataset = tf.data.TFRecordDataset(val_records)
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(1)

    img_orig, img_preprocessed = None, None
    for ind_entry, entry in enumerate(dataset):
        if ind_entry == entry_to_viz:
            img_orig = entry['image']
            img_preprocessed, _, _ = model.preprocess_entry(entry)

    def plot_top_activations(activations, k=8):    
        summed_activations = tf.reduce_sum(tf.abs(activations), axis=[0, 1, 2])
        top_k = tf.math.top_k(summed_activations, k=k).indices.numpy()

        for i, act_ind in enumerate(top_k):
            plt.subplot(2, np.ceil(k/2), i+1)
            plt.imshow(activations[0, :, :, act_ind], cmap='plasma')        

    # See how the layer evolves over time with more training.
    img_orig_ds = cv2.resize(img_orig[0].numpy(), (32,32), interpolation=cv2.INTER_AREA)
    for checkpoint in checkpoints:
        model.load_weights(checkpoint)
        viz_model = tf.keras.Model(model.model.get_layer(name='image_input').output, 
                                   model.model.get_layer(name=target_layer).output)    
        out = viz_model.predict_on_batch(img_preprocessed)
        plt.figure()
        plt.imshow(img_orig_ds)

        plt.figure()
        plot_top_activations(out)
