In [None]:
import os
import pickle
from glob import glob
from subprocess import run
from pprint import pprint

from tqdm.notebook import tqdm, trange

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
RUN_DIR = os.environ.get('NOTEBOOK_RUN_DIR', 'runs/gym_example_2024-04-21T19:46:11.725273/')
HEATMAP_SIZE = int(os.environ.get('NOTEBOOK_HEATMAP_SIZE', '64'))
N_TRAIN_STEPS = int(os.environ.get('NOTEBOOK_N_TRAIN_STEPS', '1'))
OUT_FILE = os.environ.get('NOTEBOOK_OUT_FILE', f'{RUN_DIR}/pendulum_1.webm')

In [None]:
print("Looking for root dir")
while not glob('.git/'):
    print("No git repo in", os.getcwd())
    os.chdir('..')
print("Found git repo in", os.getcwd())

In [None]:
import outrl
from outrl.gym_utils import GymBoxAgent
import examples.gym_example as gym_example

In [None]:
frame = 0
train_step = 1
os.makedirs("plots/frames", exist_ok=True)

checkpoints = glob(f"{RUN_DIR}/train_step_*.pkl")
with_idx = sorted([
    (int(f_name.rsplit("_", 1)[-1].split(".", 1)[0]), f_name)
    for f_name in checkpoints
])[-N_TRAIN_STEPS:]
pprint(with_idx)
for i, checkpoint_path in tqdm(with_idx):
    with open(checkpoint_path, 'rb') as f:
        data = pickle.load(f)
    if len(data['_replay_buffer']) == 0:
        continue
    def obs_to_coord(obs):
        assert len(obs.shape) == 2
        assert obs.shape[0] == 3
        theta = np.arctan2(raw_obs[:, 0], raw_obs[:, 1])
        theta_norm = (HEATMAP_SIZE) * theta / (2 * np.pi)
        ang_vel_norm = (HEATMAP_SIZE) * (raw_obs[:, 2] + 8.0) / 16.0
        return theta_norm, ang_vel_norm
    
    obs_heatmap = np.zeros((HEATMAP_SIZE, HEATMAP_SIZE))
    all_obs = np.stack([ep_data.episode['observations'] for ep_data in data['_replay_buffer']]).transpose(2, 0, 1)
    
    theta = np.arctan2(all_obs[0], all_obs[1])
    theta_idx = np.floor((HEATMAP_SIZE - 1) * theta / (2 * np.pi)).astype(int).flatten()
    ang_vel = all_obs[2]
    ang_vel_idx = np.floor((HEATMAP_SIZE - 1) * (ang_vel + 8.0) / 16.0).astype(int).flatten()
    for i, x in enumerate(theta_idx):
        obs_heatmap[x,ang_vel_idx[i]] += 1
        
    cfg = gym_example.GymConfig.from_dict(data['cfg'])
    actor = GymBoxAgent(
        obs_size=3,
        action_size=1,
        hidden_sizes=cfg.encoder_hidden_sizes, 
        pi_hidden_sizes=cfg.pi_hidden_sizes, 
        init_std=cfg.init_std, 
        min_std=cfg.min_std)
    
    trainer = outrl.Trainer(cfg, actor)
    trainer.load_state_dict(data)

    train_inputs = trainer._preprocess()
    vf_vals = np.stack([train_input.vf_returns for train_input in train_inputs])
    advantages = np.stack([train_input.advantages for train_input in train_inputs])

    d_ang_vel = ang_vel[:, 1:] - ang_vel[:, -1:]
    d_theta = theta[:, 1:] - theta[:, -1:]

    all_actions = np.stack([ep_data.episode['actions']
                            for ep_data in data['_replay_buffer']]).transpose(2, 0, 1)

    for t in trange(vf_vals.shape[1]):
        plt.clf()
        plt.imshow(np.log(obs_heatmap), extent=[-8, 8, -np.pi, np.pi,], aspect=16.0 / (2 * np.pi), cmap='Greys')
        plt.scatter(ang_vel[:, t], theta[:, t], c=vf_vals[:, t])
        plt.quiver(ang_vel[:, t], theta[:, t],
                   all_actions[0, :, t], 0, advantages[:, t], width=0.01, cmap='plasma')
        plt.savefig(f"plots/frames/pendulum_frame={frame}.png")
        frame += 1

In [None]:
run(['ffmpeg', '-y', '-i', 'plots/frames/pendulum_frame=%d.png', OUT_FILE])