## Training our Ultimate Werewolf Agents

Our env currently has agents playing as werewolves or villagers.

We will want to try training just the villager agents, just the werewolf agents, and both at the same time.


We also want to explore different training environments


We will start by installing some additional dependencies, just for this purpose

In [None]:
%pip install git+https://github.com/WillDudley/tianshou.git

### Setting up baselines

We want to have something to compare our agents to. For this we will run 

In [12]:
import sys
sys.path.append('../')
from voting_games.werewolf_env_v0 import raw_env
import random
from tqdm import tqdm

In [7]:
def random_policy(observation, agent):
    # these are the other wolves. we cannot vote for them either
    available_actions = list(range(len(observation['observation']['player_status'])))
    # dead players
    action_mask = observation['action_mask']

    legal_actions = [action for action,is_alive,is_wolf in zip(available_actions, action_mask, observation['observation']['roles']) if is_alive and not is_wolf]
    # wolves don't vote for other wolves. will select another villager at random
    action = random.choice(legal_actions)
    return action

In [16]:
# Games with 5 players
five_player_env = raw_env(num_agents=5, werewolves=1)

avg_game_length = 0
wolf_wins = 0
villager_wins = 0

num_games = 100
five_player_env.reset()

for _ in tqdm(range(num_games)):

    for agent in five_player_env.agent_iter():
        observation, reward, termination, truncation, info = five_player_env.last()
        action = random_policy(observation, agent) if not termination or truncation else None
        five_player_env.step(action)
    
    # get some stats
    winner = five_player_env.world_state['winners']
    day = five_player_env.world_state['day']

    if winner:
        wolf_wins += 1
    else:
        villager_wins += 1
    
    avg_game_length += (day * 1.0)/num_games 

    # reset 
    five_player_env.reset()

print(f'Average game length = {avg_game_length:.2f}')
print(f'Wolf wins : {wolf_wins}')
print(f'Villager wins: {villager_wins}')




100%|██████████| 100/100 [00:00<00:00, 5237.51it/s]

Average game length = 1.91
Wolf wins : 79
Villager wins: 21





In [17]:
# Games with 10 players
ten_player_env = raw_env(num_agents=10, werewolves=2)

avg_game_length = 0
wolf_wins = 0
villager_wins = 0

num_games = 100

ten_player_env.reset()

for _ in tqdm(range(num_games)):

    for agent in ten_player_env.agent_iter():
        observation, reward, termination, truncation, info = ten_player_env.last()
        action = random_policy(observation, agent) if not termination or truncation else None
        ten_player_env.step(action)
    
    # get some stats
    winner = ten_player_env.world_state['winners']
    day = ten_player_env.world_state['day']

    if winner:
        wolf_wins += 1
    else:
        villager_wins += 1
    
    avg_game_length += (day * 1.0)/num_games 

    # reset 
    ten_player_env.reset()

print(f'Average game length = {avg_game_length:.2f}')
print(f'Wolf wins : {wolf_wins}')
print(f'Villager wins: {villager_wins}')

100%|██████████| 100/100 [00:00<00:00, 1418.17it/s]

Average game length = 4.26
Wolf wins : 93
Villager wins: 7





In [18]:
# Games with 20 players
twenty_player_env = raw_env(num_agents=20, werewolves=3)

avg_game_length = 0
wolf_wins = 0
villager_wins = 0

num_games = 100

twenty_player_env.reset()

for _ in tqdm(range(num_games)):

    for agent in twenty_player_env.agent_iter():
        observation, reward, termination, truncation, info = twenty_player_env.last()
        action = random_policy(observation, agent) if not termination or truncation else None
        twenty_player_env.step(action)
    
    # get some stats
    winner = twenty_player_env.world_state['winners']
    day = twenty_player_env.world_state['day']

    if winner:
        wolf_wins += 1
    else:
        villager_wins += 1
    
    avg_game_length += (day * 1.0)/num_games 

    # reset 
    twenty_player_env.reset()

print(f'Average game length = {avg_game_length:.2f}')
print(f'Wolf wins : {wolf_wins}')
print(f'Villager wins: {villager_wins}')

100%|██████████| 100/100 [00:00<00:00, 287.80it/s]

Average game length = 8.60
Wolf wins : 91
Villager wins: 9



