In [1]:
%load_ext autoreload
%autoreload 2
%pylab inline

import sys
import glob
import pandas as pd
import os
import seaborn as sns
# from tqdm import tqdm_notebook as tqdm
from tqdm import tqdm
import pickle
from statsmodels.distributions.empirical_distribution import ECDF
from collections import defaultdict
import logging
from open_spiel.python.examples.ubc_mccfr_cpp_example import action_to_bids
from open_spiel.python.examples.ubc_nfsp_example import policy_from_checkpoint,  lookup_model_and_args
from open_spiel.python.examples.ubc_br import BR_DIR, make_dqn_agent
from open_spiel.python.examples.ubc_utils import *
from open_spiel.python.examples.ubc_decorators import CachingAgentDecorator
from open_spiel.python.pytorch import ubc_nfsp, ubc_dqn, ubc_rnn

from open_spiel.python.pytorch.ubc_nfsp import NFSP
import bokeh
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import HoverTool, ColumnDataSource, ColorBar, LogColorMapper, LinearColorMapper
from bokeh.transform import linear_cmap, log_cmap
from open_spiel.python import rl_environment, policy

import yaml
import torch

output_notebook()

Populating the interactive namespace from numpy and matplotlib


In [32]:
game = pyspiel.load_game("hallway(width=11,height=1)")
NUM_ITERS = 100_000
env = rl_environment.Environment(game)
agents = []
report_freq = 5_000

In [33]:
player_id = 0
config_folder = '/apps/open_spiel/open_spiel/python/examples/network_configs'
with open(f'{config_folder}/config_example_mlp.yml', 'rb') as fh:
    config = yaml.load(fh, Loader=yaml.FullLoader)

num_actions = game.num_distinct_actions()
num_players = game.num_players()
state_size = env.observation_spec()["info_state"][0]

rl_model, rl_model_args = lookup_model_and_args(config['rl_model'], state_size, num_actions, num_players)

dqn_kwargs = {
    "replay_buffer_capacity": config['replay_buffer_capacity'],
    "epsilon_decay_duration": NUM_ITERS,
    "epsilon_start": 0.8,
    "epsilon_end": 0.01,
    "batch_size": config['batch_size'],
    "learning_rate": config['rl_learning_rate'],
    "learn_every": config['learn_every'],
    "min_buffer_size_to_learn": config['min_buffer_size_to_learn'],
    "optimizer_str": config['optimizer_str'],
    "update_target_network_every": config['update_target_network_every'],
    "loss_str": config['loss_str']
}

print(dqn_kwargs)

agent = ubc_dqn.DQN(
    player_id,
    num_actions, 
    q_network_model=rl_model,
    q_network_args=rl_model_args,
    **dqn_kwargs
)
agents.append(agent)

{'replay_buffer_capacity': 50000, 'epsilon_decay_duration': 100000, 'epsilon_start': 0.8, 'epsilon_end': 0.01, 'batch_size': 256, 'learning_rate': 0.01, 'learn_every': 64, 'min_buffer_size_to_learn': 1000, 'optimizer_str': 'sgd', 'update_target_network_every': 1000, 'loss_str': 'mse'}


In [34]:
def check_on_q_values(agent, game, state=None):
    q_network = agent._q_network
    if state is None:
        state = game.new_initial_state()
    legal_actions = state.legal_actions()
    it = state.information_state_tensor()
    info_state = q_network.prep_batch([q_network.reshape_infostate(it)])
    q_values = q_network(info_state).detach()[0]
    legal_q_values = q_values[legal_actions]
    action_dict = get_actions(game)
    legal_action_names = np.array(list(action_dict.values()))
    return {s: q for s,q in zip(legal_action_names, legal_q_values)}

In [35]:
for i in tqdm(range(NUM_ITERS)):
    if i % report_freq == 0:
        print(check_on_q_values(agent, game))
    
    time_step = env.reset()
    while not time_step.last():
        player_id = time_step.observations["current_player"]
        agent = agents[player_id]
        agent_output = agent.step(time_step)
        action_list = [agent_output.action]
        time_step = env.step(action_list)

    # Episode is over, step all agents with final info state.
    for player_id, agent in enumerate(agents):
        agent.step(time_step)


  0%|          | 25/100000 [00:00<06:53, 241.68it/s]

