# Small Nimmt 6 tournament

In [None]:
import numpy as np
import logging
import sys
import torch
import pickle
from tqdm import tqdm
from matplotlib import pyplot as plt

sys.path.append("../")

from rl_6_nimmt import Tournament, GameSession
from rl_6_nimmt.agents import Human, DrunkHamster, MaskedReinforceAgent, BatchedReinforceAgent, BatchedACERAgent

logging.basicConfig(format="%(message)s",level=logging.INFO)
for name in logging.root.manager.loggerDict:
    if not "rl_6_nimmt" in name:
        logging.getLogger(name).setLevel(logging.WARNING)


## Agents

In [None]:
agents = {}

for i in range(1):
    agents[f"Batched REINFORCE {i+1}"] = BatchedReinforceAgent(r_factor=0.1)
    agents[f"Batched REINFORCE {i+1}"].train()
    
for i in range(1):
    agents[f"Masked REINFORCE {i+1}"] = MaskedReinforceAgent(r_factor=0.1)
    agents[f"Masked REINFORCE {i+1}"].train()
    
for i in range(2):
    agents[f"Batched ACER {i+1}"] = BatchedACERAgent()
    agents[f"Batched ACER {i+1}"].train()
    
for i in range(1):
    agents[f"Random {i+1}"] = DrunkHamster()
    
baseline_agents = [DrunkHamster()]

merle = Human("Merle")


In [None]:
tournament = Tournament(min_players=2, max_players=2, baseline_agents=baseline_agents)

for name, agent in agents.items():
    tournament.add_player(name, agent)

print(tournament)

## Load existing state (skip when running this for the first time)

In [None]:
# agents, tournament = pickle.load(open("./.tournament.pickle", "rb"))

## Let the games begin

In [None]:
num_games = 200000
block_len = 10000

try:
    tqdm._instances.clear()  # Important after cancelling any step
except:
    pass

while tournament.total_games < num_games:
    for _ in tqdm(range(block_len)):
        tournament.play_game()
    print(tournament)


In [None]:
pickle.dump((agents, tournament), open( "./.tournament.pickle", "wb" ))

## Let's see the results

In [None]:
print(tournament)

In [None]:
def create_color(name):
    if "Batched REINFORCE" in name:
        return "C0"
    elif "Masked REINFORCE" in name:
        return "C1"
    elif "Batched ACER" in name:
        return "C2"
    elif "Random" in name:
        return "0.6"
    
    
def create_label(name):
    return name[:-2] if "1" in name else None


def plot_running_mean(x, y, n=10, **kwargs):
    cumsum = np.cumsum(np.insert(y, 0, 0))
    running_mean_y = (cumsum[n:] - cumsum[:-n]) / n
    
    cumsum = np.cumsum(np.insert(x, 0, 0))
    running_mean_x = (cumsum[n:] - cumsum[:-n]) / n
    
    plt.plot(running_mean_x, running_mean_y, **kwargs)


In [None]:
labels = [
    "Mean Hornochsen in tournament",
    "Mean Hornochsen vs baseline",
    "Win fraction in tournament",
    "Win fraction vs baseline"
]
quantities = [
    tournament.tournament_scores,
    tournament.baseline_scores,
    tournament.tournament_wins,
    tournament.baseline_wins
]
fns = [lambda x : -x, lambda x : -x, id, id]
baselines = [False, True, False, True]

In [None]:
fig = plt.figure(figsize=(8,8))

for panel, (quantity, label, fn, baseline) in enumerate(zip(quantities, labels, fns, baselines)):
    ax = plt.subplot(2,2,panel + 1)
    
    for name in tournament.agents.keys():
        y = np.array(quantity[name])
        x = 1000 * np.arange(1, len(y) + 1) if baseline else np.arange(1, len(y) + 1)
        plot_running_mean(x, fn(y), n = 10 if baseline else 1000, label=create_label(name), color=create_color(name))
    
    plt.legend(loc="upper left")
    plt.xlabel("Played tournament games")
    plt.ylabel(label)

plt.tight_layout()
plt.show()


## Winner vs Merle

In [None]:
session = GameSession(merle, tournament.winner())

# The DEBUG output shows the moves, which is nice
for name in logging.root.manager.loggerDict:
    if "rl_6_nimmt" in name:
        logging.getLogger(name).setLevel(logging.DEBUG)
        

In [None]:
session.play_game(render=True)