In [1]:
import sys
import copy 
import random
import itertools
import pandas as pd
import numpy as np
import torch
import tqdm.auto as tqdm
from pathlib import Path
import torch_scatter

from typing import *
from enum import Enum
from dataclasses import dataclass

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import logging
logger = logging.getLogger("research")

In [4]:
import os,sys
sys.path.insert(0, str(Path.cwd().parent))

In [5]:
import utils.logging
utils.logging.setup(debug=False)

### Narde rules
https://www.bkgm.com/variants/Narde.html

In [6]:
class Game:
    class Step(Enum):
        IDLE = "idle"
        ROLL = "roll"
        TURN = "trun"
        FINISHED = "finished"

    @staticmethod
    def _init_seed(seed: int | bool | None) -> int:
        if isinstance(seed, bool):
            seed = random.randint(0, 2**32-1) if seed else None

        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        
        return seed

    def __init__(self, seed: bool | int | None = None, verbose = False):
        self.seed = self._init_seed(seed)
        self.board = np.zeros(24, dtype=int)
        self.board[0] = 15
        self.board[12] = -15
        self.dice = [0, 0]
        self.home = [0, 0]
        self.pturn = 0
        self.t = 0
        self.head_moves = 0
        self.valid_moves = None
        self.step : Game.Step = Game.Step.IDLE
        self.loglevel = logging.INFO if verbose else logging.DEBUG
        logger.log(self.loglevel, f"new game started, seed={self.seed}")

    def _get_next_pos(self, pos: int, steps: int) -> int:
        dst_pos = pos + steps
        if dst_pos > 24:
            dst_pos = dst_pos % 25 + 1
        if dst_pos < 1:
            dst_pos = 24 + dst_pos
        return dst_pos
    
    def _get_head(self, player: Optional[int] = None) -> int:
        player = player if player is not None else self.pturn
        return 1 if player == 0 else 13
    
    def _get_user_sign(self, player: Optional[int] = None) -> int:
        player = player if player is not None else self.pturn
        return 1 if player == 0 else -1

    def _can_move_home(self):
        counter = self.home[self.pturn]
        home = range(19, 25) if self.pturn == 0 else range(7, 13)
        for pos in home:
            if self._has_checkers(pos, player=self._cur_player):
                counter += self._get_checkers(pos, player=self._cur_player)
        return counter == 15

    def _is_move_home(self, pos: int, steps: int) -> bool:
        dst_pos = self._get_next_pos(pos, steps)
        return (self.pturn == 0 and dst_pos < pos) or \
               (self.pturn == 1 and pos <= 12 and dst_pos > 12)
    
    @property
    def _opponent(self) -> int:
        return 1 if self.pturn == 0 else 0
    
    @property
    def _cur_player(self) -> int:
        return self.pturn
    
    def _has_checkers(self, pos: int, player: Optional[int] = None):
        return self._get_checkers(pos, player) > 0
    
    def _get_checkers(self, pos: int, player: Optional[int] = None):
        player = player if player is not None else self.pturn
        return self._get_user_sign(player) * self.board[pos-1]
    
    def _find_prime(self, pos: int) -> Optional[Tuple[int, int]]:
        seq = 1
        result = [pos, pos]
        next_ptr = self._get_next_pos(pos, 1)
        while seq < 6 and self._has_checkers(next_ptr, player=self._cur_player):
            seq += 1
            result[1] = next_ptr
            next_ptr = self._get_next_pos(next_ptr, 1)
        prev_ptr = self._get_next_pos(pos, -1)
        while seq < 6 and self._has_checkers(prev_ptr, player=self._cur_player):
            seq += 1
            result[0] = prev_ptr
            prev_ptr = self._get_next_pos(prev_ptr, -1)
        return tuple(result) if seq == 6 else None
    
    def _is_blocking_prime(self, dst_pos: int) -> bool:
        prime_range = self._find_prime(dst_pos)
        if prime_range:
            prime_end_pos = prime_range[1]
            for step in range(1, 24):
                search_pos = self._get_next_pos(prime_end_pos, step)
                if self._get_head(self._opponent) == search_pos:
                    return True # we reached other player's home without finding it's checkers
                if self._has_checkers(pos=search_pos, player=self._opponent):
                    break
        return False
    
    def _check_move(self, pos: int, steps: int):
        dst_pos = self._get_next_pos(pos, steps)
        if self.step != Game.Step.TURN:
            raise RuntimeError("invalid action")
        if not (1 <= pos <= 24):
            raise RuntimeError("invalid position")
        if not self._has_checkers(pos, player=self._cur_player):
            raise RuntimeError(f"no checkers at position {pos}")
        if self._has_checkers(dst_pos, player=self._opponent):
            raise RuntimeError(f"can't move to position {dst_pos}")
        if steps not in self.dice:
            raise RuntimeError(f"no dice with value {steps}")
        if self._is_move_home(pos, steps):
            if not self._can_move_home():
                raise RuntimeError(f"not all checkers are at finishing table")
        if self._get_head() == pos and self.head_moves > 0:
            if not (self.head_moves == 1 and self.dice[0] == self.dice[1] and self.t < 2):
                raise RuntimeError(f"can't make any more head moves")
        if self._is_blocking_prime(dst_pos):
            raise RuntimeError(f"can't form a blocking prime")
        # TODO: If player can play one number but not both, they must play the higher one

    def _render_player(self, player: Optional[int] = None, lower: bool = True) -> str:
        player = player if player is not None else self.pturn
        result = "O" if player == 0 else "X"
        return result.lower() if lower else result

    def _is_valid_move(self, pos: int, steps: int) -> bool:
        try:
            self._check_move(pos, steps)
            return True
        except RuntimeError as e:
            return False
    
    def _enum_valid_moves(self) -> Iterator[Tuple[int, int]]:
        for pos in range(1, 25):
            if self._has_checkers(pos, player=self._cur_player):
                for steps in range(1, 7):
                    if self._is_valid_move(pos, steps):
                        yield (pos, steps)

    def get_valid_moves(self) -> List[Tuple[int, int]]:
        if self.valid_moves is None:
            self.valid_moves = list(self._enum_valid_moves())
        return self.valid_moves
    
    def has_valid_moves(self):
        return len(self.get_valid_moves()) > 0

    def start(self, d1: int = 0, d2: int = 0) -> "Game":
        if self.step != Game.Step.IDLE:
            raise RuntimeError("invalid action")

        self.dice = [d1 or random.randint(1, 6), d2 or random.randint(1, 6)]
        while self.dice[0] == self.dice[1]:
            self.dice = [random.randint(1, 6), random.randint(1, 6)]

        self.step = Game.Step.ROLL
        if self.dice[0] > self.dice[1]:
            self.pturn = 0
        else: # self.dice[0] < self.dice[1]:
            self.pturn = 1
        return self

    def roll(self, d1: int = 0, d2: int = 0) -> "Game":
        if self.step != Game.Step.ROLL:
            raise RuntimeError("invalid action")
        self.dice = [d1 or random.randint(1, 6), d2 or random.randint(1, 6)]
        logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} rolls {self.dice}")
        if self.dice[0] == self.dice[1]:
            self.dice += self.dice
        self.step = Game.Step.TURN
        return self

    def turn(self, pos: int, steps: int) -> "Game":
        dst_pos = self._get_next_pos(pos, steps)
        self._check_move(pos, steps)
        if self._is_move_home(pos, steps):
            logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} moves: {pos}->HOME")
            self.board[pos-1] -= self._get_user_sign()
            self.home[self.pturn] += 1
        else:
            logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} moves: {pos}-({steps})->{dst_pos}")
            self.board[pos-1] -= self._get_user_sign()
            self.board[dst_pos-1] += self._get_user_sign()

        if self._get_head() == pos:
            self.head_moves += 1

        if len(self.dice) > 2:
            self.dice.pop(self.dice.index(steps, -1))
        else:
            self.dice[self.dice.index(steps)] = 0

        if self.home[self.pturn] == 15:
            logger.log(self.loglevel, f"t={self.t}, game finished, p={self._render_player()} wins")
            self.step = Game.Step.FINISHED
        elif self.dice[0] == 0 and self.dice[1] == 0:
            self.step = Game.Step.ROLL
            self.pturn = self._opponent
            self.t += 1
            self.head_moves = 0

        self.valid_moves = None
        return self
    
    def is_finished(self):
        return self.step == Game.Step.FINISHED
    
    def skip(self) -> "Game":
        if self.step != Game.Step.TURN:
            raise RuntimeError("invalid action")
        
        if self.has_valid_moves():
            raise RuntimeError("skip only possible when there's no moves")
        else:
            logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} has no eligible moves, skipping")
            self.dice = [0, 0]
            self.step = Game.Step.ROLL
            self.pturn = (self.pturn + 1) % 2
            self.t += 1
            self.head_moves = 0
            
        self.valid_moves = None
        return self

    def __repr__(self):
        template = """
        |{oha}| 24 | 23 | 22 | 21 | 20 | 19 |{xst}| 18 | 17 | 16 | 15 | 14 | 13 |{xho}|
        |{ohb}|-----------------------------|     |-----------------------------|{xhn}|
        |{ohc}|{x1}|{w1}|{v1}|{u1}|{t1}|{s1}|     |{r1}|{q1}|{p1}|{o1}|{n1}|{m1}|{xhm}|
        |{ohd}|{x2}|{w2}|{v2}|{u2}|{t2}|{s2}|     |{r2}|{q2}|{p2}|{o2}|{n2}|{m2}|{xhl}|
        |{ohe}|{x3}|{w3}|{v3}|{u3}|{t3}|{s3}|     |{r3}|{q3}|{p3}|{o3}|{n3}|{m3}|{xhk}|
        |{ohf}|{x4}|{w4}|{v4}|{u4}|{t4}|{s4}|     |{r4}|{q4}|{p4}|{o4}|{n4}|{m4}|{xhj}|
        |{ohg}|{x5}|{w5}|{v5}|{u5}|{t5}|{s5}|     |{r5}|{q5}|{p5}|{o5}|{n5}|{m5}|{xhi}|
        |{ohh}|-----------------------------|{dcs}|-----------------------------|{xhh}|
        |{ohi}|{a5}|{b5}|{c5}|{d5}|{e5}|{f5}|     |{g5}|{h5}|{i5}|{j5}|{k5}|{l5}|{xhg}|
        |{ohj}|{a4}|{b4}|{c4}|{d4}|{e4}|{f4}|     |{g4}|{h4}|{i4}|{j4}|{k4}|{l4}|{xhf}|
        |{ohk}|{a3}|{b3}|{c3}|{d3}|{e3}|{f3}|     |{g3}|{h3}|{i3}|{j3}|{k3}|{l3}|{xhe}|
        |{ohl}|{a2}|{b2}|{c2}|{d2}|{e2}|{f2}|     |{g2}|{h2}|{i2}|{j2}|{k2}|{l2}|{xhd}|
        |{ohm}|{a1}|{b1}|{c1}|{d1}|{e1}|{f1}|     |{g1}|{h1}|{i1}|{j1}|{k1}|{l1}|{xhc}|
        |{ohn}|-----------------------------|     |-----------------------------|{xhb}|
        |{oho}| 01 | 02 | 03 | 04 | 05 | 06 |{ost}| 07 | 08 | 09 | 10 | 11 | 12 |{xha}|
        """

        pixels = {}
        for pos in range(1,25):
            x = chr(ord("a") + pos - 1)
            checkers = abs(int(self.board[pos-1]))
            pixel = self._render_player(player = 0 if self.board[pos-1] > 0 else 1)
            for y in range(1, 6):
                if checkers > 0:
                    if y < 5 or checkers == 1:
                        pixels[f"{x}{y}"] = (f"  {pixel} ")
                    else: # y == 5 and checkers > 1:
                        pixels[f"{x}{y}"] = f"({checkers}".rjust(3) + ")"
                    checkers -= 1
                else:
                    pixels[f"{x}{y}"] = "    "
        pixels["dcs"] = f" {self.dice[0] or ' '}:{self.dice[1] or ' '} "

        for player in [0, 1]:
            pixel = self._render_player(player=player)
            for idx in range(15):
                h = pixel + "h" + chr(ord("a") + idx)
                pixels[h] = f"  {pixel}  " if self.home[player] > idx else "     "

        pixels["ost"] = "     "
        pixels["xst"] = "     "
        if self.step == Game.Step.ROLL:
            pixels["ost" if self.pturn == 0 else "xst"] = " ROL "
        elif self.step == Game.Step.TURN:
            pixels["ost" if self.pturn == 0 else "xst"] = f" ({len([d for d in self.dice if d])}) "
        elif self.step == Game.Step.FINISHED:
            pixels["ost" if self.pturn == 0 else "xst"] = " WIN "

        return template.format(**pixels)

