In [1]:
import sys
import copy 
import random
import itertools
import pandas as pd
import numpy as np
import torch
import tqdm.auto as tqdm
from pathlib import Path
import torch_scatter

from typing import *
from enum import Enum
from dataclasses import dataclass

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import logging
logger = logging.getLogger("research")

In [4]:
import os,sys
sys.path.insert(0, str(Path.cwd().parent))

In [5]:
import utils.logging
utils.logging.setup(debug=False)

### Narde rules
https://www.bkgm.com/variants/Narde.html

In [6]:
class Game:
    class Step(Enum):
        IDLE = "idle"
        ROLL = "roll"
        TURN = "trun"
        FINISHED = "finished"

    @staticmethod
    def _init_seed(seed: int | bool | None) -> int:
        if isinstance(seed, bool):
            seed = random.randint(0, 2**32-1) if seed else None

        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        
        return seed

    def __init__(self, seed: bool | int | None = None, verbose = False):
        self.seed = self._init_seed(seed)
        self.board = np.zeros(24, dtype=int)
        self.board[0] = 15
        self.board[12] = -15
        self.dice = [0, 0]
        self.home = [0, 0]
        self.pturn = 0
        self.t = 0
        self.head_moves = 0
        self.valid_moves = None
        self.step : Game.Step = Game.Step.IDLE
        self.loglevel = logging.INFO if verbose else logging.DEBUG
        logger.log(self.loglevel, f"new game started, seed={self.seed}")

    def _get_next_pos(self, pos: int, steps: int) -> int:
        dst_pos = pos + steps
        if dst_pos > 24:
            dst_pos = dst_pos % 25 + 1
        if dst_pos < 1:
            dst_pos = 24 + dst_pos
        return dst_pos
    
    def _get_head(self, player: Optional[int] = None) -> int:
        player = player if player is not None else self.pturn
        return 1 if player == 0 else 13
    
    def _get_user_sign(self, player: Optional[int] = None) -> int:
        player = player if player is not None else self.pturn
        return 1 if player == 0 else -1

    def _can_move_home(self):
        counter = self.home[self.pturn]
        home = range(19, 25) if self.pturn == 0 else range(7, 13)
        for pos in home:
            if self._has_checkers(pos, player=self._cur_player):
                counter += self._get_checkers(pos, player=self._cur_player)
        return counter == 15

    def _is_move_home(self, pos: int, steps: int) -> bool:
        dst_pos = self._get_next_pos(pos, steps)
        return (self.pturn == 0 and dst_pos < pos) or \
               (self.pturn == 1 and pos <= 12 and dst_pos > 12)
    
    @property
    def _opponent(self) -> int:
        return 1 if self.pturn == 0 else 0
    
    @property
    def _cur_player(self) -> int:
        return self.pturn
    
    def _has_checkers(self, pos: int, player: Optional[int] = None):
        return self._get_checkers(pos, player) > 0
    
    def _get_checkers(self, pos: int, player: Optional[int] = None):
        player = player if player is not None else self.pturn
        return self._get_user_sign(player) * self.board[pos-1]
    
    def _find_prime(self, pos: int) -> Optional[Tuple[int, int]]:
        seq = 1
        result = [pos, pos]
        next_ptr = self._get_next_pos(pos, 1)
        while seq < 6 and self._has_checkers(next_ptr, player=self._cur_player):
            seq += 1
            result[1] = next_ptr
            next_ptr = self._get_next_pos(next_ptr, 1)
        prev_ptr = self._get_next_pos(pos, -1)
        while seq < 6 and self._has_checkers(prev_ptr, player=self._cur_player):
            seq += 1
            result[0] = prev_ptr
            prev_ptr = self._get_next_pos(prev_ptr, -1)
        return tuple(result) if seq == 6 else None
    
    def _is_blocking_prime(self, dst_pos: int) -> bool:
        prime_range = self._find_prime(dst_pos)
        if prime_range:
            prime_end_pos = prime_range[1]
            for step in range(1, 24):
                search_pos = self._get_next_pos(prime_end_pos, step)
                if self._get_head(self._opponent) == search_pos:
                    return True # we reached other player's home without finding it's checkers
                if self._has_checkers(pos=search_pos, player=self._opponent):
                    break
        return False
    
    def _check_move(self, pos: int, steps: int):
        dst_pos = self._get_next_pos(pos, steps)
        if self.step != Game.Step.TURN:
            raise RuntimeError("invalid action")
        if not (1 <= pos <= 24):
            raise RuntimeError("invalid position")
        if not self._has_checkers(pos, player=self._cur_player):
            raise RuntimeError(f"no checkers at position {pos}")
        if self._has_checkers(dst_pos, player=self._opponent):
            raise RuntimeError(f"can't move to position {dst_pos}")
        if steps not in self.dice:
            raise RuntimeError(f"no dice with value {steps}")
        if self._is_move_home(pos, steps):
            if not self._can_move_home():
                raise RuntimeError(f"not all checkers are at finishing table")
        if self._get_head() == pos and self.head_moves > 0:
            if not (self.head_moves == 1 and self.dice[0] == self.dice[1] and self.t < 2):
                raise RuntimeError(f"can't make any more head moves")
        if self._is_blocking_prime(dst_pos):
            raise RuntimeError(f"can't form a blocking prime")
        # TODO: If player can play one number but not both, they must play the higher one

    def _render_player(self, player: Optional[int] = None, lower: bool = True) -> str:
        player = player if player is not None else self.pturn
        result = "O" if player == 0 else "X"
        return result.lower() if lower else result

    def _is_valid_move(self, pos: int, steps: int) -> bool:
        try:
            self._check_move(pos, steps)
            return True
        except RuntimeError as e:
            return False
    
    def _enum_valid_moves(self) -> Iterator[Tuple[int, int]]:
        for pos in range(1, 25):
            if self._has_checkers(pos, player=self._cur_player):
                for steps in range(1, 7):
                    if self._is_valid_move(pos, steps):
                        yield (pos, steps)

    def get_valid_moves(self) -> List[Tuple[int, int]]:
        if self.valid_moves is None:
            self.valid_moves = list(self._enum_valid_moves())
        return self.valid_moves
    
    def has_valid_moves(self):
        return len(self.get_valid_moves()) > 0

    def start(self, d1: int = 0, d2: int = 0) -> "Game":
        if self.step != Game.Step.IDLE:
            raise RuntimeError("invalid action")

        self.dice = [d1 or random.randint(1, 6), d2 or random.randint(1, 6)]
        while self.dice[0] == self.dice[1]:
            self.dice = [random.randint(1, 6), random.randint(1, 6)]

        self.step = Game.Step.ROLL
        if self.dice[0] > self.dice[1]:
            self.pturn = 0
        else: # self.dice[0] < self.dice[1]:
            self.pturn = 1
        return self

    def roll(self, d1: int = 0, d2: int = 0) -> "Game":
        if self.step != Game.Step.ROLL:
            raise RuntimeError("invalid action")
        self.dice = [d1 or random.randint(1, 6), d2 or random.randint(1, 6)]
        logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} rolls {self.dice}")
        if self.dice[0] == self.dice[1]:
            self.dice += self.dice
        self.step = Game.Step.TURN
        return self

    def turn(self, pos: int, steps: int) -> "Game":
        dst_pos = self._get_next_pos(pos, steps)
        self._check_move(pos, steps)
        if self._is_move_home(pos, steps):
            logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} moves: {pos}->HOME")
            self.board[pos-1] -= self._get_user_sign()
            self.home[self.pturn] += 1
        else:
            logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} moves: {pos}-({steps})->{dst_pos}")
            self.board[pos-1] -= self._get_user_sign()
            self.board[dst_pos-1] += self._get_user_sign()

        if self._get_head() == pos:
            self.head_moves += 1

        if len(self.dice) > 2:
            self.dice.pop(self.dice.index(steps, -1))
        else:
            self.dice[self.dice.index(steps)] = 0

        if self.home[self.pturn] == 15:
            logger.log(self.loglevel, f"t={self.t}, game finished, p={self._render_player()} wins")
            self.step = Game.Step.FINISHED
        elif self.dice[0] == 0 and self.dice[1] == 0:
            self.step = Game.Step.ROLL
            self.pturn = self._opponent
            self.t += 1
            self.head_moves = 0

        self.valid_moves = None
        return self
    
    def is_finished(self):
        return self.step == Game.Step.FINISHED
    
    def skip(self) -> "Game":
        if self.step != Game.Step.TURN:
            raise RuntimeError("invalid action")
        
        if self.has_valid_moves():
            raise RuntimeError("skip only possible when there's no moves")
        else:
            logger.log(self.loglevel, f"t={self.t}, p={self._render_player()} has no eligible moves, skipping")
            self.dice = [0, 0]
            self.step = Game.Step.ROLL
            self.pturn = (self.pturn + 1) % 2
            self.t += 1
            self.head_moves = 0
            
        self.valid_moves = None
        return self

    def __repr__(self):
        template = """
        |{oha}| 24 | 23 | 22 | 21 | 20 | 19 |{xst}| 18 | 17 | 16 | 15 | 14 | 13 |{xho}|
        |{ohb}|-----------------------------|     |-----------------------------|{xhn}|
        |{ohc}|{x1}|{w1}|{v1}|{u1}|{t1}|{s1}|     |{r1}|{q1}|{p1}|{o1}|{n1}|{m1}|{xhm}|
        |{ohd}|{x2}|{w2}|{v2}|{u2}|{t2}|{s2}|     |{r2}|{q2}|{p2}|{o2}|{n2}|{m2}|{xhl}|
        |{ohe}|{x3}|{w3}|{v3}|{u3}|{t3}|{s3}|     |{r3}|{q3}|{p3}|{o3}|{n3}|{m3}|{xhk}|
        |{ohf}|{x4}|{w4}|{v4}|{u4}|{t4}|{s4}|     |{r4}|{q4}|{p4}|{o4}|{n4}|{m4}|{xhj}|
        |{ohg}|{x5}|{w5}|{v5}|{u5}|{t5}|{s5}|     |{r5}|{q5}|{p5}|{o5}|{n5}|{m5}|{xhi}|
        |{ohh}|-----------------------------|{dcs}|-----------------------------|{xhh}|
        |{ohi}|{a5}|{b5}|{c5}|{d5}|{e5}|{f5}|     |{g5}|{h5}|{i5}|{j5}|{k5}|{l5}|{xhg}|
        |{ohj}|{a4}|{b4}|{c4}|{d4}|{e4}|{f4}|     |{g4}|{h4}|{i4}|{j4}|{k4}|{l4}|{xhf}|
        |{ohk}|{a3}|{b3}|{c3}|{d3}|{e3}|{f3}|     |{g3}|{h3}|{i3}|{j3}|{k3}|{l3}|{xhe}|
        |{ohl}|{a2}|{b2}|{c2}|{d2}|{e2}|{f2}|     |{g2}|{h2}|{i2}|{j2}|{k2}|{l2}|{xhd}|
        |{ohm}|{a1}|{b1}|{c1}|{d1}|{e1}|{f1}|     |{g1}|{h1}|{i1}|{j1}|{k1}|{l1}|{xhc}|
        |{ohn}|-----------------------------|     |-----------------------------|{xhb}|
        |{oho}| 01 | 02 | 03 | 04 | 05 | 06 |{ost}| 07 | 08 | 09 | 10 | 11 | 12 |{xha}|
        """

        pixels = {}
        for pos in range(1,25):
            x = chr(ord("a") + pos - 1)
            checkers = abs(int(self.board[pos-1]))
            pixel = self._render_player(player = 0 if self.board[pos-1] > 0 else 1)
            for y in range(1, 6):
                if checkers > 0:
                    if y < 5 or checkers == 1:
                        pixels[f"{x}{y}"] = (f"  {pixel} ")
                    else: # y == 5 and checkers > 1:
                        pixels[f"{x}{y}"] = f"({checkers}".rjust(3) + ")"
                    checkers -= 1
                else:
                    pixels[f"{x}{y}"] = "    "
        pixels["dcs"] = f" {self.dice[0] or ' '}:{self.dice[1] or ' '} "

        for player in [0, 1]:
            pixel = self._render_player(player=player)
            for idx in range(15):
                h = pixel + "h" + chr(ord("a") + idx)
                pixels[h] = f"  {pixel}  " if self.home[player] > idx else "     "

        pixels["ost"] = "     "
        pixels["xst"] = "     "
        if self.step == Game.Step.ROLL:
            pixels["ost" if self.pturn == 0 else "xst"] = " ROL "
        elif self.step == Game.Step.TURN:
            pixels["ost" if self.pturn == 0 else "xst"] = f" ({len([d for d in self.dice if d])}) "
        elif self.step == Game.Step.FINISHED:
            pixels["ost" if self.pturn == 0 else "xst"] = " WIN "

        return template.format(**pixels)

