In [32]:
import numpy as np
import torch
import sys
sys.path.append('../')
from voting_games.werewolf_env_v0 import plurality_env, plurality_Phase, plurality_Role
import random
import copy
from typing import Any, Generator, Optional, Tuple
from tqdm import tqdm
import mlflow

In [2]:
env = plurality_env(num_agents=10, werewolves=2)
env.reset()

def random_coordinated_wolf(env, action=None):
    villagers_remaining = set(env.world_state["villagers"]) & set(env.world_state['alive'])
    wolves_remaining = set(env.world_state["werewolves"]) & set(env.world_state['alive'])

    target = random.choice(list(villagers_remaining))
    return {wolf: int(target.split("_")[-1]) for wolf in wolves_remaining}

def random_wolfs(env):
    return {wolf: env.action_space(wolf).sample() for
            wolf in set(env.world_state["werewolves"]) & set(env.world_state['alive'])}

def revenge_coordinated_wolf(env, actions = None):
    villagers_remaining = set(env.world_state["villagers"]) & set(env.world_state['alive'])
    wolves_remaining = set(env.world_state["werewolves"]) & set(env.world_state['alive'])

    # who tried to vote out a wolf last time?
    
    target = random.choice(list(villagers_remaining))
    # pick 
    for wolf in wolves_remaining:
        actions[wolf] = [0] * len(env.possible_agents)
        actions[wolf][int(target.split("_")[-1])] = -1
        for curr_wolf in wolves_remaining:
            actions[wolf][int(curr_wolf.split("_")[-1])] = 1
    # for wolf in env.werewolves_remaining:

def random_single_target_villager(env, agent):
    targets = set(env.world_state["alive"]) - set([agent])
    return int(random.choice(list(targets)).split("_")[-1])

# random_coordinated_wolf(env)
def random_agent_action(env, agent, action=None):
   return env.action_space(agent).sample()

def random_coordinated_single_wolf(env, agent, action=None):
    villagers_remaining = set(env.world_state["villagers"]) & set(env.world_state['alive'])
    return action if action != None else int(random.choice(list(villagers_remaining)).split("_")[-1])



In [3]:

def play_static_wolf_game(env, wolf_policy, villager_agent, num_times=100) -> tuple(plurality_Role):

    villager_wins = 0
    loop = tqdm(range(num_times))

    for _ in loop:
        next_observations, rewards, terminations, truncations, infos = env.reset()

        wolf_brain = {'day': 1, 'phase': 0, 'action': None}

        while env.agents:
            observations = copy.deepcopy(next_observations)

            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villager steps
                # villagers actions
            for villager in villagers:
                actions[villager] = villager_agent(env, villager)

            # at least one wolf
            day = observations[list(observations)[0]]['observation']['day']
            phase = observations[list(observations)[0]]['observation']['phase']
            
            if wolf_brain['day'] != day or wolf_brain['phase'] == plurality_Phase.NIGHT:
                wolf_brain = {'day': day, 'phase': phase, 'action': None}
            
            for wolf in wolves:
                action = wolf_policy(env, wolf, action=wolf_brain['action'])
                wolf_brain['action'] = action
                actions[wolf] = action
                
            # wolf steps
            # actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

        winner = env.world_state['winners']
        if winner == plurality_Role.VILLAGER:
            villager_wins += 1

        loop.set_description(f"Villagers won {villager_wins} out of a total of {num_times} games")

env = plurality_env(num_agents=10, werewolves=2)
env.reset()

play_static_wolf_game(env, random_coordinated_single_wolf, random_single_target_villager, num_times=1000)
play_static_wolf_game(env, random_coordinated_single_wolf, random_agent_action, num_times=1000)

# print("Random Coordinated Wolves")
# print("\t vs. Single Target Random Villagers")
# play_static_wolf_game(env, random_coordinated_wolf, random_single_target_villager, num_times=1000)
# print("\t vs. Random Villagers")
# play_static_wolf_game(env, random_coordinated_wolf, random_agent_action, num_times=1000)
# print("------------------------------------\n")
# print("Random Wolves")
# print("\t vs. Single Target Random Villagers")
# play_static_wolf_game(env, random_wolfs, random_single_target_villager, num_times=1000)
# print("\t vs. Random Villagers")
# play_static_wolf_game(env, random_wolfs, random_agent_action, num_times=1000)
# print("------------------------------------\n")

Villagers won 396 out of a total of 1000 games: 100%|██████████| 1000/1000 [00:04<00:00, 229.45it/s]
Villagers won 302 out of a total of 1000 games: 100%|██████████| 1000/1000 [00:04<00:00, 222.44it/s]


In [3]:
class PluralityAgent(torch.nn.Module):

    def __init__(self, num_actions, obs_size=None):

        self.critic = torch.nn.Sequential(
            self._layer_init(torch.nn.Linear(obs_size, 64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,1), std=1.0),
        )

        self.actor = torch.nn.Sequential(
            self._layer_init(torch.nn.Linear(obs_size, 64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64,64)),
            torch.nn.Tanh(),
            self._layer_init(torch.nn.Linear(64, num_actions), std=0.01),
        )
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, x):
        return self.critic(x)
    
    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)

        probs = torch.distributions.categorical.Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)

