In [2]:
import numpy as np
import torch
import sys
sys.path.append('../')
from voting_games.werewolf_env_v0 import plurality_env, plurality_Phase, plurality_Role
import random
import copy
from typing import Any, Generator, Optional, Tuple
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


Done
Done


In [3]:
env = plurality_env(num_agents=10, werewolves=2)
env.reset()

def random_coordinated_wolf(env):
    villagers_remaining = set(env.world_state["villagers"]) & set(env.world_state['alive'])
    wolves_remaining = set(env.world_state["werewolves"]) & set(env.world_state['alive'])

    target = random.choice(list(villagers_remaining))
    return {wolf: int(target.split("_")[-1]) for wolf in wolves_remaining}

def random_wolfs(env):
    return {wolf: env.action_space(wolf).sample() for
            wolf in set(env.world_state["werewolves"]) & set(env.world_state['alive'])}

def revenge_coordinated_wolf(env, actions = None):
    villagers_remaining = set(env.world_state["villagers"]) & set(env.world_state['alive'])
    wolves_remaining = set(env.world_state["werewolves"]) & set(env.world_state['alive'])

    # who tried to vote out a wolf last time?
    
    target = random.choice(list(villagers_remaining))
    # pick 
    for wolf in wolves_remaining:
        actions[wolf] = [0] * len(env.possible_agents)
        actions[wolf][int(target.split("_")[-1])] = -1
        for curr_wolf in wolves_remaining:
            actions[wolf][int(curr_wolf.split("_")[-1])] = 1
    # for wolf in env.werewolves_remaining:

def random_single_target_villager(env, agent):
    targets = set(env.world_state["alive"]) - set([agent])
    return int(random.choice(list(targets)).split("_")[-1])

# random_coordinated_wolf(env)
def random_agent_action(env, agent):
   return env.action_space(agent).sample()


In [9]:
def play_static_wolf_game(env, wolf_policy, villager_agent, num_times=100) -> tuple(plurality_Role):

    villager_wins = 0
    loop = tqdm(range(num_times))

    for _ in loop:
        next_observations, rewards, terminations, truncations, infos = env.reset()
        while env.agents:
            observations = copy.deepcopy(next_observations)
            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villager steps
            if env.world_state["phase"] != plurality_Phase.NIGHT:
                # villagers actions
                for villager in villagers:
                    actions[villager] = villager_agent(env, villager)

            # wolf steps
            actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

        winner = env.world_state['winners']
        if winner == plurality_Role.VILLAGER:
            villager_wins += 1

        loop.set_description(f"Villagers won {villager_wins} out of a total of {num_times} games")

env = plurality_env(num_agents=10, werewolves=2)
env.reset()

print("Random Coordinated Wolves")
print("\t vs. Single Target Random Villagers")
play_static_wolf_game(env, random_coordinated_wolf, random_single_target_villager, num_times=1000)
print("\t vs. Random Villagers")
play_static_wolf_game(env, random_coordinated_wolf, random_agent_action, num_times=1000)
print("------------------------------------\n")
print("Random Wolves")
print("\t vs. Single Target Random Villagers")
play_static_wolf_game(env, random_wolfs, random_single_target_villager, num_times=1000)
print("\t vs. Random Villagers")
play_static_wolf_game(env, random_wolfs, random_agent_action, num_times=1000)
print("------------------------------------\n")

Random Coordinated Wolves
	 vs. Single Target Random Villagers


Villagers won 105 out of a total of 1000 games: 100%|██████████| 1000/1000 [00:05<00:00, 176.66it/s]


	 vs. Random Villagers


Villagers won 51 out of a total of 1000 games: 100%|██████████| 1000/1000 [00:04<00:00, 200.31it/s]


------------------------------------

Random Wolves
	 vs. Single Target Random Villagers


Villagers won 689 out of a total of 1000 games: 100%|██████████| 1000/1000 [00:04<00:00, 200.93it/s]


	 vs. Random Villagers


Villagers won 589 out of a total of 1000 games: 100%|██████████| 1000/1000 [00:05<00:00, 175.50it/s]

------------------------------------






In [None]:
class PluralityAgent(torch.nn.Module):

    def __init__(self, num_actions, obs_size=None):

        self.critic = torch.nn.Sequential(
            self._layer_init(torch.nn.Linear(obs_size, 64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,1), std=1.0),
        )

        self.actor = torch.nn.Sequential(
            self._layer_init(torch.nn.Linear(obs_size, 64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64, num_actions), std=0.01),
        )
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, x):
        return self.critic(x)
    
    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)

        probs = torch.distributions.categorical.Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)

