In [None]:
# SET CONFIGURATION ARGUMENTS

args = {
    "nt": 1,
    "nb_epoch": 250,
    "batch_size": 1,
    "output_channels": [3, 48, 96, 192],
    "num_P_CNN": 1,
    "num_R_CLSTM": 1,
    "num_passes": 1,
    "pan_hierarchical": False,
    "downscale_factor": 4,
    "resize_images": False,
    "train_proportion": 0.7,
    "results_subdir": "dummy",
    "dataset_weights": "various",
    "data_subset_weights": "gen_ellipseV_crossH",
    "dataset": "general_shape_static",
    "data_subset": "general_ellipse_static_2nd_stage",
    "data_subset_mode": "test",
    "model_choice": "baseline",
    "system": "laptop",
    "reserialize_dataset": False,
    "output_mode": "Error"
}
args["results_subdir"] = f"interp_results/{args['dataset']}/{args['data_subset']}"

In [None]:
# LOAD MODEL

import argparse
from config import update_settings, get_settings
from data_utils import serialize_dataset
import numpy as np
import os
from datetime import datetime


update_settings(args["system"], args["dataset_weights"], args["data_subset_weights"], args["results_subdir"])
DATA_DIR, WEIGHTS_DIR, RESULTS_SAVE_DIR, LOG_DIR = get_settings()["dirs"]
data_dirs = [DATA_DIR, WEIGHTS_DIR, RESULTS_SAVE_DIR, LOG_DIR]
if not os.path.exists(RESULTS_SAVE_DIR):
    os.makedirs(RESULTS_SAVE_DIR)

import os
import warnings
import hickle as hkl

# Suppress warnings
warnings.filterwarnings("ignore")
# or '2' to filter out INFO messages too
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
import shutil
import keras
from keras import backend as K
from keras import layers
from data_utils import SequenceGenerator, IntermediateEvaluations, create_dataset_from_serialized_generator, config_gpus 
%matplotlib tk
import matplotlib.pyplot as plt
# import addcopyfighandler
from matplotlib.colors import LinearSegmentedColormap
import random

# PICK MODEL
if args["model_choice"] == "baseline":
    # Predict next frame along RGB channels only
    if not args['pan_hierarchical']:
        from PPN_models.PPN_Baseline import ParaPredNet
    else:
        from PPN_models.PPN_Baseline import ParaPredNet
        print("Using Pan-Hierarchical Representation")
elif args["model_choice"] == "cl_delta":
    # Predict next frame and change from current frame
    from PPN_models.PPN_CompLearning_Delta_Predictions import ParaPredNet
elif args["model_choice"] == "cl_recon":
    # Predict current and next frame
    from PPN_models.PPN_CompLearning_Recon_Predictions import ParaPredNet
elif args["model_choice"] == "multi_channel":
    # Predict next frame along Disparity, Material Index, Object Index, 
    # Optical Flow, Motion Boundaries, and RGB channels all stacked together
    assert args["dataset"] == "monkaa" or args["dataset"] == "driving", "Multi-channel model only works with Monkaa or Driving dataset"
    from PPN_models.PPN_Multi_Channel import ParaPredNet
    bottom_layer_output_channels = 7 # 1 Disparity, 3 Optical Flow, 3 RGB
    args["output_channels"][0] = bottom_layer_output_channels
else:
    raise ValueError("Invalid model choice")

# where weights are loaded prior to eval
if (args["dataset_weights"], args["data_subset_weights"]) in [
    ("rolling_square", "single_rolling_square"),
    ("rolling_circle", "single_rolling_circle"),
]:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["data_subset"]+"_weights.hdf5")
elif (args["dataset_weights"], args["data_subset_weights"]) in [
    ("all_rolling", "single"),
    ("all_rolling", "multi")
]:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["dataset_weights"]+"_"+args["data_subset_weights"]+"_weights.hdf5")
elif args["dataset_weights"] in ["all_rolling", "ball_collisions", "various"]:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["dataset_weights"]+"_"+args["data_subset_weights"]+"_weights.hdf5")
else:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["dataset_weights"]+"_weights.hdf5")
# weights_file = os.path.join(f"/home/evalexii/Documents/Thesis/code/parallel_prednet/model_weights/{args['dataset_weights']}/{args['data_subset_weights']}", f"para_prednet_{args['data_subset_weights']}_weights.hdf5")
assert os.path.exists(weights_file), "Weights file not found"
if args['dataset'] != args['dataset_weights']: 
    print(f"WARNING: dataset ({args['dataset']}) and dataset_weights ({args['dataset_weights']}/{args['data_subset_weights']}) do not match - generalizing...") 
