In [32]:
import json
import torch

# state = torch.load('models/c51/cartpole-v0/reproducibility-0_l575yfpi_296000_20220205T024634.pth')
# state = torch.load('models/c51/frozenlake-v1/reproducibility-0_ra56h3s5_163840_20220205T025400.pth')
state = torch.load('models/c51/frozenlake-v1/reproducibility-0_30vk1qbd_163840_20220205T030126.pth')
print(json.dumps(state['config'], indent=2, sort_keys=True))

{
  "activation": "relu",
  "batch_size": 1024,
  "buffer_size": 100000,
  "buffer_type": "uniform",
  "distribution_resolution": 51,
  "embed": false,
  "embed_size": 0,
  "env": "FrozenLake-v1",
  "env_args": {
    "is_slippery": true
  },
  "eps_sched_final": 0.01,
  "eps_sched_len": 32000,
  "gamma": 0.995,
  "layer_size": 64,
  "learning_starts": 2048,
  "log_step": 8192,
  "lr": 0.001,
  "max_episode_steps": null,
  "max_return": 10,
  "min_return": -10,
  "num_layers": 2,
  "save_final": true,
  "save_max_eps": false,
  "seed": 0,
  "steps": 163840,
  "target_update_freq": 4096,
  "training_freq": 64,
  "use_target": false
}


In [33]:
import gym
from spin_class.algos.c51 import make_model

kwargs = state['config']['env_args'] if 'env_args' in state['config'] else {}
env = gym.make(state['config']['env'], **kwargs)
device = torch.device('cpu')
q_net = make_model(env, device, state['config'])
q_net.load_state_dict(state['q_state_dict'])
q_net.eval()

C51MLP(
  (head): Sequential(
    (0): OneHot1d()
    (1): Linear(in_features=16, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear2d(
      (linear): Linear(in_features=64, out_features=204, bias=True)
    )
    (6): Softmax(dim=-1)
  )
)

In [34]:
import gym
from gym import wrappers
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
from torch.distributions.normal import Normal

# start virtual display
if 'display' not in globals():
    display = Display(visible=False, size=(1400, 900))
    display.start()

def play(q_net, env, config, steps=1000):
    env = wrappers.Monitor(env, "./video", force=True)
    obs_dtype = (
        torch.int64
        if isinstance(env.observation_space, gym.spaces.Discrete)
        else torch.float32
    )
    min_return = config["min_return"]
    max_return = config["max_return"]
    dist_res = config["distribution_resolution"]
    delta_z = torch.as_tensor(
        (max_return - min_return) / (dist_res - 1), dtype=torch.float32, device=device
    )
    z = torch.arange(
        min_return,
        max_return + 1e-3,
        delta_z,
        dtype=torch.float32,
        device=device,
    )
    obs = env.reset()
    for _ in range(steps):
        with torch.no_grad():
            obs_t = torch.as_tensor(obs, dtype=obs_dtype, device=device).unsqueeze(0)
            q = q_net(obs_t)[0]
            expected_q = torch.matmul(q, z)
            action = torch.argmax(expected_q).cpu().numpy().tolist()
        obs, reward, done, info = env.step(action)
        if done:
            print(_)
            break
    env.close()

    video = io.open('./video/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''
        <video alt="test" autoplay loop controls style="height: 400px;">
            <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        </video>'''.format(encoded.decode('ascii'))))

    #HTML(data='''
    #    <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>
    #'''.format(encoded.decode('ascii')))

In [22]:
play(q_net, env, state["config"], steps=200)

180


In [35]:
import random
import torch

def play_frozenlake(q_net, env, config, eps=0.0):
    min_return = config["min_return"]
    max_return = config["max_return"]
    dist_res = config["distribution_resolution"]
    delta_z = torch.as_tensor(
        (max_return - min_return) / (dist_res - 1), dtype=torch.float32, device=device
    )
    z = torch.arange(
        min_return,
        max_return + 1e-3,
        delta_z,
        dtype=torch.float32,
        device=device,
    )
    obs_dtype = (
        torch.int64
        if isinstance(env.observation_space, gym.spaces.Discrete)
        else torch.float32
    )
    done = False
    obs = env.reset()
    print('====== step 0 ======')
    env.render()
    for i in range(100):
        with torch.no_grad():
            obs_t = torch.as_tensor(obs, dtype=obs_dtype, device=device).unsqueeze(0)
            q = q_net(obs_t)[0]
            expected_q = torch.matmul(q, z)
            action = torch.argmax(expected_q).cpu().numpy().tolist() if random.random() > eps else env.action_space.sample()
        obs, reward, done, info = env.step(action)
        print(f'====== step {i + 1} ======')
        env.render()
        if done:
            print(_)
            break
    env.close()

In [36]:
play_frozenlake(q_net, env, state["config"], 0.0)


[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
  (Left)
SFFF
FHFH
FFFH
HF[41mF[0mG
  (Down)
SFFF
FHFH
FFFH
HFF[41mG[0m
C51MLP(
  (head): Sequential(
    (0): OneHot1d()
    (1): Linear(in_features=16, out_features=64, bias=True)
    (2): ReLU()
    (3): L