In [108]:
class PluralityRecurrentAgent(torch.nn.Module):

    def __init__(self, num_actions, obs_size=None, hidden_state_size=64):
        super().__init__()

        # actor
        self.a_recurrent_layer = self._rec_layer_init(torch.nn.LSTM(64, hidden_state_size, batch_first=True))
        self.a_fc1 = self._layer_init(torch.nn.Linear(obs_size,64))
        self.a_fc2 = self._layer_init(torch.nn.Linear(64,num_actions), std=0.01)

        # critic
        self.c_recurrent_layer = self._rec_layer_init(torch.nn.LSTM(64, hidden_state_size, batch_first=True))
        self.c_fc1 = self._layer_init(torch.nn.Linear(obs_size,64))
        self.c_fc2 = self._layer_init(torch.nn.Linear(64,1), std=1.0)
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def _rec_layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        for name, param in layer.named_parameters():
            if "bias" in name:
                torch.nn.init.constant_(param, bias_const)
            if "weight" in name:
                torch.nn.init.orthogonal_(param, std)
        return layer

    def get_value(self, x, recurrent_cell:torch.tensor):
        h = torch.tanh(self.c_fc1(x))
        h, recurrent_cell = self.c_recurrent_layer(torch.unsqueeze(h,1), recurrent_cell)
        h = torch.tanh(self.c_fc2(h))

        # UPDATE THIS X
        return h, recurrent_cell
    
    def get_action_and_value(self, x, a_recurrent_cell:torch.tensor, c_recurrent_cell:torch.tensor, action=None):
        h = torch.tanh(self.a_fc1(x))
        # we are strictly on a sequence length of 1 here, using prior information baked in
        h, recurrent_cell = self.a_recurrent_layer(torch.unsqueeze(h,1), a_recurrent_cell)
        h = torch.tanh(self.a_fc2(h))
        probs = torch.distributions.categorical.Categorical(logits=h)

        if action is None:
            action = probs.sample()

        c_val, c_rec = self.get_value(x, c_recurrent_cell)
        # UPDATE THIS X
        return action, probs.log_prob(action), probs.entropy(), recurrent_cell, c_val, c_rec

In [103]:
env = plurality_env()
observations, rewards, terminations, truncations, infos = env.reset()
obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]
rec_agent = PluralityRecurrentAgent(num_actions=env.action_space("player_0").n, obs_size=obs_size)

# had to call, and unsqueeze the obs. we did this because we need to pass in a batch size.
# we also want to keep the length to 1 for now, and pass each through the model
tobs = torch.tensor(env.convert_obs(observations['player_0']['observation']), dtype=torch.float)
tobs = torch.unsqueeze(tobs, 0)
rec_agent.get_action_and_value(tobs, 
                               (torch.zeros((1), 64, dtype=torch.float32),
                                torch.zeros((1), 64, dtype=torch.float32))
                               )



TypeError: PluralityRecurrentAgent.get_action_and_value() missing 1 required positional argument: 'c_recurrent_cell'

