# Using CleanRL PPO to try and train Villagers

In [1]:
import numpy as np
import torch
import sys
sys.path.append('../')
from voting_games.werewolf_env_v0 import raw_env
import random
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


Slightly based on [this](https://pettingzoo.farama.org/tutorials/cleanrl/implementing_PPO/), and the [following blogpost](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/). Another PZ implementation referenced is [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py)

One more link I plan on reading, at least for PPO is [here](https://towardsdatascience.com/elegantrl-mastering-the-ppo-algorithm-part-i-9f36bc47b791)

In [2]:
env = raw_env(num_agents=10, werewolves=2)
env.reset()

In [6]:
env.observation_spaces['player_1'].sample()['observation'].values()

odict_values([2, 1, array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True]), array([0, 0, 1, 0, 1, 1, 0, 1, 0, 1]), array([1.1784745, 6.6172085, 4.434359 , 2.475156 , 3.2237024, 3.2628539,
       7.9956994, 6.5257916, 4.5016713, 6.69284  ], dtype=float32)])

In [122]:
def flat_obs(observation):
    return  np.asarray([observation['day']] + \
            [observation['phase']] + \
            [int(status) for status in observation['player_status']] + \
            [role for role in observation['roles']] + \
            [vote for vote in observation['votes']])

In [123]:
def random_policy(observation, agent):
    # these are the other wolves. we cannot vote for them either
    available_actions = list(range(len(observation['observation']['player_status'])))
    # dead players
    action_mask = observation['action_mask']

    legal_actions = [action for action,is_alive,is_wolf in zip(available_actions, action_mask, observation['observation']['roles']) if is_alive and not is_wolf]
    # wolves don't vote for other wolves. will select another villager at random
    action = random.choice(legal_actions)
    return action

In [8]:
flat_obs(env.observation_spaces['player_1'].sample()['observation']).shape[-1]

32

In [124]:
ten_player_env = raw_env(num_agents=10, werewolves=2)

avg_game_length = 0
wolf_wins = 0
villager_wins = 0

num_games = 1000

ten_player_env.reset()

for _ in tqdm(range(num_games)):

    for agent in ten_player_env.agent_iter():
        observation, reward, termination, truncation, info = ten_player_env.last()
        action = random_policy(observation, agent) if not termination or truncation else None
        ten_player_env.step(action)
    
    # get some stats
    winner = ten_player_env.world_state['winners']
    day = ten_player_env.world_state['day']

    if winner:
        wolf_wins += 1
    else:
        villager_wins += 1
    
    avg_game_length += (day * 1.0)/num_games 

    # reset 
    ten_player_env.reset()

print(f'Average game length = {avg_game_length:.2f}')
print(f'Wolf wins : {wolf_wins}')
print(f'Villager wins: {villager_wins}')

100%|██████████| 1000/1000 [00:01<00:00, 932.83it/s]

Average game length = 4.20
Wolf wins : 895
Villager wins: 105





## PPO Training

In [131]:
class Agent(torch.nn.Module):
    def __init__(self, num_actions, obs_size):
        super().__init__()

        self.critic = torch.nn.Sequential(
            self._layer_init(torch.nn.Linear(obs_size, 64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,1), std=1.0),
        )

        self.actor = torch.nn.Sequential(
            self._layer_init(torch.nn.Linear(obs_size, 64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64, num_actions), std=0.01),
        )
    
    def get_value(self, x):
        return self.critic(x)
    
    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)

        probs = torch.distributions.categorical.Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)

    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

def batchify_obs(obs, device):
    """Converts PZ style observations to batch of torch arrays."""
    # convert to list of np arrays
    obs = np.stack([obs[a] for a in obs], axis=0)
    obs = torch.tensor(obs).to(device)

def batchify(x, device):
    """Converts PZ style returns to batch of torch arrays."""
    # convert to list of np arrays
    x = np.stack([x[a] for a in x], axis=0)
    # convert to torch
    x = torch.tensor(x).to(device)

    return x

def unbatchify(x, env):
    """Converts np array to PZ style arguments."""
    x = x.cpu().numpy()
    x = {a: x[i] for i, a in enumerate(env.possible_agents)}

    return x


In [138]:
# ALGORITHM PARAMETERS
# TODO: What is really necessary here?
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ent_coef = 0.1 #
vf_coef = 0.1 #
clip_coef = 0.1 #
gamma = 0.99 #
gae_lambda = 0.95
batch_size = 16 #
max_cycles = 125 #
total_episodes = 100 #
update_epochs = 3 #

In [139]:
### Env Setup
env = raw_env(num_agents=10, werewolves=2)
# env.reset()
num_agents = 10
num_actions = env.action_spaces['player_1'].n
observation_size = flat_obs(env.observation_spaces['player_1'].sample()['observation']).shape[-1]

# Learner Setup
ppo_agent = Agent(num_actions=num_actions, obs_size=observation_size)
optimizer = torch.optim.Adam(ppo_agent.parameters(), lr=0.001, eps=1e-5)
# for agents in 
# # Algorithm Logic : Episode Storage