In [7]:
# def validate(game: Game):
#     p1, p2 = 0, 0
#     for i in range(len(game.board)):
#         if game.board[i] > 0:
#             p1 += game.board[i]
#         if game.board[i] < 0:
#             p2 -= game.board[i]
#     p1 += game.home[0]
#     p2 += game.home[1]
#     assert (p1, p2) == (15, 15), "invalid # of checkers at the board"

In [8]:
# def random_move(game: Game):
#     if not game.is_finished():
#         moves = game.get_valid_moves()
#         if len(moves) > 0:
#             pos, steps = random.choice(moves)
#             game.turn(pos, steps)
#             return True
#         else:
#             game.skip()
#     return False

# def auto_turn(game: Game):
#     game.roll()
#     for _ in range(len(game.dice)):
#         if not random_move(game):
#             break
#     return game

# def auto_rollout(game, turns: int = 100):
#     for turn in range(turns):
#         auto_turn(game)
#         validate(game)
#         if game.is_finished():
#             break
#     return game

In [9]:
# g = Game(seed=42, verbose=True)
# auto_rollout(g.start(), 100)

## Test cases:

In [10]:
import unittest

In [11]:
class GameTests(unittest.TestCase):

    def test_heads(self):
        g = Game(seed=42).start()
        g.roll(3,3).turn(1,3).turn(1,3) # < can make 2 moves from head on doubles
        with self.assertRaises(RuntimeError):
            g.turn(1,3)
        g.turn(4,3).turn(4,3)
        g.roll(1,2).turn(13,1).turn(14,2) # p2
        g.roll(2,2).turn(1,2)
        with self.assertRaises(RuntimeError):
            g.turn(1, 2)
        g.turn(3,2)
        with self.assertRaises(RuntimeError):
            g.turn(1, 2)
        g.turn(5,2)

    def test_skips(self):
        g = Game(seed=42).start()
        g.roll(6,1).turn(1, 6).turn(7, 1)
        g.roll(3,2).turn(13, 3).turn(16, 2)
        g.roll(2,2).turn(1, 2).turn(3, 2).turn(5, 2).turn(8, 2)
        g.roll(6,1).turn(13,6).turn(18,1)
        g.roll(6,6).turn(10, 6).turn(16,6).turn(1, 6) # < there's no moves for red here
        with self.assertRaises(RuntimeError):
            g.turn(1, 6)
        g.skip()

    def test_blocking_primes(self):
        g = Game(seed=42).start()
        g.roll(1,1).turn(1,1).turn(1,1).turn(2,1).turn(3,1)
        g.roll(6,5).turn(13,6).turn(19,5)
        g.roll(2,3).turn(1,2).turn(3,3)
        g.roll(4,2).turn(13,4).turn(17,2)
        g.roll(3,1).turn(1,1).turn(2,3)
        g.roll(3,6).turn(13,3).turn(16,6)
        g.roll(1,2)
        with self.assertRaises(RuntimeError):
            g.turn(1,2)

    def test_nonblocking_primes(self):
        g = Game(seed=42).start()
        g.roll(1,1).turn(1,1).turn(1,1).turn(2,1).turn(3,1)
        g.roll(6,5).turn(13,6).turn(19,5)
        g.roll(2,3).turn(1,2).turn(3,3)
        g.roll(4,2).turn(13,4).turn(17,2)
        g.roll(3,1).turn(1,1).turn(2,3)
        g.roll(3,6).turn(24,3).turn(3,6)
        g.roll(1,2).turn(1,2) # < valid move, since x is ahead