In [4]:
class PluralityRecurrentAgent(torch.nn.Module):

    def __init__(self, num_actions, obs_size=None, hidden_state_size=64):
        super().__init__()

        # actor
        self.a_recurrent_layer = self._rec_layer_init(torch.nn.LSTM(64, hidden_state_size, batch_first=True))
        self.a_fc1 = self._layer_init(torch.nn.Linear(obs_size,64))
        self.a_fc2 = self._layer_init(torch.nn.Linear(hidden_state_size,num_actions), std=0.01)

        # critic
        self.c_recurrent_layer = self._rec_layer_init(torch.nn.LSTM(64, hidden_state_size, batch_first=True))
        self.c_fc1 = self._layer_init(torch.nn.Linear(obs_size,64))
        self.c_fc2 = self._layer_init(torch.nn.Linear(hidden_state_size,1), std=1.0)
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def _rec_layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        for name, param in layer.named_parameters():
            if "bias" in name:
                torch.nn.init.constant_(param, bias_const)
            if "weight" in name:
                torch.nn.init.orthogonal_(param, std)
        return layer

    def get_value(self, x, recurrent_cell:torch.tensor):
        h = torch.tanh(self.c_fc1(x))
        h, recurrent_cell = self.c_recurrent_layer(torch.unsqueeze(h,1), recurrent_cell)
        h = torch.tanh(self.c_fc2(h))

        # UPDATE THIS X
        return h, recurrent_cell
    
    def get_action_and_value(self, x, a_recurrent_cell:torch.tensor, c_recurrent_cell:torch.tensor, action=None):
        h = torch.tanh(self.a_fc1(x))
        # we are strictly on a sequence length of 1 here, using prior information baked in
        h, recurrent_cell = self.a_recurrent_layer(torch.unsqueeze(h,1), a_recurrent_cell)
        h = torch.tanh(self.a_fc2(h))
        probs = torch.distributions.categorical.Categorical(logits=h)

        if action is None:
            action = probs.sample()

        c_val, c_rec = self.get_value(x, c_recurrent_cell)
        # UPDATE THIS X
        return action, probs.log_prob(action), probs.entropy(), recurrent_cell, c_val, c_rec
    
class PluralityRecurrentAgentv2(torch.nn.Module):

    def __init__(self, num_actions, obs_size=None, hidden_state_size=128):
        super().__init__()

        # actor
        self.a_recurrent_layer = self._rec_layer_init(torch.nn.LSTM(128, hidden_state_size, batch_first=True))
        self.a_fc1 = self._layer_init(torch.nn.Linear(obs_size, 128))
        self.a_fc2 = self._layer_init(torch.nn.Linear(hidden_state_size,num_actions), std=0.01)

        # critic
        self.c_recurrent_layer = self._rec_layer_init(torch.nn.LSTM(128, hidden_state_size, batch_first=True))
        self.c_fc1 = self._layer_init(torch.nn.Linear(obs_size,128))
        self.c_fc2 = self._layer_init(torch.nn.Linear(hidden_state_size,1), std=1.0)
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def _rec_layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        for name, param in layer.named_parameters():
            if "bias" in name:
                torch.nn.init.constant_(param, bias_const)
            if "weight" in name:
                torch.nn.init.orthogonal_(param, std)
        return layer

    def get_value(self, x, recurrent_cell:torch.tensor):
        h = torch.tanh(self.c_fc1(x))
        h, recurrent_cell = self.c_recurrent_layer(torch.unsqueeze(h,1), recurrent_cell)
        h = h.squeeze(1)
        h = torch.tanh(self.c_fc2(h))

        # UPDATE THIS X
        return h, recurrent_cell
    
    def get_action_and_value(self, x, a_recurrent_cell:torch.tensor, c_recurrent_cell:torch.tensor, action=None):
        h = torch.tanh(self.a_fc1(x))
        h, recurrent_cell = self.a_recurrent_layer(torch.unsqueeze(h,1), a_recurrent_cell)
        h = h.squeeze(1)
        h = torch.tanh(self.a_fc2(h))
        probs = torch.distributions.categorical.Categorical(logits=h)

        if action is None:
            action = probs.sample()

        c_val, c_rec = self.get_value(x, c_recurrent_cell)
        # UPDATE THIS X
        return action, probs.log_prob(action), probs.entropy(), recurrent_cell, c_val, c_rec


In [18]:
class RolloutBuffer():
    
    def __init__(self, 
                 buffer_size: int, 
                 gamma: float, 
                 gae_lambda: float,
                 is_recurrent: bool,
                 recurrent_size: int = None,
                 ):
        '''
            @bufffer_size: This is the number of trajectories
        '''
        self.steps = []

        
        self.rewards = None
        self.actions = None
        self.dones = None
        self.observations = None

        # do we want these for both actor and critic?
        self.actor_hcxs = None 
        self.critic_hcxs = None 


        self.log_probs = None
        self.values = None
        self.advantages = None

        self.buffer_size = buffer_size
        self.gamma = gamma 
        self.gae_lambda = gae_lambda
        self.is_recurrent = is_recurrent
        if self.is_recurrent:
            self.recurrent_size = recurrent_size
        
        self.reset()

    def reset(self):
        self.rewards = []
        self.actions = []
        self.dones = []
        self.observations = []

        # do we want these for both actor and critic?
        self.actor_hcxs = []
        self.critic_hcxs = []

        self.log_probs = []
        self.values = []
        self.advantages = []
        self.returns = []



    def add_replay(self, game) -> bool:
         
         self.rewards.append(game['rewards'])
         self.actions.append(game['actions'])
         self.dones.append(game["terms"])
         self.observations.append(game["obs"])
         self.log_probs.append(game["logprobs"])
         self.values.append(game["values"])
         self.actor_hcxs.append(game["hcxs"][:-1])
         self.critic_hcxs.append(game["c_hcxs"][:-1])
        
         advantages, returns = self._calculate_advantages(game)
             
         self.advantages.append(advantages)
         self.returns.append(returns)

         return True
    
    def _calculate_advantages(self, game):
        """Generalized advantage estimation (GAE)
        """
        advantages = torch.zeros_like(torch.tensor(game['rewards']))

        for t in reversed(range(len(game['rewards']))):
             delta = game['rewards'][t] + self.gamma * game['values'][max((t+1)%len(game['rewards']),t)] - game['values'][t]
             advantages[t] = delta + self.gamma * self.gae_lambda * advantages[max((t+1)%len(game['rewards']),t)]

        # adv and returns
        return advantages, advantages + torch.tensor(game['values'])
    
    def get_minibatch_generator(self, batch_size):

        # fold and stack observations
        actions = torch.cat([item for sublist in self.actions for item in sublist])
        logprobs = torch.cat([item for sublist in self.log_probs for item in sublist])
        returns = torch.cat(self.returns)
        values = torch.cat([item for sublist in self.values for item in sublist])
        advantages = torch.cat(self.advantages).float()
        actor_hxs, actor_cxs = zip(*[(hxs, cxs) for hxs, cxs in [item for sublist in self.actor_hcxs for item in sublist]])
        critic_hxs, critic_cxs = zip(*[(hxs, cxs) for hxs, cxs in [item for sublist in self.critic_hcxs for item in sublist]])
        observations = torch.cat([item for sublist in self.observations for item in sublist])

        index = np.arange(len(observations))

        np.random.shuffle(index)

        # We do not handle remaining stuff here
        for start in range(0,len(observations), batch_size):
            end = start + batch_size
            batch_index = index[start:end].astype(int)

            yield {
                "actions": actions[batch_index],
                "logprobs": logprobs[batch_index],
                "returns": returns[batch_index],
                "values": values[batch_index],
                "advantages": advantages[batch_index],
                # we are using sequence lengths of 1, because everything should be encoded in 
                "actor_hxs": torch.swapaxes(torch.cat(actor_hxs)[batch_index],0,1),
                "actor_cxs": torch.swapaxes(torch.cat(actor_cxs)[batch_index],0,1),
                "critic_hxs": torch.swapaxes(torch.cat(critic_hxs)[batch_index],0,1),
                "critic_cxs": torch.swapaxes(torch.cat(critic_cxs)[batch_index],0,1),
                "observations": observations[batch_index]
            }