In [120]:
class RolloutBuffer():
    
    def __init__(self, 
                 buffer_size: int, 
                 gamma: float, 
                 gae_lambda: float,
                 is_recurrent: bool,
                 recurrent_size: int = None,
                 ):
        '''
            @bufffer_size: This is the number of trajectories
        '''
        self.steps = []

        
        self.rewards = None
        self.actions = None
        self.dones = None
        self.observations = None

        # do we want these for both actor and critic?
        self.actor_hcxs = None 
        self.critic_hcxs = None 


        self.log_probs = None
        self.values = None
        self.advantages = None

        self.buffer_size = buffer_size
        self.gamma = gamma 
        self.gae_lambda = gae_lambda
        self.is_recurrent = is_recurrent
        if self.is_recurrent:
            self.recurrent_size = recurrent_size
        
        self.reset()

    def reset(self):
        self.rewards = []
        self.actions = []
        self.dones = []
        self.observations = []

        # do we want these for both actor and critic?
        self.actor_hcxs = []
        self.critic_hcxs = []

        self.log_probs = []
        self.values = []
        self.advantages = []
        self.returns = []



    def add_replay(self, game) -> bool:
         
         self.rewards.append(game['rewards'])
         self.actions.append(game['actions'])
         self.dones.append(game["terms"])
         self.observations.append(game["obs"])
         self.log_probs.append(game["logprobs"])
         self.values.append(game["values"])
         self.actor_hcxs.append(game["a_hcxs"][:-1])
         self.critic_hcxs.append(game["c_hcxs"][:-1])
        
         advantages, returns = self._calculate_advantages(game)
             
         self.advantages.append(advantages)
         self.returns.append(returns)

         return True
    
    def _calculate_advantages(self, game):
        """Generalized advantage estimation (GAE)
            Arguments:
                last_value {torch.tensor} -- Value of the last agent's state
                gamma {float} -- Discount factor
                lamda {float} -- GAE regularization parameter
        """
        advantages = torch.zeros_like(torch.tensor(game['rewards']))

        for t in reversed(range(len(game['rewards']))):
             delta = game['rewards'][t] + self.gamma * game['values'][max((t+1)%len(game['rewards']),t)] - game['values'][t]
             advantages[t] = delta + self.gamma * self.gae_lambda * advantages[max((t+1)%len(game['rewards']),t)]

        # adv and returns
        return advantages, advantages + torch.tensor(game['values'])
    
    def get_minibatch_generator(self, batch_size):

        # fold and stack observations
        actions = torch.cat([item for sublist in self.actions for item in sublist])
        logprobs = torch.cat([item for sublist in self.log_probs for item in sublist])
        returns = torch.cat(self.returns)
        values = torch.cat([item for sublist in self.values for item in sublist])
        advantages = torch.cat(self.advantages)
        actor_hxs, actor_cxs = zip(*[(hxs, cxs) for hxs, cxs in [item for sublist in self.actor_hcxs for item in sublist]])
        critic_hxs, critic_cxs = zip(*[(hxs, cxs) for hxs, cxs in [item for sublist in self.critic_hcxs for item in sublist]])
        observations = torch.cat([item for sublist in self.observations for item in sublist])

        index = np.arange(len(observations))

        np.random.shuffle(index)

        # We do not handle remaining stuff here
        for start in range(0,len(observations), batch_size):
            end = start + batch_size
            batch_index = index[start:end].astype(int)

            yield {
                "actions": actions[batch_index],
                "logprobs": logprobs[batch_index],
                "returns": returns[batch_index],
                "values": values[batch_index],
                "advantages": advantages[batch_index],
                # we are using sequence lengths of 1, because everything should be encoded in 
                "actor_hxs": torch.swapaxes(torch.cat(actor_hxs)[batch_index],0,1),
                "actor_cxs": torch.swapaxes(torch.cat(actor_cxs)[batch_index],0,1),
                "critic_hxs": torch.swapaxes(torch.cat(critic_hxs)[batch_index],0,1),
                "critic_cxs": torch.swapaxes(torch.cat(critic_cxs)[batch_index],0,1),
                "observations": observations[batch_index]
            }



In [121]:
@torch.no_grad()
def fill_recurrent_buffer(env, wolf_policy, villager_policy, num_times=10) -> RolloutBuffer:

    buffer = RolloutBuffer(buffer_size=10, 
                           gamma=0.90, 
                           gae_lambda=0.90,
                           is_recurrent=True)
    buffer.reset()
    
    for _ in range(num_times):
        ## Play the game 
        next_observations, rewards, terminations, truncations, infos = env.reset()
        # init recurrent stuff for actor and critic to 0 as well
        magent_obs = {agent: {'obs': [], 
                              'rewards': [], 
                              'actions': [], 
                              'logprobs': [], 
                              'values': [], 
                              'terms': [],

                              # obs size, and 1,1,64 as we pass batch first
                              'a_hcxs': [(torch.zeros((1,1,64), dtype=torch.float32), torch.zeros((1,1,64), dtype=torch.float32))],
                              'c_hcxs': [(torch.zeros((1,1,64), dtype=torch.float32), torch.zeros((1,1,64), dtype=torch.float32))]
                    } for agent in env.agents if not env.agent_roles[agent]}
        while env.agents:
            observations = copy.deepcopy(next_observations)
            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villager steps
            if env.world_state["phase"] != plurality_Phase.NIGHT:
                # villagers actions
                for villager in villagers:
                    #torch.tensor(env.convert_obs(observations['player_0']['observation']), dtype=torch.float)
                    torch_obs = torch.tensor(env.convert_obs(observations[villager]['observation']), dtype=torch.float)
                    obs = torch.unsqueeze(torch_obs, 0)

                    # TODO: Testing this, we may need a better way to pass in villagers
                    actor_recurrent_cell = magent_obs[villager]["a_hcxs"][-1]
                    critic_recurrent_cell = magent_obs[villager]["c_hcxs"][-1]
                    
                    # ensure that the obs is of size (batch,seq,inputs)
                    ml_action,  logprobs, _, actor_recurrent_cell, c_val, critic_recurrent_cell = villager_policy(obs, actor_recurrent_cell, critic_recurrent_cell)
                    actions[villager] = ml_action.item()

                    # can store some stuff 
                    magent_obs[villager]["obs"].append(obs)
                    magent_obs[villager]["actions"].append(ml_action)
                    magent_obs[villager]["logprobs"].append(logprobs)
                    magent_obs[villager]["values"].append(c_val)

                    #store the next recurrent cells
                    magent_obs[villager]["a_hcxs"].append(actor_recurrent_cell)
                    magent_obs[villager]["c_hcxs"].append(critic_recurrent_cell)

            # wolf steps
            actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

            for villager in villagers:
                    if env.history[-1]["phase"] == plurality_Phase.NIGHT:
                        magent_obs[villager]["rewards"][-1] += rewards[villager]
                        magent_obs[villager]["terms"][-1] = terminations[villager]
                    else:
                        magent_obs[villager]["rewards"].append(rewards[villager])
                        magent_obs[villager]["terms"].append(terminations[villager])

        ## Fill bigger buffer, keeping in mind sequence
        for agent in magent_obs:
            buffer.add_replay(magent_obs[agent])
    
    return buffer
        