unittest.main(argv=[''], verbosity=2, exit=False)

test_blocking_primes (__main__.GameTests) ... ok
test_heads (__main__.GameTests) ... ok
test_nonblocking_primes (__main__.GameTests) ... ok
test_skips (__main__.GameTests) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.003s

OK


<unittest.main.TestProgram at 0x16a41dba0>

## Player interface & automation

In [12]:
class BasePlayer:
    def play_turn(self, game: Game) -> bool:
        pass

class RandomPlayer(BasePlayer):
    def play_turn(self, game: Game) -> bool:
        actions = game.get_valid_moves()
        if len(actions) > 0:
            pos, steps = random.choice(actions)
            game.turn(pos, steps)
            return True
        else:
            game.skip()
            return False

class LazyPlayer(BasePlayer):
    def play_turn(self, game: Game) -> bool:
        actions = game.get_valid_moves()
        if actions:
            pos, steps = actions[0]
            game.turn(pos, steps)
            return True
        else:
            game.skip()
            return False

In [13]:
class AutoGame(Game):
    def __init__(self, player1: BasePlayer | None = None, player2: BasePlayer | None = None, start: bool = True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.auto_player1 = player1
        self.auto_player2 = player2
        if start:
            self.start()

    def _automate(self) -> "AutoGame":
        while not self.is_finished():
            if self.step == Game.Step.ROLL:
                self.roll()
            if self.step == Game.Step.TURN:
                if not self.has_valid_moves():
                    self.skip()
                else:
                    player = self.auto_player1 if (self.pturn == 0) else self.auto_player2
                    if player:
                        player.play_turn(self)
                    else:
                        break
        return self
   
    def start(self, *args, **kwargs) -> "AutoGame":
        super().start(*args, **kwargs)
        return self._automate()

    def turn(self, *args, **kwargs) -> "AutoGame":
        super().turn(*args, **kwargs)
        return self._automate()

    def play_sequence(self, turns: List[Tuple[int, int]]) -> "AutoGame":
        for (pos, steps) in turns:
            self.turn(pos, steps)
        return self

In [14]:
class AutoGameTests(unittest.TestCase):

    def test_basic(self):
        AutoGame(seed=324).play_sequence([(1, 6), (1, 6), (13, 1), (14, 4),(7,5),(1,4),(13,1),(18,3)])
        with self.assertRaises(RuntimeError):
            AutoGame(seed=324).play_sequence([(1, 6), (1, 6), (13, 1), (14, 4),(7,5),(1,4),(13,1),(18,3),(4, 1)])

    def test_autoplayer(self):
        AutoGame(seed=324, player2=LazyPlayer()).play_sequence([(1,6), (1,6), (7,5), (1,4)])
        AutoGame(seed=324, player1=LazyPlayer()).play_sequence([(13,4), (17,1)])
        self.assertEqual(AutoGame(seed=324, player1=LazyPlayer(), player2=LazyPlayer()).is_finished(), True)
        self.assertEqual(AutoGame(seed=324, start=False, player1=LazyPlayer(), player2=LazyPlayer()).is_finished(), False)

unittest.main(argv=[''], verbosity=2, exit=False)

test_autoplayer (__main__.AutoGameTests) ... ok
test_basic (__main__.AutoGameTests) ... ok
test_blocking_primes (__main__.GameTests) ... ok
test_heads (__main__.GameTests) ... ok
test_nonblocking_primes (__main__.GameTests) ... ok
test_skips (__main__.GameTests) ... ok

----------------------------------------------------------------------
Ran 6 tests in 0.018s

OK


<unittest.main.TestProgram at 0x106b68460>

In [15]:
@dataclass
class Result:
    winner: int
    turns: int
    reward: int

def summarize(game: Game) -> Result:
    assert game.is_finished()
    reward =  (2 if game.home[game._opponent] == 0 else 1)
    reward *= (1 if game.pturn == 0 else -1)
    result = Result(
        winner = int(game.pturn == 0),
        turns = game.t,
        reward = reward
    )
    return result

def simulate(game: Game, player1: BasePlayer, player2: BasePlayer | None = None) -> Result:
    while not game.is_finished():
        player = player1 if (game.pturn == 0) else player2
        if player:
            player.play_turn(game)
    return summarize(game)

In [16]:
# summarize(AutoGame(player1=LazyPlayer(), player2=LazyPlayer()))

In [17]:
# simulate(AutoGame(player2=LazyPlayer()), player1=LazyPlayer())

In [18]:
# simulate(AutoGame(), player1=LazyPlayer(), player2=LazyPlayer())

In [19]:
import math

games = 100

results = []
for p1, p2, start in tqdm.tqdm(list(itertools.product(["lazy", "rand"], ["lazy", "rand"], ["random", "first", "second"])), leave=False):
    exp_info = {"p1": p1, "p2": p2, "start": start}
    player1 = LazyPlayer() if p1 == "lazy" else RandomPlayer()
    player2 = LazyPlayer() if p2 == "lazy" else RandomPlayer()
    start_args = [6,6] if start == "random" else ([6,1] if start == "first" else [1,6])
    sims = pd.DataFrame([summarize(AutoGame(player1, player2, False).start(*start_args)).__dict__ for _ in tqdm.trange(games, postfix=exp_info, leave=False)])
    exp_info["games"] = games
    exp_info["wins"] = sims["winner"].sum()
    wins_mu = exp_info["wins"] / exp_info["games"]
    wins_sd = round(math.sqrt(exp_info["games"] * wins_mu * (1 - wins_mu)), 2)
    exp_info["win_rate_lo"] = (exp_info["wins"] - wins_sd*3) / exp_info["games"]
    exp_info["win_rate_mu"] = wins_mu
    exp_info["win_rate_hi"] = (exp_info["wins"] + wins_sd*3) / exp_info["games"]
    exp_info["avg_turns"] = sims["turns"].mean()
    exp_info["avg_reward"] = sims["reward"].mean()
    results.append(exp_info)

pd.DataFrame(results)

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=lazy, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=lazy, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=lazy, start=second]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=rand, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=rand, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=lazy, p2=rand, start=second]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=lazy, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=lazy, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=lazy, start=second]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=rand, start=random]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=rand, start=first]

  0%|          | 0/100 [00:00<?, ?it/s, p1=rand, p2=rand, start=second]