In [6]:
@torch.no_grad()
def fill_recurrent_buffer(env, wolf_policy, villager_agent, num_times=10, hidden_state_size=None) -> RolloutBuffer:

    buffer = RolloutBuffer(buffer_size=10, 
                           gamma=0.99, 
                           gae_lambda=0.95,
                           is_recurrent=True)
    buffer.reset()
    
    for _ in range(num_times):
        ## Play the game 
        next_observations, rewards, terminations, truncations, infos = env.reset()
        # init recurrent stuff for actor and critic to 0 as well
        magent_obs = {agent: {'obs': [], 
                              'rewards': [], 
                              'actions': [], 
                              'logprobs': [], 
                              'values': [], 
                              'terms': [],

                              # obs size, and 1,1,64 as we pass batch first
                              'a_hcxs': [(torch.zeros((1,1,hidden_state_size), dtype=torch.float32), torch.zeros((1,1,hidden_state_size), dtype=torch.float32))],
                              'c_hcxs': [(torch.zeros((1,1,hidden_state_size), dtype=torch.float32), torch.zeros((1,1,hidden_state_size), dtype=torch.float32))]
                    } for agent in env.agents if not env.agent_roles[agent]}
        
        wolf_brain = {'day': 1, 'phase': 0, 'action': None}
        while env.agents:
            observations = copy.deepcopy(next_observations)
            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villager steps
                # villagers actions
            for villager in villagers:
                #torch.tensor(env.convert_obs(observations['player_0']['observation']), dtype=torch.float)
                torch_obs = torch.tensor(env.convert_obs(observations[villager]['observation']), dtype=torch.float)
                obs = torch.unsqueeze(torch_obs, 0)

                # TODO: Testing this, we may need a better way to pass in villagers
                actor_recurrent_cell = magent_obs[villager]["a_hcxs"][-1]
                critic_recurrent_cell = magent_obs[villager]["c_hcxs"][-1]
                
                # ensure that the obs is of size (batch,seq,inputs)
                ml_action,  logprobs, _, actor_recurrent_cell, c_val, critic_recurrent_cell = villager_agent.get_action_and_value(obs, actor_recurrent_cell, critic_recurrent_cell)
                actions[villager] = ml_action.item()

                # can store some stuff 
                magent_obs[villager]["obs"].append(obs)
                magent_obs[villager]["actions"].append(ml_action)
                magent_obs[villager]["logprobs"].append(logprobs)
                magent_obs[villager]["values"].append(c_val)

                #store the next recurrent cells
                magent_obs[villager]["a_hcxs"].append(actor_recurrent_cell)
                magent_obs[villager]["c_hcxs"].append(critic_recurrent_cell)


            # wolf steps
            day = observations[list(observations)[0]]['observation']['day']
            phase = observations[list(observations)[0]]['observation']['phase']

            if wolf_brain['day'] != day or wolf_brain['phase'] == plurality_Phase.NIGHT:
                wolf_brain = {'day': day, 'phase': phase, 'action': None}

            for wolf in wolves:
                action = wolf_policy(env, wolf, action=wolf_brain['action'])
                wolf_brain['action'] = action
                actions[wolf] = action

            # actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

            for villager in villagers:
                magent_obs[villager]["rewards"].append(rewards[villager])
                magent_obs[villager]["terms"].append(terminations[villager])

        ## Fill bigger buffer, keeping in mind sequence
        for agent in magent_obs:
            buffer.add_replay(magent_obs[agent])
    
    return buffer
        

# env = plurality_env(num_agents=10, werewolves=2)
# observations, rewards, terminations, truncations, infos = env.reset()

# obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]
# rec_agent = PluralityRecurrentAgent(num_actions=env.action_space("player_0").n, obs_size=obs_size)

# def test_policy(obs, net, agent=None, env=None, a_rec=None, c_rec=None):
#    # 
#    return rec_agent.get_action_and_value(obs, a_rec, c_rec)

# def test_recurrent_policy(obs, agent=None, env=None):
#     # we need to return the hx and cx from the model, chyou know? and also have an initial one of 0 to feed the model the first time
#     pass


# buff = fill_recurrent_buffer(env, random_coordinated_wolf, test_policy, num_times=10)


