In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial

import jax
import jax.numpy as jnp

from calculate_metric import get_stats_for_state
from visualize_actor import get_state_traj


In [3]:
artifact_version = "343"
num_episodes = 100
model_artifact_remote_name = (
    f"josssdan/JaxInforMARL/PPO_RNN_Runner_State:v{artifact_version}"
)

traj_batch, config, env = get_state_traj(model_artifact_remote_name, artifact_version, num_episodes)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m:   13 of 13 files downloaded.  


Config:
{'derived_values': {'minibatch_size': 19200,
                    'num_actors': 300,
                    'num_updates': 78,
                    'scaled_clip_eps': 0.2},
 'env_config': {'env_cls_name': 'TargetMPEEnvironment',
                'env_kwargs': {'agent_communication_type': None,
                               'agent_control_noise_std': 0.0,
                               'agent_max_speed': -1,
                               'agent_visibility_radius': [1],
                               'collision_reward': -5,
                               'entities_initial_coord_radius': [1],
                               'entity_acceleration': 5,
                               'max_steps': 50,
                               'num_agents': 3,
                               'one_time_death_reward': 10}},
 'network_config': {'actor_num_hidden_linear_layer': 2,
                    'critic_num_hidden_linear_layer': 2,
                    'entity_type_embedding_dim': 4,
                   

In [4]:
num_envs = config.training_config.num_envs
num_agents = config.env_config.env_kwargs.num_agents
num_steps = config.env_config.env_kwargs.max_steps

In [5]:
# reshaping so that the axis becomes num_env, num_steps, num_agents...

traj_batch = jax.tree.map(lambda x: x.reshape(num_steps, num_agents, num_envs, *x.shape[2:]), traj_batch)
traj_batch = jax.tree.map(
    lambda x: jnp.swapaxes(x, 1, 2),
    traj_batch,
)
traj_batch = jax.tree.map(
    lambda x: jnp.swapaxes(x, 0, 1),
    traj_batch,
)


In [6]:
jax.tree.map(lambda x: x.shape, traj_batch)

TransitionWithEnvState(global_done=(100, 50, 3), done=(100, 50, 3), action=(100, 50, 3), value=(100, 50, 3), reward=(100, 50, 3), log_prob=(100, 50, 3), obs=(100, 50, 3, 6), graph=GraphsTupleWithAgentIndex(nodes=(100, 50, 3, 6, 7), edges=(100, 50, 3, 33, 1), receivers=(100, 50, 3, 33), senders=(100, 50, 3, 33), globals=None, n_node=(100, 50, 3), n_edge=(100, 50, 3), agent_indices=(100, 50, 3)), world_state=(100, 50, 3, 18), info={'returned_episode': (100, 50, 3), 'returned_episode_lengths': (100, 50, 3), 'returned_episode_returns': (100, 50, 3)}, env_state=LogEnvState(env_state=MPEState(dones=(100, 50, 3, 3), step=(100, 50, 3), entity_positions=(100, 50, 3, 6, 2), entity_velocities=(100, 50, 3, 6, 2), did_agent_die_this_time_step=(100, 50, 3, 3), agent_communication_message=(100, 50, 3, 0), agent_visibility_radius=(100, 50, 3, 3)), episode_returns=(100, 50, 3, 3), episode_lengths=(100, 50, 3, 3), returned_episode_returns=(100, 50, 3, 3), returned_episode_lengths=(100, 50, 3, 3)))

In [7]:
# summing across all steps in episode and across all agents
total_reward = jnp.sum(traj_batch.reward, axis=(1, 2))
avg_reward_per_episode = jnp.average(total_reward).item()

In [8]:
avg_reward_per_episode

765.2213745117188

In [9]:
done = jnp.swapaxes(traj_batch.done, 1, 2)  # so that it becomes num_env, num_agents, num_steps
avg_goal_reach_time_in_episode_fraction = (jnp.argmax(done, axis=-1) + 1) / num_steps
agents_that_didnt_reach_goal = jnp.all(~done, axis=-1)
avg_goal_reach_time_in_episode_fraction = avg_goal_reach_time_in_episode_fraction.at[agents_that_didnt_reach_goal].set(
    1)
avg_goal_reach_time_in_episode_fraction = jnp.average(avg_goal_reach_time_in_episode_fraction).item()

In [10]:
avg_goal_reach_time_in_episode_fraction

0.5062667727470398

In [11]:
reached_goal = jnp.any(done, axis=-1)
all_agents_reached_goal = jnp.all(reached_goal, axis=-1)

episode_percent_all_agents_reached_goals = jnp.average(all_agents_reached_goal) * 100
episode_percent_all_agents_reached_goals = episode_percent_all_agents_reached_goals.item()

In [12]:
episode_percent_all_agents_reached_goals

22.0

In [13]:
@partial(jax.jit, static_argnums=(0,))
def compute_stats_for_all_episode(env, state):
    compute_stats_for_every_step = jax.vmap(get_stats_for_state, in_axes=(None, 0))
    compute_all_stats = jax.vmap(compute_stats_for_every_step, in_axes=(None, 0))
    return compute_all_stats(env, state)

In [14]:
env_state = traj_batch.env_state.env_state
env_state = jax.tree.map(lambda x: x[:, :, 0],
                         env_state)  # take state from one agent since it will be the same for all agents

In [15]:
num_collisions, num_agent_died = compute_stats_for_all_episode(env, env_state)

In [16]:
avg_num_collision_across_all_episodes = jnp.average(num_collisions).item()
avg_num_deaths_across_all_episodes = jnp.average(num_agent_died).item()

In [17]:
avg_reward_per_episode, avg_goal_reach_time_in_episode_fraction, f"{episode_percent_all_agents_reached_goals} %", avg_num_collision_across_all_episodes

(765.2213745117188, 0.5062667727470398, '22.0 %', 0.03989999741315842)