Unnamed: 0,p1,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward
0,lazy,lazy,random,100,40,0.253,0.4,0.547,90.85,-0.14
1,lazy,lazy,first,100,48,0.33,0.48,0.63,92.42,0.09
2,lazy,lazy,second,100,43,0.2815,0.43,0.5785,91.19,-0.08
3,lazy,rand,random,100,75,0.6201,0.75,0.8799,92.77,0.73
4,lazy,rand,first,100,89,0.7961,0.89,0.9839,92.75,1.14
5,lazy,rand,second,100,76,0.6319,0.76,0.8881,93.42,0.78
6,rand,lazy,random,100,56,0.4112,0.56,0.7088,94.37,0.12
7,rand,lazy,first,100,46,0.3106,0.46,0.6094,93.48,-0.06
8,rand,lazy,second,100,51,0.36,0.51,0.66,93.15,0.02
9,rand,rand,random,100,48,0.33,0.48,0.63,96.39,-0.07


Interesting bias is observed here, as lazy policy seems to have a significant advantage when played against random.  
Bias only shows up if it plays as player1, but disappears if it plays as player2

TODO: Figure out what's going on here

# Training a model with Q-Learning

In [20]:
@dataclass
class ReplaySample:
    state_action: torch.Tensor
    reward: int = 0
    next_state_actions: torch.Tensor | None = None


In [21]:
class ReplayBuffer:
    def __init__(self, size: 1_000_000):
        self.buffer = [None] * size
        self.insert_ptr = 0
        self.upper_bound = 0

    def add(self, sample: ReplaySample):
        self.buffer[self.insert_ptr] = sample
        self.insert_ptr = self.insert_ptr+1 if self.insert_ptr+1 < len(self.buffer) else 0
        self.upper_bound = max(self.insert_ptr, self.upper_bound)
    
    def sample(self, k: int = 1) -> List[ReplaySample]:
        return random.choices(self.buffer[:self.upper_bound], k=k)
    
    def __getitem__(self, index: int) -> ReplaySample:
        if index < self.upper_bound:
            return self.buffer[index]
        raise IndexError()
    
    def __len__(self) -> int:
        return self.upper_bound

In [22]:
class QPolicy(BasePlayer):
    def __init__(
            self, 
            layers=[32, 64], 
            device="cpu", 
            training: bool = True,
            replay_buffer_size: int = 1_000_000
        ):
        self.device = device
        network = []
        network.append(torch.nn.Linear(31, layers[0]))
        network.append(torch.nn.ReLU())
        for idx in range(1, len(layers)):
            network.append(torch.nn.Linear(layers[idx-1], layers[idx]))
            network.append(torch.nn.ReLU())
        network.append(torch.nn.Linear(layers[-1], 1))
        self.q_network = torch.nn.Sequential(*network).to(self.device)
        self.t_network = torch.nn.Sequential(*network).to(self.device)
        self.training = training
        self.replay_buffer = ReplayBuffer(size=replay_buffer_size)
        # self.q_network.apply(self._init_weights)
        self.optimizer = torch.optim.Adam(self.q_network.parameters())
        self.prev_state_action = None
        self.gamma = 1
        self.lr = 0.001
        self.grad_clip = 10
        self.soft_epsilon = 0.05
        self._sync_networks()

    # @staticmethod
    # def _init_weights(m):
    #     if hasattr(m, "weight"):
    #         torch.nn.init.xavier_uniform_(m.weight, gain=2 ** (1.0 / 2))
    #     if hasattr(m, "bias"):
    #         torch.nn.init.zeros_(m.bias)

    def _sync_networks(self) -> None:
        logger.debug("syncing weights of t-network")
        self.t_network.load_state_dict(self.q_network.state_dict())
    
    def _encode_state(self, game: Game) -> torch.Tensor:
        dice = game.dice.copy()
        if len(dice) < 4:
            dice += [0] * (4 - len(dice))
        return torch.concat([
            torch.tensor(game.board),
            torch.tensor(dice),
            torch.tensor([game.head_moves])
        ]).to(self.device).float()

    def _encode_actions(self, actions: List[Tuple[int, int]]) -> torch.Tensor:
        if not actions:
            actions = [(-1,-1)]
        return torch.tensor(actions).to(self.device).float()
    
    def _encode_state_actions(self, game: Game) -> torch.Tensor:
        return self._concat_state_actions(
            self._encode_state(game),
            self._encode_actions(game.get_valid_moves())
        )

    def _concat_state_actions(self, state: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        actions = actions.view(-1, 2)
        return torch.concat([
                state.unsqueeze(dim=0).broadcast_to((actions.shape[0], -1)),
                actions
            ], dim=1)
   
    def _sample_batch(self, batch_size: int = 32):
        rewards = []
        curr_state_actions = []
        next_state_actions = []
        next_state_actions_idx = []
        for sample_id, sample in enumerate(self.replay_buffer.sample(k=batch_size)):
            rewards.append(sample.reward)
            curr_state_actions.append(sample.state_action)
            if sample.next_state_actions is not None:
                next_state_actions.append(sample.next_state_actions)
                next_state_actions_idx.extend([sample_id] * sample.next_state_actions.shape[0])

        rewards = torch.tensor(rewards).float().to(self.device)
        curr_state_actions = torch.vstack(curr_state_actions)
        next_state_actions = torch.vstack(next_state_actions)
        next_state_actions_idx = torch.tensor(next_state_actions_idx, dtype=torch.long).to(self.device)
        assert next_state_actions.shape[0] == next_state_actions_idx.shape[0]
        return curr_state_actions, rewards, next_state_actions, next_state_actions_idx
    
    def _calc_loss(self, q_scores, t_scores, sample_ids, rewards) -> torch.Tensor:
        t_scores_max = rewards.scatter_reduce(dim=0, index=sample_ids, src=t_scores.squeeze(), reduce="max", include_self=False)
        # t_scores_max, t_score_idx = torch_scatter.scatter_max(t_scores.squeeze(), index=sample_ids, dim=0)
        td_target = rewards + self.gamma * t_scores_max
        td_error = td_target - q_scores.squeeze()
        return td_error.pow(2).mean()
    
    def _train_step(self, batch_size: int = 32) -> None:
        if len(self.replay_buffer) < batch_size:
            logger.warning("not enough samples in replay buffer, skipping train step")

        self.optimizer.zero_grad()
    
        curr_sa, rewards, next_sa, sample_ids  = self._sample_batch(batch_size)

        q_scores = self.q_network(curr_sa)
        with torch.no_grad():
            t_scores = self.t_network(next_sa)

        loss = self._calc_loss(q_scores, t_scores, sample_ids, rewards)

        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), self.grad_clip)
            
        for group in self.optimizer.param_groups:
            group["lr"] = self.lr
        
        self.optimizer.step()

        return loss.item(), grad_norm.item()


    def play_turn(self, game: Game):
        valid_moves = list(game.get_valid_moves())
        if not valid_moves:
            game.skip()
            return False

        with torch.no_grad():
            state_actions = self._encode_state_actions(game)

            if self.training and self.prev_state_action is not None:
                self.replay_buffer.add(ReplaySample(state_action=self.prev_state_action, reward=self._calc_reward(game), next_state_actions=state_actions))
                self.prev_state_action = None
        
            scores = self.q_network(state_actions).squeeze()
            action_idx = scores.argmax(dim=-1).item()
            if np.random.random() < self.soft_epsilon:
                action_idx = random.randint(0, len(valid_moves)-1)
            self.prev_state_action = state_actions[action_idx]

            game.turn(*valid_moves[action_idx])

            if game.is_finished() and self.training and self.prev_state_action is not None:
                self.replay_buffer.add(ReplaySample(state_action=self.prev_state_action, reward=self._calc_reward(game)))
                self.prev_state_action = None

            # self._train_step()

        return True

    def _calc_reward(self, game: Game) -> int:
        reward = 0
        if game.is_finished():
            reward  = (2 if game.home[game._opponent] == 0 else 1)
            reward *= (1 if game.pturn == 0 else -1)
        return reward


