In [None]:
import einops
import matplotlib.pyplot as plt
import torch as t
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked
from typing import List, Tuple
patch_typeguard()

In [None]:
%load_ext line_profiler

In [None]:
DEVICE = 'cpu'

In [None]:
Boards = TensorType["batch_size", 3, 3, t.long]

def finished(boards):
    return t.amin(t.amin(boards != 0, dim=-1), dim=-1)

In [None]:
def maybe_typechecked(f):
    return f

In [None]:
@maybe_typechecked
def ttn(tt: TensorType["batch": ..., 3, 3]) -> TensorType["batch": ..., 9]:
    return einops.rearrange(tt, "... i j -> ... (i j)")

@maybe_typechecked
def ntt(n: TensorType["batch": ..., 9]) -> TensorType["batch": ..., 3, 3]:
    return einops.rearrange(n, "... (i j) -> ... i j", i=3, j=3)

In [None]:
@maybe_typechecked
def get_wins(boards : Boards) -> TensorType["batch_size", t.long]:
    by_rows = t.amax(t.amin(boards, dim=2), dim=1)
    by_cols = t.amax(t.amin(boards, dim=1), dim=1)
    by_diag = t.amin(t.diagonal(boards, dim1=-2, dim2=-1), dim=-1)
    by_odiag = t.amin(t.diagonal(boards[:, 2 - t.arange(3), :], dim1=-2, dim2=-1), dim=-1)
    return t.nn.functional.relu(t.max(t.max(by_rows, by_cols), t.max(by_diag, by_odiag)))

@maybe_typechecked
def get_outcomes(boards : Boards) -> TensorType["batch_size", t.long]:
    return get_wins(boards) - get_wins(-boards)

In [None]:
@maybe_typechecked
class TicTacToe(t.nn.Module):
    def __init__(self, num_layers: int, hidden_size: int):
        super().__init__()
        self.embed = t.nn.Sequential(t.nn.Linear(9, hidden_size), t.nn.ReLU())
        self.layers = t.nn.Sequential(*[
            t.nn.Sequential(
                t.nn.Linear(hidden_size, hidden_size),
                t.nn.ReLU(),
            )
            for _ in range(num_layers)
        ])
        self.unembed = t.nn.Linear(hidden_size, 27)

    def input_device(self):
        return self.embed[0].weight.device

    def forward(self, x: Boards) -> TensorType["batch_size", 3, 3, 3, float]:
        x = self.embed(ttn(x.to(t.float)))
        x = self.layers(x)
        x = self.unembed(x)
        return einops.rearrange(x, 'b (i j o) -> b i j o', i=3, j=3)

model = TicTacToe(num_layers=2, hidden_size=32).to(DEVICE)
model(t.zeros(1, 3, 3, dtype=t.long, device=model.input_device()))

In [None]:
@maybe_typechecked
def get_win_probs(model: TicTacToe, boards: Boards) -> Tuple[TensorType["batch_size", 9, 3], TensorType["batch_size", 9]]:
    logits = model(boards)
    logits = einops.rearrange(logits, 'b i j o -> b (i j) o')
    win_probs = t.softmax(logits, dim=-1)[:, :, 2]
    return logits, win_probs

def show_win_probs(model: TicTacToe, boards: Boards) -> TensorType["batch_size", 3, 3]:
    _, win_probs = get_win_probs(model, boards)
    return ntt(win_probs)

In [None]:
@maybe_typechecked
def choose_moves(model: TicTacToe, boards: Boards) -> Tuple[TensorType["batch_size", 9, 3], TensorType["batch_size", t.long]]:
    logits, win_probs = get_win_probs(model, boards)
    flat_boards = ttn(boards)
    moves = t.multinomial(win_probs * (flat_boards == 0), num_samples=1).squeeze(-1)
    return logits, moves

In [None]:
@maybe_typechecked
def make_moves(boards: Boards, moves: TensorType["batch_size", t.long]) -> Boards:
    return boards + ntt(t.nn.functional.one_hot(moves, 9))

In [None]:
@maybe_typechecked
def play_game(model: TicTacToe) -> List[Boards]:
    game = [t.zeros(1, 3, 3, dtype=t.long, device=model.input_device())]
    while t.max(game[-1] == 0) and get_outcomes(game[-1]) == 0:
        _, moves = choose_moves(model, game[-1])
        next = make_moves(game[-1], moves)
        game.append(-next)
    return game

play_game(model)

In [None]:
@maybe_typechecked
def train(model: TicTacToe, lr: float, batch_size: int, num_rounds: int):
    model.train()
    optimizer = t.optim.SGD(model.parameters(), lr=lr)
    boards = t.zeros(batch_size, 3, 3, dtype=t.long, device=model.input_device())
    losses = []
    for i in range(num_rounds):
        optimizer.zero_grad()

        logits, moves = choose_moves(model=model, boards=boards)
        assert logits.shape == (batch_size, 9, 3)
        assert moves.shape == (batch_size,)
        logits_of_moves = t.gather(logits, dim=1, index=moves[:, None, None].expand(-1, -1, 3)).squeeze(1)
        assert logits_of_moves.shape == (batch_size, 3), logits_of_moves.shape
        moved = make_moves(boards=boards, moves=moves)
        assert moved.shape == (batch_size, 3, 3)
        finished_by_me = finished(moved)
        assert finished_by_me.shape == (batch_size,)
        indices = t.arange(batch_size, device=model.input_device())
        indices_of_finished = t.masked_select(indices, finished_by_me)
        (num_finished_by_me,) = indices_of_finished.shape
        logits_of_games_i_ended = t.index_select(logits_of_moves, dim=0, index=indices_of_finished)
        assert logits_of_games_i_ended.shape == (num_finished_by_me, 3)
        outcomes_of_games_i_ended = t.index_select(get_outcomes(moved) + 1, dim=0, index=indices_of_finished)
        assert outcomes_of_games_i_ended.shape == (num_finished_by_me,)
        after_my_move = -moved * (~finished_by_me)[:, None, None]
        _, replies = choose_moves(model=model, boards=after_my_move)
        after_their_reply = -make_moves(boards=after_my_move, moves=replies)
        finished_by_them = (~finished_by_me) & finished(after_their_reply)
        indices_of_finished = t.masked_select(indices, finished_by_them)
        logits_of_games_they_ended = t.index_select(logits_of_moves, dim=0, index=indices_of_finished)
        outcomes_of_games_they_ended = t.index_select(get_outcomes(after_their_reply) + 1, dim=0, index=indices_of_finished)
        loss = t.nn.functional.cross_entropy(
            input=t.cat((logits_of_games_i_ended, logits_of_games_they_ended), dim=0),
            target=t.cat((outcomes_of_games_i_ended, outcomes_of_games_they_ended), dim=0),
        )
        boards = after_my_move

        loss.backward()
        if loss > 0:
            losses.append(loss.item())
        optimizer.step()

    w = 20
    plt.plot([sum(losses[i : i + w]) / w for i in range(len(losses) - w)])

In [None]:
batch_size = 16384 if t.cuda.is_available() else 1024
train(model, lr=1e-3, batch_size=batch_size, num_rounds=1000)

In [None]:
show_win_probs(model, boards=t.tensor([[[0, 0, 0], [0, -1, 0], [0, 0, 0]]], device=model.input_device()))