In [None]:
from matplotlib import animation
from pprint import pprint
import ray
%matplotlib notebook

from run import *

In [None]:
def run_evaluation(
    agent_run_names, 
    agent_checkpoints, 
    config_name, 
    policy_name="ppo", 
    seed=1,
    heterogeneous=True,
):
    ray.shutdown()
    ray.init()
    config = load_config(config_name)
    eval_run_name = ""
    for i in range(len(agent_run_names)):
        eval_run_name += f"{agent_run_names[i]}_{agent_checkpoints[i]}_"
    eval_run_name = eval_run_name[:-1]
    ray_dir = f"{os.path.expanduser('~')}/ray_results"
    eval_results_dir = f"{ray_dir}/{eval_run_name}"
    register_env("ZSC-Cleaner", lambda _: CleanerEnv(config["env_config"], run_name=eval_run_name))

    agents = {}
    for i in range(config["env_config"]["num_agents"]):
        agent = Agent(
            policy_name=policy_name,
            run_name=agent_run_names[i],
            agent_num=i,
            config=config,
            seed=seed,
            heterogeneous=heterogeneous,
        )
        agents[agent.name] = agent
        trainer_agents = {}
        if agent.heterogeneous:
            for other_agent_num in range(agent.config["env_config"]["num_agents"]):
                other_agent = deepcopy(agent)
                other_agent.agent_num = other_agent_num
                other_agent.name = f"{agent.run_name}:{other_agent_num}"
                trainer_agents[other_agent.name] = other_agent
            else:
                agent.name = 
        agent.trainer = create_trainer(
            agent.policy_name,
            trainer_agents, 
            agent.config, 
            agent.results_dir, 
            seed=agent.seed, 
            heterogeneous=agent.heterogeneous, 
            num_workers=1
        )
        checkpoint_num = agent_checkpoints[i]
        checkpoint_path = f"{ray_dir}/{agent.run_name}/checkpoint_" \
                          f"{str(checkpoint_num).zfill(6)}/checkpoint-{checkpoint_num}"
        agent.trainer.load_checkpoint(checkpoint_path)
    
    ani = evaluate(
        agents,
        config,
        eval_run_name,
        num_episodes=1,
        record=True,
    )
    return ani

In [None]:
agent_run_names = ["hom123", "hom456"]
agent_checkpoints = [1001, 1001]
config_name = "simple_11x11"
heterogeneous = False

run_evaluation(
    agent_run_names, 
    agent_checkpoints, 
    config_name, 
    policy_name="ppo", 
    seed=1,
    heterogeneous=heterogeneous
)

In [9]:
grid = grid_from_config(load_config("ring_11x11"))
agent_pos_from_grid(grid)

[Position(i=1, j=1), Position(i=1, j=11)]

In [15]:
agent_pos = np.where(grid["agent"])
[Position(agent_pos[0][num], agent_pos[1][num]) for num in range(len(agent_pos[0]))]

[Position(i=1, j=1),
 Position(i=1, j=11),
 Position(i=11, j=1),
 Position(i=11, j=11)]