In [7]:
def calc_minibatch_loss(agent: PluralityRecurrentAgent, samples: dict, clip_range: float, beta: float, v_loss_coef: float, optimizer):

    # TODO:Consider checking for NAans anywhere. we cant have these. also do this in the model itself
    # if torch.isnan(tensor).any(): print(f"{label} contains NaN values")

    # get new log probs need to pass in the recurrent cells as well for actor and critic
    _, logprobs, entropies, _, values, _ = agent.get_action_and_value(samples['observations'], 
                                (samples['actor_hxs'], samples['actor_cxs']),
                                (samples['critic_hxs'], samples['critic_cxs']),
                                samples['actions']
                                )
    
    ratio = torch.exp(logprobs - samples['logprobs'])

    # normalize advantages
    norm_advantage = (samples["advantages"] - samples["advantages"].mean()) / (samples["advantages"].std() + 1e-8)
    # normalized_advantage = normalized_advantage.unsqueeze(1).repeat(1, len(self.action_space_shape)) # Repeat is necessary for multi-discrete action spaces

    # policy loss w/ surrogates
    surr1 = norm_advantage * ratio
    surr2 = norm_advantage * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    policy_loss = torch.min(surr1, surr2)
    policy_loss = policy_loss.mean()

    # Value  function loss
    clipped_values = samples["values"] + (values - samples["values"]).clamp(min=-clip_range, max=clip_range)
    vf_loss = torch.max((values - samples['returns']) ** 2, (clipped_values - samples["returns"]) ** 2)
    vf_loss = vf_loss.mean()

    # Entropy Bonus
    entropy_loss = entropies.mean()

    # Complete loss
    loss = -(policy_loss - v_loss_coef * vf_loss + beta * entropy_loss)


    # TODO : do i reset the LR here? do I want to?

    
    # Compute gradients
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=0.5)
    optimizer.step()

    return [policy_loss.cpu().data.numpy(),
            vf_loss.cpu().data.numpy(),
            loss.cpu().data.numpy(),
            entropy_loss.cpu().data.numpy()]

In [8]:
@torch.no_grad()
def play_recurrent_game(env, wolf_policy, villager_agent, num_times=10, hidden_state_size=None):
    
    wins = 0
    loop = tqdm(range(num_times))
    for _ in loop:
        ## Play the game 
        next_observations, rewards, terminations, truncations, infos = env.reset()
        # init recurrent stuff for actor and critic to 0 as well
        magent_obs = {agent: {'obs': [], 
                              # obs size, and 1,1,64 as we pass batch first
                              'a_hcxs': [(torch.zeros((1,1,hidden_state_size), dtype=torch.float32), torch.zeros((1,1,hidden_state_size), dtype=torch.float32))],
                              'c_hcxs': [(torch.zeros((1,1,hidden_state_size), dtype=torch.float32), torch.zeros((1,1,hidden_state_size), dtype=torch.float32))]
                    } for agent in env.agents if not env.agent_roles[agent]}
        

        wolf_brain = {'day': 1, 'phase': 0, 'action': None}

        while env.agents:
            observations = copy.deepcopy(next_observations)
            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villagers actions
            for villager in villagers:
                #torch.tensor(env.convert_obs(observations['player_0']['observation']), dtype=torch.float)
                torch_obs = torch.tensor(env.convert_obs(observations[villager]['observation']), dtype=torch.float)
                obs = torch.unsqueeze(torch_obs, 0)

                # TODO: Testing this, we may need a better way to pass in villagers
                actor_recurrent_cell = magent_obs[villager]["a_hcxs"][-1]
                critic_recurrent_cell = magent_obs[villager]["c_hcxs"][-1]
                
                # ensure that the obs is of size (batch,seq,inputs)
                ml_action,  logprobs, _, actor_recurrent_cell, c_val, critic_recurrent_cell = villager_agent.get_action_and_value(obs, actor_recurrent_cell, critic_recurrent_cell)
                actions[villager] = ml_action.item()

                #store the next recurrent cells
                magent_obs[villager]["a_hcxs"].append(actor_recurrent_cell)
                magent_obs[villager]["c_hcxs"].append(critic_recurrent_cell)

            # wolf steps
            day = observations[list(observations)[0]]['observation']['day']
            phase = observations[list(observations)[0]]['observation']['phase']
            
            if wolf_brain['day'] != day or wolf_brain['phase'] == plurality_Phase.NIGHT:
                wolf_brain = {'day': day, 'phase': phase, 'action': None}
            
            for wolf in wolves:
                action = wolf_policy(env, wolf, action=wolf_brain['action'])
                wolf_brain['action'] = action
                actions[wolf] = action

            # actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

        ## Fill bigger buffer, keeping in mind sequence
        winner = env.world_state['winners']
        if winner == plurality_Role.VILLAGER:
            wins += 1

        loop.set_description(f"Villagers won {wins} out of a total of {num_times} games")
    
    return wins

In [13]:

def calc_minibatch_loss(agent: PluralityRecurrentAgent, samples: dict, clip_range: float, beta: float, v_loss_coef: float, optimizer):

    # TODO:Consider checking for NAans anywhere. we cant have these. also do this in the model itself
    # if torch.isnan(tensor).any(): print(f"{label} contains NaN values")

    # get new log probs need to pass in the recurrent cells as well for actor and critic
    _, logprobs, entropies, _, values, _ = agent.get_action_and_value(samples['observations'], 
                                (samples['actor_hxs'], samples['actor_cxs']),
                                (samples['critic_hxs'], samples['critic_cxs']),
                                samples['actions']
                                )
    
    ratio = torch.exp(logprobs - samples['logprobs'])

    # normalize advantages
    norm_advantage = (samples["advantages"] - samples["advantages"].mean()) / (samples["advantages"].std() + 1e-8)
    # normalized_advantage = normalized_advantage.unsqueeze(1).repeat(1, len(self.action_space_shape)) # Repeat is necessary for multi-discrete action spaces

    # policy loss w/ surrogates
    surr1 = norm_advantage * ratio
    surr2 = norm_advantage * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    policy_loss = torch.min(surr1, surr2)
    policy_loss = policy_loss.mean()

    # Value  function loss
    clipped_values = samples["values"] + (values - samples["values"]).clamp(min=-clip_range, max=clip_range)
    vf_loss = torch.max((values - samples['returns']) ** 2, (clipped_values - samples["returns"]) ** 2)
    vf_loss = vf_loss.mean()

    # Entropy Bonus
    entropy_loss = entropies.mean()

    # Complete loss
    loss = -(policy_loss - v_loss_coef * vf_loss + beta * entropy_loss)


    # TODO : do i reset the LR here? do I want to?

    
    # Compute gradients
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=0.5)
    optimizer.step()

    return [policy_loss.cpu().data.numpy(),
            vf_loss.cpu().data.numpy(),
            loss.cpu().data.numpy(),
            entropy_loss.cpu().data.numpy()]

