In [26]:
import numpy as np
import torch
import math
from torch import nn as nn
from torch.nn import functional as F
from torch.optim import AdamW

np.set_printoptions(linewidth=300)

torch.manual_seed(0)
np.random.seed(0)

In [27]:
def split_state(state):
  split = np.zeros((2, len(state), len(state)))
  for i, row in enumerate(state):
    for j, cell in enumerate(row):
      if cell == 1:
        split[0][i][j] = 1
      elif cell == -1:
        split[1][i][j] = 1
  return split

In [28]:
class Environment():
  def __init__(self, board_len=30):

    self.board_len = board_len
    self.board = np.zeros((board_len, board_len))

    self.x_token = 1
    self.o_token = -1

    self.turn = self.x_token # x starts

    self.move_hist = []

  def reset(self):
    self.board = np.zeros((self.board_len, self.board_len))
    self.turn = self.x_token
    return

  def step(self, action, override_turn=None):

    if override_turn is not None:
      self.board[action[0]][action[1]] = override_turn
      self.move_hist.append((action, override_turn))
      return

    self.board[action[0]][action[1]] = self.turn
    self.move_hist.append((action, self.turn))
    self.turn *= -1 # turn swaps
    return self.game_over()

  def game_over(self):
    win = np.ones(5)
    win_diag = np.identity(5)

    for i in range(len(self.board)):
      for j in range(len(self.board[i]) - len(win) + 1):

        similarity = np.sum(win * self.board[i][j : j + len(win)])
        similarity_t = np.sum(win * self.board.T[i][j : j + len(win)])

        if similarity ==  5 or similarity_t == 5:
          return 1
        elif similarity ==  -5 or similarity_t == -5:
          return -1

    for i in range(len(self.board) - len(win) + 1):
      for j in range(len(self.board[i]) - len(win) + 1):

        similarity = np.sum(win_diag * self.board[i : i+len(win_diag), j : j +len(win_diag)])
        similarity_t = np.sum(np.rot90(win_diag) * self.board[i : i+len(win_diag), j : j +len(win_diag)])

        if similarity ==  5 or similarity_t == 5:
          return 1
        elif similarity ==  -5 or similarity_t == -5:
          return -1
    
    if np.any(self.board == 0) is False: # draw, VERY IMPROBABLE
      return 0

    return 10


  def render(self):
    show_board = np.full((self.board_len, self.board_len), ' ')
    for move in self.move_hist:
      action, turn = move
      if turn == 1:
        show_board[action[0]][action[1]] = 'X'
      else:
        show_board[action[0]][action[1]] = 'O'
    print("="*120)    
    print(show_board)
    print("="*120)    
    return

In [29]:
def ucb_score(child_prior, parent_visit_count, child_visit_count):
  return child_prior * math.sqrt(parent_visit_count) / (child_visit_count + 1)

def np_softmax(arr_2d):
  flat = arr_2d.flatten()
  e_x = np.exp(flat - np.max(flat))
  e_x = e_x / e_x.sum()
  e_x = e_x.reshape(arr_2d.shape)
  return e_x

class Node():
  def __init__(self, state, turn):
    self.state = state
    self.children = []
    self.visit_count = 0
    self.turn = turn

  def is_leaf_node(self):
    return (len(self.children) == 0)

  def expand(self, model):

    if self.is_leaf_node() is False: # safety so a node doesn't double its children
      return

    p_vals, value = model.predict(split_state(self.state))
    p_vals = p_vals[0]

    for i, row in enumerate(p_vals):
      for j, prior_prob in enumerate(row):
        if self.state[i][j] == 0: # only append available moves
          self.children.append(Edge(prior_prob, (i, j), self.turn))

    return value

  def find_leaf(self, model):

    self.visit_count += 1

    if self.is_leaf_node():
      return self.expand(model)

    # find child node that maximuzes Q + U
    max_ind, max_val = 0.0, 0.0
    for i, child_node in enumerate(self.children):
      val = self.turn * (child_node.Q + self.turn * ucb_score(child_node.P, self.visit_count, child_node.N)) # turn added so that if O is playing, the val is positive for good moves
      if val > max_val:
        max_ind = i
        max_val = val
    
    v = self.children[max_ind].traverse(self.state, model)
    return v

class Edge():
  def __init__(self, p, action, turn):

    self.action = action
    self.turn = turn # turn before action is made
    self.N = 0
    self.P = p
    self.W = 0
    self.Q = 0

    self.node = None

  def initialize_node(self, state): # destination node doesn't need to be initialized all the time, only if we're actually going to use it
    next_state = np.copy(state)
    next_state[self.action[0]][self.action[1]] = self.turn
    self.node = Node(next_state, self.turn*-1)
    return

  def traverse(self, state, model):
    self.initialize_node(state)

    v =  self.node.find_leaf(model)
    self.N += 1
    self.W += v
    self.Q = self.W / self.N

    return v


