In [1]:
import json
import torch

state = torch.load('models/ddqn/frozenlake-v1/reproducibility-0_1alf0ajm_245759_20220131T202108.pth')
#state = torch.load('models/ddqn/cartpole-v0/reproducibility-0_1wpkgfkj_327679_20220130T222047.pth')
print(json.dumps(state['config'], indent=2, sort_keys=True))

{
  "activation": "relu",
  "alpha": 0.6,
  "batch_size": 1024,
  "buffer_size": 200000,
  "buffer_type": "uniform",
  "embed": false,
  "embed_size": 0,
  "env": "FrozenLake-v1",
  "env_args": {
    "is_slippery": true
  },
  "eps_sched_final": 0.02,
  "eps_sched_len": 100000,
  "gamma": 0.995,
  "layer_size": 64,
  "learning_starts": 2048,
  "log_step": 4096,
  "lr": 0.001,
  "max_episode_steps": null,
  "num_layers": 3,
  "save_final": true,
  "save_max_eps": false,
  "seed": 0,
  "steps": 245760,
  "target_update_freq": 8192,
  "training_freq": 64
}


In [2]:
import gym
from spin_class.algos.ddqn import make_model

kwargs = state['config']['env_args'] if 'env_args' in state['config'] else {}
env = gym.make(state['config']['env'], **kwargs)
device = torch.device('cpu')
q_net = make_model(env, device, state['config'])
q_net.load_state_dict(state['q_state_dict'])
q_net.eval()

DQNMLP(
  (head): Sequential(
    (0): OneHot1d()
    (1): Linear(in_features=16, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=64, bias=True)
    (6): ReLU()
    (7): Linear(in_features=64, out_features=4, bias=True)
  )
)

In [3]:
import gym
from gym import wrappers
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
from torch.distributions.normal import Normal

# start virtual display
if 'display' not in globals():
    display = Display(visible=False, size=(1400, 900))
    display.start()

def play(q_net, env, steps=1000):
    env = wrappers.Monitor(env, "./video", force=True)
    obs_dtype = (
        torch.int64
        if isinstance(env.observation_space, gym.spaces.Discrete)
        else torch.float32
    )
    obs = env.reset()
    for _ in range(steps):
        with torch.no_grad():
            obs_t = torch.as_tensor(obs, dtype=obs_dtype, device=device).unsqueeze(
                    0
                )
            q = q_net(obs_t)[0]
            action = torch.argmax(q).cpu().numpy().tolist()
        obs, reward, done, info = env.step(action)
        if done:
            print(_)
            break
    env.close()

    video = io.open('./video/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''
        <video alt="test" autoplay loop controls style="height: 400px;">
            <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        </video>'''.format(encoded.decode('ascii'))))

    #HTML(data='''
    #    <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>
    #'''.format(encoded.decode('ascii')))

In [None]:
play(q_net, env, steps=200)

In [14]:
import random
import torch

def play_frozenlake(q_net, env, eps=0.0):
    obs_dtype = (
        torch.int64
        if isinstance(env.observation_space, gym.spaces.Discrete)
        else torch.float32
    )
    obs = env.reset()
    print('====== step 0 ======')
    env.render()
    for i in range(100):
        with torch.no_grad():
            obs_t = torch.as_tensor(obs, dtype=obs_dtype, device=device).unsqueeze(0)
            q = q_net(obs_t)[0]
            action = torch.argmax(q).cpu().numpy().tolist() if random.random() > eps else env.action_space.sample()
        obs, reward, done, info = env.step(action)
        print(f'====== step {i + 1} ======')
        env.render()
        if done:
            print(_)
            break
    env.close()

In [17]:
play_frozenlake(q_net, env, 0.0)


[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
[41mS[0mFFF
FHFH
FFFH
HFFG
  (Left)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Left)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
FFFH
H[41mF[0mFG
  (Right)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
  (Left)
SFFF
FHFH
F[41mF[0mFH
HFFG
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
  (Left)
SFFF
FHFH
FFFH
HF[41mF[0mG
  (Down)
SFFF
FHFH
FFFH
HFF[41mG[0m
DQNMLP(
  (head): Sequential(
    (0): OneHot1d()
    (1): Linear(in_features=16, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=64, bias=True)
    (6): ReLU()
    (7): Linear(in_features=64, out_features=4, bias=True)
  )
)