### Some 
CLIP_RANGE = 0.1
BETA = 0.1
V_LOSS_COEF = 0.1
BATCH_SIZE = 256
TRAIN_LOOPS = 1000
EPOCHS = 3
GAMES_PER_EPOCH = 100
HIDDEN_STATE_SIZE=256


env = plurality_env(num_agents=10, werewolves=2)
observations, rewards, terminations, truncations, infos = env.reset()
obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]
train_agent = PluralityRecurrentAgentv2(num_actions=env.action_space("player_0").n, obs_size=obs_size,hidden_state_size=HIDDEN_STATE_SIZE)
optimizer = torch.optim.Adam(train_agent.parameters(), lr=0.0005, eps=1e-5)
# Testing passing a minibatch into this 

train_info = []

for tid in range(TRAIN_LOOPS):
   # train 100 times
   for epid in range(EPOCHS):
      if tid % 10 == 0 and epid == 0:
         # print(f'Playing games with our trained agent after {epid} epochs')
         wins = play_recurrent_game(env, random_coordinated_single_wolf, train_agent, num_times=100, hidden_state_size=HIDDEN_STATE_SIZE)

      # fill buffer
      buff = fill_recurrent_buffer(env, random_coordinated_single_wolf, train_agent, num_times=GAMES_PER_EPOCH, hidden_state_size=HIDDEN_STATE_SIZE)

      # run through batches and train network
      for batch in buff.get_minibatch_generator(BATCH_SIZE):
         train_info.append(calc_minibatch_loss(train_agent, batch, clip_range=CLIP_RANGE, beta=BETA, v_loss_coef=V_LOSS_COEF, optimizer=optimizer))

train_stats = np.mean(train_info, axis=0)
print(train_stats)

torch.save(train_agent, "rnn_agent_3")


Villagers won 33 out of a total of 100 games: 100%|██████████| 100/100 [00:09<00:00, 10.49it/s]
Villagers won 28 out of a total of 100 games: 100%|██████████| 100/100 [00:24<00:00,  4.13it/s]
Villagers won 30 out of a total of 100 games: 100%|██████████| 100/100 [00:07<00:00, 13.68it/s]
Villagers won 31 out of a total of 100 games: 100%|██████████| 100/100 [00:07<00:00, 14.26it/s]
Villagers won 34 out of a total of 100 games: 100%|██████████| 100/100 [00:07<00:00, 14.08it/s]
Villagers won 33 out of a total of 100 games: 100%|██████████| 100/100 [00:07<00:00, 13.78it/s]
Villagers won 23 out of a total of 100 games: 100%|██████████| 100/100 [00:07<00:00, 14.29it/s]
Villagers won 37 out of a total of 100 games: 100%|██████████| 100/100 [00:07<00:00, 14.01it/s]


ValueError: Expected parameter logits (Tensor of shape (1, 10)) of distribution Categorical(logits: torch.Size([1, 10])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]])

In [41]:
# after training is done

play_recurrent_game(env, random_coordinated_single_wolf, train_agent, num_times=1000, hidden_state_size=HIDDEN_STATE_SIZE)

Villagers won 299 out of a total of 1000 games: 100%|██████████| 1000/1000 [01:30<00:00, 11.03it/s]


299

## Agent with shared network

In [49]:
class PluralityRecurrentAgentv3(torch.nn.Module):
    def __init__(self, config:dict, num_actions, obs_size=None):
        super().__init__()

        # recurrent layer
        # TODO: Do I want 2 here?
        self.recurrent_layer = self._rec_layer_init(torch.nn.LSTM(obs_size, config['rec_hidden_size'], num_layers=config['rec_layers'], batch_first=True))

        # hidden layers
        self.fc_joint = self._layer_init(torch.nn.Linear(config['rec_hidden_size'], config['hidden_mlp_size']))
        self.policy_hidden = self._layer_init(torch.nn.Linear(config['hidden_mlp_size'], config['hidden_mlp_size']))
        self.value_hidden = self._layer_init(torch.nn.Linear(config['hidden_mlp_size'], config['hidden_mlp_size']))

        # policy output
        self.policy_out = self._layer_init(torch.nn.Linear(config['hidden_mlp_size'], num_actions), std=0.01)

        # value output
        self.value_out = self._layer_init(torch.nn.Linear(config['hidden_mlp_size'], 1), std=1.0)
    
    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        # torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def _rec_layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        for name, param in layer.named_parameters():
            # if "bias" in name:
                # torch.nn.init.constant_(param, bias_const)
            if "weight" in name:
                torch.nn.init.orthogonal_(param, std)
        return layer
    
    
    def forward(self, x, recurrent_cell: torch.tensor):

        # pass  through the Recurrence Layer
        h, recurrent_cell = self.recurrent_layer(torch.unsqueeze(x,1), recurrent_cell)
        h = torch.squeeze(h,1)

        # Pass through a hidden layer
        h = torch.relu(self.fc_joint(h))

        # Split for Value and Policy
        h_value = torch.relu(self.value_hidden(h))
        h_policy = torch.relu(self.policy_hidden(h))

        # value
        value = self.value_out(h_value).reshape(-1)

        # policy
        policy = self.policy_out(h_policy)
        policy = torch.distributions.Categorical(logits=policy)

        return policy, value, recurrent_cell
    

