# Checkpoint Evaluation
* The purpose of this notebook is to check prediction performance as a function of training epochs (checkpoints).
* Includes train/val prediction errors and visualization of the latter image features.

## Adjustable Parameters

In [None]:
checkpoint_dir = 'log/nuscenes_multipath_lstm' # location of the weights relative to the repo.
model_type     = 'multipath'                   # multipath or regression
dataset_type  = 'nuscenes'                     # nuscenes or l5kit

run_prediction_evaluation = True # visualize performance vs. training epoch
run_layer_visualization   = True # visualize intermediate CNN activations vs. training epoch

## Code Setup

In [None]:
import os
import sys
import glob
from tqdm import tqdm
from collections import defaultdict
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import cv2

ROOTDIR = os.getcwd().split('scripts')[0]
sys.path.append(os.path.join(ROOTDIR, 'scripts'))

from models.regression import Regression
from models.multipath import MultiPath
from datasets.splits import NUSCENES_TRAIN, NUSCENES_VAL, L5KIT_TRAIN, L5KIT_VAL
from datasets.tfrecord_utils import _parse_function
from evaluation.gmm_prediction import GMMPrediction
from evaluation.pandas_df_utils import eval_prediction_dict, average_selected_keys

In [None]:
checkpoints = glob.glob(os.path.join(ROOTDIR, checkpoint_dir, '*.h5') )
checkpoints.sort()
if len(checkpoints) == 0:
    raise ValueError("No records found!")
else:
    print("{} records found.".format(len(checkpoints)))

if dataset_type == 'nuscenes':
    n_t, n_h_t = 12, 2
    anchors = np.load(os.path.join(ROOTDIR, 'data/nuscenes_clusters_16.npy'))
    weights = np.load(os.path.join(ROOTDIR, 'data/nuscenes_clusters_16_weights.npy'))
    train_records, val_records = NUSCENES_TRAIN, NUSCENES_VAL
elif dataset_type == 'l5kit':
    n_t, n_h_t = 25, 5
    anchors = np.load(os.path.join(ROOTDIR, 'data/l5kit_clusters_16.npy'))
    weights = np.load(os.path.join(ROOTDIR, 'data/l5kit_clusters_16_weights.npy'))
    train_records, val_records = L5KIT_TRAIN, L5KIT_VAL
else:
    raise ValueError("{} not implemented".format(dataset_type))
        
if model_type == 'multipath':
    model = MultiPath(num_timesteps=n_t, num_hist_timesteps=n_h_t, anchors=anchors, weights=weights)
elif model_type == 'regression':
    model = Regression(num_timesteps=n_t, num_hist_timesteps=n_h_t)
else:
    raise ValueError("{} not implemented".format(dataset_type))

## Prediction Evaluation

In [None]:
if run_prediction_evaluation: 
    metric_key_prefixes = ["traj_LL", "class_top", "min_ade", "min_fde", "minmax_dist"]
    ks_eval = [1,3,5]
    keys_to_average = [f"{x}_{y}" for x in metric_key_prefixes for y in ks_eval]
    
    checkpoint_train_dict = defaultdict(lambda: [])
    checkpoint_val_dict   = defaultdict(lambda: [])
    
    # Aggregate prediction metrics by checkpoint and by dataset split.
    for checkpoint in checkpoints:
        model.load_weights(checkpoint)
        model_name = model.model.name + '_' + checkpoint.split('/')[-1].split('_')[0] # bit hacky, may need fixing
        
        for split in ['train', 'val']:
            predictions_dict = model.predict( eval(f"{split}_records") )
            metrics_df       = eval_prediction_dict(predictions_dict, anchors, model_name, ks_eval=ks_eval)
            avg_df           = average_selected_keys(metrics_df, keys_to_average)
            
            checkpoint_df = eval(f"checkpoint_{split}_dict")
            for key in avg_df.keys():
                checkpoint_df[key].append(avg_df[key])
            checkpoint_df['model'].append(model_name)
    
    # Plot Results Across Checkpoints.  Hard-coding the keys for now, future can make this automated.
    epochs = [int(name.split('_')[-1]) for name in checkpoint_train_dict['model']]    
    min_epoch, max_epoch = np.amin(epochs), np.amax(epochs)
    epoch_delta = 10
    epoch_ticks = np.arange(min_epoch, max_epoch + epoch_delta, epoch_delta).astype(np.int)
        
    for plot_ind, key_prefix in enumerate(metric_key_prefixes):
        fig, ((ax1), (ax2)) = plt.subplots(1, 2, sharex=True, sharey=True)    
        for k in ks_eval:    
            ax1.plot(epochs, checkpoint_train_dict['%s_%d' % (key_prefix, k)], label='%d' % k)
            ax2.plot(epochs, checkpoint_val_dict['%s_%d' % (key_prefix, k)], label='%d' % k)
        ax1.set_xticks(epoch_ticks)
        ax2.set_xticks(epoch_ticks)
        ax1.set_xlabel('Epoch')
        ax2.set_xlabel('Epoch')   
        ax1.set_ylabel(key_prefix)
        ax1.grid()                
        ax2.grid()
        ax2.legend()    
        fig.tight_layout()
        
        if plot_ind == 0:
            plt.suptitle("Train / Val")
    
    plt.show()

## Layer Visualization

In [None]:
#[print(layer.name) for layer in model.model.layers] # To locate which layer to visualize.

if run_layer_visualization:

    target_layer = 'batch_normalization'

    entry_to_viz = 15 # which dataset example to view
    dataset = tf.data.TFRecordDataset(val_records)
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(1)

    img_orig, img_preprocessed = None, None
    for ind_entry, entry in enumerate(dataset):
        if ind_entry == entry_to_viz:
            img_orig = entry['image']
            img_preprocessed, _, _ = model.preprocess_entry(entry)

    def plot_top_activations(activations, k=8):    
        summed_activations = tf.reduce_sum(tf.abs(activations), axis=[0, 1, 2])
        top_k = tf.math.top_k(summed_activations, k=k).indices.numpy()

        for i, act_ind in enumerate(top_k):
            plt.subplot(2, np.ceil(k/2), i+1)
            plt.imshow(activations[0, :, :, act_ind], cmap='plasma')        

    # See how the layer evolves over time with more training.
    img_orig_ds = cv2.resize(img_orig[0].numpy(), (32,32), interpolation=cv2.INTER_AREA)
    for checkpoint in checkpoints:
        model.load_weights(checkpoint)
        viz_model = tf.keras.Model(model.model.get_layer(name='image_input').output, 
                                   model.model.get_layer(name=target_layer).output)    
        out = viz_model.predict_on_batch(img_preprocessed)
        plt.figure()
        plt.imshow(img_orig_ds)

        plt.figure()
        plot_top_activations(out)
