In [71]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [72]:
from functools import partial

import jax
import jax.numpy as jnp

from calculate_metric import get_stats_for_state
from visualize_actor import get_state_traj


In [73]:
artifact_version = "467"
num_episodes = 100
model_artifact_remote_name = (
    f"josssdan/JaxInforMARL/PPO_RNN_Runner_State:v{artifact_version}"
)

traj_batch, config, env = get_state_traj(model_artifact_remote_name, artifact_version, num_episodes=num_episodes,
                                         store_action_field=False)

[34m[1mwandb[0m:   12 of 12 files downloaded.  


Config:
{'derived_values': {'minibatch_size': 19200,
                    'num_actors': 300,
                    'num_updates': 39,
                    'scaled_clip_eps': 0.2},
 'env_config': {'env_cls_name': 'TargetMPEEnvironment',
                'env_kwargs': {'agent_communication_type': 'HIDDEN_STATE',
                               'agent_control_noise_std': 0.0,
                               'agent_max_speed': -1,
                               'agent_visibility_radius': [0.5],
                               'collision_reward_coefficient': -1,
                               'distance_to_goal_reward_coefficient': 10,
                               'entities_initial_coord_radius': [1],
                               'entity_acceleration': 5,
                               'max_steps': 100,
                               'num_agents': 3,
                               'one_time_death_reward': 5}},
 'network_config': {'actor_num_hidden_linear_layer': 2,
                    'critic_nu

In [74]:
num_envs = config.training_config.num_envs
num_agents = config.env_config.env_kwargs.num_agents
num_steps = config.env_config.env_kwargs.max_steps

In [75]:
# reshaping so that the axis becomes num_env, num_steps, num_agents...

traj_batch = jax.tree.map(lambda x: x.reshape(num_steps, num_agents, num_envs, *x.shape[2:]), traj_batch)
traj_batch = jax.tree.map(
    lambda x: jnp.swapaxes(x, 1, 2),
    traj_batch,
)
traj_batch = jax.tree.map(
    lambda x: jnp.swapaxes(x, 0, 1),
    traj_batch,
)


In [76]:
jax.tree.map(lambda x: x.shape, traj_batch)

TransitionForVisualization(global_done=(100, 100, 3), done=(100, 100, 3), action=(100, 100, 3), value=(100, 100, 3), reward=(100, 100, 3), log_prob=(100, 100, 3), obs=(100, 100, 3, 6), graph=GraphsTupleWithAgentIndex(nodes=(100, 100, 3, 6, 71), edges=(100, 100, 3, 21, 1), receivers=(100, 100, 3, 21), senders=(100, 100, 3, 21), globals=None, n_node=(100, 100, 3), n_edge=(100, 100, 3), agent_indices=(100, 100, 3)), world_state=(100, 100, 3, 18), info={'returned_episode': (100, 100, 3), 'returned_episode_lengths': (100, 100, 3), 'returned_episode_returns': (100, 100, 3)}, env_state=LogEnvState(env_state=MPEState(dones=(100, 100, 3, 3), step=(100, 100, 3), entity_positions=(100, 100, 3, 6, 2), entity_velocities=(100, 100, 3, 6, 2), did_agent_die_this_time_step=(100, 100, 3, 3), agent_communication_message=(100, 100, 3, 3, 64), agent_visibility_radius=(100, 100, 3, 3)), episode_returns=(100, 100, 3, 3), episode_lengths=(100, 100, 3, 3), returned_episode_returns=(100, 100, 3, 3), returned_ep

In [77]:
# summing across all steps in episode and across all agents
total_reward = jnp.sum(traj_batch.reward, axis=(1, 2))
avg_reward_per_episode = jnp.average(total_reward).item()

In [78]:
avg_reward_per_episode

-4792.0556640625

In [79]:
done = jnp.swapaxes(traj_batch.done, 1, 2)  # so that it becomes num_env, num_agents, num_steps
avg_goal_reach_time_in_episode_fraction = (jnp.argmax(done, axis=-1) + 1) / num_steps
agents_that_didnt_reach_goal = jnp.all(~done, axis=-1)
avg_goal_reach_time_in_episode_fraction = avg_goal_reach_time_in_episode_fraction.at[agents_that_didnt_reach_goal].set(
    1)
avg_goal_reach_time_in_episode_fraction = jnp.average(avg_goal_reach_time_in_episode_fraction).item()

In [80]:
avg_goal_reach_time_in_episode_fraction

0.11403333395719528

In [81]:
reached_goal = jnp.any(done, axis=-1)
all_agents_reached_goal = jnp.all(reached_goal, axis=-1)

episode_percent_all_agents_reached_goals = jnp.average(all_agents_reached_goal) * 100
episode_percent_all_agents_reached_goals = episode_percent_all_agents_reached_goals.item()

In [82]:
episode_percent_all_agents_reached_goals

100.0

In [83]:
@partial(jax.jit, static_argnums=(0,))
def compute_stats_for_all_episode(env, state):
    compute_stats_for_every_step = jax.vmap(get_stats_for_state, in_axes=(None, 0))
    compute_all_stats = jax.vmap(compute_stats_for_every_step, in_axes=(None, 0))
    return compute_all_stats(env, state)

In [84]:
env_state = traj_batch.env_state.env_state
env_state = jax.tree.map(lambda x: x[:, :, 0],
                         env_state)  # take state from one agent since it will be the same for all agents

In [85]:
num_collisions, num_agent_died = compute_stats_for_all_episode(env, env_state)

In [86]:
avg_num_collision_across_all_episodes = jnp.average(num_collisions).item()
avg_num_deaths_across_all_episodes = jnp.average(num_agent_died).item()

In [87]:
avg_reward_per_episode, avg_goal_reach_time_in_episode_fraction, f"{episode_percent_all_agents_reached_goals} %", avg_num_collision_across_all_episodes

(-4792.0556640625, 0.11403333395719528, '100.0 %', 0.10619999468326569)

In [88]:
(-4745.63916015625, 0.11216667294502258, '100.0 %', 0.09730000048875809)

(-4745.63916015625, 0.11216667294502258, '100.0 %', 0.09730000048875809)