# # rb = rollback
# end_step = 0
# total_episodic_return = 0
# rb_obs = 
# rb_actions =
# rb_logprobs = 
# rb_rewards = 
# rb_terms = 
# rb_values =


In [140]:
# Training Logic
for episode in tqdm(range(total_episodes)):
    with torch.no_grad():
        env.reset()

        # magent_list = {agent: [] for agent in env.agents}
        magent_list = {agent : [] for agent in env.agents if not env.agent_roles[agent]}

        # print(magent_list.keys())
        for magent in env.agent_iter():
            observation, reward, termination, truncation, info = env.last()

            # werewolves have full role TODO: add logic for wolves herevisibility
            if sum(observation['observation']['roles']):
                # TODO: find a cleaner way to identify a wolf
                action = random_policy(observation, magent) if not termination or truncation else None
            else:
                obs = torch.Tensor(flat_obs(observation['observation']))
                if not termination or truncation:
                    action, logprobs, _, value = ppo_agent.get_action_and_value(obs)
                else:
                    action = None

                magent_list[magent].append({
                    "obs": obs, 
                    "action": action,
                    "prev_reward": reward,
                    "logprobs": logprobs,
                    "term": termination,
                    "value": value
                    })

            env.step(action)
        
        # take the sequential observations of each agent, and store them appropriately
        magent_obs = {agent: {'obs': [], 'rewards': [], 'actions': [], 'logprobs': [], 'values': [], 'terms': []} for agent in magent_list}
        for key, value in magent_list.items():
            # print(f'-- {key} --')
            for s1, s2 in zip(value, value[1:]):
                magent_obs[key]['obs'].append(s1['obs'])
                magent_obs[key]['rewards'].append(s2['prev_reward'])
                magent_obs[key]['actions'].append(s1['action'])
                magent_obs[key]['logprobs'].append(s1['logprobs'])
                magent_obs[key]['values'].append(s1['value'])
                magent_obs[key]['terms'].append(s2['term'])

    # We will do this for each agent in the episode
    # essentially we are calculating advantages and returns
    with torch.no_grad():
        for player, records in magent_obs.items():
            # print(f'{records}')
            advantages = torch.zeros_like(torch.tensor(records['rewards']))

            for t in reversed(range(len(records['obs']))):
                # print(f'T: {t+1} - Rewards : {torch.tensor(records["rewards"])[t+1]} ')
                # not using terms, as these are episodic

                ## this was the last one. We are not using any terminal states in a good way

                if t == len(records['obs']) - 1:
                    #print(f'T: {t} - Rewards at end : {torch.tensor(records["rewards"])[t]} ')
                    #print(f'T: {t} - Actions at end : {torch.tensor(records["actions"])[t]} ')
                    delta = records["rewards"][t] - records["values"][t]
                    advantages[t]  = delta
                else:
                    #print(f'T: {t} - Rewards : {torch.tensor(records["rewards"])[t]} ')
                    #print(f'T: {t} - Actions : {torch.tensor(records["actions"])[t]} ')                    
                    delta = records["rewards"][t] + gamma * records["values"][t+1] - records["values"][t]
                    advantages[t]  = delta + gamma * gamma * advantages[t+1]

                #delta = records['rewards'][t] + gamma * records['values'][t+1] - records['values'][t]
            magent_obs[player]["advantages"] = advantages
            magent_obs[player]["returns"] = advantages + torch.tensor(records["values"])
                #advantages[t] = delta + gamma * gamma * advantages[t+1]
    

    # optimize the policy and the value network now
    # we can take all our observations now and flatten them into one bigger list of individual transitions
    # TODO: could make this setting into a single loop, but maybe this is clearer. ALso could make all these tensors earlier

    # rec = list(magent_obs.values())[0]['obs']
    # # print(rec)
    # # print(torch.stack(rec))
    # # print([item['actions'] for item in magent_obs.values()])
    # # print(torch.cat([item['advantages'] for item in magent_obs.values()]))
    # # rec = torch.cat([item['advantages'] for item in magent_obs.values()])
    # # print(f'Length of {len(rec)} \n{rec}')
    # # rec = torch.cat([torch.stack(item['logprobs']) for item in magent_obs.values()])
    # # print(f'Length of {len(rec)} \n{rec}')
    # # rec = torch.cat([torch.stack(item['obs']) for item in magent_obs.values()])
    # # print(f'Length of {len(rec)}')

    #print(torch.stack(list(magent_obs.values())[0])
    b_observations = torch.cat([torch.stack(item['obs']) for item in magent_obs.values()])
    b_logprobs = torch.cat([torch.stack(item['logprobs']) for item in magent_obs.values()])
    b_actions = torch.cat([torch.stack(item['actions']) for item in magent_obs.values()])
    b_returns = torch.cat([item['returns'] for item in magent_obs.values()])
    b_values = torch.cat([torch.stack(item['values']) for item in magent_obs.values()])
    b_advantages =  torch.cat([item['advantages'] for item in magent_obs.values()])



    # b_index stands for batch index
    b_index = np.arange(len(b_observations))
    clip_fracs = []
    for epoch in range(update_epochs):
        np.random.shuffle(b_index)
        for start in range(0, len(b_observations), batch_size):
            end = start + batch_size
            batch_index = b_index[start:end]

            _, newlogprob, entropy, value = ppo_agent.get_action_and_value(
                b_observations[batch_index], b_actions.long()[batch_index])
            
            logratio = newlogprob - b_logprobs[batch_index]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fracs += [
                    ((ratio - 1.0).abs() > clip_coef).float().mean().item()
                ]
            
            # normalizing advantages
            advantages = b_advantages[batch_index]
            advantages = advantages.float()
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            # policy loss
            pg_loss1 = -advantages * ratio
            pg_loss2 = -advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # value loss
            value = value.flatten()
            v_loss_unclipped = (value - b_returns[batch_index]) ** 2
            v_clipped = b_values[batch_index] + torch.clamp(
                value - b_values[batch_index],
                -clip_coef,
                clip_coef,
            )
            v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    # could move them from GPU here
    y_pred, y_true = b_values.numpy(), b_returns.numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
    
    if episode % 20 == 0:
        print(f"Training episode {episode}")
        #print(f"Episodic Return: {np.mean(total_episodic_return)}")
        #print(f"Episode Length: {end_step}")
        print("")
        print(f"Value Loss: {v_loss.item()}")
        print(f"Policy Loss: {pg_loss.item()}")
        print(f"Old Approx KL: {old_approx_kl.item()}")
        print(f"Approx KL: {approx_kl.item()}")
        print(f"Clip Fraction: {np.mean(clip_fracs)}")
        print(f"Explained Variance: {explained_var.item()}")
        print("\n-------------------------------------------\n")


            
        

  5%|▌         | 5/100 [00:00<00:04, 20.32it/s]