### Training and evaluation

In [23]:
def practice(
        policy: QPolicy, 
        games: int = 100,
        train_every: int = 1, 
        sync_every: int = 100,
        batch_size = 32,
        show_progress : bool = False, 
        gamma: float = 0.99, 
        lr: float = 0.001,
        grad_clip: float = 10,
        soft_epsilon: float = 0
    ):
    policy.training = True
    policy.gamma = gamma
    policy.lr = lr
    policy.grad_clip = grad_clip
    policy.soft_epsilon = soft_epsilon

    loss_vals = []
    grad_vals = []
    for game_id in (tqdm.trange(games, leave=False, desc="practicing") if show_progress else range(games)):
        simulate(AutoGame(), player1=policy, player2=RandomPlayer())
        if game_id % train_every == 0:
            loss, grad = policy._train_step(batch_size=batch_size)
            loss_vals.append(loss)
            grad_vals.append(grad)
        if game_id % sync_every == 0:
            policy._sync_networks()
    return sum(loss_vals) / len(loss_vals), sum(grad_vals) / len(grad_vals)

In [24]:
def evaluate(model_name: str, policy: QPolicy, games: int = 100):
    prev_training = policy.training
    policy.training = False

    results = []
    sims = pd.DataFrame([simulate(AutoGame(player2=RandomPlayer()), player1=policy).__dict__ for _ in tqdm.trange(games, leave=False, desc="evaluating")])
    exp_info = {"model": model_name, "p2": "random", "start": "random"}
    exp_info["games"] = games
    exp_info["wins"] = sims["winner"].sum()
    wins_mu = exp_info["wins"] / exp_info["games"]
    wins_sd = round(math.sqrt(exp_info["games"] * wins_mu * (1 - wins_mu)), 2)
    exp_info["win_rate_lo"] = (exp_info["wins"] - wins_sd*3) / exp_info["games"]
    exp_info["win_rate_mu"] = wins_mu
    exp_info["win_rate_hi"] = (exp_info["wins"] + wins_sd*3) / exp_info["games"]
    exp_info["avg_turns"] = sims["turns"].mean()
    exp_info["avg_reward"] = sims["reward"].mean()
    results.append(exp_info)

    policy.training = prev_training
    return pd.DataFrame(results)

In [25]:
def train_eval_loop(
    policy: QPolicy,
    epochs: int = 1000,
    practice_games: int = 1000,
    batch_size: int = 32,
    eval_games: int = 100,
    sync_every: int = 50,
    **kwargs
):
    epoch_pbar = tqdm.trange(1, epochs+1, desc="train/eval epochs")
    results = []
    result = evaluate(f"untrained", policy, games=eval_games).loc[0].to_dict()
    epoch_pbar.set_postfix({"win_rate": result["win_rate_mu"], "avg_reward": result["avg_reward"]})
    logger.info(f"untrained: win_rate={result['win_rate_mu']:.4%}, avg_reward={result['avg_reward']:.2f}")
    results.append(result)
    for epoch_id in epoch_pbar:
        avg_loss, avg_grad = practice(policy, games=practice_games, train_every=1, sync_every=sync_every, batch_size=batch_size, show_progress=True, **kwargs)
        epoch_pbar.set_postfix({"win_rate": result["win_rate_mu"], "avg_reward": result["avg_reward"], "avg_loss": avg_loss, "avg_grad": avg_grad})
        result = evaluate(f"epoch-{epoch_id}", policy, games=eval_games).loc[0].to_dict()
        result["avg_loss"] = avg_loss
        result["avg_grad"] = avg_grad
        results.append(result)
        epoch_pbar.set_postfix({"win_rate": result["win_rate_mu"], "avg_reward": result["avg_reward"], "avg_loss": avg_loss, "avg_grad": avg_grad})
        logger.info(f"epoch={epoch_id}: win_rate={result['win_rate_mu']:.4%}, avg_reward={result['avg_reward']:.2f}, {avg_loss=:.4f}, {avg_grad=:.4f}")
    results = pd.DataFrame(results)
    return results

In [26]:
nn_player = QPolicy(layers=[32, 64, 32], device="cpu")