else:
    print(f"OK: dataset ({args['dataset']}) and dataset_weights ({args['dataset_weights']}/{args['dataset_weights']}) match") 

# Training parameters
nt = args["nt"]  # number of time steps
batch_size = args["batch_size"]  # 4
output_channels = args["output_channels"]

# Define image shape
if args["dataset"] == "kitti":
    original_im_shape = (128, 160, 3)
    im_shape = original_im_shape
elif args["dataset"] == "monkaa" or args["dataset"] == "driving":
    original_im_shape = (540, 960, 3)
    downscale_factor = args["downscale_factor"]
    im_shape = (original_im_shape[0] // downscale_factor, original_im_shape[1] // downscale_factor, 3)
elif args["dataset"] in ["rolling_square", "rolling_circle"]:
    original_im_shape = (50, 100, 3)
    downscale_factor = args["downscale_factor"]
    im_shape = (original_im_shape[0] // downscale_factor, original_im_shape[1] // downscale_factor, 3) if args["resize_images"] else original_im_shape
else:
    original_im_shape = (50, 50, 3)
    downscale_factor = args["downscale_factor"]
    im_shape = (original_im_shape[0] // downscale_factor, original_im_shape[1] // downscale_factor, 3) if args["resize_images"] else original_im_shape

print(f"Working on dataset: {args['dataset']}")

# Create ParaPredNet
if args["dataset"] == "kitti":
    # These are Kitti specific input shapes
    inputs = (keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)))
    PPN = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])  # [3, 48, 96, 192]
    outputs = PPN(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

elif args["dataset"] == "monkaa":
    # These are Monkaa specific input shapes
    inputs = (keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
    )
    PPN = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])  # [3, 48, 96, 192]
    outputs = PPN(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

elif args["dataset"] == "driving":
    # These are driving specific input shapes
    inputs = (keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
    )
    PPN = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])  # [3, 48, 96, 192]
    outputs = PPN(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

else:
    inputs = keras.Input(shape=(nt, im_shape[0], im_shape[1], 3))
    PPN_layer = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])
    PPN_layer.output_mode = "Prediction"
    PPN_layer.continuous_eval = True
    outputs = PPN_layer(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

resos = PPN.layers[-1].resolutions
PPN.compile(optimizer="adam", loss="mean_squared_error")
print("ParaPredNet compiled...")
PPN.build(input_shape=(None, nt) + im_shape)
print(PPN.summary())
num_layers = len(output_channels)  # number of layers in the architecture
print(f"{num_layers} PredNet layers with resolutions:")
for i in reversed(range(num_layers)):
    print(f"Layer {i+1}:  {resos[i][0]} x {resos[i][1]} x {output_channels[i]}")

# load previously saved weights
try: 
    PPN.load_weights(weights_file)
    print("Weights loaded successfully...")
except: 
    raise ValueError("Weights don't fit - exiting...")

# Load dataset - only working for animations
try:
    test_data = hkl.load(DATA_DIR + f"{args['data_subset']}_{args['data_subset_mode']}.hkl")[0]
except:
    png_paths = [DATA_DIR + f"{args['dataset']}/frames/{args['data_subset']}/"]
    # png_paths = [DATA_DIR + f"{args['dataset']}/frames/{args['data_subset']}_{args['data_subset_mode']}/"]
    serialize_dataset(data_dirs, pfm_paths=[], pgm_paths=[], png_paths=png_paths, dataset_name=args['data_subset'], test_data=True)
    print("Dataset serialized...")
    test_data = hkl.load(DATA_DIR + f"{args['data_subset']}_{args['data_subset_mode']}.hkl")[0]
print("Test data ready...")

td_len = test_data.shape[0]
ppn = PPN.layers[-1]


In [None]:
# DATA COLLECTION CODE - AGGREGATION OR STATE DATA

collection_mode = "agg" # "agg" or "state"

if not os.path.exists(DATA_DIR + f"/{collection_mode}_data_{args['data_subset']}_{args['data_subset_mode']}.hkl"):
    print(f"Collecting {collection_mode} data...")

    # Assuming initialization and data loading is done here
    start = 0
    stop = td_len - 1
    num_samples = td_len - 1
    sample_shape = (1, 1, *test_data.shape[1:])

    # Initialize lists to store intermediate data for each layer
    R_states_list = [[] for _ in ppn.predlayers]
    P_states_list = [[] for _ in ppn.predlayers]

    # initialize lists to hold global max pooled states
    R_state_maxes = [[] for _ in ppn.predlayers]
    P_state_maxes = [[] for _ in ppn.predlayers]

    # initialize lists to hold top k% of channels by weight
    R_agg_mtx = [np.zeros((ppn.predlayers[j].output_channels)) for j in range(len(ppn.predlayers))]
    P_agg_mtx = [np.zeros((ppn.predlayers[j].output_channels)) for j in range(len(ppn.predlayers))]

    # Collect all states
    indices = np.random.permutation(range(start, stop))
    for it, i in enumerate(indices[:num_samples]):
        ppn.init_layer_states()
        ground_truth_image = np.reshape(test_data[i], sample_shape)
        predicted_image = ppn(ground_truth_image)
        predicted_image = ppn(ground_truth_image)
        # ground_truth_image = np.reshape(test_data[i+1], sample_shape)
        # predicted_image = ppn(ground_truth_image)
        
        # # UNCOMMENT TO COLLECT ALL STATE DATA
        # for j in range(len(ppn.predlayers)):
        #     R_states_list[j].append(ppn.predlayers[j].states["R"][0])
        #     P_states_list[j].append(ppn.predlayers[j].states["P"][0])

        # # UNCOMMENT TO COLLECT GLOBAL MAX POOLED STATES
        # for j in range(len(ppn.predlayers)):
        #     R_state_maxes[j].append(np.max(ppn.predlayers[j].states["R"][0], axis=(0,1)))
        #     P_state_maxes[j].append(np.max(ppn.predlayers[j].states["P"][0], axis=(0,1)))

        # UNCOMMMENT TO COLLENT AGGREGATION MATRICES
        # Aggregate top k% of channels by weight
        # Need a vector of length equal to the number of channels in the layer
        k = 0.1
        for j in range(len(ppn.predlayers)):
            R_max_pooled = np.max(ppn.predlayers[j].states["R"][0], axis=(0,1))
            P_max_pooled = np.max(ppn.predlayers[j].states["P"][0], axis=(0,1))
            R_norm = R_max_pooled / np.sum(R_max_pooled)
            P_norm = P_max_pooled / np.sum(P_max_pooled)
            R_sorted = np.argsort(R_norm)[::-1]
            P_sorted = np.argsort(P_norm)[::-1]
            R_sum = 0
            P_sum = 0
            R_indices = []
            P_indices = []
            for l in range(len(R_sorted)):
                if R_sum <= k:
                    R_sum += R_norm[R_sorted[l]]
                if (R_sum <= k) or (l == 0):
                    R_indices.append(R_sorted[l])
                else:
                    break
            for l in range(len(R_sorted)):
                if P_sum <= k:
                    P_sum += P_norm[P_sorted[l]]
                if (P_sum <= k) or (l == 0):
                    P_indices.append(R_sorted[l])
                else:
                    break
            R_agg_mtx[j][R_indices] += 1
            P_agg_mtx[j][P_indices] += 1

    print("Done...")
    # Save intermediate state data
    print(f"Saving {collection_mode} data...")
    hkl.dump([R_agg_mtx, P_agg_mtx], DATA_DIR + f"/{collection_mode}_data_{args['data_subset']}_{args['data_subset_mode']}.hkl")
    print("Done...")

else:
    print(f"{collection_mode} data for '{args['data_subset']}_{args['data_subset_mode']}' already exists...")



In [None]:
# LOAD AND PLOT AGGREGATION MATRICES AND DIFFERENCE MATRICES

agg_datasets = [
    DATA_DIR + f"/agg_data_general_cross_static_2nd_stage_test.hkl",
    DATA_DIR + f"/agg_data_general_ellipse_static_2nd_stage_test.hkl",
]

# Post-process all states after collecting
R_agg_mtxs = [[] for _ in agg_datasets]
P_agg_mtxs = [[] for _ in agg_datasets]

for i in range(len(agg_datasets)):
    R_agg_mtxs[i], P_agg_mtxs[i] = hkl.load(agg_datasets[i])

# Plot the aggregate matrices
fig, axs = plt.subplots(4, 2, figsize=(10, 10))
colors = ['blue','orange']
for i in range(len(agg_datasets)):
    for j in range(len(ppn.predlayers)):
        indices = range(len(R_agg_mtxs[i][j]))
        axs[j,0].bar(indices,R_agg_mtxs[i][j], alpha=0.7, color=colors[i])
        axs[j,1].bar(indices,P_agg_mtxs[i][j], alpha=0.7, color=colors[i])
        axs[j,0].set_title(f"R Agg, Layer {j+1}")
        axs[j,1].set_title(f"P Agg, Layer {j+1}")
# plt.legend()
plt.show()

# Find the difference matrices from the two datasets
R_diffs = [[] for _ in range(len(agg_datasets))]
P_diffs = [[] for _ in range(len(agg_datasets))]
for i in range(len(agg_datasets)):
    for j in range(len(ppn.predlayers)):
        R_diff = R_agg_mtxs[i][j] - R_agg_mtxs[1-i][j]
        R_diff[R_diff < 0] = 0
        P_diff = P_agg_mtxs[i][j] - P_agg_mtxs[1-i][j]
        P_diff[P_diff < 0] = 0
        R_diffs[i].append(R_diff)
        P_diffs[i].append(P_diff)

# Plot the difference matrices
fig, axs = plt.subplots(4, 2, figsize=(10, 10))
colors = ['blue','orange']
for i in range(len(agg_datasets)):
    for j in range(len(ppn.predlayers)):
        indices = range(len(R_diffs[i][j]))
        axs[j,0].bar(indices,R_diffs[i][j], alpha=0.7, color=colors[i])
        axs[j,1].bar(indices,P_diffs[i][j], alpha=0.7, color=colors[i])
        axs[j,0].set_title(f"R Aggregate Differences, Layer {j+1}")
        axs[j,1].set_title(f"P Aggregate Differences, Layer {j+1}")
# plt.legend()
plt.show()

In [None]:
# LOAD STATE MAXES

state_datasets = [
    DATA_DIR + f"/state_data_general_cross_static_2nd_stage_test.hkl",
    DATA_DIR + f"/state_data_general_ellipse_static_2nd_stage_test.hkl",
]

# Post-process all states after collecting
R_state_maxes = [[] for _ in state_datasets]
P_state_maxes = [[] for _ in state_datasets]
R_max_indices = [[] for _ in state_datasets]
P_max_indices = [[] for _ in state_datasets]

# if saved datasets are state maxes
for i in range(len(state_datasets)):
    R_state_maxes[i], P_state_maxes[i] = hkl.load(state_datasets[i])
    for j in range(len(ppn.predlayers)):
        R_state_maxes[i][j] = np.array(R_state_maxes[i][j])
        P_state_maxes[i][j] = np.array(P_state_maxes[i][j])

# # if saved datasets are state tensors
# for i in range(len(state_datasets)):
#     R_states_list, P_states_list = hkl.load(state_datasets[i])
#     for j in range(len(ppn.predlayers)):
#         R_states = np.array(R_states_list[j])
#         P_states = np.array(P_states_list[j])
        
#         # Max pooling across all samples for each layer
#         R_max_pooled = np.max(R_states, axis=(1, 2))  # Max pooling across spatial dimensions
#         P_max_pooled = np.max(P_states, axis=(1, 2))
        
#         R_state_maxes[i].append(R_max_pooled)
#         P_state_maxes[i].append(P_max_pooled)
        
#         # Find top-k indices
#         num_filters = R_states.shape[-1]
#         top_k = num_filters // 3
        
#         R_max_indices[i].append(np.argsort(R_max_pooled, axis=1)[:,-top_k:])
#         P_max_indices[i].append(np.argsort(P_max_pooled, axis=1)[:,-top_k:])

print("Done...")

In [None]:
# PLOT ALL LAYER STATE MAXES

# Calculate mean, std of max activations per layer, plot error bars
# discard activations with std > 0.1
# then plot top 10 activations with x-axis as filter index
std_max = 0.1
top_k = 10
fig, axs = plt.subplots(2, 4, figsize=(50, 50))
labels = ["Cross-Right", "Ellipse-Down"]
for k in range(len(state_datasets)):
    for i in range(len(ppn.predlayers)):
        R_mean_o = np.mean(R_state_maxes[k][i], axis=0)
        R_std_o = np.std(R_state_maxes[k][i], axis=0)
        R_indices_o = np.arange(len(R_mean_o))
        R_mean_i = R_mean_o[R_std_o < std_max]
        R_mean = R_mean_i[np.argsort(R_mean_i)[-top_k:]]
        R_std_i = R_std_o[R_std_o < std_max]
        R_std = R_std_i[np.argsort(R_mean_i)[-top_k:]]
        R_indices_i = R_indices_o[R_std_o < std_max]
        R_indices = R_indices_i[np.argsort(R_mean_i)[-top_k:]]
        axs[0,i].errorbar(R_indices, R_mean, yerr=R_std, fmt='o', label=labels[k])
        P_mean_o = np.mean(P_state_maxes[k][i], axis=0)
        P_std_o = np.std(P_state_maxes[k][i], axis=0)
        P_indices_o = np.arange(len(P_mean_o))
        P_mean_i = P_mean_o[P_std_o < std_max]
        P_mean = P_mean_i[np.argsort(P_mean_i)[-top_k:]]
        P_std_i = P_std_o[P_std_o < std_max]
        P_std = P_std_i[np.argsort(P_mean_i)[-top_k:]]
        P_indices_i = P_indices_o[P_std_o < std_max]
        P_indices = P_indices_i[np.argsort(P_mean_i)[-top_k:]]
        axs[1,i].errorbar(P_indices, P_mean, yerr=P_std, fmt='o', label=labels[k])
        # axs[0,i].set_ylim(-0.1, 1.1)
        # axs[1,i].set_ylim(-0.1, 1.1)
        axs[0,i].set_title(f"R-State Layer {i+1}")
        axs[1,i].set_title(f"P-State Layer {i+1}")
        axs[0,i].set_xlabel("Channel Index")
        axs[1,i].set_xlabel("Channel Index")
        axs[0,i].legend()
        axs[1,i].legend()
        
fig.suptitle(f"Top {top_k} R- / P-State Max Channel-Values with STD < {std_max}")
# axs[1].set_title(f"Top {top_k} P State Max Values with STD < {std_max}")
plt.show()

In [None]:
# PLOT SINGLE LAYER STATE MAXES

# Calculate mean, std of max activations per layer, plot error bars
# discard activations with std > 0.1
# then plot top 10 activations with x-axis as filter index
std_max = 1
top_k = 200
fig, axs = plt.subplots(2, 1, figsize=(50, 50))
labels = ["Cross-Right", "Ellipse-Down"]
i = 3 # layer index (L-1)
for k in range(len(state_datasets)):

    R_mean_o = np.mean(R_state_maxes[k][i], axis=0)
    R_std_o = np.std(R_state_maxes[k][i], axis=0)
    R_indices_o = np.arange(len(R_mean_o))
    R_mean_i = R_mean_o[R_std_o < std_max]
    R_mean = R_mean_i[np.argsort(R_mean_i)[-top_k:]]
    R_std_i = R_std_o[R_std_o < std_max]
    R_std = R_std_i[np.argsort(R_mean_i)[-top_k:]]
    R_indices_i = R_indices_o[R_std_o < std_max]
    R_indices = R_indices_i[np.argsort(R_mean_i)[-top_k:]]
    P_mean_o = np.mean(P_state_maxes[k][i], axis=0)
    P_std_o = np.std(P_state_maxes[k][i], axis=0)
    P_indices_o = np.arange(len(P_mean_o))
    P_mean_i = P_mean_o[P_std_o < std_max]
    P_mean = P_mean_i[np.argsort(P_mean_i)[-top_k:]]
    P_std_i = P_std_o[P_std_o < std_max]
    P_std = P_std_i[np.argsort(P_mean_i)[-top_k:]]
    P_indices_i = P_indices_o[P_std_o < std_max]
    P_indices = P_indices_i[np.argsort(P_mean_i)[-top_k:]]
    axs[0].errorbar(R_indices, R_mean, yerr=R_std, fmt='o', label=labels[k])
    axs[1].errorbar(P_indices, P_mean, yerr=P_std, fmt='o', label=labels[k])
    axs[0].set_ylim(-0.1, 1.1)
    axs[1].set_ylim(-0.1, 1.1)
    axs[0].set_title(f"R-State Layer {i+1}")
    axs[1].set_title(f"P-State Layer {i+1}")
    axs[0].set_xlabel("Channel Index")
    axs[1].set_xlabel("Channel Index")
    axs[0].legend()
    axs[1].legend()
fig.suptitle(f"Top {top_k} R- / P-State Max Channel-Values with STD < {std_max}")
# axs[1].set_title(f"Top {top_k} P State Max Values with STD < {std_max}")
plt.show()

In [None]:
# OLD STATE MAX EXTRACTION CODE

# iterate through whole test dataset pulling out filters from each layer
ppn = PPN.layers[-1]
start = 0
stop = td_len-1
num_samples = td_len-1
# initialize lists to hold global max pooled states
R_state_maxes = [None]*len(ppn.predlayers)
P_state_maxes = [None]*len(ppn.predlayers)
R_max_indices = [None]*len(ppn.predlayers)
P_max_indices = [None]*len(ppn.predlayers)
all_R_states_maxes = None
all_P_states_maxes = None
all_R_max_indices = None
all_P_max_indices = None
for it, i in enumerate(random.sample(range(start, stop), num_samples)):
    print(f"Sample {it+1}/{num_samples}...")
    # if i > 0: break
    # manually initialize PPN layer states
    ppn.init_layer_states()
    # run image through twice to get activation that captures class-recognition
    # not necessary to feed second sequence image in, as the image only contributes to bottom-up error, but nonetheless...
    ground_truth_image = np.reshape(test_data[i], (1, 1, *test_data.shape[1:]))
    predicted_image = ppn(ground_truth_image)
    ground_truth_image = np.reshape(test_data[i+1], (1, 1, *test_data.shape[1:]))
    predicted_image = ppn(ground_truth_image)

    # add all state tensors to lists
    R_states = []
    P_states = []
    for j in range(len(ppn.predlayers)):
        R_states.append(ppn.predlayers[j].states["R"][0])
        P_states.append(ppn.predlayers[j].states["P"][0])
    
    # perform global max pooling on each layer's R and P states
    for j in range(len(ppn.predlayers)):
        R_state_maxes[j] = np.expand_dims(np.max(R_states[j], axis=(0,1)), axis=0) if R_state_maxes[j] is None else np.concatenate((R_state_maxes[j], np.expand_dims(np.max(R_states[j], axis=(0,1)), axis=0)), axis=0)
        P_state_maxes[j] = np.expand_dims(np.max(P_states[j], axis=(0,1)), axis=0) if P_state_maxes[j] is None else np.concatenate((P_state_maxes[j], np.expand_dims(np.max(P_states[j], axis=(0,1)), axis=0)), axis=0)

        num_filters = R_states[j].shape[-1]
        assert R_states[j].shape[-1] == P_states[j].shape[-1]
        top_k = int(num_filters/3)

        R_max_indices[j] = np.expand_dims(np.argsort(R_state_maxes[j][-1])[-top_k:], axis=0) if R_max_indices[j] is None else np.concatenate((R_max_indices[j], np.expand_dims(np.argsort(R_state_maxes[j][-1])[:top_k], axis=0)), axis=0)
        P_max_indices[j] = np.expand_dims(np.argsort(P_state_maxes[j][-1])[-top_k:], axis=0) if P_max_indices[j] is None else np.concatenate((P_max_indices[j], np.expand_dims(np.argsort(P_state_maxes[j][-1])[:top_k], axis=0)), axis=0)
    
    # also find pan-hierarchical distributed representations
    all_R_states_maxes = np.expand_dims(np.concatenate([j[-1] for j in R_state_maxes], axis=0), axis=0) if all_R_states_maxes is None else np.concatenate((all_R_states_maxes, np.expand_dims(np.concatenate([j[-1] for j in R_state_maxes], axis=0), axis=0)), axis=0)
    all_P_states_maxes = np.expand_dims(np.concatenate([j[-1] for j in P_state_maxes], axis=0), axis=0) if all_P_states_maxes is None else np.concatenate((all_P_states_maxes, np.expand_dims(np.concatenate([j[-1] for j in P_state_maxes], axis=0), axis=0)), axis=0)
    all_R_max_indices = np.expand_dims(np.argsort(all_R_states_maxes[-1])[-10:], axis=0) if all_R_max_indices is None else np.concatenate((all_R_max_indices, np.expand_dims(np.argsort(all_R_states_maxes[-1])[:10], axis=0)), axis=0)
    all_P_max_indices = np.expand_dims(np.argsort(all_P_states_maxes[-1])[-10:], axis=0) if all_P_max_indices is None else np.concatenate((all_P_max_indices, np.expand_dims(np.argsort(all_P_states_maxes[-1])[:10], axis=0)), axis=0)

    # # plot filter weights for each layer
    # num_layers = len(ppn.predlayers)
    # for layer in ppn.predlayers:
    #     for p_c_layer in layer.prediction.conv_layers:
    #         for weights in p_c_layer.trainable_weights:
    #             if len(weights.shape) == 1: continue

