In [1]:
import os
import sys
import logging
import pathlib

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import logging
logger = logging.getLogger("research")

In [4]:
sys.path.insert(0, str(pathlib.Path.cwd().parent))

In [5]:
import utils.logging
utils.logging.setup(debug=False)

In [6]:
from backgammon.game import *

# Training a model with Q-Learning

In [7]:
@dataclass
class ReplaySample:
    state_action: torch.Tensor
    reward: int = 0
    next_state_actions: torch.Tensor | None = None


In [8]:
class ReplayBuffer:
    def __init__(self, size: 1_000_000):
        self.buffer = [None] * size
        self.insert_ptr = 0
        self.upper_bound = 0

    def add(self, sample: ReplaySample):
        self.buffer[self.insert_ptr] = sample
        self.insert_ptr = self.insert_ptr+1 if self.insert_ptr+1 < len(self.buffer) else 0
        self.upper_bound = max(self.insert_ptr, self.upper_bound)
    
    def sample(self, k: int = 1) -> List[ReplaySample]:
        return random.choices(self.buffer[:self.upper_bound], k=k)
    
    def __getitem__(self, index: int) -> ReplaySample:
        if index < self.upper_bound:
            return self.buffer[index]
        raise IndexError()
    
    def __len__(self) -> int:
        return self.upper_bound