In [27]:
train_eval_args = {"epochs": 10, "practice_games": 5000, "batch_size": 32, "eval_games": 1000, "sync_every": 2500}

In [28]:
train_eval_loop(nn_player, lr=0.001, gamma=0.99, grad_clip=10, soft_epsilon=0.5, **train_eval_args)

train/eval epochs:   0%|          | 0/10 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:08:31,020 - research - INFO - untrained: win_rate=48.1000%, avg_reward=-0.09


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:10:45,229 - research - INFO - epoch=1: win_rate=58.6000%, avg_reward=0.25, avg_loss=0.0065, avg_grad=0.1722


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:13:25,119 - research - INFO - epoch=2: win_rate=50.1000%, avg_reward=0.01, avg_loss=0.0046, avg_grad=0.1457


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:16:14,392 - research - INFO - epoch=3: win_rate=45.1000%, avg_reward=-0.15, avg_loss=0.0086, avg_grad=0.2577


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:19:03,641 - research - INFO - epoch=4: win_rate=37.6000%, avg_reward=-0.34, avg_loss=0.0044, avg_grad=0.1962


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:21:53,171 - research - INFO - epoch=5: win_rate=43.9000%, avg_reward=-0.18, avg_loss=0.0048, avg_grad=0.1982


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:24:42,824 - research - INFO - epoch=6: win_rate=42.4000%, avg_reward=-0.22, avg_loss=0.0049, avg_grad=0.2018


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:27:32,892 - research - INFO - epoch=7: win_rate=52.2000%, avg_reward=0.07, avg_loss=0.0049, avg_grad=0.2047


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:30:23,897 - research - INFO - epoch=8: win_rate=49.7000%, avg_reward=0.02, avg_loss=0.0051, avg_grad=0.2082


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:33:14,868 - research - INFO - epoch=9: win_rate=45.7000%, avg_reward=-0.13, avg_loss=0.0049, avg_grad=0.1916


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:36:06,808 - research - INFO - epoch=10: win_rate=56.0000%, avg_reward=0.17, avg_loss=0.0046, avg_grad=0.1844


Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward,avg_loss,avg_grad
0,untrained,random,random,1000,481,0.4336,0.481,0.5284,94.914,-0.088,,
1,epoch-1,random,random,1000,586,0.53926,0.586,0.63274,94.555,0.254,0.006483,0.172234
2,epoch-2,random,random,1000,501,0.45357,0.501,0.54843,94.908,0.014,0.004611,0.145685
3,epoch-3,random,random,1000,451,0.40378,0.451,0.49822,95.94,-0.151,0.008641,0.257674
4,epoch-4,random,random,1000,376,0.33004,0.376,0.42196,95.91,-0.344,0.004424,0.196203
5,epoch-5,random,random,1000,439,0.39193,0.439,0.48607,95.127,-0.178,0.004755,0.198166
6,epoch-6,random,random,1000,424,0.37711,0.424,0.47089,95.74,-0.218,0.004941,0.201841
7,epoch-7,random,random,1000,522,0.4746,0.522,0.5694,94.523,0.07,0.004942,0.204653
8,epoch-8,random,random,1000,497,0.44957,0.497,0.54443,94.562,0.016,0.005121,0.208213
9,epoch-9,random,random,1000,457,0.40975,0.457,0.50425,95.934,-0.127,0.004889,0.191605


In [29]:
train_eval_loop(nn_player, lr=0.0005, gamma=0.99, grad_clip=10, soft_epsilon=0.25, **train_eval_args)

train/eval epochs:   0%|          | 0/10 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:36:26,470 - research - INFO - untrained: win_rate=59.2000%, avg_reward=0.27


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:39:16,414 - research - INFO - epoch=1: win_rate=57.2000%, avg_reward=0.17, avg_loss=0.0043, avg_grad=0.1896


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:42:07,175 - research - INFO - epoch=2: win_rate=55.2000%, avg_reward=0.15, avg_loss=0.0051, avg_grad=0.2284


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:44:57,989 - research - INFO - epoch=3: win_rate=53.6000%, avg_reward=0.10, avg_loss=0.0054, avg_grad=0.2533


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:47:48,883 - research - INFO - epoch=4: win_rate=52.5000%, avg_reward=0.07, avg_loss=0.0056, avg_grad=0.2616


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:50:40,095 - research - INFO - epoch=5: win_rate=57.6000%, avg_reward=0.22, avg_loss=0.0054, avg_grad=0.2525


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:53:31,493 - research - INFO - epoch=6: win_rate=60.3000%, avg_reward=0.27, avg_loss=0.0055, avg_grad=0.2578


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:56:23,420 - research - INFO - epoch=7: win_rate=58.1000%, avg_reward=0.24, avg_loss=0.0054, avg_grad=0.2555


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 03:59:15,398 - research - INFO - epoch=8: win_rate=56.8000%, avg_reward=0.19, avg_loss=0.0055, avg_grad=0.2675


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:02:06,918 - research - INFO - epoch=9: win_rate=51.6000%, avg_reward=0.04, avg_loss=0.0056, avg_grad=0.2625


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:04:59,087 - research - INFO - epoch=10: win_rate=60.5000%, avg_reward=0.30, avg_loss=0.0057, avg_grad=0.2647


Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward,avg_loss,avg_grad
0,untrained,random,random,1000,592,0.54538,0.592,0.63862,95.108,0.267,,
1,epoch-1,random,random,1000,572,0.52505,0.572,0.61895,95.438,0.171,0.004335,0.189551
2,epoch-2,random,random,1000,552,0.50481,0.552,0.59919,95.737,0.15,0.005057,0.228371
3,epoch-3,random,random,1000,536,0.48869,0.536,0.58331,95.374,0.101,0.005426,0.253321
4,epoch-4,random,random,1000,525,0.47763,0.525,0.57237,95.14,0.067,0.005642,0.261554
5,epoch-5,random,random,1000,576,0.52911,0.576,0.62289,96.508,0.217,0.005436,0.252477
6,epoch-6,random,random,1000,603,0.55659,0.603,0.64941,95.805,0.27,0.005541,0.2578
7,epoch-7,random,random,1000,581,0.5342,0.581,0.6278,96.126,0.24,0.005437,0.255523
8,epoch-8,random,random,1000,568,0.52102,0.568,0.61498,95.963,0.187,0.005455,0.267494
9,epoch-9,random,random,1000,516,0.4686,0.516,0.5634,96.657,0.044,0.005553,0.262498


In [30]:
train_eval_loop(nn_player, lr=0.00025, gamma=0.99, grad_clip=10, soft_epsilon=0.10, **train_eval_args)