In [7]:
# def validate(game: Game):
#     p1, p2 = 0, 0
#     for i in range(len(game.board)):
#         if game.board[i] > 0:
#             p1 += game.board[i]
#         if game.board[i] < 0:
#             p2 -= game.board[i]
#     p1 += game.home[0]
#     p2 += game.home[1]
#     assert (p1, p2) == (15, 15), "invalid # of checkers at the board"

In [8]:
# def random_move(game: Game):
#     if not game.is_finished():
#         moves = game.get_valid_moves()
#         if len(moves) > 0:
#             pos, steps = random.choice(moves)
#             game.turn(pos, steps)
#             return True
#         else:
#             game.skip()
#     return False

# def auto_turn(game: Game):
#     game.roll()
#     for _ in range(len(game.dice)):
#         if not random_move(game):
#             break
#     return game

# def auto_rollout(game, turns: int = 100):
#     for turn in range(turns):
#         auto_turn(game)
#         validate(game)
#         if game.is_finished():
#             break
#     return game

In [9]:
# g = Game(seed=42, verbose=True)
# auto_rollout(g.start(), 100)

## Test cases:

In [10]:
import unittest

In [11]:
class GameTests(unittest.TestCase):

    def test_heads(self):
        g = Game(seed=42).start()
        g.roll(3,3).turn(1,3).turn(1,3) # < can make 2 moves from head on doubles
        with self.assertRaises(RuntimeError):
            g.turn(1,3)
        g.turn(4,3).turn(4,3)
        g.roll(1,2).turn(13,1).turn(14,2) # p2
        g.roll(2,2).turn(1,2)
        with self.assertRaises(RuntimeError):
            g.turn(1, 2)
        g.turn(3,2)
        with self.assertRaises(RuntimeError):
            g.turn(1, 2)
        g.turn(5,2)

    def test_skips(self):
        g = Game(seed=42).start()
        g.roll(6,1).turn(1, 6).turn(7, 1)
        g.roll(3,2).turn(13, 3).turn(16, 2)
        g.roll(2,2).turn(1, 2).turn(3, 2).turn(5, 2).turn(8, 2)
        g.roll(6,1).turn(13,6).turn(18,1)
        g.roll(6,6).turn(10, 6).turn(16,6).turn(1, 6) # < there's no moves for red here
        with self.assertRaises(RuntimeError):
            g.turn(1, 6)
        g.skip()

    def test_blocking_primes(self):
        g = Game(seed=42).start()
        g.roll(1,1).turn(1,1).turn(1,1).turn(2,1).turn(3,1)
        g.roll(6,5).turn(13,6).turn(19,5)
        g.roll(2,3).turn(1,2).turn(3,3)
        g.roll(4,2).turn(13,4).turn(17,2)
        g.roll(3,1).turn(1,1).turn(2,3)
        g.roll(3,6).turn(13,3).turn(16,6)
        g.roll(1,2)
        with self.assertRaises(RuntimeError):
            g.turn(1,2)

    def test_nonblocking_primes(self):
        g = Game(seed=42).start()
        g.roll(1,1).turn(1,1).turn(1,1).turn(2,1).turn(3,1)
        g.roll(6,5).turn(13,6).turn(19,5)
        g.roll(2,3).turn(1,2).turn(3,3)
        g.roll(4,2).turn(13,4).turn(17,2)
        g.roll(3,1).turn(1,1).turn(2,3)
        g.roll(3,6).turn(24,3).turn(3,6)
        g.roll(1,2).turn(1,2) # < valid move, since x is ahead

