# Small Nimm 6 tournament

In [1]:
import numpy as np
import logging
import sys
import torch
import pickle
from tqdm import tqdm
from matplotlib import pyplot as plt

sys.path.append("../")

from rl_6_nimmt import Tournament, GameSession
from rl_6_nimmt.agents import Human, DrunkHamster, BatchedACERAgent, Noisy_D3QN_PRB_NStep, MCSAgent, PolicyMCSAgent, PUCTAgent

logging.basicConfig(format="%(message)s",level=logging.INFO)
for name in logging.root.manager.loggerDict:
    if not "rl_6_nimmt" in name:
        logging.getLogger(name).setLevel(logging.WARNING)


## Agents

In [2]:
agents = {}
    
agents[f"Random"] = DrunkHamster()
agents[f"D3QN"] = Noisy_D3QN_PRB_NStep(history_length=int(1e5), n_steps=10)
agents[f"ACER"] = BatchedACERAgent(minibatch=10)
agents[f"MCS"] = MCSAgent(mc_max=100)
agents[f"PolicyMCS"] = PolicyMCSAgent(mc_max=100)
agents[f"AlphaAlmostZero"] = PUCTAgent(mc_max=100)

for agent in agents.values():
    try:
        agent.train()
    except:
        pass
    
baseline_agents = [DrunkHamster()]

merle = Human("Merle")


In [3]:
tournament = Tournament(min_players=2, max_players=2, baseline_agents=baseline_agents)

for name, agent in agents.items():
    tournament.add_player(name, agent)

print(tournament)

Tournament after 0 games:
----------------------------------------------------------------------------------------------------
 Agent                | Games | Tournament score | Tournament wins | Baseline score | Baseline wins 
----------------------------------------------------------------------------------------------------
               Random |     0 |                - |               - |              - |             - 
                 D3QN |     0 |                - |               - |              - |             - 
                 ACER |     0 |                - |               - |              - |             - 
                  MCS |     0 |                - |               - |              - |             - 
            PolicyMCS |     0 |                - |               - |              - |             - 
      AlphaAlmostZero |     0 |                - |               - |              - |             - 
-----------------------------------------------------------------

## Load existing state (skip when running this for the first time)

In [4]:
# agents, tournament = pickle.load(open("./.tournament.pickle", "rb"))

# print(tournament)

## Let the games begin

In [None]:
num_games = 10000
block_len = 100

try:
    tqdm._instances.clear()  # Important after cancelling any step
except:
    pass

while tournament.total_games < num_games:
    for _ in tqdm(range(block_len)):
        tournament.play_game()
    print(tournament)
        
    if tournament.total_games < num_games:
        tournament.evolve(max_players=8, max_per_descendant=3, copies=(2,2))


  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
pickle.dump((agents, tournament), open( "./.tournament.pickle", "wb" ))

## Let's see the results

In [None]:
print(tournament)

In [None]:
def plot_running_mean(x, y, n=10, **kwargs):
    cumsum = np.cumsum(np.insert(y, 0, 0))
    running_mean_y = (cumsum[n:] - cumsum[:-n]) / n
    
    cumsum = np.cumsum(np.insert(x, 0, 0))
    running_mean_x = (cumsum[n:] - cumsum[:-n]) / n
    
    plt.plot(running_mean_x, running_mean_y, **kwargs)


In [None]:
labels = [
    "Mean Hornochsen in tournament",
    "Mean Hornochsen vs baseline",
    "Win fraction in tournament",
    "Win fraction vs baseline"
]
quantities = [
    tournament.tournament_scores,
    tournament.baseline_scores,
    tournament.tournament_wins,
    tournament.baseline_wins
]
fns = [lambda x : -x, lambda x : -x, lambda x : x, lambda x : x]
baselines = [False, True, False, True]

colors = {
    "DQN": "C0",
    "ACER": "C1",
    "MCS": "C2",
    "AlphaAlmostZero": "C3",
    "Random": "0.6",
}
lss = {
    "DQN": "-",
    "ACER": "-",
    "MCS": "-",
    "AlphaAlmostZero": "-",
    "Random": "-",
}

In [None]:
fig = plt.figure(figsize=(10,10))

for panel, (quantity, label, fn, baseline) in enumerate(zip(quantities, labels, fns, baselines)):
    ax = plt.subplot(2,2,panel + 1)
    
    for name in tournament.agents.keys():
        y = np.array(quantity[name])
        x = 10 * np.arange(1, len(y) + 1) if baseline else np.arange(1, len(y) + 1)
        plot_running_mean(x, fn(y), n=200, label=name, color=colors[name], ls=lss[name], lw=1.5)
    
    if panel == 0:
        plt.legend(loc="center left")
        
    plt.xlabel("Played tournament games")
    plt.ylabel(label)

plt.tight_layout()
plt.savefig("tournament_results.pdf")


## Winner vs Merle

In [None]:
opponent = agents["AlphaAlmostZero"]
opponent.mc_max = 1000

session = GameSession(merle, opponent)

# The DEBUG output shows the moves, which is nice
for name in logging.root.manager.loggerDict:
    if "rl_6_nimmt" in name:
        logging.getLogger(name).setLevel(logging.DEBUG)
        

In [None]:
session.play_game(render=True)