train/eval epochs:   0%|          | 0/10 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:05:18,921 - research - INFO - untrained: win_rate=61.4000%, avg_reward=0.30


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:08:10,477 - research - INFO - epoch=1: win_rate=52.8000%, avg_reward=0.07, avg_loss=0.0050, avg_grad=0.2533


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:11:04,488 - research - INFO - epoch=2: win_rate=54.4000%, avg_reward=0.14, avg_loss=0.0059, avg_grad=0.2979


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:13:56,501 - research - INFO - epoch=3: win_rate=60.7000%, avg_reward=0.30, avg_loss=0.0058, avg_grad=0.2955


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:16:48,581 - research - INFO - epoch=4: win_rate=46.3000%, avg_reward=-0.11, avg_loss=0.0058, avg_grad=0.3011


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:19:41,639 - research - INFO - epoch=5: win_rate=58.2000%, avg_reward=0.24, avg_loss=0.0061, avg_grad=0.3205


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:22:34,047 - research - INFO - epoch=6: win_rate=59.5000%, avg_reward=0.22, avg_loss=0.0058, avg_grad=0.3157


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:25:26,480 - research - INFO - epoch=7: win_rate=58.6000%, avg_reward=0.23, avg_loss=0.0058, avg_grad=0.3083


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:28:19,383 - research - INFO - epoch=8: win_rate=56.8000%, avg_reward=0.17, avg_loss=0.0058, avg_grad=0.3122


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:31:12,483 - research - INFO - epoch=9: win_rate=61.9000%, avg_reward=0.30, avg_loss=0.0057, avg_grad=0.3158


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:34:05,750 - research - INFO - epoch=10: win_rate=55.0000%, avg_reward=0.13, avg_loss=0.0061, avg_grad=0.3289


Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward,avg_loss,avg_grad
0,untrained,random,random,1000,614,0.56783,0.614,0.66017,95.434,0.303,,
1,epoch-1,random,random,1000,528,0.48063,0.528,0.57537,95.94,0.075,0.005027,0.253329
2,epoch-2,random,random,1000,544,0.49675,0.544,0.59125,96.175,0.137,0.005853,0.297899
3,epoch-3,random,random,1000,607,0.56065,0.607,0.65335,96.106,0.299,0.005757,0.295506
4,epoch-4,random,random,1000,463,0.41569,0.463,0.51031,96.651,-0.11,0.005832,0.301061
5,epoch-5,random,random,1000,582,0.5352,0.582,0.6288,95.467,0.244,0.006102,0.320532
6,epoch-6,random,random,1000,595,0.54844,0.595,0.64156,95.365,0.219,0.005818,0.315676
7,epoch-7,random,random,1000,586,0.53926,0.586,0.63274,95.277,0.225,0.005775,0.308308
8,epoch-8,random,random,1000,568,0.52102,0.568,0.61498,96.031,0.175,0.005803,0.312226
9,epoch-9,random,random,1000,619,0.57292,0.619,0.66508,95.883,0.305,0.005747,0.315775


In [31]:
train_eval_loop(nn_player, lr=0.00015, gamma=0.99, grad_clip=10, soft_epsilon=0.05, **train_eval_args)

train/eval epochs:   0%|          | 0/10 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:34:25,441 - research - INFO - untrained: win_rate=58.5000%, avg_reward=0.23


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:37:19,680 - research - INFO - epoch=1: win_rate=55.5000%, avg_reward=0.13, avg_loss=0.0059, avg_grad=0.3182


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:40:12,834 - research - INFO - epoch=2: win_rate=58.0000%, avg_reward=0.22, avg_loss=0.0060, avg_grad=0.3250


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:43:06,082 - research - INFO - epoch=3: win_rate=62.6000%, avg_reward=0.33, avg_loss=0.0061, avg_grad=0.3364


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:45:59,456 - research - INFO - epoch=4: win_rate=62.9000%, avg_reward=0.32, avg_loss=0.0065, avg_grad=0.3566


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:48:52,831 - research - INFO - epoch=5: win_rate=62.1000%, avg_reward=0.34, avg_loss=0.0062, avg_grad=0.3496


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:51:46,158 - research - INFO - epoch=6: win_rate=66.0000%, avg_reward=0.42, avg_loss=0.0062, avg_grad=0.3495


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:54:39,615 - research - INFO - epoch=7: win_rate=57.6000%, avg_reward=0.22, avg_loss=0.0062, avg_grad=0.3509


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 04:57:33,384 - research - INFO - epoch=8: win_rate=62.3000%, avg_reward=0.35, avg_loss=0.0063, avg_grad=0.3539


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:00:27,669 - research - INFO - epoch=9: win_rate=60.0000%, avg_reward=0.28, avg_loss=0.0062, avg_grad=0.3633


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:03:21,252 - research - INFO - epoch=10: win_rate=65.1000%, avg_reward=0.43, avg_loss=0.0061, avg_grad=0.3533


Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward,avg_loss,avg_grad
0,untrained,random,random,1000,585,0.53826,0.585,0.63174,95.546,0.233,,
1,epoch-1,random,random,1000,555,0.50784,0.555,0.60216,95.852,0.129,0.005909,0.318177
2,epoch-2,random,random,1000,580,0.53317,0.58,0.62683,95.7,0.216,0.005953,0.324958
3,epoch-3,random,random,1000,626,0.5801,0.626,0.6719,95.868,0.333,0.006132,0.336373
4,epoch-4,random,random,1000,629,0.58316,0.629,0.67484,95.678,0.317,0.006497,0.356565
5,epoch-5,random,random,1000,621,0.57498,0.621,0.66702,95.208,0.337,0.006236,0.349558
6,epoch-6,random,random,1000,660,0.61506,0.66,0.70494,95.159,0.424,0.00615,0.34955
7,epoch-7,random,random,1000,576,0.52911,0.576,0.62289,95.796,0.218,0.00616,0.350875
8,epoch-8,random,random,1000,623,0.57701,0.623,0.66899,95.863,0.355,0.006264,0.353914
9,epoch-9,random,random,1000,600,0.55353,0.6,0.64647,95.581,0.277,0.006166,0.363265


In [32]:
train_eval_loop(nn_player, lr=0.00010, gamma=0.99, grad_clip=10, soft_epsilon=0.01, **train_eval_args)