unittest.main(argv=[''], verbosity=2, exit=False)

test_blocking_primes (__main__.GameTests) ... ok
test_heads (__main__.GameTests) ... ok
test_nonblocking_primes (__main__.GameTests) ... ok
test_skips (__main__.GameTests) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.003s

OK


<unittest.main.TestProgram at 0x15c42d450>

## Player interface & automation

In [12]:
class BasePlayer:
    def play_turn(self, game: Game) -> bool:
        pass

class RandomPlayer(BasePlayer):
    def play_turn(self, game: Game) -> bool:
        actions = game.get_valid_moves()
        if len(actions) > 0:
            pos, steps = random.choice(actions)
            game.turn(pos, steps)
            return True
        else:
            game.skip()
            return False

class LazyPlayer(BasePlayer):
    def play_turn(self, game: Game) -> bool:
        actions = game.get_valid_moves()
        if actions:
            pos, steps = actions[0]
            game.turn(pos, steps)
            return True
        else:
            game.skip()
            return False

In [13]:
class AutoGame(Game):
    def __init__(self, player1: BasePlayer | None = None, player2: BasePlayer | None = None, start: bool = True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.auto_player1 = player1
        self.auto_player2 = player2
        if start:
            self.start()

    def _automate(self) -> "AutoGame":
        while not self.is_finished():
            if self.step == Game.Step.ROLL:
                self.roll()
            if self.step == Game.Step.TURN:
                if not self.has_valid_moves():
                    self.skip()
                else:
                    player = self.auto_player1 if (self.pturn == 0) else self.auto_player2
                    if player:
                        player.play_turn(self)
                    else:
                        break
        return self
   
    def start(self, *args, **kwargs) -> "AutoGame":
        super().start(*args, **kwargs)
        return self._automate()

    def turn(self, *args, **kwargs) -> "AutoGame":
        super().turn(*args, **kwargs)
        return self._automate()

    def play_sequence(self, turns: List[Tuple[int, int]]) -> "AutoGame":
        for (pos, steps) in turns:
            self.turn(pos, steps)
        return self

In [14]:
class AutoGameTests(unittest.TestCase):

    def test_basic(self):
        AutoGame(seed=324).play_sequence([(1, 6), (1, 6), (13, 1), (14, 4),(7,5),(1,4),(13,1),(18,3)])
        with self.assertRaises(RuntimeError):
            AutoGame(seed=324).play_sequence([(1, 6), (1, 6), (13, 1), (14, 4),(7,5),(1,4),(13,1),(18,3),(4, 1)])

    def test_autoplayer(self):
        AutoGame(seed=324, player2=LazyPlayer()).play_sequence([(1,6), (1,6), (7,5), (1,4)])
        AutoGame(seed=324, player1=LazyPlayer()).play_sequence([(13,4), (17,1)])
        self.assertEqual(AutoGame(seed=324, player1=LazyPlayer(), player2=LazyPlayer()).is_finished(), True)
        self.assertEqual(AutoGame(seed=324, start=False, player1=LazyPlayer(), player2=LazyPlayer()).is_finished(), False)

unittest.main(argv=[''], verbosity=2, exit=False)

test_autoplayer (__main__.AutoGameTests) ... ok
test_basic (__main__.AutoGameTests) ... ok
test_blocking_primes (__main__.GameTests) ... ok
test_heads (__main__.GameTests) ... ok
test_nonblocking_primes (__main__.GameTests) ... ok
test_skips (__main__.GameTests) ... ok

----------------------------------------------------------------------
Ran 6 tests in 0.059s

OK


<unittest.main.TestProgram at 0x121369ff0>

In [15]:
@dataclass
class Result:
    winner: int
    turns: int
    reward: int

def summarize(game: Game) -> Result:
    assert game.is_finished()
    reward =  (2 if game.home[game._opponent] == 0 else 1)
    reward *= (1 if game.pturn == 0 else -1)
    result = Result(
        winner = int(game.pturn == 0),
        turns = game.t,
        reward = reward
    )
    return result

def simulate(game: Game, player1: BasePlayer, player2: BasePlayer | None = None) -> Result:
    while not game.is_finished():
        player = player1 if (game.pturn == 0) else player2
        if player:
            player.play_turn(game)
    return summarize(game)

In [16]:
# summarize(AutoGame(player1=LazyPlayer(), player2=LazyPlayer()))

In [17]:
# simulate(AutoGame(player2=LazyPlayer()), player1=LazyPlayer())

In [18]:
# simulate(AutoGame(), player1=LazyPlayer(), player2=LazyPlayer())

In [19]:
import math

games = 100

results = []
for p1, p2, start in tqdm.tqdm(list(itertools.product(["lazy", "rand"], ["lazy", "rand"], ["random", "first", "second"])), leave=False):
    exp_info = {"p1": p1, "p2": p2, "start": start}
    player1 = LazyPlayer() if p1 == "lazy" else RandomPlayer()
    player2 = LazyPlayer() if p2 == "lazy" else RandomPlayer()
    start_args = [6,6] if start == "random" else ([6,1] if start == "first" else [1,6])
    sims = pd.DataFrame([summarize(AutoGame(player1, player2, False).start(*start_args)).__dict__ for _ in tqdm.trange(games, postfix=exp_info, leave=False)])
    exp_info["games"] = games
    exp_info["wins"] = sims["winner"].sum()
    wins_mu = exp_info["wins"] / exp_info["games"]
    wins_sd = round(math.sqrt(exp_info["games"] * wins_mu * (1 - wins_mu)), 2)
    exp_info["win_rate_lo"] = (exp_info["wins"] - wins_sd*3) / exp_info["games"]
    exp_info["win_rate_mu"] = wins_mu
    exp_info["win_rate_hi"] = (exp_info["wins"] + wins_sd*3) / exp_info["games"]
    exp_info["avg_turns"] = sims["turns"].mean()
    exp_info["avg_reward"] = sims["reward"].mean()
    results.append(exp_info)

pd.DataFrame(results)

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=lazy, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=lazy, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=lazy, start=second]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=rand, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=rand, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=rand, start=second]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=lazy, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=lazy, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=lazy, start=second]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=rand, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=rand, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=rand, start=second]

