In [None]:
import torch
import torch.nn as nn
from copy import deepcopy


class TicTacToe:
    def __init__(self, num_boards=128, board_size=3):
        self.reset(num_boards, board_size)

    def reset(self, num_boards, board_size, mask=None):
        if mask is None:
            self.boards = torch.zeros(num_boards, board_size, board_size)
            self.game_over = torch.zeros(num_boards)
            self.current_player = torch.ones(num_boards)
        else:
            self.boards[mask] = 0
            self.game_over[mask] = 0
            self.current_player[mask] = 1

    def is_valid_move(self, batch, row, col):
        return self.boards[batch, row, col] == 0

    def make_move(self, batch, row, col):
        valid_moves = self.is_valid_move(batch, row, col)

        batch = batch[valid_moves]
        row = row[valid_moves]
        col = col[valid_moves]

        self.boards[batch, row, col] = self.current_player
        wins = self.check_wins()
        self.game_over[wins] = 1
        return True

    def change_turn(self, mask):
        self.current_player[mask] = self.current_player[mask] * -1

    def get_empty_positions(self):
        moves_per_game = [torch.where(board == 0) for board in self.boards]
        return moves_per_game

    def check_wins(self):
        boards = self.boards
        board_h, board_w = boards.shape[-2:]

        # check horizontals and verticals
        has_won = torch.all(boards == 1, dim=-1).any(dim=-1)
        has_won = has_won or torch.all(boards == 1, dim=-2).any(dim=-1)

        # check diagonals
        idx = torch.arange(board_h)
        has_won = has_won or torch.all(boards[:, -idx-1, idx] == 1, dim=-1)
        has_won = has_won or torch.all(boards[:, idx, idx] == 1, dim=-1)

        return has_won


class AdaNorm(nn.Module):
    def __init__(self, dim):
        super(AdaNorm, self).__init__()
        self.eps = eps
        self.proj = nn.Linear(dim, dim*2)
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)

    def forward(self, x, emb):
        scale, shift = self.proj(torch.nn.functional.silu(emb)).chunk(2, dim=-1)
        return self.norm(x) * scale + shift


class ResnetBlock(nn.Module):
    def __init__(self, dim, mid_dim, dropout=0.0, act=nn.GELU, bias=False):
        super(ResnetBlock, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.net = nn.Sequential(
            nn.Linear(dim, mid_dim, bias=bias),
            act(),
            nn.Linear(mid_dim, dim, bias=bias),
        )

    def forward(self, x):
        return x + self.net(self.norm(x))

class Agent(nn.Module):
    def __init__(self,
                dropout=0.0, 
                num_layers=6,
                dim=128,
                board_size=3,
                ):
        super().__init__()
        self.turn_embed = nn.Embedding(2, dim)
        self.in_norm = AdaNorm(dim)
        self.proj_in = nn.Linear(board_size * board_size, dim)

        self.stem = nn.ModuleList([ResnetBlock(dim, dim*2, dropout) for _ in range(num_layers)])

        self.q_layers = nn.ModuleList([ResnetBlock(dim, dim*2, dropout) for _ in range(1)])
        self.q_norm_out = nn.LayerNorm(dim)
        self.q_proj_out = nn.Linear(dim, board_size * board_size)

        self.v_layers = nn.ModuleList([ResnetBlock(dim, dim*2, dropout) for _ in range(1)])
        self.v_norm_out = nn.LayerNorm(dim)
        self.v_proj_out = nn.Linear(dim, 1)


    def forward(self, x, turn):
        x = self.proj_in(x)
        x = self.in_norm(x, self.turn_embed(turn))
        for layer in self.stem:
            x = layer(x)
        
        # q = x
        # for layer in self.q_layers:
        #     q = layer(q)
        # q = self.q_norm_out(q)
        # q = self.q_proj_out(q)

        # v = x
        # for layer in self.v_layers:
        #     v = layer(v)
        # v = self.v_norm_out(v)
        # v = self.v_proj_out(v)

        return q, v






In [None]:
import torch
import torch.nn as nn
from copy import deepcopy


class TicTacToe:
    def __init__(self, num_boards=128, board_size=3):
        self.reset(num_boards, board_size)

    def reset(self, num_boards, board_size, mask=None):
        if mask is None:
            self.boards = torch.zeros(num_boards, board_size, board_size)
            self.game_over = torch.zeros(num_boards)
            self.current_player = torch.ones(num_boards)
        else:
            self.boards[mask] = 0
            self.game_over[mask] = 0
            self.current_player[mask] = 1

    def is_valid_move(self, batch, row, col):
        return self.boards[batch, row, col] == 0

    def make_move(self, batch, row, col):
        valid_moves = self.is_valid_move(batch, row, col)

        batch = batch[valid_moves]
        row = row[valid_moves]
        col = col[valid_moves]

        self.boards[batch, row, col] = self.current_player
        wins = self.check_wins()
        self.game_over[wins] = 1
        return True

    def change_turn(self, mask):
        self.current_player[mask] = self.current_player[mask] * -1

    def get_empty_positions(self):
        moves_per_game = [torch.where(board == 0) for board in self.boards]
        return moves_per_game

    def check_wins(self):
        boards = self.boards
        board_h, board_w = boards.shape[-2:]

        # check horizontals and verticals
        has_won = torch.all(boards == 1, dim=-1).any(dim=-1)
        has_won = has_won or torch.all(boards == 1, dim=-2).any(dim=-1)

        # check diagonals
        idx = torch.arange(board_h)
        has_won = has_won or torch.all(boards[:, -idx-1, idx] == 1, dim=-1)
        has_won = has_won or torch.all(boards[:, idx, idx] == 1, dim=-1)

        return has_won


class AdaNorm(nn.Module):
    def __init__(self, dim):
        super(AdaNorm, self).__init__()
        self.eps = eps
        self.proj = nn.Linear(dim, dim*2)
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)

    def forward(self, x, emb):
        scale, shift = self.proj(torch.nn.functional.silu(emb)).chunk(2, dim=-1)
        return self.norm(x) * scale + shift



