In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

import parallel

import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical

In [2]:
class Super_Agent(nn.Module):
    #Common agent class for all hiders/seekers
    
    def __init__(self, num_actions, num_agents):
        super().__init__()

        # CNN architecture inspired by DQN for Atari
        self.network = nn.Sequential(
            nn.Conv2d(5, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 7 x 7
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Output: 64 x 7 x 7
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),  # Output: 64 x 7 x 7
            nn.ReLU(),
            nn.Flatten(),  # Output: 64 * 7 * 7 = 3136
        )
        self.actor = self._layer_init(nn.Linear(3136, num_actions), std=0.01)
        self.critic = self._layer_init(nn.Linear(3136, 1))

    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, x):
        return self.critic(self.network(x / 1.0))  # Normalize input to [0, 1]

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x / 1.0)  # Normalize input to [0, 1]
        
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

    

def batchify_obs(obs, device):
    """Converts PZ style observations to batch of torch arrays."""
    # convert to list of np arrays
    obs = np.stack([obs[a] for a in obs], axis=0)
    # convert to torch
    obs = torch.tensor(obs).to(device)

    return obs


def batchify(x, device):
    """Converts PZ style returns to batch of torch arrays."""
    # convert to list of np arrays
    x = np.stack([x[a] for a in x], axis=0)
    # convert to torch
    x = torch.tensor(x).to(device)

    return x


def unbatchify(x, env):
    """Converts np array to PZ style arguments."""
    x = x.cpu().numpy()
    x = {a: x[i] for i, a in enumerate(env.possible_agents)}

    return x