Unnamed: 0,p1,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward
0,lazy,lazy,random,100,40,0.253,0.4,0.547,90.85,-0.14
1,lazy,lazy,first,100,48,0.33,0.48,0.63,92.42,0.09
2,lazy,lazy,second,100,43,0.2815,0.43,0.5785,91.19,-0.08
3,lazy,rand,random,100,75,0.6201,0.75,0.8799,92.77,0.73
4,lazy,rand,first,100,89,0.7961,0.89,0.9839,92.75,1.14
5,lazy,rand,second,100,76,0.6319,0.76,0.8881,93.42,0.78
6,rand,lazy,random,100,56,0.4112,0.56,0.7088,94.37,0.12
7,rand,lazy,first,100,46,0.3106,0.46,0.6094,93.48,-0.06
8,rand,lazy,second,100,51,0.36,0.51,0.66,93.15,0.02
9,rand,rand,random,100,48,0.33,0.48,0.63,96.39,-0.07


Interesting bias is observed here, as lazy policy seems to have a significant advantage when played against random.  
Bias only shows up if it plays as player1, but disappears if it plays as player2

TODO: Figure out what's going on here

# Training a model with Q-Learning

In [20]:
@dataclass
class ReplaySample:
    state_action: torch.Tensor
    reward: int = 0
    next_state_actions: torch.Tensor | None = None