In [11]:
class PGPolicy(BasePlayer):
    def __init__(
            self, 
            device="cpu", 
            replay_buffer_size: int = 1_000_000
        ):
        self.device = device
        self.replay_buffer = ReplayBuffer(size=replay_buffer_size)

    def _encode_state(self, game: Game) -> torch.Tensor:
        dice = game.dice.copy()
        if len(dice) < 4:
            dice += [0] * (4 - len(dice))
        return torch.concat([
            torch.tensor(game.board),
            torch.tensor(dice),
            torch.tensor([game.head_moves])
        ]).to(self.device).float()

    def _encode_actions(self, actions: List[Tuple[int, int]]) -> torch.Tensor:
        if not actions:
            actions = [(-1,-1)]
        return torch.tensor(actions).to(self.device).float()
   
    def _sample_batch(self, batch_size: int = 32):
        pass
    
    def _calc_loss(self, q_scores, t_scores, sample_ids, rewards) -> torch.Tensor:
        pass

    def get_returns(self, rewards) -> torch.Tensor:
        pass

    def get_advantages(self, states: torch.Tensor, returns: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            advantages = (returns - self.baseline_network(states))
        return advantages
    
    def _update_policy(self, batch_size: int = 32) -> None:
        self.optimizer.zero_grad()
        states, actions, rewards = self._sample_batch(batch_size)
        returns = self.get_returns(rewards)
        advantages = self.get_advantages(returns, states)

        ## advantages
        distribution = torch.distributions.Categorical(logits=self.policy(states))
        loss = (-distribution.log_prob(actions) * advantages).mean()
        loss.backward()        
        self.optimizer.step()
        

    def _update_baseline(self, states, returns):
        self.optimizer.zero_grad()
        values = self.baseline_network(states)
        advantages = (returns - values)
        loss = advantages.pow(2).mean()
        loss.backward()
        self.optimizer.step()



    def play_turn(self, game: Game):
        valid_moves = list(game.get_valid_moves())
        if not valid_moves:
            game.skip()
            return False

        with torch.no_grad():
            state_actions = self._encode_state_actions(game)

            if self.training and self.prev_state_action is not None:
                self.replay_buffer.add(ReplaySample(state_action=self.prev_state_action, reward=self._calc_reward(game), next_state_actions=state_actions))
                self.prev_state_action = None
        
            if game._randfloat() < self.soft_epsilon:
                action_idx = game._randint(0, len(valid_moves)-1)
            else:
                scores = self.q_network(state_actions).squeeze()
                action_idx = scores.argmax(dim=-1).item()

            self.prev_state_action = state_actions[action_idx]

            game.turn(*valid_moves[action_idx])

            if game.is_finished() and self.training and self.prev_state_action is not None:
                self.replay_buffer.add(ReplaySample(state_action=self.prev_state_action, reward=self._calc_reward(game)))
                self.prev_state_action = None

            # self._train_step()

        return True

    def _calc_reward(self, game: Game) -> int:
        reward = 0
        if game.is_finished():
            reward  = (2 if game.home[game._opponent] == 0 else 1)
            reward *= (1 if game.pturn == 0 else -1)
        return reward


### Training and evaluation

In [12]:
def practice(
        policy: QPolicy, 
        games: int = 100,
        train_every: int = 1, 
        sync_every: int = 100,
        batch_size = 32,
        show_progress : bool = False, 
        gamma: float = 0.99, 
        lr: float = 0.001,
        grad_clip: float = 10,
        soft_epsilon: float = 0
    ):
    policy.training = True
    policy.gamma = gamma
    policy.lr = lr
    policy.grad_clip = grad_clip
    policy.soft_epsilon = soft_epsilon

    loss_vals = []
    grad_vals = []
    for game_id in (tqdm.trange(games, leave=False, desc="practicing") if show_progress else range(games)):
        simulate(AutoGame(), player1=policy, player2=RandomPlayer())
        if game_id % train_every == 0:
            loss, grad = policy._train_step(batch_size=batch_size)
            loss_vals.append(loss)
            grad_vals.append(grad)
        if game_id % sync_every == 0:
            policy._sync_networks()
    return sum(loss_vals) / len(loss_vals), sum(grad_vals) / len(grad_vals)

In [13]:
def evaluate(model_name: str, policy: QPolicy, games: int = 100):
    prev_training = policy.training
    policy.training = False

    results = []
    sims = pd.DataFrame([simulate(AutoGame(player2=RandomPlayer()), player1=policy).__dict__ for _ in tqdm.trange(games, leave=False, desc="evaluating")])
    exp_info = {"model": model_name, "p2": "random", "start": "random"}
    exp_info["games"] = games
    exp_info["wins"] = sims["winner"].sum()
    wins_mu = exp_info["wins"] / exp_info["games"]
    wins_sd = round(math.sqrt(exp_info["games"] * wins_mu * (1 - wins_mu)), 2)
    exp_info["win_rate_lo"] = (exp_info["wins"] - wins_sd*3) / exp_info["games"]
    exp_info["win_rate_mu"] = wins_mu
    exp_info["win_rate_hi"] = (exp_info["wins"] + wins_sd*3) / exp_info["games"]
    exp_info["avg_turns"] = sims["turns"].mean()
    exp_info["avg_reward"] = sims["reward"].mean()
    results.append(exp_info)

    policy.training = prev_training
    return pd.DataFrame(results)

In [14]:
def train_eval_loop(
    policy: QPolicy,
    epochs: int = 1000,
    practice_games: int = 1000,
    batch_size: int = 32,
    eval_games: int = 100,
    sync_every: int = 50,
    **kwargs
):
    epoch_pbar = tqdm.trange(1, epochs+1, desc="train/eval epochs")
    results = []
    result = evaluate(f"untrained", policy, games=eval_games).loc[0].to_dict()
    epoch_pbar.set_postfix({"win_rate": result["win_rate_mu"], "avg_reward": result["avg_reward"]})
    logger.info(f"untrained: win_rate={result['win_rate_mu']:.4%}, avg_reward={result['avg_reward']:.2f}")
    results.append(result)
    for epoch_id in epoch_pbar:
        avg_loss, avg_grad = practice(policy, games=practice_games, train_every=1, sync_every=sync_every, batch_size=batch_size, show_progress=True, **kwargs)
        epoch_pbar.set_postfix({"win_rate": result["win_rate_mu"], "avg_reward": result["avg_reward"], "avg_loss": avg_loss, "avg_grad": avg_grad})
        result = evaluate(f"epoch-{epoch_id}", policy, games=eval_games).loc[0].to_dict()
        result["avg_loss"] = avg_loss
        result["avg_grad"] = avg_grad
        results.append(result)
        epoch_pbar.set_postfix({"win_rate": result["win_rate_mu"], "avg_reward": result["avg_reward"], "avg_loss": avg_loss, "avg_grad": avg_grad})
        logger.info(f"epoch={epoch_id}: win_rate={result['win_rate_mu']:.4%}, avg_reward={result['avg_reward']:.2f}, {avg_loss=:.4f}, {avg_grad=:.4f}")
    results = pd.DataFrame(results)
    return results

In [15]:
nn_player = QPolicy(layers=[128, 512, 128], device="cpu")

In [16]:
train_eval_args = {"epochs": 100, "practice_games": 1000, "batch_size": 100, "eval_games": 100, "sync_every": 999}

In [None]:
train_eval_loop(nn_player, lr=0.00001, gamma=0.99, grad_clip=3, soft_epsilon=0.1, **train_eval_args)

train/eval epochs:   0%|          | 0/100 [00:00<?, ?it/s]

evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

2024-12-27 15:20:07,111 - backgammon - INFO - untrained: win_rate=88.0000%, avg_reward=1.09


practicing:   0%|          | 0/1000 [00:00<?, ?it/s]

### Experiments with scatter & reduce:

When executed on MPS, scatter_reduce implementation fails with the following error:

NotImplementedError: The operator 'aten::scatter_reduce.two_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [34]:
# q_score = nn_player.q_network(curr_state_actions)

# with torch.no_grad():
#     t_score = nn_player.t_network(next_state_actions).squeeze()
#     t_score_max = torch.zeros_like(rewards).scatter_reduce(0, index=next_state_actions_idx, src=t_score, reduce="max", include_self=False)
    
# t_score_max

And this implementation fails on MPS with another error:

RuntimeError: src.device().is_cpu() INTERNAL ASSERT FAILED at "csrc/cpu/scatter_cpu.cpp":11, please report a bug to PyTorch. src must be CPU tensor

In [35]:
# q_score = nn_player.q_network(curr_state_actions)

# with torch.no_grad():
#     t_score = nn_player.t_network(next_state_actions).squeeze()
#     t_score_max, t_score_idx = torch_scatter.scatter_max(t_score, index=next_state_actions_idx, dim=0)
    
# t_score_max

In [36]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=next_state_actions_idx, dim=0)
# max_score, max_score_idx

In [37]:
# torch.zeros_like(rewards).scatter_reduce(0, index=next_state_actions_idx, src=score.squeeze(), reduce="max", include_self=False)

In [38]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=next_state_actions_idx, dim=0)
# max_score, max_score_idx

In [39]:
# data = torch.arange(24).view(-1,6).long()
# data

In [40]:
# score = data.float().sum(dim=1)
# score

In [41]:
# max_score, max_score_idx = torch_scatter.scatter_max(score, index=torch.tensor([0,0,0,1]), dim=0)
# max_score, max_score_idx

In [42]:
# data[max_score_idx]