In [None]:
import gymnasium as gym
import torchinfo
import tqdm 
import copy
import numpy as np 
from importlib import reload
import hnefatafl_utils as hu 
import model as mdl
import torch
import piece_sampling
import hnefatafl_env
from memory import ReplayMemory
from optimization import optimize_batch

NUM_ENVS = 32
BATCH_SIZE = 4
BOARD = hu.GAME_HNEFATAFL
BOARD_SIDELENGTH = BOARD.shape[1]

envs = gym.vector.AsyncVectorEnv(
    env_fns=[lambda: hnefatafl_env.HnefataflEnv(initial_board=BOARD)
                    for _ in range(NUM_ENVS)],
    shared_memory=False,
)

device = 'cpu'
model_royal = mdl.Policy(board_sidelength=BOARD_SIDELENGTH).to(device)
model_attackers = mdl.Policy(board_sidelength=BOARD_SIDELENGTH).to(device)
model_royal_target_estimation = mdl.Policy(board_sidelength=BOARD_SIDELENGTH).to(device)
model_attackers_target_estimation = mdl.Policy(board_sidelength=BOARD_SIDELENGTH).to(device)
torch.set_default_device(device)
memory_royal = ReplayMemory()
memory_attackers = ReplayMemory()
optimizer_attacker = torch.optim.Adam(model_attackers.parameters(), lr=0.01)
optimizer_royal = torch.optim.Adam(model_royal.parameters(), lr=0.01)

# Input has shape (batch_size, sidelength, sidelength, 3)
# 3 because of 3 possible pieces (attacker, defender, king), denoted as bools
torchinfo.summary(model_royal, input_size=(BATCH_SIZE * NUM_ENVS, BOARD_SIDELENGTH, BOARD_SIDELENGTH, 3))

In [None]:
# Train the models using self-play
rewards = []
losses = []
BASE_TEMPERATURE = 100
EPOCHS = 50
UPDATE_POLICY_STEPS = 8 # update policy every X steps
UPDATE_TARGET_ESTIMATOR_STEPS = 32 # update target estimation net every X steps
GAMMA = 0.9 # base discount factor; will get higher as training progresses

for epoch in range(EPOCHS):
    # Get initial observation
    print(f'#### EPOCH {epoch} ####')
    obs, info = envs.reset()
    terminated = False
    
    cum_reward = 0
    n_attack_wins = 0
    n_defends_win = 0
    # Hard reset after 1000 moves from both sides
    pbar = tqdm.trange(1000)
    for i in pbar:
        # temperature schedule: more exploration towards begin of epoch
        temperature = np.max([1, BASE_TEMPERATURE - (i/10) * 2])
        
        # ##################
        # Attacker begins
        # ##################
        board = torch.tensor(obs, dtype=torch.float32)
        from_piece_q, to_piece_q = model_attackers.forward(board)
        actions = piece_sampling.generate_actions_from_q(
            from_piece_q.detach().clone(), to_piece_q.detach().clone(), board, hu.TEAM_ATTACKER, temperature
        )
        next_state, reward, terminated, truncated, info = envs.step(actions)
        memory_attackers.remember(obs, actions, reward, next_state, terminated, obs, np.zeros_like(reward))
        memory_royal.update_memory_after_opponent_action(next_state, reward)
        obs = next_state
        cum_reward += np.sum(reward.flatten())
        
        if np.any(terminated):
            n_attack_wins += np.sum(terminated)

        # ##################
        # Defender takes action
        # ##################
        board = torch.tensor(obs, dtype=torch.float32)
        from_piece_q, to_piece_q = model_royal.forward(board)
        actions = piece_sampling.generate_actions_from_q(
            from_piece_q.detach().clone(), to_piece_q.detach().clone(), board, hu.TEAM_DEFENDER, temperature
        )
        next_state, reward, terminated, truncated, info = envs.step(actions)
        memory_royal.remember(obs, actions, reward, next_state, terminated, obs, np.zeros_like(reward))
        memory_attackers.update_memory_after_opponent_action(next_state, reward)
        obs = next_state
        cum_reward += np.sum(reward.flatten())

        if np.any(terminated):
            n_defends_win += np.sum(terminated)
        
        if (i > BATCH_SIZE) & (i % UPDATE_POLICY_STEPS == 0):
            # Update model each BATCH_SIZE steps
            # sample minibatch from replay memory and train model on reward / estimated reward discrepancy
            batch = memory_attackers.sample(BATCH_SIZE)
            loss_attack = optimize_batch(optimizer_attacker, model_attackers, model_attackers_target_estimation, batch, GAMMA)
            batch = memory_royal.sample(BATCH_SIZE)
            loss_defend = optimize_batch(optimizer_royal, model_royal, model_royal_target_estimation, batch, GAMMA)

            losses.append(loss_defend.detach().numpy() + loss_attack.detach().numpy())
            rewards.append(cum_reward)
            
            pretty_attacker_loss = np.round(loss_attack.detach().numpy(), 2)
            pretty_defender_loss = np.round(loss_defend.detach().numpy(), 2)
            
            pbar.set_postfix({
                'Temp': temperature,
                'A-loss': pretty_attacker_loss,
                'D-loss': pretty_defender_loss,
                'Cum. reward': cum_reward,
                'Wins A': n_attack_wins,
                'Wins D': n_defends_win
            })
            
            cum_reward = 0

        if (i > 0) & ((i * (epoch + 1)) % UPDATE_TARGET_ESTIMATOR_STEPS == 0):
            # every X steps, update target estimator
            model_attackers_target_estimation = copy.deepcopy(model_attackers)
            model_royal_target_estimation = copy.deepcopy(model_royal)

In [None]:
""" Example: Have two agents play against each other in real time and watch them """

import time 

envs = gym.vector.AsyncVectorEnv(
    env_fns=[lambda: hnefatafl_env.HnefataflEnv(render_mode='human')]
)
TEMPERATURE = 1
#envs = gym.make_vec("HnefataflEnv-v0", render_mode="human", num_envs=1)
obs, info = envs.reset()
for i in range(500):
    board = torch.Tensor(obs)
    from_piece_q, to_piece_q = model_attackers.forward(board)
    actions = piece_sampling.generate_actions_from_q(
        from_piece_q.detach().clone(), to_piece_q.detach().clone(), board, hu.TEAM_ATTACKER, TEMPERATURE
    )
    obs, reward, terminated, _, info = envs.step(actions)
    
    if all(terminated): 
        break

    time.sleep(0.01)

    # Agent 2 takes action
    board = torch.Tensor(obs)
    from_piece_q, to_piece_q = model_royal.forward(board)
    actions = piece_sampling.generate_actions_from_q(
        from_piece_q.detach().clone(), to_piece_q.detach().clone(), board, hu.TEAM_DEFENDER, TEMPERATURE
    )
    obs, reward, terminated, _, info = envs.step(actions)
    
    time.sleep(0.01)
    
    if all(terminated):
        envs.close()
        break