In [21]:
class ReplayBuffer:
    def __init__(self, size: 1_000_000):
        self.buffer = [None] * size
        self.insert_ptr = 0
        self.upper_bound = 0

    def add(self, sample: ReplaySample):
        self.buffer[self.insert_ptr] = sample
        self.insert_ptr = self.insert_ptr + 1 if self.insert_ptr < len(self.buffer) else 0
        self.upper_bound = max(self.insert_ptr, self.upper_bound)
    
    def sample(self, k: int = 1) -> List[ReplaySample]:
        return random.choices(self.buffer[:self.upper_bound], k=k)
    
    def __getitem__(self, index: int) -> ReplaySample:
        if index < self.upper_bound:
            return self.buffer[index]
        raise IndexError()
    
    def __len__(self) -> int:
        return self.upper_bound

In [22]:
class QPolicy(BasePlayer):
    def __init__(
            self, 
            layers=[32, 64], 
            device="cpu", 
            training: bool = True,
            replay_buffer_size: int = 1_000_000, 
            params = {"lr": 0.001}
        ):
        self.device = device
        network = []
        network.append(torch.nn.Linear(31, layers[0]))
        network.append(torch.nn.ReLU())
        for idx in range(1, len(layers)):
            network.append(torch.nn.Linear(layers[idx-1], layers[idx]))
            network.append(torch.nn.ReLU())
        network.append(torch.nn.Linear(layers[-1], 1))
        self.q_network = torch.nn.Sequential(*network).to(self.device)
        self.t_network = torch.nn.Sequential(*network).to(self.device)
        self.training = training
        self.replay_buffer = ReplayBuffer(size=replay_buffer_size)
        self.q_network.apply(self._init_weights)
        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=params["lr"])
        self.prev_state_action = None
        self._sync_networks()

    @staticmethod
    def _init_weights(m):
        if hasattr(m, "weight"):
            torch.nn.init.xavier_uniform_(m.weight, gain=2 ** (1.0 / 2))
        if hasattr(m, "bias"):
            torch.nn.init.zeros_(m.bias)

    def _sync_networks(self) -> None:
        logger.debug("syncing weights of t-network")
        self.t_network.load_state_dict(self.q_network.state_dict())
    
    def _encode_state(self, game: Game) -> torch.Tensor:
        dice = game.dice.copy()
        if len(dice) < 4:
            dice += [0] * (4 - len(dice))
        return torch.concat([
            torch.tensor(game.board),
            torch.tensor(dice),
            torch.tensor([game.head_moves])
        ]).to(self.device).float()

    def _encode_actions(self, actions: List[Tuple[int, int]]) -> torch.Tensor:
        if not actions:
            actions = [(-1,-1)]
        return torch.tensor(actions).to(self.device).float()
    
    def _encode_state_actions(self, game: Game) -> torch.Tensor:
        return self._concat_state_actions(
            self._encode_state(game),
            self._encode_actions(game.get_valid_moves())
        )

    def _concat_state_actions(self, state: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        actions = actions.view(-1, 2)
        return torch.concat([
                state.unsqueeze(dim=0).broadcast_to((actions.shape[0], -1)),
                actions
            ], dim=1)
   
    def _sample_batch(self, batch_size: int = 32):
        rewards = []
        curr_state_actions = []
        next_state_actions = []
        next_state_actions_idx = []
        for sample_id, sample in enumerate(self.replay_buffer.sample(k=batch_size)):
            rewards.append(sample.reward)
            curr_state_actions.append(sample.state_action)
            if sample.next_state_actions is not None:
                next_state_actions.append(sample.next_state_actions)
                next_state_actions_idx.extend([sample_id] * sample.next_state_actions.shape[0])

        rewards = torch.tensor(rewards).float().to(self.device)
        curr_state_actions = torch.vstack(curr_state_actions)
        next_state_actions = torch.vstack(next_state_actions)
        next_state_actions_idx = torch.tensor(next_state_actions_idx, dtype=torch.long).to(self.device)
        assert next_state_actions.shape[0] == next_state_actions_idx.shape[0]
        return curr_state_actions, rewards, next_state_actions, next_state_actions_idx
    
    def _calc_loss(self, q_scores, t_scores, sample_ids, rewards) -> torch.Tensor:
        gamma = 1 # 0.99
        t_scores_max = rewards.scatter_reduce(dim=0, index=sample_ids, src=t_scores.squeeze(), reduce="max", include_self=False)
        td_target = rewards + gamma * t_scores_max
        td_error = td_target - q_scores.squeeze()
        return td_error.pow(2).mean()
    
    def _train_step(self, batch_size: int = 32) -> None:
        if len(self.replay_buffer) < batch_size:
            logger.warning("not enough samples in replay buffer, skipping train step")

        self.optimizer.zero_grad()
    
        curr_sa, rewards, next_sa, sample_ids  = self._sample_batch(batch_size)

        q_scores = self.q_network(curr_sa)
        with torch.no_grad():
            t_scores = self.t_network(next_sa)

        loss = self._calc_loss(q_scores, t_scores, sample_ids, rewards)

        loss.backward()
        # grad_norm = torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
            
        # for group in self.optimizer.param_groups:
        #     group["lr"] = self.learning_rate_schedule.get_value(self.step)
        
        self.optimizer.step()

        return loss.item()#, grad_norm.item()


    def play_turn(self, game: Game):
        valid_moves = list(game.get_valid_moves())
        if not valid_moves:
            game.skip()
            return False

        with torch.no_grad():
            state_actions = self._encode_state_actions(game)

            if self.training and self.prev_state_action is not None:
                self.replay_buffer.add(ReplaySample(state_action=self.prev_state_action, reward=self._calc_reward(game), next_state_actions=state_actions))
                self.prev_state_action = None
        
            scores = self.q_network(state_actions).squeeze()
            action_idx = scores.argmax(dim=-1).item()
            self.prev_state_action = state_actions[action_idx]

            game.turn(*valid_moves[action_idx])

            if game.is_finished() and self.training and self.prev_state_action is not None:
                self.replay_buffer.add(ReplaySample(state_action=self.prev_state_action, reward=self._calc_reward(game)))
                self.prev_state_action = None

            # self._train_step()

        return True

    def _calc_reward(self, game: Game) -> int:
        reward = 0
        if game.is_finished():
            reward  = (2 if game.home[game._opponent] == 0 else 1)
            reward *= (1 if game.pturn == 0 else -1)
        return reward


### Training and evaluation

In [23]:
def practice(policy: QPolicy, games: int = 100, train_every: int = 1, sync_every: int = 50, batch_size = 32, show_progress : bool = False):
    loss_vals = []
    for game_id in (tqdm.trange(games, leave=False, desc="practicing") if show_progress else range(games)):
        simulate(AutoGame(), player1=policy, player2=RandomPlayer())
        if game_id % train_every == 0:
            loss = policy._train_step(batch_size=batch_size)
            loss_vals.append(loss)
        if game_id % sync_every == 0:
            policy._sync_networks()
    return sum(loss_vals) / len(loss_vals)

In [24]:
def evaluate(model_name: str, policy: QPolicy, games: int = 100):
    prev_training = policy.training
    policy.training = False

    results = []
    sims = pd.DataFrame([simulate(AutoGame(player2=RandomPlayer()), player1=policy).__dict__ for _ in tqdm.trange(games, leave=False, desc="evaluating")])
    exp_info = {"model": model_name, "p2": "random", "start": "random"}
    exp_info["games"] = games
    exp_info["wins"] = sims["winner"].sum()
    wins_mu = exp_info["wins"] / exp_info["games"]
    wins_sd = round(math.sqrt(exp_info["games"] * wins_mu * (1 - wins_mu)), 2)
    exp_info["win_rate_lo"] = (exp_info["wins"] - wins_sd*3) / exp_info["games"]
    exp_info["win_rate_mu"] = wins_mu
    exp_info["win_rate_hi"] = (exp_info["wins"] + wins_sd*3) / exp_info["games"]
    exp_info["avg_turns"] = sims["turns"].mean()
    exp_info["avg_reward"] = sims["reward"].mean()
    results.append(exp_info)

    policy.training = prev_training
    return pd.DataFrame(results)

In [25]:
def train_eval_loop(policy: QPolicy, epochs: int = 1000, practice_games: int = 1000, eval_games: int = 100, sync_every: int = 50):
    epoch_pbar = tqdm.trange(epochs, desc="train/eval epochs")
    results = []
    result = evaluate(f"untrained", policy, games=eval_games).loc[0].to_dict()
    epoch_pbar.set_postfix({"avg_reward": result["avg_reward"], "win_rate": result["win_rate_mu"]})
    results.append(result)
    for epoch_id in epoch_pbar:
        avg_loss = practice(policy, games=practice_games, train_every=1, sync_every=sync_every, batch_size=32, show_progress=True)
        epoch_pbar.set_postfix({"avg_loss": avg_loss, "avg_reward": result["avg_reward"], "win_rate": result["win_rate_mu"]})
        result = evaluate(f"epoch-{epoch_id}", policy, games=eval_games).loc[0].to_dict()
        results.append(result)
        epoch_pbar.set_postfix({"avg_loss": avg_loss, "avg_reward": result["avg_reward"], "win_rate": result["win_rate_mu"]})
    results = pd.DataFrame(results)
    return results

In [34]:
nn_player = QPolicy(layers=[32, 64, 32], device="cpu", params={"lr": 0.000001})

In [35]:
evaluate("untrained", nn_player)

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward
0,untrained,random,random,100,32,0.1802,0.32,0.4598,96.24,-0.6


In [None]:
train_eval_loop(nn_player, epochs=100, practice_games=100, eval_games=100)

train/eval epochs:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

practicing:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

### Experiments with scatter & reduce:

In [None]:
# q_score = nn_player.q_network(curr_state_actions)

# with torch.no_grad():
#     t_score = nn_player.t_network(next_state_actions).squeeze()
#     t_score_max = torch.zeros_like(rewards).scatter_reduce(0, index=next_state_actions_idx, src=t_score, reduce="max", include_self=False)
    
# t_score_max

tensor([-1.3251,  7.6956,  5.5489,  0.0000])

In [None]:
# q_score = nn_player.q_network(curr_state_actions)

# with torch.no_grad():
#     t_score = nn_player.t_network(next_state_actions).squeeze()
#     t_score_max, t_score_idx = torch_scatter.scatter_max(t_score, index=next_state_actions_idx, dim=0)
    
# t_score_max

tensor([-1.3251,  7.6956,  5.5489])

In [None]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=next_state_actions_idx, dim=0)
# max_score, max_score_idx

(tensor([[ 3.5378],
         [ 2.6326],
         [-1.3690],
         [ 7.6219]]),
 tensor([[0],
         [3],
         [4],
         [7]]))

In [None]:
# torch.zeros_like(rewards).scatter_reduce(0, index=next_state_actions_idx, src=score.squeeze(), reduce="max", include_self=False)

tensor([ 3.5378,  2.6326, -1.3690,  7.6219])

In [None]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=next_state_actions_idx, dim=0)
# max_score, max_score_idx

In [None]:
# data = torch.arange(24).view(-1,6).long()
# data

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])

In [163]:
# score = data.float().sum(dim=1)
# score

tensor([ 15.,  51.,  87., 123.])

In [164]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=torch.tensor([0,0,0,1]), dim=0)
# max_score, max_score_idx

(tensor([ 87., 123.]), tensor([2, 3]))

In [None]:
# data[max_score_idx]

tensor([[12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]])