env = plurality_env(num_agents=10, werewolves=2)
observations, rewards, terminations, truncations, infos = env.reset()

obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]
rec_agent = PluralityRecurrentAgent(num_actions=env.action_space("player_0").n, obs_size=obs_size)

def test_policy(obs, net, agent=None, env=None, a_rec=None, c_rec=None):
   # 
   return rec_agent.get_action_and_value(obs, a_rec, c_rec)

def test_recurrent_policy(obs, agent=None, env=None):
    # we need to return the hx and cx from the model, chyou know? and also have an initial one of 0 to feed the model the first time
    pass


buff = fill_recurrent_buffer(env, random_coordinated_wolf, test_policy, num_times=10)


In [125]:


env = plurality_env(num_agents=10, werewolves=2)
observations, rewards, terminations, truncations, infos = env.reset()
obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]
rec_agent = PluralityRecurrentAgent(num_actions=env.action_space("player_0").n, obs_size=obs_size)

# Testing passing a minibatch into this 

minibatch_gen = buff.get_minibatch_generator(32)
first_batch = next(minibatch_gen)


actions, logprobs, entropies, _, values, _ = rec_agent.get_action_and_value(first_batch['observations'], 
                                (first_batch['actor_hxs'], first_batch['actor_cxs']),
                                (first_batch['critic_hxs'], first_batch['critic_cxs']),
                                first_batch['actions']
                                )

print("helo World")


In [None]:
class PPOTrainer:
    def __init__(self, config:dict, run_id:str="run", device:torch.device=torch.device("cpu")) -> None:
        """Initializes all needed training components.
        Arguments:
            config {dict} -- Configuration and hyperparameters of the environment, trainer and model.
            run_id {str, optional} -- A tag used to save Tensorboard Summaries and the trained model. Defaults to "run".
            device {torch.device, optional} -- Determines the training device. Defaults to cpu.
        """
        # Set variables
        self.config = config
        self.recurrence = config["recurrence"]
        self.device = device
        self.run_id = run_id
        self.lr_schedule = config["learning_rate_schedule"]
        self.beta_schedule = config["beta_schedule"]
        self.cr_schedule = config["clip_range_schedule"]
        

        # setup mlflow run if we are using it
    def 
    def _get_mini_batch_loss():


In [50]:
# now that we have a buffer, lets try and train it
def calc_minibatch_loss(agent: PluralityRecurrentAgent, samples: dict, clip_range: float, beta: float, v_loss_coef: float, lr: float):

    # get new log probs need to pass in the recurrent cells as well for actor and critic
    _, logprobs, entropies, values = agent.get_action_and_value(observations, samples['actions'])
    ratio = torch.exp(logprobs - samples['logprobs'])

    # normalize advantages
    norm_advantage = (samples["advantages"] - samples["advantages"].mean()) / (samples["advantages"].std() + 1e-8)
    # normalized_advantage = normalized_advantage.unsqueeze(1).repeat(1, len(self.action_space_shape)) # Repeat is necessary for multi-discrete action spaces

    # policy loss w/ surrogates
    surr1 = norm_advantage * ratio
    surr2 = norm_advantage * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    policy_loss = torch.min(surr1, surr2)
    policy_loss = policy_loss.mean()

    # Value  function loss
    clipped_values = samples["values"] + (values - samples["values"]).clamp(min=-clip_range, max=clip_range)
    vf_loss = torch.max((values - samples['returns']) ** 2, (clipped_values - samples["returns"]) ** 2)
    vf_loss = vf_loss.mean()

    # Entropy Bonus
    entropy_loss = entropies.mean()

    # Complete loss
    loss = -(policy_loss - v_loss_coef * vf_loss + beta * entropy_loss)


    # Compute gradients
    for pg in self.optimizer.param_groups:
        pg["lr"] = lr
    self.optimizer.zero_grad()
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.config["max_grad_norm"])
    self.optimizer.step()

    return [policy_loss.cpu().data.numpy(),
            vf_loss.cpu().data.numpy(),
            loss.cpu().data.numpy(),
            entropy_loss.cpu().data.numpy()]