In [3]:
"""ALGO PARAMS"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ent_coef = 0.1
vf_coef = 0.1
clip_coef = 0.1
gamma = 0.99
batch_size = 32
stack_size = 4
frame_size = (64, 64)
max_cycles = 250
total_episodes = 200

""" ENV SETUP """
env = parallel.parallel_env(grid_size=7)

num_agents = len(env.possible_agents)
num_actions = env.action_space(env.possible_agents[0]).n
observation_size = env.observation_space(env.possible_agents[0]).shape

""" LEARNER SETUP """
agent = Super_Agent(num_actions=num_actions, num_agents=2).to(device)
optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5)

""" ALGO LOGIC: EPISODE STORAGE"""
end_step = 0
total_episodic_return = 0
rb_obs = torch.zeros((max_cycles, num_agents, 5,7,7)).to(device)
rb_actions = torch.zeros((max_cycles, num_agents)).to(device)
rb_logprobs = torch.zeros((max_cycles, num_agents)).to(device)
rb_rewards = torch.zeros((max_cycles, num_agents)).to(device)
rb_terms = torch.zeros((max_cycles, num_agents)).to(device)
rb_values = torch.zeros((max_cycles, num_agents)).to(device)

""" TRAINING LOGIC """
# train for n number of episodes
for episode in range(total_episodes):
    # collect an episode
    with torch.no_grad():
        # collect observations and convert to batch of torch tensors
        next_obs, info = env.reset(seed=None)
        # reset the episodic return
        total_episodic_return = 0

        # each episode has num_steps
        for step in range(0, max_cycles):
            # rollover the observation
            obs = batchify_obs(next_obs, device)
            # get action for first agent from the trained agent
            # get random actions for other agents
            actions = torch.zeros(num_agents, dtype=torch.long).to(device)
            logprobs = torch.zeros(num_agents).to(device)
            values = torch.zeros(num_agents).to(device)
                
            for i in range(2):
                # First agent uses policy network
                first_agent = env.possible_agents[i]
                first_agent_obs = obs[i].unsqueeze(0)
                actions[i], logprobs[i], _, values[i] = agent.get_action_and_value(first_agent_obs)

            # Other agents use random policy
            for i in range(2, num_agents):
                actions[i] = torch.randint(0, num_actions, (1,)).to(device)
                logprobs[i] = torch.log(torch.tensor(1.0/num_actions))
                values[i] = 0.0  # No value estimation for random agents

            # execute the environment and log data
            next_obs, rewards, terms, truncs, infos = env.step(
                unbatchify(actions, env)
            )

            # add to episode storage
            rb_obs[step] = obs
            rb_rewards[step] = batchify(rewards, device)
            rb_terms[step] = batchify(terms, device)
            rb_actions[step] = actions
            rb_logprobs[step] = logprobs
            rb_values[step] = values

            # compute episodic return
            total_episodic_return += rb_rewards[step].cpu().numpy()

            # if we reach termination or truncation, end
            if any([terms[a] for a in terms]) or any([truncs[a] for a in truncs]):
                end_step = step
                break

    # Bootstrap value and advantages only for the first agent
    with torch.no_grad():
        rb_advantages = torch.zeros_like(rb_rewards).to(device)
        for t in reversed(range(end_step)):
            for i in range(2):
                delta = (
                    rb_rewards[t, i]  # Only first agent's reward
                    + gamma * rb_values[t + 1, i] * rb_terms[t + 1, i]
                    - rb_values[t, i]
                )
                rb_advantages[t, i] = delta + gamma * gamma * rb_advantages[t + 1, i]
        rb_returns = rb_advantages + rb_values

    # convert our episodes to batch of individual transitions (only for first agent)
    b_obs = torch.cat([rb_obs[:end_step, 0], rb_obs[:end_step, 1]], dim=0)
    b_logprobs = torch.cat([rb_logprobs[:end_step, 0], rb_logprobs[:end_step, 1]], dim=0)
    b_actions = torch.cat([rb_actions[:end_step, 0], rb_actions[:end_step, 1]], dim=0)
    b_returns = torch.cat([rb_returns[:end_step, 0], rb_returns[:end_step, 1]], dim=0)
    b_values = torch.cat([rb_values[:end_step, 0], rb_values[:end_step, 1]], dim=0)
    b_advantages = torch.cat([rb_advantages[:end_step, 0], rb_advantages[:end_step, 1]], dim=0)

    # Optimizing the policy and value network
    b_index = np.arange(len(b_obs))
    clip_fracs = []
    for repeat in range(3):
        # shuffle the indices we use to access the data
        np.random.shuffle(b_index)
        for start in range(0, len(b_obs), batch_size):
            # select the indices we want to train on
            end = start + batch_size
            batch_index = b_index[start:end]

            _, newlogprob, entropy, value = agent.get_action_and_value(
                b_obs[batch_index], b_actions.long()[batch_index]
            )
            logratio = newlogprob - b_logprobs[batch_index]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fracs += [
                    ((ratio - 1.0).abs() > clip_coef).float().mean().item()
                ]

            # normalize advantages
            advantages = b_advantages[batch_index]
            advantages = (advantages - advantages.mean()) / (
                advantages.std() + 1e-8
            )

            # Policy loss
            pg_loss1 = -b_advantages[batch_index] * ratio
            pg_loss2 = -b_advantages[batch_index] * torch.clamp(
                ratio, 1 - clip_coef, 1 + clip_coef
            )
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            value = value.flatten()
            v_loss_unclipped = (value - b_returns[batch_index]) ** 2
            v_clipped = b_values[batch_index] + torch.clamp(
                value - b_values[batch_index],
                -clip_coef,
                clip_coef,
            )
            v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

    print(f"Training episode {episode}")
    print(f"Episodic Return: {rb_rewards[:, 0].mean().item()}")
    print(f"Episode Length: {end_step}")
    print("")
    print(f"Value Loss: {v_loss.item()}")
    print(f"Policy Loss: {pg_loss.item()}")
    print(f"Old Approx KL: {old_approx_kl.item()}")
    print(f"Approx KL: {approx_kl.item()}")
    print(f"Clip Fraction: {np.mean(clip_fracs)}")
    print(f"Explained Variance: {explained_var.item()}")
    print("\n-------------------------------------------\n")
'''
""" RENDER THE POLICY """

env = parallel.parallel_env(render_mode="human",grid_size=7)

agent.eval()

with torch.no_grad():
    # render 5 episodes out
    for episode in range(5):
        obs, infos = env.reset(seed=None)
        obs = batchify_obs(obs, device)
        terms = [False]
        truncs = [False]
        while not any(terms) and not any(truncs):
            # First agent uses trained policy
            first_agent_obs = obs[0].unsqueeze(0)
            actions, logprobs, _, values = agent.get_action_and_value(first_agent_obs)

            # Other agents use random actions
            other_actions = torch.randint(0, num_actions, (num_agents-1,))
            full_actions = torch.cat([actions, other_actions])

            obs, rewards, terms, truncs, infos = env.step(unbatchify(full_actions, env))
            obs = batchify_obs(obs, device)
            terms = [terms[a] for a in terms]
            truncs = [truncs[a] for a in truncs]'''

Training episode 0
Episodic Return: -6.192000389099121
Episode Length: 199

Value Loss: 38555.921875
Policy Loss: 271.5767517089844
Old Approx KL: 0.005476261954754591
Approx KL: 0.0031052830163389444
Clip Fraction: 0.08573717986926055
Explained Variance: -2.7894973754882812e-05

-------------------------------------------

Training episode 1
Episodic Return: -4.828000068664551
Episode Length: 199

Value Loss: 30875.853515625
Policy Loss: 230.19256591796875
Old Approx KL: 0.0395609587430954
Approx KL: 0.006656165700405836
Clip Fraction: 0.14537545828483042
Explained Variance: 3.17692756652832e-05

-------------------------------------------

Training episode 2
Episodic Return: -5.500000476837158
Episode Length: 199

Value Loss: 35231.13671875
Policy Loss: 234.8712921142578
Old Approx KL: 0.025779733434319496
Approx KL: 0.001655459520407021
Clip Fraction: 0.3444368134324367
Explained Variance: -2.4557113647460938e-05

-------------------------------------------

Training episode 3
Episo

Training episode 26
Episodic Return: -6.584000110626221
Episode Length: 199

Value Loss: 3576.08251953125
Policy Loss: -34.164363861083984
Old Approx KL: -0.032126106321811676
Approx KL: 0.010374759323894978
Clip Fraction: 0.29235348105430603
Explained Variance: 0.011089205741882324

-------------------------------------------

Training episode 27
Episodic Return: -5.988000392913818
Episode Length: 199

Value Loss: 9773.701171875
Policy Loss: 36.202117919921875
Old Approx KL: 0.04620013386011124
Approx KL: 0.012902784161269665
Clip Fraction: 0.37065018446017534
Explained Variance: 0.004499197006225586

-------------------------------------------

Training episode 28
Episodic Return: -5.5320000648498535
Episode Length: 199

Value Loss: 6750.55029296875
Policy Loss: -94.01605987548828
Old Approx KL: -0.0334811732172966
Approx KL: 0.020203731954097748
Clip Fraction: 0.4080815040148221
Explained Variance: -0.00039374828338623047

-------------------------------------------

Training episod

Training episode 52
Episodic Return: -7.808000564575195
Episode Length: 199

Value Loss: 19422.357421875
Policy Loss: 154.30563354492188
Old Approx KL: 0.12420584261417389
Approx KL: 0.026435546576976776
Clip Fraction: 0.4693223467239967
Explained Variance: 0.004492282867431641

-------------------------------------------

Training episode 53
Episodic Return: -7.160000324249268
Episode Length: 199

Value Loss: 3318.99169921875
Policy Loss: 69.24312591552734
Old Approx KL: 0.04624566063284874
Approx KL: 0.008942400105297565
Clip Fraction: 0.4355540306140215
Explained Variance: -0.002503514289855957

-------------------------------------------

Training episode 54
Episodic Return: -4.760000228881836
Episode Length: 199

Value Loss: 4128.91015625
Policy Loss: -53.21342468261719
Old Approx KL: -0.0420396514236927
Approx KL: 0.005599145777523518
Clip Fraction: 0.300022895137469
Explained Variance: -0.0005141496658325195

-------------------------------------------

Training episode 55
Episo

Training episode 78
Episodic Return: -4.604000091552734
Episode Length: 199

Value Loss: 12623.2685546875
Policy Loss: -149.57144165039062
Old Approx KL: -0.2860981523990631
Approx KL: 0.1553158164024353
Clip Fraction: 0.5005723467239966
Explained Variance: -0.007363557815551758

-------------------------------------------

Training episode 79
Episodic Return: -6.372000217437744
Episode Length: 199

Value Loss: 5359.47021484375
Policy Loss: 50.304264068603516
Old Approx KL: 0.0940658375620842
Approx KL: 0.02214549295604229
Clip Fraction: 0.48019688863020676
Explained Variance: -0.0038111209869384766

-------------------------------------------

Training episode 80
Episodic Return: -6.868000507354736
Episode Length: 199

Value Loss: 6877.2021484375
Policy Loss: 106.48084259033203
Old Approx KL: 0.13412266969680786
Approx KL: 0.01370516512542963
Clip Fraction: 0.4605082441598941
Explained Variance: -0.0018982887268066406

-------------------------------------------

Training episode 81
E

Training episode 103
Episodic Return: -4.4079999923706055
Episode Length: 199

Value Loss: 1284.5731201171875
Policy Loss: 6.077377796173096
Old Approx KL: 0.038817502558231354
Approx KL: 0.0304816123098135
Clip Fraction: 0.34684066206980974
Explained Variance: -0.0047866106033325195

-------------------------------------------

Training episode 104
Episodic Return: -5.832000255584717
Episode Length: 199

Value Loss: 2934.02001953125
Policy Loss: 50.58412170410156
Old Approx KL: 0.08832858502864838
Approx KL: 0.01999724470078945
Clip Fraction: 0.34706959892541933
Explained Variance: -0.0021696090698242188

-------------------------------------------

Training episode 105
Episodic Return: -4.860000133514404
Episode Length: 199

Value Loss: 533.8305053710938
Policy Loss: -0.8403636813163757
Old Approx KL: 0.003729607444256544
Approx KL: 0.0042157769203186035
Clip Fraction: 0.19619963451837882
Explained Variance: 0.014272034168243408

-------------------------------------------

Training 

Training episode 128
Episodic Return: -3.9200000762939453
Episode Length: 199

Value Loss: 3573.639404296875
Policy Loss: -55.365386962890625
Old Approx KL: -0.0011268258094787598
Approx KL: 0.01001429557800293
Clip Fraction: 0.33459249215248305
Explained Variance: -0.009222030639648438

-------------------------------------------

Training episode 129
Episodic Return: -4.648000240325928
Episode Length: 199

Value Loss: 2963.62451171875
Policy Loss: 32.81934356689453
Old Approx KL: 0.14453071355819702
Approx KL: 0.01817890629172325
Clip Fraction: 0.4758470730903821
Explained Variance: -0.008554458618164062

-------------------------------------------

Training episode 130
Episodic Return: -4.516000270843506
Episode Length: 199

Value Loss: 2468.075439453125
Policy Loss: 47.795536041259766
Old Approx KL: 0.0493476577103138
Approx KL: 0.02597481571137905
Clip Fraction: 0.4174679502462729
Explained Variance: 0.0013760924339294434

-------------------------------------------

Training epis

Training episode 153
Episodic Return: -5.88800048828125
Episode Length: 199

Value Loss: 7576.42333984375
Policy Loss: 101.3913803100586
Old Approx KL: 0.17155934870243073
Approx KL: 0.036002710461616516
Clip Fraction: 0.5700549483299255
Explained Variance: -0.0011593103408813477

-------------------------------------------

Training episode 154
Episodic Return: -4.212000370025635
Episode Length: 199

Value Loss: 1649.1275634765625
Policy Loss: 14.579275131225586
Old Approx KL: 0.09604661166667938
Approx KL: 0.015752797946333885
Clip Fraction: 0.318566851126842
Explained Variance: 0.016584038734436035

-------------------------------------------

Training episode 155
Episodic Return: -5.644000053405762
Episode Length: 199

Value Loss: 4755.97998046875
Policy Loss: 57.01995086669922
Old Approx KL: 0.035743750631809235
Approx KL: 0.016276288777589798
Clip Fraction: 0.44837454496285856
Explained Variance: -0.0030862092971801758

-------------------------------------------

Training episod

Training episode 178
Episodic Return: -4.776000022888184
Episode Length: 199

Value Loss: 1169.6859130859375
Policy Loss: -28.396875381469727
Old Approx KL: -0.005278230179101229
Approx KL: 0.0036788457073271275
Clip Fraction: 0.30689102716934985
Explained Variance: -0.021791934967041016

-------------------------------------------

Training episode 179
Episodic Return: -5.336000442504883
Episode Length: 199

Value Loss: 2120.249755859375
Policy Loss: 51.92646408081055
Old Approx KL: 0.060091692954301834
Approx KL: 0.025374243035912514
Clip Fraction: 0.2656822355511861
Explained Variance: 0.03380537033081055

-------------------------------------------

Training episode 180
Episodic Return: -3.8480002880096436
Episode Length: 199

Value Loss: 4841.841796875
Policy Loss: -30.369577407836914
Old Approx KL: 0.003857084782794118
Approx KL: 0.035748984664678574
Clip Fraction: 0.385416668959153
Explained Variance: 0.0023258328437805176

-------------------------------------------

Training e

'\n""" RENDER THE POLICY """\n\nenv = parallel.parallel_env(render_mode="human",grid_size=7)\n\nagent.eval()\n\nwith torch.no_grad():\n    # render 5 episodes out\n    for episode in range(5):\n        obs, infos = env.reset(seed=None)\n        obs = batchify_obs(obs, device)\n        terms = [False]\n        truncs = [False]\n        while not any(terms) and not any(truncs):\n            # First agent uses trained policy\n            first_agent_obs = obs[0].unsqueeze(0)\n            actions, logprobs, _, values = agent.get_action_and_value(first_agent_obs)\n\n            # Other agents use random actions\n            other_actions = torch.randint(0, num_actions, (num_agents-1,))\n            full_actions = torch.cat([actions, other_actions])\n\n            obs, rewards, terms, truncs, infos = env.step(unbatchify(full_actions, env))\n            obs = batchify_obs(obs, device)\n            terms = [terms[a] for a in terms]\n            truncs = [truncs[a] for a in truncs]'