In [None]:
import math
import random
import time
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm, trange


In [None]:
# -------------------------
# config + helpers
# -------------------------

@dataclass
class TrainConfig:
    steps: int = 2000
    games_per_step: int = 64

    beta: float = 0.02
    value_coef: float = 1.0
    grad_clip: float = 1.0

    # exploration schedule
    temp_start: float = 1.5
    temp_end: float = 0.3
    eps_start: float = 0.10
    eps_end: float = 0.02

    # logging / eval
    eval_every: int = 100
    eval_games: int = 200

    # printing
    print_every: int = 100         # print EVERY step (line-by-line)
    show_progress_bar: bool = True  # set True if you still want tqdm bar

    # speed toggles
    compile_model: bool = False  # torch.compile (PyTorch 2.x), sometimes faster


class EMA:
    def __init__(self, alpha: float = 0.05):
        self.alpha = alpha
        self.value: Optional[float] = None

    def update(self, x: float) -> float:
        self.value = x if self.value is None else self.alpha * x + (1 - self.alpha) * self.value
        return self.value

In [None]:
# -------------------------
# TicTacToe environment (single-game, for eval)
# -------------------------

WIN_LINES = [
    (0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
    (0, 3, 6), (1, 4, 7), (2, 5, 8),  # cols
    (0, 4, 8), (2, 4, 6)              # diags
]

@dataclass
class TicTacToe:
    # board: 0 empty, +1 X, -1 O
    board: List[int]
    player: int  # +1 X to move, -1 O to move

    @staticmethod
    def new():
        return TicTacToe(board=[0] * 9, player=+1)

    def legal_moves(self) -> List[int]:
        return [i for i, v in enumerate(self.board) if v == 0]

    def is_terminal(self) -> Tuple[bool, int]:
        """
        Returns (done, winner):
          done: bool
          winner: +1 if X wins, -1 if O wins, 0 if draw/ongoing
        """
        for a, b, c in WIN_LINES:
            s = self.board[a] + self.board[b] + self.board[c]
            if s == 3:
                return True, +1
            if s == -3:
                return True, -1

        if all(v != 0 for v in self.board):
            return True, 0  # draw

        return False, 0

    def step(self, action: int) -> Tuple[bool, int]:
        if self.board[action] != 0:
            raise ValueError(f"Illegal move: {action}")
        self.board[action] = self.player
        done, winner = self.is_terminal()
        if not done:
            self.player *= -1
        return done, winner


In [None]:
# -------------------------
# encoding + model
# -------------------------

def board_to_tokens_perspective(board: List[int], player: int) -> torch.LongTensor:
    """
    Perspective encoding:
      board values are in {0, +1, -1}
      multiply by player => {0, +1 (self), -1 (opponent)}
      map to tokens:
        0 -> 0 (empty)
        +1 -> 1 (self)
        -1 -> 2 (opponent)
    """
    toks = []
    for v in board:
        pv = v * player
        if pv == 0:
            toks.append(0)
        elif pv == +1:
            toks.append(1)
        else:
            toks.append(2)
    return torch.tensor(toks, dtype=torch.long)

def masked_softmax(logits: torch.Tensor, legal_mask: torch.Tensor) -> torch.Tensor:
    """
    logits: [B, 9]
    legal_mask: [B, 9] bool
    """
    masked = logits.masked_fill(~legal_mask, -1e9)
    return F.softmax(masked, dim=-1)

In [None]:
class TinyTTTTransformer(nn.Module):
    def __init__(
        self,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        ff_mult: int = 4,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.token_emb = nn.Embedding(3, d_model)  # empty/self/opp
        self.pos_emb = nn.Embedding(9, d_model)    # 9 squares
        self.drop = nn.Dropout(dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * ff_mult,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        self.policy_head = nn.Linear(d_model, 9)
        self.value_head = nn.Linear(d_model, 1)

    def forward(self, tokens: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        tokens: [B, 9] long in {0,1,2}
        returns:
          logits: [B, 9]
          value:  [B]
        """
        B = tokens.size(0)
        pos = torch.arange(9, device=tokens.device).unsqueeze(0).expand(B, 9)  # [B,9]

        x = self.token_emb(tokens) + self.pos_emb(pos)  # [B,9,d]
        x = self.drop(x)
        x = self.encoder(x)                              # [B,9,d]

        pooled = x.mean(dim=1)                           # [B,d]
        logits = self.policy_head(pooled)                # [B,9]
        value = torch.tanh(self.value_head(pooled)).squeeze(-1)  # [B]
        return logits, value

In [None]:
# -------------------------
# FAST vectorized self-play (the big speedup)
# -------------------------

WIN_LINES_T = torch.tensor([
    [0,1,2],[3,4,5],[6,7,8],
    [0,3,6],[1,4,7],[2,5,8],
    [0,4,8],[2,4,6]
], dtype=torch.long)

def board_to_tokens_perspective_batch(board: torch.Tensor, player: torch.Tensor) -> torch.Tensor:
    """
    board:  (B,9) in {-1,0,+1} (integer)
    player: (B,)  in {-1,+1}
    returns tokens: (B,9) long with mapping empty=0, self=1, opp=2
    """
    rel = board * player[:, None]  # self=+1, opp=-1
    tokens = torch.zeros_like(rel, dtype=torch.long)
    tokens[rel == 1] = 1
    tokens[rel == -1] = 2
    return tokens

def check_winner_batch(board: torch.Tensor) -> torch.Tensor:
    """
    board: (B,9) in {-1,0,+1}
    returns winner: (B,) in {-1,0,+1}
    """
    lines = board[:, WIN_LINES_T.to(board.device)]  # (B,8,3)
    sums = lines.sum(dim=2)                         # (B,8)
    winner = torch.zeros((board.size(0),), dtype=torch.int16, device=board.device)
    winner[(sums == 3).any(dim=1)] = 1
    winner[(sums == -3).any(dim=1)] = -1
    return winner

@torch.no_grad()
def collect_self_play_batch_vectorized(
    model: nn.Module,
    device: torch.device,
    games: int,
    temperature: float,
    epsilon: float,
) -> Dict[str, torch.Tensor]:
    was_training = model.training
    """
    Plays 'games' self-play games in parallel (vectorized) and returns one flat batch.
    Much faster than looping games one-by-one.
    """
    model.eval()

    B = games
    board = torch.zeros((B, 9), dtype=torch.int16, device=device)   # {-1,0,+1}
    player = torch.ones((B,), dtype=torch.int16, device=device)     # start with +1
    done = torch.zeros((B,), dtype=torch.bool, device=device)
    winners = torch.zeros((B,), dtype=torch.int16, device=device)
    lengths = torch.zeros((B,), dtype=torch.int64, device=device)

    states_list: List[torch.Tensor] = []
    actions_list: List[torch.Tensor] = []
    masks_list: List[torch.Tensor] = []
    players_list: List[torch.Tensor] = []
    game_ids_list: List[torch.Tensor] = []

    for _ply in range(9):
        active = ~done
        if not active.any():
            break

        idx = torch.where(active)[0]    # active game indices
        b = board[idx]                  # (A,9)
        p = player[idx]                 # (A,)
        legal_mask = (b == 0)           # (A,9) bool

        tokens = board_to_tokens_perspective_batch(b, p)  # (A,9) long
        logits, _ = model(tokens)                         # (A,9)

        A = idx.numel()
        actions = torch.empty((A,), dtype=torch.long, device=device)

        # epsilon-greedy: some random moves
        use_rand = (torch.rand((A,), device=device) < epsilon)

        if use_rand.any():
            w = legal_mask[use_rand].float()
            actions[use_rand] = torch.multinomial(w, 1).squeeze(1)

        if (~use_rand).any():
            lm = legal_mask[~use_rand]
            lg = logits[~use_rand]

            if temperature <= 0:
                masked = lg.masked_fill(~lm, -1e9)
                actions[~use_rand] = masked.argmax(dim=1)
            else:
                probs = masked_softmax(lg / temperature, lm)
                actions[~use_rand] = torch.multinomial(probs, 1).squeeze(1)

        # record transitions (state before applying action)
        states_list.append(tokens)         # (A,9)
        actions_list.append(actions)       # (A,)
        masks_list.append(legal_mask)      # (A,9)
        players_list.append(p)             # (A,)
        game_ids_list.append(idx)          # (A,)

        # apply moves
        board[idx, actions] = p
        lengths[idx] += 1

        # terminal check
        w = check_winner_batch(board)                 # (B,)
        full = (board != 0).all(dim=1)
        newly_done = (~done) & ((w != 0) | full)

        winners[newly_done] = w[newly_done]           # draw remains 0
        done[newly_done] = True

        # switch player for still-active games
        player[~done] = -player[~done]

    # flatten transitions
    states = torch.cat(states_list, dim=0)     # (N,9) long on device
    actions = torch.cat(actions_list, dim=0)   # (N,) long on device
    masks = torch.cat(masks_list, dim=0)       # (N,9) bool on device
    pl = torch.cat(players_list, dim=0)        # (N,) int16 on device
    gids = torch.cat(game_ids_list, dim=0)     # (N,) long on device

    # z per transition = winner_of_game * player_at_state
    z = winners[gids].to(torch.float32) * pl.to(torch.float32)
    

    if was_training:
          model.train()
          
    return {
				"states": states,
				"actions": actions,
				"masks": masks,
				"z": z,
				"winners": winners.to(torch.int64).cpu(),
				"lengths": lengths.cpu(),
			}


# -------------------------
# loss + eval
# -------------------------

def compute_loss(
    logits: torch.Tensor,
    values: torch.Tensor,
    actions: torch.Tensor,
    legal_mask: torch.Tensor,
    z: torch.Tensor,
    beta: float = 0.01,
    value_coef: float = 1.0,
) -> Tuple[torch.Tensor, Dict[str, float]]:
    probs = masked_softmax(logits, legal_mask)  # [N,9]

    # log pi(a|s)
    a = actions.unsqueeze(1)                    # [N,1]
    pa = probs.gather(1, a).squeeze(1)          # [N]
    logp = torch.log(pa.clamp_min(1e-12))       # stable

    # entropy
    entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=1)  # [N]

    policy_loss = -(z * logp).mean()
    value_loss = F.mse_loss(values, z)
    entropy_bonus = entropy.mean()

    total = policy_loss + value_coef * value_loss - beta * entropy_bonus

    stats = {
        "loss": float(total.item()),
        "policy_loss": float(policy_loss.item()),
        "value_loss": float(value_loss.item()),
        "entropy": float(entropy_bonus.item()),
    }
    return total, stats


@torch.inference_mode()
def eval_vs_random(
    model: nn.Module,
    device: torch.device,
    games: int = 200,
) -> Tuple[float, float, float]:
    model.eval()
    wins = draws = losses = 0

    for g in range(games):
        env = TicTacToe.new()
        model_side = +1 if (g % 2 == 0) else -1

        while True:
            done, winner = env.is_terminal()
            if done:
                if winner == 0:
                    draws += 1
                elif winner == model_side:
                    wins += 1
                else:
                    losses += 1
                break

            legal = env.legal_moves()

            if env.player == model_side:
                legal_mask = torch.zeros(9, dtype=torch.bool)
                legal_mask[legal] = True
                tokens = board_to_tokens_perspective(env.board, env.player).to(device).unsqueeze(0)  # [1,9]
                logits, _ = model(tokens)
                logits = logits.squeeze(0)
                masked = logits.masked_fill(~legal_mask.to(device), -1e9)
                action = int(masked.argmax().item())
            else:
                action = random.choice(legal)

            env.step(action)

    total = wins + draws + losses
    return wins / total, draws / total, losses / total


In [None]:
# -------------------------
# main training loop
# -------------------------

def pick_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def main():
    cfg = TrainConfig()
    device = pick_device()
    print("Device:", device)

    model = TinyTTTTransformer(
        d_model=64,
        n_heads=4,
        n_layers=2,
        ff_mult=4,
        dropout=0.0
    ).to(device)

    if cfg.compile_model and hasattr(torch, "compile"):
        model = torch.compile(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    ema_loss = EMA(alpha=0.05)
    ema_pi = EMA(alpha=0.05)
    ema_v = EMA(alpha=0.05)

    if cfg.show_progress_bar:
        iterator = trange(1, cfg.steps + 1, desc="training", unit="step")
        log = tqdm.write
    else:
        iterator = range(1, cfg.steps + 1)
        log = print

    for step in iterator:
        model.train()
        step_t0 = time.perf_counter()

        # linear schedule
        t = step / cfg.steps
        temperature = cfg.temp_start + t * (cfg.temp_end - cfg.temp_start)
        epsilon = cfg.eps_start + t * (cfg.eps_end - cfg.eps_start)

        # fast vectorized self-play
        sp_t0 = time.perf_counter()
        batch = collect_self_play_batch_vectorized(
            model=model,
            device=device,
            games=cfg.games_per_step,
            temperature=temperature,
            epsilon=epsilon,
        )
        sp_sec = time.perf_counter() - sp_t0

        logits, values = model(batch["states"])
        loss, stats = compute_loss(
            logits=logits,
            values=values,
            actions=batch["actions"],
            legal_mask=batch["masks"],
            z=batch["z"],
            beta=cfg.beta,
            value_coef=cfg.value_coef,
        )

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        optimizer.step()

        winners = batch["winners"]
        x_wins = int((winners == +1).sum().item())
        o_wins = int((winners == -1).sum().item())
        draws = int((winners == 0).sum().item())
        avg_len = float(batch["lengths"].float().mean().item())

        s_loss = ema_loss.update(stats["loss"])
        s_pi = ema_pi.update(stats["policy_loss"])
        s_v = ema_v.update(stats["value_loss"])

        step_sec = time.perf_counter() - step_t0

        # print every step line-by-line (your requested format)
        if (step % cfg.print_every) == 0:
            log(
                f"[step {step:4d}] "
                f"loss {s_loss:.4f} | pi {s_pi:.4f} | v {s_v:.4f} | H {stats['entropy']:.3f} | "
                f"temp {temperature:.2f} | eps {epsilon:.2f} | avgT {avg_len:.2f} | "
                f"X {x_wins:3d} O {o_wins:3d} D {draws:3d} | "
                f"sp_s {sp_sec:.2f} | step_s {step_sec:.2f}"
            )

        # evaluation
        if (step % cfg.eval_every) == 0:
            win, draw, lose = eval_vs_random(model, device, games=cfg.eval_games)
            log(
                f"  eval vs random ({cfg.eval_games}): "
                f"W {win:.2%} D {draw:.2%} L {lose:.2%}"
            )

    torch.save(model.state_dict(), "ttt_transformer.pt")
    print("Saved model to ttt_transformer.pt")


if __name__ == "__main__":
    main()