class RolloutBufferv3():
    
    def __init__(self, buffer_size: int, gamma: float, gae_lambda: float):
        '''
            @bufffer_size: This is the number of trajectories
        '''
        
        self.rewards = None
        self.actions = None
        self.dones = None
        self.observations = None

        # do we want these for both actor and critic?
        self.hcxs = None 

        self.log_probs = None
        self.values = None
        self.advantages = None

        self.buffer_size = buffer_size
        self.gamma = gamma 
        self.gae_lambda = gae_lambda

        self.reset(gamma=gamma, gae_lambda=gae_lambda)

    def reset(self, gamma: float, gae_lambda: float):
        self.rewards = []
        self.actions = []
        self.dones = []
        self.observations = []

        # do we want these for both actor and critic?
        self.hcxs = []

        self.log_probs = []
        self.values = []
        self.advantages = []
        self.returns = []

        self.gamma = gamma 
        self.gae_lambda = gae_lambda

    def add_replay(self, game) -> bool:
         
         self.rewards.append(game['rewards'])
         self.actions.append(game['actions'])
         self.dones.append(game["terms"])
         self.observations.append(game["obs"])
         self.log_probs.append(game["logprobs"])
         self.values.append(game["values"])
         self.hcxs.append(game["hcxs"][:-1])
        
         advantages, returns = self._calculate_advantages(game)
             
         self.advantages.append(advantages)
         self.returns.append(returns)

         return True
    
    @torch.no_grad()
    def _calculate_advantages(self, game):
        """Generalized advantage estimation (GAE)
        """
        advantages = torch.zeros_like(torch.tensor(game['rewards']))

        for t in reversed(range(len(game['rewards']))):
             delta = game['rewards'][t] + self.gamma * game['values'][max((t+1)%len(game['rewards']),t)] - game['values'][t]
             advantages[t] = delta + self.gamma * self.gae_lambda * advantages[max((t+1)%len(game['rewards']),t)]

        # adv and returns
        return advantages, advantages + torch.tensor(game['values'])
    
    def get_minibatch_generator(self, batch_size):

        # fold and stack observations
        actions = torch.cat([item for sublist in self.actions for item in sublist])
        logprobs = torch.cat([item for sublist in self.log_probs for item in sublist])
        returns = torch.cat(self.returns)
        values = torch.cat([item for sublist in self.values for item in sublist])
        advantages = torch.cat(self.advantages).float()

        # TODO : Gotta update these to work with a single set of hxs, rxs
        hxs, cxs = zip(*[(hxs, cxs) for hxs, cxs in [item for sublist in self.hcxs for item in sublist]])
        observations = torch.cat([item for sublist in self.observations for item in sublist])

        index = np.arange(len(observations))

        np.random.shuffle(index)

        # We do not handle remaining stuff here
        for start in range(0,len(observations), batch_size):
            end = start + batch_size
            batch_index = index[start:end].astype(int)

            yield {
                "actions": actions[batch_index],
                "logprobs": logprobs[batch_index],
                "returns": returns[batch_index],
                "values": values[batch_index],
                "advantages": advantages[batch_index],
                # we are using sequence lengths of 1, because everything should be encoded in 
                "hxs": torch.swapaxes(torch.cat(hxs)[batch_index],0,1),
                "cxs": torch.swapaxes(torch.cat(cxs)[batch_index],0,1),
                "observations": observations[batch_index]
            } 

@torch.no_grad()
def fill_recurrent_buffer(buffer, env, config:dict, wolf_policy, villager_agent) -> RolloutBuffer:

    buffer.reset(gamma=config["training"]["gamma"], gae_lambda=config["training"]["gae_lambda"])
    
    for _ in range(config["training"]["buffer_games_per_update"]):
        ## Play the game 
        next_observations, rewards, terminations, truncations, infos = env.reset()
        # init recurrent stuff for actor and critic to 0 as well
        magent_obs = {agent: {'obs': [], 
                              'rewards': [], 
                              'actions': [], 
                              'logprobs': [], 
                              'values': [], 
                              'terms': [],

                              # obs size, and 1,1,64 as we pass batch first
                              'hcxs': [(torch.zeros((1,1,config["model"]["recurrent_hidden_size"]), dtype=torch.float32), 
                                        torch.zeros((1,1,config["model"]["recurrent_hidden_size"]), dtype=torch.float32))]
                    } for agent in env.agents if not env.agent_roles[agent]}
        
        wolf_brain = {'day': 1, 'phase': 0, 'action': None}
        while env.agents:
            observations = copy.deepcopy(next_observations)
            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villager steps
                # villagers actions
            for villager in villagers:
                #torch.tensor(env.convert_obs(observations['player_0']['observation']), dtype=torch.float)
                torch_obs = torch.tensor(env.convert_obs(observations[villager]['observation']), dtype=torch.float)
                obs = torch.unsqueeze(torch_obs, 0)

                # TODO: Testing this, we may need a better way to pass in villagers
                recurrent_cell = magent_obs[villager]["hcxs"][-1]
                
                # ensure that the obs is of size (batch,seq,inputs)
                # this needs to be updated
                policy, value, recurrent_cell = villager_agent(obs, recurrent_cell)
                action = policy.sample()
                
                actions[villager] = action.item()

                # can store some stuff 
                magent_obs[villager]["obs"].append(obs)
                magent_obs[villager]["actions"].append(action)

                # how do we get these
                magent_obs[villager]["logprobs"].append(policy.log_prob(action))
                magent_obs[villager]["values"].append(value)

                #store the next recurrent cells
                magent_obs[villager]["hcxs"].append(recurrent_cell)


            # wolf steps
            day = observations[list(observations)[0]]['observation']['day']
            phase = observations[list(observations)[0]]['observation']['phase']

            if wolf_brain['day'] != day or wolf_brain['phase'] == plurality_Phase.NIGHT:
                wolf_brain = {'day': day, 'phase': phase, 'action': None}

            for wolf in wolves:
                action = wolf_policy(env, wolf, action=wolf_brain['action'])
                wolf_brain['action'] = action
                actions[wolf] = action

            # actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

            for villager in villagers:
                magent_obs[villager]["rewards"].append(rewards[villager])
                magent_obs[villager]["terms"].append(terminations[villager])

        ## Fill bigger buffer, keeping in mind sequence
        for agent in magent_obs:
            buffer.add_replay(magent_obs[agent])
    
    return buffer

