In [1]:
import sys
sys.path.append('..')
from a2c_ppo_acktr import algo, utils
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from evaluation import evaluate
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib
    
%run model_evaluation

from a2c_ppo_acktr.storage import RolloutStorage
from a2c_ppo_acktr.envs import make_vec_envs
from a2c_ppo_acktr import algo, utils
from  a2c_ppo_acktr.model import Policy
import torch

import proplot as pplt
from collections import defaultdict

import imageio

In [11]:
import imageio

anim_folder = 'data/anim/tmp'
for i in range(10):
    fig, ax = pplt.subplots(nrows=2)
    x = np.linspace(0, 4, 2000)
    y = np.sin(x - i*0.1)
    ax[0].plot(x, y)
    plt.savefig(f'{anim_folder}/line-{i}.png')
    plt.close()
    
with imageio.get_writer('line.gif', mode='i') as writer:
    for i in range(10):
        image = imageio.imread(f'{anim_folder}/line-{i}.png')
        writer.append_data(image)
        



# Generate 1 episodes

In [58]:
#Generate frames from single episode

pplt.rc['meta.linewidth'] = 0
pplt.rc['figure.facecolor'] = 'f4f4f4'

array = [
    [0, 2, 2, 2],
    [1, 2, 2, 2],
    [0, 2, 2, 2],
]
width = 4

cmap = matplotlib.cm.get_cmap('coolwarm')
def cmap_val(val, cmin=-1.0, cmax=1.0, cmap=cmap):
    #Get rgba from val
    if val < cmin:
        val = cmin
    if val > cmax:
        val = cmax
    
    val = (val - cmin) / (cmax - cmin)
    return cmap(val)



model_name = f'nav_poster_netstructure/nav_pdistal_width{width}batch200'
model, obs_rms, env_kwargs = load_model_and_env(model_name, 0)
env = gym.make('NavEnv-v0', **env_kwargs)
res = evalu(model, obs_rms, env_kwargs=env_kwargs, n=1, data_callback=poster_data_callback,
            with_activations=True)
num_frames = len(res['obs'])

val_to_rgba = {
    1/6: np.array([1., 0, 0, 1]),
    4/6: np.array([1., 1, 0, 1])
}
num_rays = env.num_rays
#Draw observation rectangles
square_size = 2
gap = 0.3