{'LEFT': tensor(0.), 'RIGHT': tensor(0.)}


  5%|▌         | 5028/100000 [00:21<06:28, 244.31it/s]

{'LEFT': tensor(0.5883), 'RIGHT': tensor(1.1203)}


 10%|█         | 10047/100000 [00:44<06:51, 218.36it/s]

{'LEFT': tensor(4.7508), 'RIGHT': tensor(5.6708)}


 15%|█▌        | 15021/100000 [01:21<13:56, 101.62it/s]

{'LEFT': tensor(6.7380), 'RIGHT': tensor(6.7801)}


 20%|██        | 20024/100000 [02:10<11:14, 118.65it/s]

{'LEFT': tensor(6.0035), 'RIGHT': tensor(6.1433)}


 25%|██▌       | 25018/100000 [02:57<11:39, 107.19it/s]

{'LEFT': tensor(5.9075), 'RIGHT': tensor(6.0698)}


 30%|███       | 30006/100000 [03:50<12:32, 93.04it/s] 

{'LEFT': tensor(5.5480), 'RIGHT': tensor(5.5749)}


 35%|███▌      | 35006/100000 [04:58<15:09, 71.47it/s] 

{'LEFT': tensor(5.5554), 'RIGHT': tensor(5.6106)}


 40%|████      | 40003/100000 [06:27<22:25, 44.60it/s]

{'LEFT': tensor(5.8169), 'RIGHT': tensor(5.8725)}


 45%|████▌     | 45010/100000 [07:59<21:37, 42.38it/s]

{'LEFT': tensor(6.1838), 'RIGHT': tensor(6.1371)}


 50%|█████     | 50009/100000 [09:29<13:39, 60.98it/s]

{'LEFT': tensor(6.1418), 'RIGHT': tensor(6.1001)}


 55%|█████▌    | 55013/100000 [10:39<09:53, 75.75it/s] 

{'LEFT': tensor(6.2447), 'RIGHT': tensor(6.1954)}


 60%|██████    | 60022/100000 [11:38<06:05, 109.53it/s]

{'LEFT': tensor(5.9129), 'RIGHT': tensor(5.8635)}


 65%|██████▌   | 65014/100000 [12:20<04:01, 145.08it/s]

{'LEFT': tensor(5.5284), 'RIGHT': tensor(5.4884)}


 70%|███████   | 70006/100000 [12:53<03:01, 165.69it/s]

{'LEFT': tensor(5.3055), 'RIGHT': tensor(5.2899)}


 75%|███████▌  | 75004/100000 [13:21<02:07, 196.49it/s]

{'LEFT': tensor(5.1276), 'RIGHT': tensor(5.1045)}


 80%|████████  | 80015/100000 [13:51<02:28, 134.35it/s]

{'LEFT': tensor(4.9300), 'RIGHT': tensor(4.9280)}


 85%|████████▌ | 85022/100000 [14:31<02:08, 116.18it/s]

{'LEFT': tensor(5.0513), 'RIGHT': tensor(5.0345)}


 90%|█████████ | 90014/100000 [15:17<01:32, 107.85it/s]

{'LEFT': tensor(5.0715), 'RIGHT': tensor(5.0424)}


 95%|█████████▌| 95021/100000 [16:05<00:44, 111.31it/s]

{'LEFT': tensor(5.0300), 'RIGHT': tensor(5.0120)}


100%|██████████| 100000/100000 [16:54<00:00, 98.54it/s]


In [38]:
x = torch.Tensor(game.new_initial_state().child(0).child(0).information_state_tensor())
agent._q_network(x)

tensor([ 4.9416, -0.1797,  5.0058,  0.1988], grad_fn=<AddBackward0>)

In [41]:
len([d for d in agent.replay_buffer._data if d.reward == 5])

2004