@torch.no_grad()
def play_recurrent_game(env, wolf_policy, villager_agent, num_times=10, hidden_state_size=None):
    
    wins = 0
    # loop = tqdm(range(num_times))
    for _ in range(num_times):
        ## Play the game 
        next_observations, rewards, terminations, truncations, infos = env.reset()
        # init recurrent stuff for actor and critic to 0 as well
        magent_obs = {agent: {'obs': [], 
                              # obs size, and 1,1,64 as we pass batch first
                              'hcxs': [(torch.zeros((1,1,hidden_state_size), dtype=torch.float32), torch.zeros((1,1,hidden_state_size), dtype=torch.float32))],
                    } for agent in env.agents if not env.agent_roles[agent]}
        

        wolf_brain = {'day': 1, 'phase': 0, 'action': None}

        while env.agents:
            observations = copy.deepcopy(next_observations)
            actions = {}

            villagers = set(env.agents) & set(env.world_state["villagers"])
            wolves = set(env.agents) & set(env.world_state["werewolves"])

            # villagers actions
            for villager in villagers:
                #torch.tensor(env.convert_obs(observations['player_0']['observation']), dtype=torch.float)
                torch_obs = torch.tensor(env.convert_obs(observations[villager]['observation']), dtype=torch.float)
                obs = torch.unsqueeze(torch_obs, 0)

                # TODO: Testing this, we may need a better way to pass in villagers
                recurrent_cell = magent_obs[villager]["hcxs"][-1]
                
                # ensure that the obs is of size (batch,seq,inputs)
                policy, value, recurrent_cell = villager_agent(obs, recurrent_cell)
                action = policy.sample()
                
                actions[villager] = action.item()

                #store the next recurrent cells
                magent_obs[villager]["hcxs"].append(recurrent_cell)

            # wolf steps
            day = observations[list(observations)[0]]['observation']['day']
            phase = observations[list(observations)[0]]['observation']['phase']
            
            if wolf_brain['day'] != day or wolf_brain['phase'] == plurality_Phase.NIGHT:
                wolf_brain = {'day': day, 'phase': phase, 'action': None}
            
            for wolf in wolves:
                action = wolf_policy(env, wolf, action=wolf_brain['action'])
                wolf_brain['action'] = action
                actions[wolf] = action

            # actions = actions | wolf_policy(env)
        
            next_observations, rewards, terminations, truncations, infos = env.step(actions)

        ## Fill bigger buffer, keeping in mind sequence
        winner = env.world_state['winners']
        if winner == plurality_Role.VILLAGER:
            wins += 1

        # loop.set_description(f"Villagers won {wins} out of a total of {num_times} games")
    
    return wins

def calc_minibatch_loss(agent: PluralityRecurrentAgentv3, samples: dict, clip_range: float, beta: float, v_loss_coef: float, optimizer):

    # TODO:Consider checking for NAans anywhere. we cant have these. also do this in the model itself
    # if torch.isnan(tensor).any(): print(f"{label} contains NaN values")
    policies, values, _ = agent(samples['observations'], (samples['hxs'], samples['cxs']))
    
    # log_probs, entropies = [], []
    log_probs = policies.log_prob(samples['actions'])
    entropies = policies.entropy() # need to sum if we have more than 1 action
    
    ratio = torch.exp(log_probs - samples['logprobs'])

    # normalize advantages
    norm_advantage = (samples["advantages"] - samples["advantages"].mean()) / (samples["advantages"].std() + 1e-8)
    # normalized_advantage = normalized_advantage.unsqueeze(1).repeat(1, len(self.action_space_shape)) # Repeat is necessary for multi-discrete action spaces

    # policy loss w/ surrogates
    surr1 = norm_advantage * ratio
    surr2 = norm_advantage * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    policy_loss = torch.min(surr1, surr2)
    policy_loss = policy_loss.mean()

    # Value  function loss
    clipped_values = samples["values"] + (values - samples["values"]).clamp(min=-clip_range, max=clip_range)
    vf_loss = torch.max((values - samples['returns']) ** 2, (clipped_values - samples["returns"]) ** 2)
    vf_loss = vf_loss.mean()

    # Entropy Bonus
    entropy_loss = entropies.mean()

    # Complete loss
    loss = -(policy_loss - v_loss_coef * vf_loss + beta * entropy_loss)


    # TODO : do i reset the LR here? do I want to?

    
    # Compute gradients
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=0.5)
    optimizer.step()

    
    return [policy_loss.cpu().data.numpy(),     # policy loss
            vf_loss.cpu().data.numpy(),         # value loss
            loss.cpu().data.numpy(),            # total loss
            entropy_loss.cpu().data.numpy()]    # entropy loss

### Some 
# CLIP_RANGE = 0.2
# BETA = 0.1
# V_LOSS_COEF = 0.1
# BATCH_SIZE = 128
# TRAIN_LOOPS = 100
# EPOCHS = 6
# GAMES_PER_EPOCH = 200
# HIDDEN_STATE_SIZE=256


# env = plurality_env(num_agents=10, werewolves=2)
# observations, rewards, terminations, truncations, infos = env.reset()
# obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]
# nn_agent = PluralityRecurrentAgentv3({"rec_hidden_size": HIDDEN_STATE_SIZE, "rec_layers": 1, "hidden_mlp_size": 128},num_actions=env.action_space("player_0").n, obs_size=obs_size)
# optimizer = torch.optim.Adam(nn_agent.parameters(), lr=0.0001)
# # Testing passing a minibatch into this 

# for tid in range(TRAIN_LOOPS):
#     # train 100 times
#     if tid % 10 == 0:
#         # print(f'Playing games with our trained agent after {epid} epochs')
#         wins = play_recurrent_game(env, random_coordinated_single_wolf, nn_agent, num_times=100, hidden_state_size=HIDDEN_STATE_SIZE)

#     # fill buffer
#     buff = fill_recurrent_buffer(env, random_coordinated_single_wolf, nn_agent, num_times=GAMES_PER_EPOCH, hidden_state_size=HIDDEN_STATE_SIZE)