class Agent(nn.Module):
    def __init__(self,
                dropout=0.0, 
                num_layers=6,
                dim=128,
                board_size=3,
                ):
        super(TransformerEncoderLayer, self).__init__()
        dim = dim_policy + dim_value + dim_pred
        self.gradient_checkpointing = False
        self.board_embed = nn.Embedding(3, dim)
        self.turn_embed = nn.Embedding(2, dim)
        self.in_norm = AdaNorm(dim)

        self.attns = nn.ModuleList([Attention(dim, dim, heads, dropout) for _ in range(num_layers)])
        self.ffs = nn.ModuleList([FeedForward(dim, dim*ff_mult, dropout) for _ in range(num_layers)])
        self.attn_norms = nn.ModuleList([AdaNorm(dim) for _ in range(num_layers)])
        self.ff_norms = nn.ModuleList([AdaNorm(dim) for _ in range(num_layers)])

        self.pos_emb = nn.Parameter(torch.randn(1, board_h*board_w, dim))

    def forward(self, board_bhw, turn):
        board = board_bhw.reshape(board_bhw.shape[0], -1)
        turn_embed = self.turn_embed(turn)
        board_embed = self.board_embed(board)
        board_embed = board_embed + self.pos_emb.expand(board.shape[0], -1, -1)
        x = self.in_norm(board_embed, turn_embed)
        for attn, ff, attn_norm, ff_norm in zip(self.attns, self.ffs, self.attn_norms, self.ff_norms):
            if self.gradient_checkpointing:
                x = x + torch.utils.checkpoint.checkpoint(attn, attn_norm(x, turn_embed), None)
                x = x + torch.utils.checkpoint.checkpoint(ff, ff_norm(x, turn_embed))
            else
                x = x + attn(attn_norm(x, turn_embed), None)
                x = x + ff(ff_norm(x, turn_embed))
        
        policy = self.policy_pooler(x)
        if self.gradient_checkpointing:
            policy = torch.utils.checkpoint.checkpoint(self.policy_ff, policy)
        else:
            policy = self.policy_ff(policy)
        policy = self.policy_norm_out(policy)
        policy = self.policy_proj_out(policy)

        # value = self.value_pooler(x)
        # value = self.value_ff(value)
        # value = self.value_norm_out(value)
        # value = self.value_proj_out(value)

        if self.gradient_checkpointing:
            pred = torch.utils.checkpoint.checkpoint(self.pred_ff, x)
        else:
            pred = self.pred_ff(x)
        pred = self.pred_norm_out(pred)
        pred = self.pred_proj_out(pred)

        return policy, pred

    def step(board_bhw, turn):
        policy, pred = self.forward(board_bhw, turn)
        # zero out impossible moves
        policy_probs = policy_probs + (board_bhw != 0).float() * -1e4
        policy_probs = torch.nn.functional.softmax(policy, dim=-1)
        pred_probs = torch.nn.functional.softmax(pred, dim=-1)

        policy_sample = torch.distributions.Categorical(policy).sample()
        log_prob = torch.distributions.Categorical(policy).log_prob(policy_sample)

        return policy_sample, policy_probs, pred_probs, log_prob


def calculate_returns(rewards, gamma=0.99):
    returns = []
    G = 0
    for reward in reversed(rewards):
        G = reward + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns)
    return returns

def update_policy(log_probs, returns):
    policy_loss = []
    for log_prob, G in zip(log_probs, returns):
        policy_loss.append(-log_prob * G)
    policy_loss = torch.cat(policy_loss).sum()

    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()

    
def train(
    lr=1e-3,
    batch_size = 32,
    steps = 1000,
):
    model = Transformer()
    model.gradient_checkpointing = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    with torch.autocast(enabled=True, device=device, dtype=torch.float16):
        for i in range(steps):
            player_2 = deepcopy(model)

            optimizer.zero_grad()
            boards = create_boards(batch_size).requires_grad_(False).to(device)
            turn = torch.randint(0, 2, (batch_size,)).requires_grad_(False).to(device)
            while True:
                policy_sample, policy_probs, pred_probs, log_prob = model.step(boards, turn)
                boards = mark_boards(boards, policy_sample, turn+1)
                has_won = check_wins(boards)
                if has_won.any():
                    break
                turn = 1 - turn
                
            





