In [None]:
import os
import sys
import time

import numpy as np
import orbax.checkpoint as ocp
import pandas as pd
import seaborn as sns
from flax import nnx

import h5py

sys.path.insert(0, os.path.abspath("../.."))

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.collections import LineCollection
from transformers.evaluation.eval_episodes import bb_record_only_goal
from transformers.models.pref_transformer import load_PT
from transformers.replayer.replayer import animate_segment

In [None]:
def rclr(d):
    log_d = np.log(d)
    log_d = np.where(np.isfinite(log_d), log_d, 0.0)
    return log_d - log_d.mean()

In [None]:
seed = 4
rng = np.random.default_rng(seed)
e_sts, e_acts, e_ts = bb_record_only_goal(0.44, obs_dist=200, seed=seed)
if e_sts.shape[1] % 100 == 0:
    fill_size = e_sts.shape[1]
else:
    fill_size = e_sts.shape[1] + (100 - (e_sts.shape[1] % 100))
n_splits = int(fill_size / 100)

if fill_size > e_sts.shape[1]:
    sts = np.pad(
        e_sts,
        ((0, 0), (0, fill_size - e_sts.shape[1]), (0, 0)),
        constant_values=0,
    )
    acts = np.pad(
        e_acts,
        ((0, 0), (0, fill_size - e_acts.shape[1]), (0, 0)),
        constant_values=0,
    )
    ts = np.arange(fill_size)
    am = np.zeros(fill_size)
    am[: e_sts.shape[1]] = 1

    sts = sts.reshape((n_splits, 100, sts.shape[2]))
    acts = acts.reshape((n_splits, 100, acts.shape[2]))
    ts = ts.reshape((n_splits, 100))
    am = am.reshape((n_splits, 100))

am_sum = int(am.sum())
lvl = sts[-1, 0, 22]
ai = sts[-1, 0, 23]
attempt = sts[-1, 0, 24]
day = sts[-1, 0, 25]
print(f"Static Features: Level = {lvl}, AI = {ai}, Attempt = {attempt}, Day = {day}")

In [None]:
p_ids = ["t0025", "t0072", "t0033", "t0064", "t0048", "t0009"]
n_samps = 1000
speeds = rng.uniform(0.0, 0.44, n_samps)
angles = rng.uniform(-180.0, 180.0, n_samps)
sample_acts = np.stack([speeds, angles]).T
discount = 0.99

In [None]:
time_range = np.arange(am_sum)
min_reward = np.inf
max_reward = -np.inf

n_rewards = []
n_weights = []
n_clr_weights = []
n_c_rewards = []

min_clr_w = np.inf
max_clr_w = -np.inf

checkpointer = ocp.Checkpointer(ocp.CompositeCheckpointHandler())
for p in p_ids:

    reward_function = os.path.expanduser(
        f"~/busy-beeway/transformers/pt_rewards_bb/{p}/best_model.ckpt"
    )

    r_model = load_PT(reward_function, checkpointer, on_cpu=True)
    r_model = nnx.jit(r_model, static_argnums=4)
    rewards, weights = r_model(sts, acts, ts, am, training=False)
    rewards = rewards["value"].reshape(
        100,
    )
    rewards = rewards[:am_sum]
    weights = weights[-1].reshape(1, 100, 100)
    weights = np.mean(weights, axis=1).reshape(
        100,
    )
    weights = weights[:am_sum]

    rclr_w = rclr(weights)

    min_reward = np.min([min_reward, np.min(rewards)])
    max_reward = np.max([max_reward, np.max(rewards)])

    n_rewards.append(np.column_stack([time_range, rewards]))

    n_weights.append(np.column_stack([time_range, weights]))

    n_clr_weights.append(np.column_stack([time_range, rclr_w]))

    min_clr_w = np.min([min_clr_w, np.min(rclr_w)])
    max_clr_w = np.max([max_clr_w, np.max(rclr_w)])

    r_mean = np.zeros(sts.shape[1])
    for i in range(sts.shape[1]):
        sts_samps = np.repeat(
            sts[:, : (i + 1), :].reshape(1, -1, sts.shape[2]), n_samps, axis=0
        )
        acts_samps = np.repeat(
            acts[:, : (i + 1), :].reshape(1, -1, acts.shape[2]), n_samps, axis=0
        )
        acts_samps[:, -1, :] = sample_acts
        ts_samps = np.repeat(ts[:, : (i + 1)].reshape(1, -1), n_samps, axis=0)
        am_samps = np.repeat(am[:, : (i + 1)].reshape(1, -1), n_samps, axis=0)
        rwd_samps, _ = r_model(
            sts_samps, acts_samps, ts_samps, am_samps, training=False
        )
        rwd_samps = rwd_samps["value"].reshape(n_samps, -1)[:, -1]
        r_mean[i] = rwd_samps.mean()

    c_rewards = np.zeros(sts.shape[1] - 1)
    for i in range(sts.shape[1] - 1):
        c_rewards[i] = rewards[i] + discount * r_mean[i + 1] - r_mean[i]
    n_c_rewards.append(c_rewards)
checkpointer.close()

