In [1]:
from copy import deepcopy
import dill
import logging
import os
from pprint import pprint

from human_aware_rl.rllib.rllib import gen_trainer_from_params, load_agent, OvercookedMultiAgent
from human_aware_rl.ppo.ppo_rllib_client import my_config
from overcooked_ai_py.agents.benchmarking import AgentEvaluator
from overcooked_ai_py.mdp.actions import Action
from overcooked_ai_py.mdp.layout_generator import DEFAILT_PARAMS_SCHEDULE_FN, LayoutGenerator, MDPParamsGenerator

RAY_DIRECTORY = os.path.expanduser("~/ray_results")

In [2]:
def checkpoint_path(run_name, seed=0, checkpoint_num=1):
    run_dir = sorted([r for r in os.listdir(RAY_DIRECTORY) if f"{run_name}_{seed}" in r])[0]
    cp_path = f"{RAY_DIRECTORY}/{run_dir}/checkpoint_{checkpoint_num}/checkpoint-{checkpoint_num}"
    return cp_path

def load_params(run_name, seed=0):
    cp_path = checkpoint_path("10k", 1024)
    params_path = "/".join(cp_path.split("/")[:-1]) + "/config.pkl"
    params = dill.load(open(params_path, "rb"))
    return params

def load_env(run_name, seed=0):
    params = load_params(run_name, seed)
    return OvercookedMultiAgent.from_config(params["environment_params"])

def load_agents(run_name, seeds, checkpoint_num):        
    agents = {}
    for seed in seeds:
        agents[seed] = load_agent(
            checkpoint_path(run_name, seed=seed, checkpoint_num=checkpoint_num), 
            policy_id="ppo", 
            agent_index=-1  # set to 0 or 1 when initializing episode
        )
    return agents

In [3]:
def cross_play(mdp, agent_0, agent_1, num_games=100, rnd_obj_prob_thresh=0.0):
    params = {
        'horizon': 400, 
        'mlam_params': {
            'start_orientations': False,
            'wait_allowed': False,
            'counter_goals': [],
            'counter_drop': [],
            'counter_pickup': [],
            'same_motion_goals': True
        }
    }
    start_state_fn = mdp.get_random_start_state_fn(random_start_pos=False, rnd_obj_prob_thresh=rnd_obj_prob_thresh)
    
    trajs_0_0 = AgentEvaluator.from_mdp(mdp, params).get_agent_pair_trajs(
        a0=agent_0, num_games=num_games, start_state_fn=start_state_fn
    )
    print(f"agent 0 self-play: {trajs_0_0[0]['ep_returns'].mean()}")

    trajs_1_1 = AgentEvaluator.from_mdp(mdp, params).get_agent_pair_trajs(
        a0=agent_1, num_games=num_games, start_state_fn=start_state_fn
    )
    print(f"agent 1 self-play: {trajs_1_1[0]['ep_returns'].mean()}")

    trajs_0_1 = AgentEvaluator.from_mdp(mdp, params).get_agent_pair_trajs(
        a0=agent_0, a1=agent_1, num_games=num_games, start_state_fn=start_state_fn
    )
    print(f"cross-play: {trajs_0_1[0]['ep_returns'].mean()}")
    
    return trajs_0_0, trajs_1_1, trajs_0_1

In [4]:
# mdp = LayoutGenerator(MDPParamsGenerator(DEFAILT_PARAMS_SCHEDULE_FN)).generate_padded_mdp()
mdp = load_env("10k", 1024).base_env.mdp
path_0 = checkpoint_path("10k", seed=1024, checkpoint_num=6601)
path_1 = checkpoint_path("mod", seed=1024, checkpoint_num=1426)
num_games = 10
rnd_obj_prob_thresh = 0.0

agent_0 = load_agent(path_0, policy_id="ppo", agent_index=0)
agent_1 = load_agent(path_1, policy_id="ppo", agent_index=1)
trajs_0_0, trajs_1_1, trajs_0_1 = cross_play(mdp, agent_0, agent_1, num_games, rnd_obj_prob_thresh)

2021-07-27 12:18:49,417	INFO resource_spec.py:212 -- Starting Ray with 14.26 GiB memory available for workers and up to 7.14 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2021-07-27 12:18:49,984	INFO trainer.py:421 -- Tip: set 'eager': true or the --eager flag to enable TensorFlow eager execution
2021-07-27 12:18:50,051	INFO trainer.py:580 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2021-07-27 12:18:57,726	INFO trainable.py:217 -- Getting current IP.
2021-07-27 12:18:57,792	INFO trainable.py:217 -- Getting current IP.
2021-07-27 12:18:57,793	INFO trainable.py:423 -- Restored on 192.168.1.233 from checkpoint: /home/anchorwatt/ray_results/10k_1024_2021-07-17_19-09-08lxahhjhf/checkpoint_6601/checkpoint-6601
2021-07-27 12:18:57,793	INFO trainable.py:430 -- Current state after restoring: {'_iteration': 6601, '_timesteps_total': 79212000, '_time_total': 92236.27

Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.
agent 0 self-play: 62.0


Avg rew: 0.00 (std: 0.00, se: 0.00); avg len: 400.00; : 100%|███████████████████████████| 10/10 [00:12<00:00,  1.26s/it]


Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.
agent 1 self-play: 0.0


Avg rew: 32.00 (std: 16.00, se: 5.06); avg len: 400.00; : 100%|█████████████████████████| 10/10 [00:12<00:00,  1.25s/it]


Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.


Avg rew: 32.00 (std: 20.40, se: 6.45); avg len: 400.00; : 100%|█████████████████████████| 10/10 [00:12<00:00,  1.28s/it]

Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.
cross-play: 32.0





In [5]:
from overcooked_ai_py.visualization.state_visualizer import *

StateVisualizer().display_rendered_trajectory(
    trajs_1_1[0], img_directory_path="/home/anchorwatt/traj_0_0"
)

pygame 1.9.5
Hello from the pygame community. https://www.pygame.org/contribute.html


interactive(children=(IntSlider(value=0, description='timestep', max=399), Output()), _dom_classes=('widget-in…

In [12]:
pickle.dumps(agent_0.policy)

PicklingError: Can't pickle <class 'ray.rllib.policy.tf_policy_template.PPOTFPolicy'>: attribute lookup PPOTFPolicy on ray.rllib.policy.tf_policy_template failed

In [7]:
agent_1

<human_aware_rl.rllib.rllib.RlLibAgent at 0x7f7f7475dad0>

ALSA lib pcm.c:8545:(snd_pcm_recover) underrun occurred
ALSA lib pcm.c:8545:(snd_pcm_recover) underrun occurred
ALSA lib pcm.c:8545:(snd_pcm_recover) underrun occurred