class MCTS():
  def __init__(self, model, root_state, turn, num_simulations=1600):
    self.model = model
    self.root = Node(root_state, turn)
    self.turn = turn
    self.num_simulations = num_simulations # can represent "thinking time"

  def search(self): # builds the search tree from the root node
    self.root.expand(self.model)

    for i in range(self.num_simulations):
      self.root.find_leaf(self.model)

    return

  def select_move(self, tau=1.0):

    max_ind, max_val = 0.0, 0.0

    for i, child in enumerate(self.root.children):

      val = child.N**(1.0/tau) # maximize visit count 

      if val > max_val:
        max_val = val
        max_ind = i

    # need to update the search tree 
    chosen = self.root.children[max_ind]
    self.root = chosen.node
    return chosen.action

  def get_pi(self, tau=1.0):
    move_dist = np.zeros((len(self.root.state), len(self.root.state)))

    for child in self.root.children:
      move_dist[child.action[0]][child.action[1]] = child.N**(1.0/tau)

    move_dist = np_softmax(move_dist)
    return move_dist

In [62]:
def softXEnt (input, target): # temporary
    logprobs = torch.nn.functional.log_softmax (input, dim = 1)
    return  -(target * logprobs).sum() / input.shape[0]


class Brain(nn.Module):
  def __init__(self, input_shape=(2, 30, 30)):
    super().__init__()

    self.input_shape = input_shape

    use_bias = True
    self.conv1 = nn.Conv2d(input_shape[0], 12, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)
    self.conv2 = nn.Conv2d(12, 24, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)
    self.conv3 = nn.Conv2d(24, 24, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)

    self.pol_conv1 = nn.Conv2d(24, 1, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)

    self.val_conv1 = nn.Conv2d(24, 1, kernel_size=5, stride=2, bias=use_bias)
    self.val_linear1 = nn.Linear(64, 50)
    self.val_linear2 = nn.Linear(50, 1)

    self.flatten = nn.Flatten()

  def forward(self, x):
    # Core:
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)

    # Policy Head:
    p = self.pol_conv1(x)
    p = p.view(-1, self.input_shape[1]*self.input_shape[2])
    p = F.softmax(p, dim=1)
    p = p.view(-1, self.input_shape[1], self.input_shape[2])

    # Value Head:
    v = self.val_conv1(x)
    v = self.flatten(v)
    v = self.val_linear1(v)
    v = self.val_linear2(v)
    v = torch.tanh(v)

    return p, v


class ZeroTTT():
  def __init__(self, board_len=30, lr=3e-4, weight_decay=0.0):
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.brain = Brain(input_shape=(2, board_len, board_len)).to(self.device)
    self.board_len = board_len

    self.optimizer = AdamW(self.brain.parameters(), lr=lr, weight_decay=weight_decay)
    self.value_loss = nn.MSELoss()
    self.policy_loss = nn.CrossEntropyLoss()

  def get_parameter_count(self):
    return sum(p.numel() for p in self.brain.parameters() if p.requires_grad)

  def predict(self, x):

    # if len(x.shape) < 4: # if doesn't have batch dimension, give it one
      # x = x.unsqueeze(0)
    if len(x.shape) < 4:
      x = np.expand_dims(x, axis=0)

    x = torch.from_numpy(x).float().to(self.device)

    policy, value = self.brain(x)
    return policy, value

  def train(self, n_games=1, num_simulations=10, batch_size=128 ,render=False):

    # TODO:
    # - Implement proper sampling every couple of games rather than learning after each game
    
    states = []
    policy_labels = []
    value_labels = []

    env = Environment(board_len=self.board_len)
    mcts = MCTS(self, env.board, env.turn, num_simulations=num_simulations)

    for i in range(n_games):
      
      print(f"Game {i+1}...")

      while env.game_over() == 10:
        mcts.search()
        states.append(np.copy(env.board))
        policy_labels.append(mcts.get_pi())
        move = mcts.select_move()
        env.step(move)

        if render is True:
          env.render()

      value_labels += [env.game_over()] * (len(policy_labels) - len(value_labels))

      if len(states) >= batch_size: # learn

        print(f"Training on {len(states)} positions...")

        states = [split_state(state) for state in states]
        states = np.array(states)
        policy_labels = torch.from_numpy(np.array(policy_labels)).to(self.device)
        value_labels = torch.from_numpy(np.array(value_labels)).float().to(self.device)

        prob, val = self.predict(states)
        val = val.flatten()

        prob = torch.flatten(prob, 1, 2)
        policy_labels = torch.flatten(policy_labels, 1, 2)

        # p_loss = self.policy_loss(prob, policy_labels)
        p_loss = softXEnt(prob, policy_labels)
        v_loss = self.value_loss(val, value_labels)

        loss = p_loss + v_loss
        loss.backward()

        self.optimizer.step() 

        states = []
        policy_labels = []
        value_labels = []

      env.reset()

In [63]:
model = ZeroTTT(board_len=20)
model.train(n_games=20, num_simulations=50, render=True, batch_size=100)

# TODO:
# env.reset not working, keeps tokens on board wtf

Game 1...
[['X' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '

KeyboardInterrupt: ignored