In [None]:
cmap = plt.get_cmap("rainbow")
colors = cmap(rng.random(len(p_ids)))
fig, axe = plt.subplots(3, figsize=(15, 10))
plt.subplots_adjust(hspace=0.5, wspace=0.3)
axe[0].set_xlim(0, am_sum)
axe[0].set_ylim(min_reward, max_reward)
n_line_collection0 = LineCollection(n_rewards, colors=colors)
axe[0].add_collection(n_line_collection0)
for i, label in enumerate(p_ids):
    axe[0].plot([], [], color=n_line_collection0.get_colors()[i], label=label)
axe[0].legend()
axe[0].set_title("Rewards")
axe[0].set_xlabel("Timestep")
axe[0].set_ylabel("Reward")

axe[1].set_xlim(0, am_sum)
axe[1].set_ylim(0, 1.0)
n_line_collection1 = LineCollection(n_weights, colors=colors)
axe[1].add_collection(n_line_collection1)
for i, label in enumerate(p_ids):
    axe[1].plot([], [], color=n_line_collection1.get_colors()[i], label=label)
axe[1].legend()
axe[1].set_title("Importance Weights")
axe[1].set_xlabel("Timestep")
axe[1].set_ylabel("Weight")

axe[2].set_xlim(0, am_sum)
axe[2].set_ylim(min_clr_w, max_clr_w)
n_line_collection2 = LineCollection(n_clr_weights, colors=colors)
axe[2].add_collection(n_line_collection2)
for i, label in enumerate(p_ids):
    axe[2].plot([], [], color=n_line_collection2.get_colors()[i], label=label)
axe[2].legend()
axe[2].set_title("Importance Weights (clr)")
axe[2].set_xlabel("Timestep")
axe[2].set_ylabel("Weight (clr)")
plt.show()

In [None]:
rad = np.zeros((len(p_ids), len(p_ids)))
corr = np.zeros((len(p_ids), len(p_ids)))
pear_dist = np.zeros((len(p_ids), len(p_ids)))
w_corr = np.zeros((len(p_ids), len(p_ids)))
w_pear_dist = np.zeros((len(p_ids), len(p_ids)))
epic = np.zeros((len(p_ids), len(p_ids)))
for p_1 in range(len(p_ids)):
    for p_2 in range(len(p_ids)):
        rad[p_1, p_2] = np.sqrt(
            ((n_clr_weights[p_1][:, 1] - n_clr_weights[p_2][:, 1]) ** 2).sum()
        )

        data = {
            f"{p_ids[p_1]}": n_rewards[p_1][:, 1],
            f"{p_ids[p_2]}": n_rewards[p_2][:, 1],
        }
        df = pd.DataFrame(data)
        corr[p_1, p_2] = df[f"{p_ids[p_1]}"].corr(df[f"{p_ids[p_2]}"])
        pear_dist[p_1, p_2] = np.sqrt(1 - corr[p_1, p_2]) / np.sqrt(2)

        w_data = {
            f"{p_ids[p_1]}": n_weights[p_1][:, 1] * n_rewards[p_1][:, 1],
            f"{p_ids[p_2]}": n_weights[p_2][:, 1] * n_rewards[p_2][:, 1],
        }
        w_df = pd.DataFrame(w_data)
        w_corr[p_1, p_2] = w_df[f"{p_ids[p_1]}"].corr(w_df[f"{p_ids[p_2]}"])
        w_pear_dist[p_1, p_2] = np.sqrt(1 - w_corr[p_1, p_2]) / np.sqrt(2)

        e_data = {
            f"{p_ids[p_1]}": n_c_rewards[p_1],
            f"{p_ids[p_2]}": n_c_rewards[p_2],
        }
        e_df = pd.DataFrame(e_data)
        epic[p_1, p_2] = e_df[f"{p_ids[p_1]}"].corr(e_df[f"{p_ids[p_2]}"])
        epic[p_1, p_2] = np.sqrt(1 - epic[p_1, p_2]) / np.sqrt(2)

In [None]:
fig, axe = plt.subplot_mosaic(
    """
    ABC
    DEF
    """,
    figsize=(15, 10),
)
plt.subplots_adjust(hspace=0.3, wspace=0.5)
sns.heatmap(rad, cmap="viridis_r", ax=axe["A"], xticklabels=p_ids, yticklabels=p_ids)
sns.heatmap(corr, cmap="viridis", ax=axe["B"], xticklabels=p_ids, yticklabels=p_ids)
sns.heatmap(
    pear_dist, cmap="viridis_r", ax=axe["C"], xticklabels=p_ids, yticklabels=p_ids
)
sns.heatmap(w_corr, cmap="viridis", ax=axe["D"], xticklabels=p_ids, yticklabels=p_ids)
sns.heatmap(
    w_pear_dist, cmap="viridis_r", ax=axe["E"], xticklabels=p_ids, yticklabels=p_ids
)
sns.heatmap(epic, cmap="viridis_r", ax=axe["F"], xticklabels=p_ids, yticklabels=p_ids)
axe["A"].set_title("Robust Aitchison Distance (Weights)")
axe["B"].set_title("Correlation (Rewards)")
axe["C"].set_title("Pearson Distance (Rewards)")
axe["D"].set_title("Weighted Correlation")
axe["E"].set_title("Weighted Pearson Distance")
axe["F"].set_title("EPIC (Rewards)")
plt.savefig("dist_goal_only_matrices.png")
plt.show()

In [None]:
time.sleep(300)