train/eval epochs:   0%|          | 0/10 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:03:40,761 - research - INFO - untrained: win_rate=66.0000%, avg_reward=0.43


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:06:33,861 - research - INFO - epoch=1: win_rate=60.3000%, avg_reward=0.29, avg_loss=0.0062, avg_grad=0.3649


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:09:27,611 - research - INFO - epoch=2: win_rate=62.9000%, avg_reward=0.38, avg_loss=0.0064, avg_grad=0.3801


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:12:18,658 - research - INFO - epoch=3: win_rate=63.0000%, avg_reward=0.36, avg_loss=0.0065, avg_grad=0.3785


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:15:09,257 - research - INFO - epoch=4: win_rate=63.4000%, avg_reward=0.37, avg_loss=0.0065, avg_grad=0.3721


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:18:00,201 - research - INFO - epoch=5: win_rate=58.4000%, avg_reward=0.27, avg_loss=0.0064, avg_grad=0.3714


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:20:51,274 - research - INFO - epoch=6: win_rate=61.4000%, avg_reward=0.34, avg_loss=0.0064, avg_grad=0.3821


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:23:42,385 - research - INFO - epoch=7: win_rate=63.9000%, avg_reward=0.39, avg_loss=0.0066, avg_grad=0.4001


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:26:33,820 - research - INFO - epoch=8: win_rate=59.5000%, avg_reward=0.28, avg_loss=0.0064, avg_grad=0.3906


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:29:25,768 - research - INFO - epoch=9: win_rate=62.6000%, avg_reward=0.37, avg_loss=0.0067, avg_grad=0.4051


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:32:17,312 - research - INFO - epoch=10: win_rate=65.3000%, avg_reward=0.43, avg_loss=0.0063, avg_grad=0.3912


Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward,avg_loss,avg_grad
0,untrained,random,random,1000,660,0.61506,0.66,0.70494,95.198,0.426,,
1,epoch-1,random,random,1000,603,0.55659,0.603,0.64941,95.866,0.288,0.006175,0.36491
2,epoch-2,random,random,1000,629,0.58316,0.629,0.67484,95.528,0.379,0.006377,0.380117
3,epoch-3,random,random,1000,630,0.58419,0.63,0.67581,95.806,0.358,0.006519,0.378514
4,epoch-4,random,random,1000,634,0.58831,0.634,0.67969,94.703,0.37,0.006541,0.372105
5,epoch-5,random,random,1000,584,0.53723,0.584,0.63077,95.792,0.267,0.006402,0.371419
6,epoch-6,random,random,1000,614,0.56783,0.614,0.66017,95.112,0.338,0.006426,0.3821
7,epoch-7,random,random,1000,639,0.59343,0.639,0.68457,95.128,0.387,0.006568,0.400136
8,epoch-8,random,random,1000,595,0.54844,0.595,0.64156,95.576,0.281,0.006385,0.390585
9,epoch-9,random,random,1000,626,0.5801,0.626,0.6719,94.745,0.369,0.006714,0.405099


In [33]:
train_eval_loop(nn_player, lr=0.00005, gamma=0.99, grad_clip=10, soft_epsilon=0.00, **train_eval_args)

train/eval epochs:   0%|          | 0/10 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:32:36,858 - research - INFO - untrained: win_rate=64.1000%, avg_reward=0.41


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:35:28,516 - research - INFO - epoch=1: win_rate=63.8000%, avg_reward=0.39, avg_loss=0.0065, avg_grad=0.3972


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:38:20,102 - research - INFO - epoch=2: win_rate=62.4000%, avg_reward=0.36, avg_loss=0.0066, avg_grad=0.4010


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:41:11,308 - research - INFO - epoch=3: win_rate=62.6000%, avg_reward=0.34, avg_loss=0.0065, avg_grad=0.4018


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:44:03,030 - research - INFO - epoch=4: win_rate=61.7000%, avg_reward=0.33, avg_loss=0.0063, avg_grad=0.3900


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:46:54,359 - research - INFO - epoch=5: win_rate=63.7000%, avg_reward=0.38, avg_loss=0.0064, avg_grad=0.4037


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:49:45,584 - research - INFO - epoch=6: win_rate=61.3000%, avg_reward=0.34, avg_loss=0.0067, avg_grad=0.4217


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:52:37,003 - research - INFO - epoch=7: win_rate=65.1000%, avg_reward=0.41, avg_loss=0.0068, avg_grad=0.4292


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:55:28,689 - research - INFO - epoch=8: win_rate=64.0000%, avg_reward=0.38, avg_loss=0.0066, avg_grad=0.4141


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 05:58:20,120 - research - INFO - epoch=9: win_rate=63.3000%, avg_reward=0.35, avg_loss=0.0063, avg_grad=0.3999


practicing:   0%|          | 0/5000 [00:00<?, ?it/s]

evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-12-11 06:01:11,441 - research - INFO - epoch=10: win_rate=60.6000%, avg_reward=0.32, avg_loss=0.0066, avg_grad=0.4128


Unnamed: 0,model,p2,start,games,wins,win_rate_lo,win_rate_mu,win_rate_hi,avg_turns,avg_reward,avg_loss,avg_grad
0,untrained,random,random,1000,641,0.59549,0.641,0.68651,94.472,0.412,,
1,epoch-1,random,random,1000,638,0.5924,0.638,0.6836,94.889,0.386,0.006541,0.397195
2,epoch-2,random,random,1000,624,0.57804,0.624,0.66996,95.484,0.361,0.006562,0.401
3,epoch-3,random,random,1000,626,0.5801,0.626,0.6719,95.293,0.341,0.00648,0.401813
4,epoch-4,random,random,1000,617,0.57089,0.617,0.66311,95.499,0.333,0.00626,0.389964
5,epoch-5,random,random,1000,637,0.59137,0.637,0.68263,95.238,0.382,0.006413,0.403744
6,epoch-6,random,random,1000,613,0.5668,0.613,0.6592,95.146,0.337,0.006719,0.421741
7,epoch-7,random,random,1000,651,0.60579,0.651,0.69621,95.531,0.414,0.006843,0.429195
8,epoch-8,random,random,1000,640,0.59446,0.64,0.68554,95.301,0.379,0.006582,0.414132
9,epoch-9,random,random,1000,633,0.58728,0.633,0.67872,95.653,0.354,0.006304,0.399889


### Experiments with scatter & reduce:

When executed on MPS, scatter_reduce implementation fails with the following error:

NotImplementedError: The operator 'aten::scatter_reduce.two_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [34]:
# q_score = nn_player.q_network(curr_state_actions)

# with torch.no_grad():
#     t_score = nn_player.t_network(next_state_actions).squeeze()
#     t_score_max = torch.zeros_like(rewards).scatter_reduce(0, index=next_state_actions_idx, src=t_score, reduce="max", include_self=False)
    
# t_score_max

And this implementation fails on MPS with another error:

RuntimeError: src.device().is_cpu() INTERNAL ASSERT FAILED at "csrc/cpu/scatter_cpu.cpp":11, please report a bug to PyTorch. src must be CPU tensor

In [35]:
# q_score = nn_player.q_network(curr_state_actions)

# with torch.no_grad():
#     t_score = nn_player.t_network(next_state_actions).squeeze()
#     t_score_max, t_score_idx = torch_scatter.scatter_max(t_score, index=next_state_actions_idx, dim=0)
    
# t_score_max

In [36]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=next_state_actions_idx, dim=0)
# max_score, max_score_idx

In [37]:
# torch.zeros_like(rewards).scatter_reduce(0, index=next_state_actions_idx, src=score.squeeze(), reduce="max", include_self=False)

In [38]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=next_state_actions_idx, dim=0)
# max_score, max_score_idx

In [39]:
# data = torch.arange(24).view(-1,6).long()
# data

In [40]:
# score = data.float().sum(dim=1)
# score

In [41]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=torch.tensor([0,0,0,1]), dim=0)
# max_score, max_score_idx

In [42]:
# data[max_score_idx]