#     # train info will hold our metrics
#     train_info = []
#     for epid in range(EPOCHS):
#       # run through batches and train network
#       for batch in buff.get_minibatch_generator(BATCH_SIZE):
#          train_info.append(calc_minibatch_loss(nn_agent, batch, clip_range=CLIP_RANGE, beta=BETA, v_loss_coef=V_LOSS_COEF, optimizer=optimizer))

#     train_stats = np.mean(train_info, axis=0)

#     # we can store the 

    

# print(train_stats)

# torch.save(nn_agent, "rnn_agent_combined")

In [46]:
config_training = {
    "model": {
        "recurrent_layers": 1,
        "recurrent_hidden_size": 128, # 256
        "mlp_size": 128, # 256
    },
    "training" : {
        "batch_size": 32, # 128
        "epochs": 3, # 6
        "updates": 10, # 1000
        "buffer_games_per_update": 10, # 200
        "clip_range": 0.2,
        "value_loss_coefficient": 0.1,
        "max_grad_norm": 0.5,
        "beta": 0.01, # entropy loss multiplier
        "learning_rate": 0.0001,
        "adam_eps": 1e-8,
        "gamma": 0.99,
        "gae_lambda": 0.95,
    }
}

config_game = {
    "rewards": {
        "day": -1,
        "player_death": -1,
        "player_win": 10,
        "player_loss": -5,
        "self_vote": -1,
        "dead_vote": -1,
        "dead_wolf": 5,
        "no_viable_vote": -1,
        "no_sleep": -1,
    },
    "gameplay": {
        "accusation_phases": 1,
        "num_agents": 10,
        "num_werewolves": 2,
    }
}

config = {
    "config_game": config_game,
    "config_training": config_training,
}

In [51]:
class PPOTrainer:
    def __init__(self, config:dict, run_id:str="run", device:torch.device=torch.device("cpu"), mlflow_uri:str=None) -> None:
        """Initializes all needed training components.
        Arguments:
            config {dict} -- Configuration and hyperparameters of the environment, trainer and model.
            run_id {str, optional} -- A tag used to save Tensorboard Summaries and the trained model. Defaults to "run".
            device {torch.device, optional} -- Determines the training device. Defaults to cpu.
        """
        # Set variables
        self.config = config
        self.device = device
        self.run_id = run_id
        self.mlflow_uri = mlflow_uri
        self.env = None

        # we are not using schedules yet
        # self.lr_schedule = config["learning_rate_schedule"]
        # self.beta_schedule = config["beta_schedule"]
        # self.cr_schedule = config["clip_range_schedule"]

        # Initialize Environment
        env = plurality_env(num_agents=10, werewolves=2)
        self.env = env
        
        observations, rewards, terminations, truncations, infos = env.reset()
        obs_size= env.convert_obs(observations['player_0']['observation']).shape[-1]

        # Initialize Buffer
        self.buffer = RolloutBufferv3(buffer_size=10, gamma=0.99, gae_lambda=0.95)

        # Initialize Model & Optimizer
        self.agent = PluralityRecurrentAgentv3({"rec_hidden_size": self.config["config_training"]["model"]["recurrent_hidden_size"], 
                                                "rec_layers": self.config["config_training"]["model"]["recurrent_layers"], 
                                                "hidden_mlp_size": self.config["config_training"]["model"]["mlp_size"]},
                                                num_actions=self.env.action_space("player_0").n,
                                                obs_size=obs_size)
        self.optimizer = torch.optim.Adam(self.agent.parameters(), lr=0.0001, eps=1e-5)

        # setup mlflow run if we are using it

    def train(self, idx: int):
        if self.mlflow_uri:
            mlflow.set_tracking_uri(self.mlflow_uri)

        name = f'{self.run_id}_{idx}'
        with mlflow.start_run(run_name=name):
            
            mlflow.log_params(self.config["config_training"]["training"])
            mlflow.log_params(self.config["config_training"]["model"])

            loop = tqdm(range(self.config["config_training"]["training"]["updates"]))

            for tid, _ in enumerate(loop):
                # train 100 times
                if tid % 2 == 0:
                    # print(f'Playing games with our trained agent after {epid} epochs')
                    loop.set_description("Playing games and averaging score")
                    wins = []
                    for _ in range(10):
                        wins.append(play_recurrent_game(self.env, 
                                                        random_coordinated_single_wolf, 
                                                        self.agent, 
                                                        num_times=50,
                                                        hidden_state_size=self.config["config_training"]["model"]["recurrent_hidden_size"]))
                    
                    mlflow.log_metric("avg_wins/50", np.mean(wins))

                loop.set_description("Filling buffer")
                # fill buffer
                buff = fill_recurrent_buffer(self.buffer, 
                                             self.env,
                                             self.config["config_training"],
                                             random_coordinated_single_wolf, 
                                             self.agent)

                # train info will hold our metrics
                train_info = []
                loop.set_description("Epoch Training")
                for _ in range(self.config['config_training']["training"]['epochs']):
                    # run through batches and train network
                    for batch in buff.get_minibatch_generator(self.config['config_training']["training"]['batch_size']):
                        train_info.append(calc_minibatch_loss(self.agent, 
                                                              batch, 
                                                              clip_range=self.config['config_training']["training"]['clip_range'], 
                                                              beta=self.config['config_training']["training"]['beta'], 
                                                              v_loss_coef=self.config['config_training']["training"]['value_loss_coefficient'], 
                                                              optimizer=self.optimizer))

                train_stats = np.mean(train_info, axis=0)
                mlflow.log_metric("policy loss", train_stats[0])
                mlflow.log_metric("value loss", train_stats[1])
                mlflow.log_metric("total loss", train_stats[2])
                mlflow.log_metric("entropy loss", train_stats[3])
            # one more run

        # torch.save(self.agent, f"rnn_agent_{self.run_id}")


trainer = PPOTrainer(config=config,run_id="plu", mlflow_uri="http://mlflow:5000")
trainer.train(1)

Epoch Training: 100%|██████████| 10/10 [02:56<00:00, 17.65s/it]                  


In [53]:
import time

www