In [None]:
from pathlib import Path 
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
from tqdm import tqdm
from IPython.display import Video

import torch
import numpy as np

import envs

import matplotlib.pyplot as plt
import matplotlib.animation as animation

JACO = 'jaco_reach_top_left'
QUADRUPED = 'quadruped_run'
WALKER = 'walker_run'
MW = 'mw_reach'
ACT2TASK_DICT = {6: WALKER, 9: JACO, 12: QUADRUPED, 4: MW}

agent_path = Path(f'/home/idlab204/submissions/ICLR2023/choreo_code/exp_local/2023.02.15/000350_choreo/last_snapshot.pt')

In [None]:
def load_agent(agent_path):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    with agent_path.open('rb') as f:
        obj = torch.load(f, map_location=torch.device(device))
        agent = obj['agent']
        step = obj['_global_step']
        agent.device = device
        agent.wm.device = device
        agent.wm.rssm.device = device
        agent.wm.rssm._cell.device = device
    return agent, step

In [None]:
agent, global_step = load_agent(agent_path)
obs_type = agent.cfg.obs_type
action_repeat = agent.cfg.action_repeat
snapshot_ts = global_step * action_repeat
agent.force_skills = True
agent.is_ft = False
agent.reward_free = True
skill_dim = agent.skill_dim

agent.use_selector = False
agent.detached_exploration = True

seed = agent.cfg.seed

task = ACT2TASK_DICT[agent.act_dim]
domain = task.split("_")[0]

train_env = envs.make(task, obs_type, frame_stack=1, 
                    action_repeat=action_repeat, seed=seed, 
                    img_size=64, exorl='exorl' in str(agent_path))

In [None]:
render_size = 64
camera_id = dict(quadruped=2).get(domain, 0)

columns = 2
while columns < np.sqrt(skill_dim):
    columns *= 2
# columns = 16
rows = skill_dim // columns
c_size = max(columns // 16 * 10, 10)

eval_mode = False
steps = 200 // action_repeat

imagelist = [[] for _ in range(skill_dim)]
rewardlist = [[] for _ in range(skill_dim)]
skill_obs = [ [0, None, 0] for _ in range(skill_dim)]

time_step = train_env.reset()
agent_state = None

for n in tqdm(range(skill_dim)):
    time_step = train_env.reset()
    agent_state = None
    
    meta = dict()
    
    skill = np.zeros(agent.skill_dim, dtype=np.float32)
    skill[n] = 1.0
    meta['skill'] = skill

    if obs_type == 'pixels':
        skill_z = torch.from_numpy(skill).to(agent.device).unsqueeze(0).unsqueeze(0)
        skill_z = skill_z @ agent.skill_module.emb.weight.T

        x = deter = agent.skill_module.skill_decoder(skill_z).mean

        stats = agent.wm.rssm._suff_stats_ensemble(x)
        index = torch.randint(0, agent.wm.rssm._ensemble, ()) 
        stats = {k: v[index] for k, v in stats.items()}
        dist = agent.wm.rssm.get_dist(stats)
        stoch = dist.sample() 
        prior = {'stoch': stoch, 'deter': deter, **stats}

        decoder = agent.wm.heads['decoder']
        openl = decoder(agent.wm.rssm.get_feat(prior))['observation'].mean.squeeze() 
        skill_img = torch.clip(openl + 0.5, 0, 1).cpu().numpy()
        skill_obs[n] = [1, skill_img.transpose(1,2,0), 0]

    if task == MW:
        frame = train_env.sim.render(
            render_size, render_size, mode="offscreen", camera_name=train_env._camera
        ).copy()
    else:
        frame = train_env.physics.render(height=render_size,
                                            width=render_size,
                                            camera_id=camera_id)

    imagelist[n].append(frame)
    rewardlist[n].append(time_step['reward'])
    for _ in range(steps):
        action, agent_state = agent.act(time_step, 
                            meta,
                            0,
                            eval_mode=eval_mode,
                            state=agent_state)
        

        time_step = train_env.step(action)

        if task == MW:
            frame = train_env.sim.render(
                render_size, render_size, mode="offscreen", camera_name=train_env._camera
            ).copy()
        else:
            frame = train_env.physics.render(height=render_size,
                                        width=render_size,
                                        camera_id=camera_id)
        imagelist[n].append(frame)
        rewardlist[n].append(time_step['reward'])
    
        if time_step['is_last']:
            time_step = train_env.reset()
            agent_state = None

for skill_id in range(skill_dim):
    skill_obs[skill_id][2] /= steps*skill_dim

if obs_type == 'pixels':
    fig, axes = plt.subplots(rows, columns, figsize=(columns // rows * c_size,c_size)) # make figure

    for index, (p_image, image, avg_image) in enumerate(skill_obs):
        r = index // columns
        c = index % columns
        a = axes[r][c]
        a.axis('off')
        a.set_aspect('equal')
        # t = a.set_title(f"Prob: {p_image:.3f}\n Avg: {avg_image:.3f}")
        im = a.imshow(image, cmap=plt.get_cmap('jet'), vmin=0, vmax=255)
    fig.tight_layout()

In [None]:
fig, axes = plt.subplots(rows, columns, figsize=( int(columns // rows * c_size) ,c_size)) # make figure

# make axesimage object
# the vmin and vmax here are very important to get the color map correct
ims = []
titles = []
for index in range(skill_dim):
    r = index // columns
    c = index % columns
    a = axes[r][c]
    a.axis('off')
    a.set_aspect('equal')
    titles.append(a.set_title(f"Index: {index} Rew:\n{np.sum(rewardlist[index]):.3f}"))
    # titles.append(a.set_title(f"Sum Reward:\n{np.sum(rewardlist[index]):.3f}"))
    im = a.imshow(imagelist[index][0], cmap=plt.get_cmap('jet'), vmin=0, vmax=255)
    ims.append(im)
fig.tight_layout()
# fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)

# function to update figure
def updatefig(j):
    # set the data in the axesimage object
    # fig.suptitle(f"Reward: {rewardlist[j]:.2f}")
    for index in range(skill_dim):
        ims[index].set(data=imagelist[index][j])
    # return the artists set
    return ims
# kick off the animation
ani = animation.FuncAnimation(fig, updatefig, frames=range(steps), 
                              interval=25 * action_repeat, blit=True)

video_path = '/tmp/video.mp4'
ani.save(video_path, savefig_kwargs=dict(bbox_inches='tight',pad_inches = 0))
plt.close()

Video(video_path, width=800, height=400)