In [1]:
from copy import deepcopy
import dill
import logging
import os
from pprint import pprint

from human_aware_rl.rllib.rllib import gen_trainer_from_params, load_agent, OvercookedMultiAgent
from human_aware_rl.ppo.ppo_rllib_client import my_config
from overcooked_ai_py.agents.benchmarking import AgentEvaluator
from overcooked_ai_py.mdp.actions import Action
from overcooked_ai_py.mdp.layout_generator import DEFAILT_PARAMS_SCHEDULE_FN, LayoutGenerator, MDPParamsGenerator

RAY_DIRECTORY = os.path.expanduser("~/ray_results")

In [2]:
def checkpoint_path(run_name, seed=0, checkpoint_num=1):
    run_dir = sorted([r for r in os.listdir(RAY_DIRECTORY) if f"{run_name}_{seed}" in r])[0]
    cp_path = f"{RAY_DIRECTORY}/{run_dir}/checkpoint_{checkpoint_num}/checkpoint-{checkpoint_num}"
    return cp_path

def load_params(run_name, seed=0):
    cp_path = checkpoint_path(run_name, seed)
    params_path = "/".join(cp_path.split("/")[:-1]) + "/config.pkl"
    params = dill.load(open(params_path, "rb"))
    return params

def load_env(run_name, seed=0):
    params = load_params(run_name, seed)
    return OvercookedMultiAgent.from_config(params["environment_params"])

def load_agents(run_name, seeds, checkpoint_num):        
    agents = {}
    for seed in seeds:
        agents[seed] = load_agent(
            checkpoint_path(run_name, seed=seed, checkpoint_num=checkpoint_num), 
            policy_id="ppo", 
            agent_index=-1  # set to 0 or 1 when initializing episode
        )
    return agents

In [3]:
def cross_play(mdp, agent_0, agent_1, num_games=100, rnd_obj_prob_thresh=0.0):
    params = {
        'horizon': 400, 
        'mlam_params': {
            'start_orientations': False,
            'wait_allowed': False,
            'counter_goals': [],
            'counter_drop': [],
            'counter_pickup': [],
            'same_motion_goals': True
        }
    }
    # start_state_fn = mdp.get_random_start_state_fn(random_start_pos=False, rnd_obj_prob_thresh=rnd_obj_prob_thresh)
    
    # TODO change to AgentEvaluator.from_mdp_lst
    trajs_0_0 = AgentEvaluator.from_mdp(mdp, params).get_agent_pair_trajs(
        a0=agent_0, num_games=num_games
    )
    print(f"agent 0 self-play: {trajs_0_0[0]['ep_returns'].mean()}")

    trajs_1_1 = AgentEvaluator.from_mdp(mdp, params).get_agent_pair_trajs(
        a0=agent_1, num_games=num_games
    )
    print(f"agent 1 self-play: {trajs_1_1[0]['ep_returns'].mean()}")

    trajs_0_1 = AgentEvaluator.from_mdp(mdp, params).get_agent_pair_trajs(
        a0=agent_0, a1=agent_1, num_games=num_games
    )
    print(f"cross-play: {trajs_0_1[0]['ep_returns'].mean()}")
    
    return trajs_0_0, trajs_1_1, trajs_0_1

In [13]:
# mdp = LayoutGenerator(MDPParamsGenerator(DEFAILT_PARAMS_SCHEDULE_FN)).generate_padded_mdp()
mdp = load_env("ring", 1).base_env.mdp
path_0 = checkpoint_path("ring", seed=1, checkpoint_num=601)
path_1 = checkpoint_path("ring", seed=1, checkpoint_num=601)
num_games = 10
rnd_obj_prob_thresh = 0.0

agent_0 = load_agent(path_0, policy_id="ppo", agent_index=0)
agent_1 = load_agent(path_1, policy_id="ppo", agent_index=1)
trajs_0_0, trajs_1_1, trajs_0_1 = cross_play(mdp, agent_0, agent_1, num_games, rnd_obj_prob_thresh)

2021-07-28 12:59:53,432	INFO trainable.py:217 -- Getting current IP.
2021-07-28 12:59:53,631	INFO trainable.py:217 -- Getting current IP.
2021-07-28 12:59:53,634	INFO trainable.py:423 -- Restored on 192.168.1.233 from checkpoint: /home/anchorwatt/ray_results/ring_1_2021-07-28_09-04-39c5zudb8n/checkpoint_601/checkpoint-601
2021-07-28 12:59:53,635	INFO trainable.py:430 -- Current state after restoring: {'_iteration': 601, '_timesteps_total': 7212000, '_time_total': 12697.656646728516, '_episodes_total': 18030}


2021-07-28 13:00:08,575	INFO trainable.py:180 -- _setup took 14.833 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
2021-07-28 13:00:08,577	INFO trainable.py:217 -- Getting current IP.
2021-07-28 13:00:08,688	INFO trainable.py:217 -- Getting current IP.
2021-07-28 13:00:08,690	INFO trainable.py:423 -- Restored on 192.168.1.233 from checkpoint: /home/anchorwatt/ray_results/ring_1_2021-07-28_09-04-39c5zudb8n/checkpoint_601/checkpoint-601
2021-07-28 13:00:08,691	INFO trainable.py:430 -- Current state after restoring: {'_iteration': 601, '_timesteps_total': 7212000, '_time_total': 12697.656646728516, '_episodes_total': 18030}
Avg rew: 0.00 (std: 0.00, se: 0.00); avg len: 400.00; : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:30<00:00,  3.03s/it]


Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.
agent 0 self-play: 0.0


Avg rew: 0.00 (std: 0.00, se: 0.00); avg len: 400.00; : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:16<00:00,  1.67s/it]


Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.
agent 1 self-play: 0.0


Avg rew: 0.00 (std: 0.00, se: 0.00); avg len: 400.00; : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:26<00:00,  2.62s/it]


Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.


Avg rew: 0.00 (std: 0.00, se: 0.00); avg len: 400.00; : 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:20<00:00,  2.07s/it]

Skipping trajectory consistency checking because MDP was recognized as variable. Trajectory consistency checking is not yet supported for variable MDPs.
cross-play: 0.0





In [14]:
from overcooked_ai_py.visualization.state_visualizer import *

StateVisualizer().display_rendered_trajectory(
    trajs_1_1[0], img_directory_path="/home/anchorwatt/traj_0_0"
)

interactive(children=(IntSlider(value=0, description='timestep', max=399), Output()), _dom_classes=('widget-in…