for frame in range(num_frames):
    fig, ax = pplt.subplots(array, share=False)
    ax.format(xlocator='null', ylocator='null')
    ax[0].format(xlim=[-1, 301], ylim=[-1, 301], facecolor='black')
    ax[1].format(xlim=[-2, 33], ylim=[-5, 32])


    #for frame in range(num_frames):
    env.character.pos = res['data']['pos'][frame]
    env.character.angle = res['data']['angle'][frame]
    env.character.update_rays(env.vis_walls, env.vis_wall_refs)
    env.render('human', ax=ax[0])

    obs = env.get_observation() #get normalized observation for easier interpretability

    ax[1].text(0.3, (square_size+gap)*num_rays, 'Obs', size='large')
    for i in range(num_rays):
        corner = [0, (square_size+gap)*i]
        color = val_to_rgba[obs[i]].copy()
        dist = obs[i+num_rays]
        color[-1] = color[-1]*dist
        rect = plt.Rectangle(corner, square_size, square_size, fc=color)
        ax[1].add_patch(rect)
        ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{obs[i+num_rays]:.2f}', verticalalignment='center')

    #Draw activation rectangles
    all_activs = res['activations'][frame]

    #Shared
    ax[1].text(4.7, (square_size+gap)*8, 'Shared\n(Recurrent)', size='large')
    start_y = (square_size+gap)*(num_rays//2 - width//2)
    activ = all_activs['shared_activations'][0].squeeze().tolist()
    for i in range(width):
        a = activ[i]
        corner = [5, (square_size+gap)*i+start_y]
        rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
        ax[1].add_patch(rect)
        ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

    #Policy Network
    for j in range(2):
        activ = all_activs['actor_activations'][j].squeeze().tolist()
        start_y = (square_size+gap)*(num_rays//2 + width//2 - 1)
        ax[1].text(10+5*j, (square_size+gap)*11, f'Policy {j+1}', size='large')
        for i in range(width):
            a = activ[i]
            corner = [10+5*j, (square_size+gap)*i+start_y]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

    #Value Network
    for j in range(2):
        activ = all_activs['critic_activations'][j].squeeze().tolist()
        start_y = (square_size+gap)*(num_rays//2 - width - 1)
        ax[1].text(10+5*j, (square_size+gap)*5, f'Value {j+1}', size='large')
        for i in range(width):
            a = activ[i]
            corner = [10+5*j, (square_size+gap)*i+start_y]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')


    #Final Actions
    ax[1].text(20, (square_size+gap)*11, f'Action Probs', size='large')
    cmap2 = matplotlib.cm.get_cmap('Haline')
    probs = model.dist(res['actor_features'][frame]).probs.squeeze().tolist()
    action_labels = ['Right', 'Forward', 'Left', 'Nothing']
    for i in range(len(probs)):
        prob = probs[i]
        corner = [20, (square_size+gap)*(i+7)]
        rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(prob, cmin=0, cmax=1, cmap=cmap2))
        ax[1].add_patch(rect)
        ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{prob:.2f}', verticalalignment='center')
        ax[1].text(corner[0]+square_size+0.5, corner[1]+square_size/2, action_labels[i], verticalalignment='center')

    #Final Values
    ax[1].text(20, (square_size+gap)*3.5, f'Value', size='large')
    value = res['values'][frame].item()
    corner = [20, (square_size+gap)*2.5]
    # val_cmap = matplotlib.cm.get_cmap('Haline')
    rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(value, cmin=0, cmax=10, cmap=cmap2))
    ax[1].add_patch(rect)
    ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{value:.2f}', verticalalignment='center')

    
    plt.savefig(f'data/anim/tmp/ep-{frame}.jpg', dpi=150)
    plt.close()
    
pplt.rc.reset()

In [60]:
#Convert episode into gif
with imageio.get_writer('data/anim/width4_t0_ep0.gif', mode='i') as writer:
    for i in range(num_frames):
        for j in range(2):
            image = imageio.imread(f'data/anim/tmp/ep-{i}.jpg')
            writer.append_data(image)

ValueError: Could not find a format to write the specified file in single-image mode

## Generate multiple episodes

In [3]:
#Generate frames from multiple episodes - agent 0

pplt.rc['meta.linewidth'] = 0
pplt.rc['figure.facecolor'] = 'f4f4f4'

array = [
    [0, 2, 2, 2],
    [1, 2, 2, 2],
    [0, 2, 2, 2],
]
width = 4

cmap = matplotlib.cm.get_cmap('coolwarm')
def cmap_val(val, cmin=-1.0, cmax=1.0, cmap=cmap):
    #Get rgba from val
    if val < cmin:
        val = cmin
    if val > cmax:
        val = cmax
    
    val = (val - cmin) / (cmax - cmin)
    return cmap(val)



model_name = f'nav_poster_netstructure/nav_pdistal_width{width}batch200'
model, obs_rms, env_kwargs = load_model_and_env(model_name, 0)
env = gym.make('NavEnv-v0', **env_kwargs)
val_to_rgba = {
    1/6: np.array([1., 0, 0, 1]),
    4/6: np.array([1., 1, 0, 1])
}
num_rays = env.num_rays
#Draw observation rectangles
square_size = 2
gap = 0.3



for ep_num in tqdm(range(5)):
    res = evalu(model, obs_rms, env_kwargs=env_kwargs, n=1, data_callback=poster_data_callback,
                with_activations=True, seed=ep_num)
    num_frames = len(res['obs'])

    for frame in range(num_frames):
        fig, ax = pplt.subplots(array, share=False)
        ax.format(xlocator='null', ylocator='null')
        ax[0].format(xlim=[-1, 301], ylim=[-1, 301], facecolor='black')
        ax[1].format(xlim=[-2, 33], ylim=[-5, 32])


        #for frame in range(num_frames):
        env.character.pos = res['data']['pos'][frame]
        env.character.angle = res['data']['angle'][frame]
        env.character.update_rays(env.vis_walls, env.vis_wall_refs)
        env.render('human', ax=ax[0])

        obs = env.get_observation() #get normalized observation for easier interpretability

        ax[1].text(0.3, (square_size+gap)*num_rays, 'Obs', size='large')
        for i in range(num_rays):
            corner = [0, (square_size+gap)*i]
            color = val_to_rgba[obs[i]].copy()
            dist = obs[i+num_rays]
            color[-1] = color[-1]*(1-dist)
            rect = plt.Rectangle(corner, square_size, square_size, fc=color)
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{obs[i+num_rays]:.2f}', verticalalignment='center')

        #Draw activation rectangles
        all_activs = res['activations'][frame]

        #Shared
        ax[1].text(4.7, (square_size+gap)*8, 'Shared\n(Recurrent)', size='large')
        start_y = (square_size+gap)*(num_rays//2 - width//2)
        activ = all_activs['shared_activations'][0].squeeze().tolist()
        for i in range(width):
            a = activ[i]
            corner = [5, (square_size+gap)*i+start_y]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

        #Policy Network
        for j in range(2):
            activ = all_activs['actor_activations'][j].squeeze().tolist()
            start_y = (square_size+gap)*(num_rays//2 + width//2 - 1)
            ax[1].text(10+5*j, (square_size+gap)*11, f'Policy {j+1}', size='large')
            for i in range(width):
                a = activ[i]
                corner = [10+5*j, (square_size+gap)*i+start_y]
                rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
                ax[1].add_patch(rect)
                ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

        #Value Network
        for j in range(2):
            activ = all_activs['critic_activations'][j].squeeze().tolist()
            start_y = (square_size+gap)*(num_rays//2 - width - 1)
            ax[1].text(10+5*j, (square_size+gap)*5, f'Value {j+1}', size='large')
            for i in range(width):
                a = activ[i]
                corner = [10+5*j, (square_size+gap)*i+start_y]
                rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
                ax[1].add_patch(rect)
                ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')


        #Final Actions
        ax[1].text(20, (square_size+gap)*11, f'Action Probs', size='large')
        cmap2 = matplotlib.cm.get_cmap('Haline')
        probs = model.dist(res['actor_features'][frame]).probs.squeeze().tolist()
        action_labels = ['Right', 'Forward', 'Left', 'Nothing']
        for i in range(len(probs)):
            prob = probs[i]
            corner = [20, (square_size+gap)*(i+7)]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(prob, cmin=0, cmax=1, cmap=cmap2))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{prob:.2f}', verticalalignment='center')
            ax[1].text(corner[0]+square_size+0.5, corner[1]+square_size/2, action_labels[i], verticalalignment='center')

        #Final Values
        ax[1].text(20, (square_size+gap)*3.5, f'Value', size='large')
        value = res['values'][frame].item()
        corner = [20, (square_size+gap)*2.5]
        # val_cmap = matplotlib.cm.get_cmap('Haline')
        rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(value, cmin=0, cmax=10, cmap=cmap2))
        ax[1].add_patch(rect)
        ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{value:.2f}', verticalalignment='center')


        plt.savefig(f'data/anim/tmp/ep-{frame}.jpg', dpi=150)
        plt.close()
        
        
    #Convert episode into gif
    with imageio.get_writer(f'data/anim/width4_t0_ep{ep_num}.gif', mode='i') as writer:
        for i in range(num_frames):
            for j in range(2):
                image = imageio.imread(f'data/anim/tmp/ep-{i}.jpg')
                writer.append_data(image)
    
pplt.rc.reset()

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:43<00:00, 20.76s/it]


In [4]:
#Generate frames from multiple episodes - agent 2

pplt.rc['meta.linewidth'] = 0
pplt.rc['figure.facecolor'] = 'f4f4f4'

array = [
    [0, 2, 2, 2],
    [1, 2, 2, 2],
    [0, 2, 2, 2],
]
width = 4

cmap = matplotlib.cm.get_cmap('coolwarm')
def cmap_val(val, cmin=-1.0, cmax=1.0, cmap=cmap):
    #Get rgba from val
    if val < cmin:
        val = cmin
    if val > cmax:
        val = cmax
    
    val = (val - cmin) / (cmax - cmin)
    return cmap(val)



model_name = f'nav_poster_netstructure/nav_pdistal_width{width}batch200'
model, obs_rms, env_kwargs = load_model_and_env(model_name, 2)
env = gym.make('NavEnv-v0', **env_kwargs)
val_to_rgba = {
    1/6: np.array([1., 0, 0, 1]),
    4/6: np.array([1., 1, 0, 1])
}
num_rays = env.num_rays
#Draw observation rectangles
square_size = 2
gap = 0.3



for ep_num in tqdm(range(5)):
    res = evalu(model, obs_rms, env_kwargs=env_kwargs, n=1, data_callback=poster_data_callback,
                with_activations=True, seed=ep_num)
    num_frames = len(res['obs'])

    for frame in range(num_frames):
        fig, ax = pplt.subplots(array, share=False)
        ax.format(xlocator='null', ylocator='null')
        ax[0].format(xlim=[-1, 301], ylim=[-1, 301], facecolor='black')
        ax[1].format(xlim=[-2, 33], ylim=[-5, 32])


        #for frame in range(num_frames):
        env.character.pos = res['data']['pos'][frame]
        env.character.angle = res['data']['angle'][frame]
        env.character.update_rays(env.vis_walls, env.vis_wall_refs)
        env.render('human', ax=ax[0])

        obs = env.get_observation() #get normalized observation for easier interpretability

        ax[1].text(0.3, (square_size+gap)*num_rays, 'Obs', size='large')
        for i in range(num_rays):
            corner = [0, (square_size+gap)*i]
            color = val_to_rgba[obs[i]].copy()
            dist = obs[i+num_rays]
            color[-1] = color[-1]*(1-dist)
            rect = plt.Rectangle(corner, square_size, square_size, fc=color)
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{obs[i+num_rays]:.2f}', verticalalignment='center')

        #Draw activation rectangles
        all_activs = res['activations'][frame]

        #Shared
        ax[1].text(4.7, (square_size+gap)*8, 'Shared\n(Recurrent)', size='large')
        start_y = (square_size+gap)*(num_rays//2 - width//2)
        activ = all_activs['shared_activations'][0].squeeze().tolist()
        for i in range(width):
            a = activ[i]
            corner = [5, (square_size+gap)*i+start_y]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

        #Policy Network
        for j in range(2):
            activ = all_activs['actor_activations'][j].squeeze().tolist()
            start_y = (square_size+gap)*(num_rays//2 + width//2 - 1)
            ax[1].text(10+5*j, (square_size+gap)*11, f'Policy {j+1}', size='large')
            for i in range(width):
                a = activ[i]
                corner = [10+5*j, (square_size+gap)*i+start_y]
                rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
                ax[1].add_patch(rect)
                ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

        #Value Network
        for j in range(2):
            activ = all_activs['critic_activations'][j].squeeze().tolist()
            start_y = (square_size+gap)*(num_rays//2 - width - 1)
            ax[1].text(10+5*j, (square_size+gap)*5, f'Value {j+1}', size='large')
            for i in range(width):
                a = activ[i]
                corner = [10+5*j, (square_size+gap)*i+start_y]
                rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
                ax[1].add_patch(rect)
                ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')


        #Final Actions
        ax[1].text(20, (square_size+gap)*11, f'Action Probs', size='large')
        cmap2 = matplotlib.cm.get_cmap('Haline')
        probs = model.dist(res['actor_features'][frame]).probs.squeeze().tolist()
        action_labels = ['Right', 'Forward', 'Left', 'Nothing']
        for i in range(len(probs)):
            prob = probs[i]
            corner = [20, (square_size+gap)*(i+7)]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(prob, cmin=0, cmax=1, cmap=cmap2))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{prob:.2f}', verticalalignment='center')
            ax[1].text(corner[0]+square_size+0.5, corner[1]+square_size/2, action_labels[i], verticalalignment='center')

        #Final Values
        ax[1].text(20, (square_size+gap)*3.5, f'Value', size='large')
        value = res['values'][frame].item()
        corner = [20, (square_size+gap)*2.5]
        # val_cmap = matplotlib.cm.get_cmap('Haline')
        rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(value, cmin=0, cmax=10, cmap=cmap2))
        ax[1].add_patch(rect)
        ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{value:.2f}', verticalalignment='center')


        plt.savefig(f'data/anim/tmp/ep-{frame}.jpg', dpi=150)
        plt.close()
        
        
    #Convert episode into gif
    with imageio.get_writer(f'data/anim/width4_t2_ep{ep_num}.gif', mode='i') as writer:
        for i in range(num_frames):
            for j in range(2):
                image = imageio.imread(f'data/anim/tmp/ep-{i}.jpg')
                writer.append_data(image)
    
pplt.rc.reset()

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:39<00:00, 19.90s/it]


### Untrained Agent

In [5]:
#Generate frames from multiple episodes - agent 2

pplt.rc['meta.linewidth'] = 0
pplt.rc['figure.facecolor'] = 'f4f4f4'

array = [
    [0, 2, 2, 2],
    [1, 2, 2, 2],
    [0, 2, 2, 2],
]
width = 4

cmap = matplotlib.cm.get_cmap('coolwarm')
def cmap_val(val, cmin=-1.0, cmax=1.0, cmap=cmap):
    #Get rgba from val
    if val < cmin:
        val = cmin
    if val > cmax:
        val = cmax
    
    val = (val - cmin) / (cmax - cmin)
    return cmap(val)



model_name = f'nav_poster_netstructure/nav_pdistal_width{width}batch200'
_, obs_rms, env_kwargs = load_model_and_env(model_name, 2)
env = gym.make('NavEnv-v0', **env_kwargs)
val_to_rgba = {
    1/6: np.array([1., 0, 0, 1]),
    4/6: np.array([1., 1, 0, 1])
}
num_rays = env.num_rays
#Draw observation rectangles
square_size = 2
gap = 0.3

nn_base_kwargs = {'hidden_size': 4}
model = Policy(env.observation_space.shape,
               env.action_space,
               base='FlexBase',
               base_kwargs={'recurrent': True,
                   **nn_base_kwargs})
model.to(device)

for ep_num in tqdm(range(3)):
    res = evalu(model, obs_rms, env_kwargs=env_kwargs, n=1, data_callback=poster_data_callback,
                with_activations=True, seed=ep_num, deterministic=False)
    num_frames = len(res['obs'])

    for frame in range(num_frames):
        fig, ax = pplt.subplots(array, share=False)
        ax.format(xlocator='null', ylocator='null')
        ax[0].format(xlim=[-1, 301], ylim=[-1, 301], facecolor='black')
        ax[1].format(xlim=[-2, 33], ylim=[-5, 32])


        #for frame in range(num_frames):
        env.character.pos = res['data']['pos'][frame]
        env.character.angle = res['data']['angle'][frame]
        env.character.update_rays(env.vis_walls, env.vis_wall_refs)
        env.render('human', ax=ax[0])

        obs = env.get_observation() #get normalized observation for easier interpretability

        ax[1].text(0.3, (square_size+gap)*num_rays, 'Obs', size='large')
        for i in range(num_rays):
            corner = [0, (square_size+gap)*i]
            color = val_to_rgba[obs[i]].copy()
            dist = obs[i+num_rays]
            color[-1] = color[-1]*(1-dist)
            rect = plt.Rectangle(corner, square_size, square_size, fc=color)
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{obs[i+num_rays]:.2f}', verticalalignment='center')

        #Draw activation rectangles
        all_activs = res['activations'][frame]

        #Shared
        ax[1].text(4.7, (square_size+gap)*8, 'Shared\n(Recurrent)', size='large')
        start_y = (square_size+gap)*(num_rays//2 - width//2)
        activ = all_activs['shared_activations'][0].squeeze().tolist()
        for i in range(width):
            a = activ[i]
            corner = [5, (square_size+gap)*i+start_y]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

        #Policy Network
        for j in range(2):
            activ = all_activs['actor_activations'][j].squeeze().tolist()
            start_y = (square_size+gap)*(num_rays//2 + width//2 - 1)
            ax[1].text(10+5*j, (square_size+gap)*11, f'Policy {j+1}', size='large')
            for i in range(width):
                a = activ[i]
                corner = [10+5*j, (square_size+gap)*i+start_y]
                rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
                ax[1].add_patch(rect)
                ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')

        #Value Network
        for j in range(2):
            activ = all_activs['critic_activations'][j].squeeze().tolist()
            start_y = (square_size+gap)*(num_rays//2 - width - 1)
            ax[1].text(10+5*j, (square_size+gap)*5, f'Value {j+1}', size='large')
            for i in range(width):
                a = activ[i]
                corner = [10+5*j, (square_size+gap)*i+start_y]
                rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(a))
                ax[1].add_patch(rect)
                ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{a:.2f}', verticalalignment='center')


        #Final Actions
        ax[1].text(20, (square_size+gap)*11, f'Action Probs', size='large')
        cmap2 = matplotlib.cm.get_cmap('Haline')
        probs = model.dist(res['actor_features'][frame]).probs.squeeze().tolist()
        action_labels = ['Right', 'Forward', 'Left', 'Nothing']
        for i in range(len(probs)):
            prob = probs[i]
            corner = [20, (square_size+gap)*(i+7)]
            rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(prob, cmin=0, cmax=1, cmap=cmap2))
            ax[1].add_patch(rect)
            ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{prob:.2f}', verticalalignment='center')
            ax[1].text(corner[0]+square_size+0.5, corner[1]+square_size/2, action_labels[i], verticalalignment='center')

        #Final Values
        ax[1].text(20, (square_size+gap)*3.5, f'Value', size='large')
        value = res['values'][frame].item()
        corner = [20, (square_size+gap)*2.5]
        # val_cmap = matplotlib.cm.get_cmap('Haline')
        rect = plt.Rectangle(corner, square_size, square_size, fc=cmap_val(value, cmin=0, cmax=10, cmap=cmap2))
        ax[1].add_patch(rect)
        ax[1].text(corner[0]+0.45, corner[1]+square_size/2, f'{value:.2f}', verticalalignment='center')


        plt.savefig(f'data/anim/tmp/ep-{frame}.jpg', dpi=150)
        plt.close()
        
        
    #Convert episode into gif
    with imageio.get_writer(f'data/anim/width4_untrained_ep{ep_num}.gif', mode='i') as writer:
        for i in range(num_frames):
            for j in range(2):
                image = imageio.imread(f'data/anim/tmp/ep-{i}.jpg')
                writer.append_data(image)
    
pplt.rc.reset()

 33%|███████████████████████████▋                                                       | 1/3 [03:49<07:38, 229.09s/it]

KeyboardInterrupt



In [6]:
from pygifsicle import optimize

optimize('data/anim/width4_untrained_ep0.gif', 'data/anim/width4_untrained_opt_ep0.gif')