Training episode 0

Value Loss: 492.10736083984375
Policy Loss: -0.06477343291044235
Old Approx KL: 0.03851878643035889
Approx KL: 0.006988584995269775
Clip Fraction: 0.21875000248352686
Explained Variance: -0.029214859008789062

-------------------------------------------



 25%|██▌       | 25/100 [00:01<00:03, 20.14it/s]

Training episode 20

Value Loss: 160.55172729492188
Policy Loss: 0.037208061665296555
Old Approx KL: -0.0072841644287109375
Approx KL: 0.001181522966362536
Clip Fraction: 0.07291666666666667
Explained Variance: -0.0002262592315673828

-------------------------------------------



 43%|████▎     | 43/100 [00:02<00:02, 19.49it/s]

Training episode 40

Value Loss: 63.47903823852539
Policy Loss: -0.02607208490371704
Old Approx KL: -0.016108814626932144
Approx KL: 0.004397149663418531
Clip Fraction: 0.12708333445092043
Explained Variance: -0.00024628639221191406

-------------------------------------------



 64%|██████▍   | 64/100 [00:03<00:01, 21.12it/s]

Training episode 60

Value Loss: 37.77366638183594
Policy Loss: 0.018477117642760277
Old Approx KL: 0.16878294944763184
Approx KL: 0.014280001632869244
Clip Fraction: 0.19791666666666666
Explained Variance: -2.1338462829589844e-05

-------------------------------------------



 85%|████████▌ | 85/100 [00:04<00:00, 20.59it/s]

Training episode 80

Value Loss: 39.60499572753906
Policy Loss: 0.0
Old Approx KL: -0.025804996490478516
Approx KL: 0.000391542911529541
Clip Fraction: 0.10416666666666667
Explained Variance: -1.4424324035644531e-05

-------------------------------------------



100%|██████████| 100/100 [00:04<00:00, 20.26it/s]


In [141]:
ten_player_env = raw_env(num_agents=10, werewolves=2)

avg_game_length = 0
wolf_wins = 0
villager_wins = 0

num_games = 1000

ten_player_env.reset()

for _ in tqdm(range(num_games)):

    for agent in ten_player_env.agent_iter():
        observation, reward, termination, truncation, info = ten_player_env.last()
        action = random_policy(observation, agent) if not termination or truncation else None


        if sum(observation['observation']['roles']):
            # TODO: find a cleaner way to identify a wolf
            action = random_policy(observation, magent) if not termination or truncation else None
        else:
            ## villagers act according to a trained policy
            obs = torch.Tensor(flat_obs(observation['observation']))
            if not termination or truncation:
                action, logprobs, _, value = ppo_agent.get_action_and_value(obs)
            else:
                action = None

        ten_player_env.step(action)
    
    # get some stats
    winner = ten_player_env.world_state['winners']
    day = ten_player_env.world_state['day']

    if winner:
        wolf_wins += 1
    else:
        villager_wins += 1
    
    avg_game_length += (day * 1.0)/num_games 

    # reset 
    ten_player_env.reset()

print(f'Average game length = {avg_game_length:.2f}')
print(f'Wolf wins : {wolf_wins}')
print(f'Villager wins: {villager_wins}')

100%|██████████| 1000/1000 [00:25<00:00, 38.84it/s]

Average game length = 4.21
Wolf wins : 909
Villager wins: 91





https://andyljones.com/posts/rl-debugging.html