In [1]:
import os
import numpy as np
import torch
import math
from torch import nn as nn
from torch.nn import functional as F
from torch.optim import AdamW
import random

np.set_printoptions(linewidth=300)

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [2]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [3]:
cd 'drive/MyDrive/Colab_Notebook/Machine_Learning/Projects/TTT'

/content/drive/MyDrive/Colab_Notebook/Machine_Learning/Projects/TTT


In [4]:
def challenge_AI(model_name, model_opt, num_simulations, player_token, board_len=10, render=True):

  ai_token = player_token * (-1)

  env = Environment(board_len)
  model = ZeroTTT(brain_path=model_name, opt_path=model_opt, board_len=board_len)

  mcts = MCTS(model, env.board, num_simulations=num_simulations)
  
  while env.game_over() == 10:

    if env.turn == player_token:
        player_move = tuple(map(int, input("Input Move: ").split(',')))
        selected_move = mcts.select_move(external_move=player_move, tau=0.0) # update the tree of the model
        if selected_move != player_move:
          print("Invalid selection, retry whole game (will fix this soon)")
        env.step(selected_move)
    else:
      mcts.search()
      move = mcts.select_move(tau=0.0)
      env.step(move)

    if render is True:
      env.render()

In [5]:
def split_state(state):
  split = np.zeros((2, len(state), len(state)))
  for i, row in enumerate(state):
    for j, cell in enumerate(row):
      if cell == 1:
        split[0][i][j] = 1
      elif cell == -1:
        split[1][i][j] = 1

  return split

In [6]:
class Environment():
  def __init__(self, board_len=30):

    self.board_len = board_len
    self.board = np.zeros((board_len, board_len))

    self.x_token = 1
    self.o_token = -1

    self.turn = self.x_token # x starts

    self.move_hist = []

  def reset(self):
    self.board = np.zeros((self.board_len, self.board_len))
    self.move_hist = []
    self.turn = self.x_token
    return

  def step(self, action, override_turn=None):

    if override_turn is not None:
      self.board[action[0]][action[1]] = override_turn
      self.move_hist.append((action, override_turn))
      return

    self.board[action[0]][action[1]] = self.turn
    self.move_hist.append((action, self.turn))
    self.turn *= -1 # turn swaps
    return self.game_over()

  def game_over(self):
    win = np.ones(5)
    win_diag = np.identity(5)

    for i in range(len(self.board)):
      for j in range(len(self.board[i]) - len(win) + 1):

        similarity = np.sum(win * self.board[i][j : j + len(win)])
        similarity_t = np.sum(win * self.board.T[i][j : j + len(win)])

        if similarity ==  5 or similarity_t == 5:
          return 1
        elif similarity ==  -5 or similarity_t == -5:
          return -1

    for i in range(len(self.board) - len(win) + 1):
      for j in range(len(self.board[i]) - len(win) + 1):

        similarity = np.sum(win_diag * self.board[i : i+len(win_diag), j : j +len(win_diag)])
        similarity_t = np.sum(np.rot90(win_diag) * self.board[i : i+len(win_diag), j : j +len(win_diag)])

        if similarity ==  5 or similarity_t == 5:
          return 1
        elif similarity ==  -5 or similarity_t == -5:
          return -1
    
    if np.any(self.board == 0) is False: # draw, VERY IMPROBABLE
      return 0

    return 10


  def render(self):
    show_board = np.full((self.board_len, self.board_len), ' ')
    for move in self.move_hist:
      action, turn = move
      if turn == 1:
        show_board[action[0]][action[1]] = 'X'
      else:
        show_board[action[0]][action[1]] = 'O'
    print("="*120)    
    print(show_board)
    print("="*120)    
    return

In [7]:
def ucb_score(child_prior, parent_visit_count, child_visit_count):
  # pb_c = math.log((parent_visit_count + 19652 + 1)/19652) + 1.25
  pb_c = math.sqrt(parent_visit_count) / (child_visit_count + 1)
  prior_score = pb_c * child_prior
  return prior_score

def np_softmax(arr_2d, dim=2):

  if dim == 2:
    flat = arr_2d.flatten()
  elif dim == 1:
    flat = arr_2d

  e_x = np.exp(flat - np.max(flat))
  e_x = e_x / e_x.sum()
  e_x = e_x.reshape(arr_2d.shape)
  return e_x

class Node():
  def __init__(self, state, dir_alpha=0.25):
    self.state = state
    self.children = []
    self.visit_count = 0
    self.alpha = dir_alpha

  def is_leaf_node(self):
    return (len(self.children) == 0)

  def expand(self, model, is_root=False):

    if self.is_leaf_node() is False: # safety so a node doesn't double its children
      return

    p_vals, value = model.predict(split_state(self.state))
    p_vals = p_vals[0].cpu().detach().numpy()

    # Add dirichlet noise if root state:
    if is_root is True:
      dir = np.random.dirichlet([0.03]*len(self.state)*len(self.state)).reshape((self.state.shape))
      p_vals = (1 - self.alpha) * p_vals + self.alpha * dir

    for i, row in enumerate(p_vals):
      for j, prior_prob in enumerate(row):
        if self.state[i][j] == 0: # only append available moves
          self.children.append(Edge(prior_prob, (i, j)))

    # Quick experiment:
    random.shuffle(self.children)

    return (-1.0)*value # negative value because this is evaluated from the position of the opposite player

  def find_leaf(self, model):

    self.visit_count += 1

    if self.is_leaf_node():
      return self.expand(model)

    # find child node that maximuzes Q + U
    max_ind, max_val = 0, 0.0
    for i, child_node in enumerate(self.children):
      val = child_node.Q + ucb_score(child_node.P, self.visit_count, child_node.N)
      if val > max_val:
        max_ind = i
        max_val = val

    v = self.children[int(max_ind)].traverse(self.state, model)
    return (-1.0)*v

class Edge():
  def __init__(self, p, action):

    self.action = action
    self.N = 0
    self.P = p
    self.W = 0
    self.Q = 0

    self.node = None

  def initialize_node(self, state): # destination node doesn't need to be initialized all the time, only if we're actually going to use it
    next_state = np.copy(state)
    next_state[self.action[0]][self.action[1]] = 1
    self.node = Node(next_state * (-1)) # multiply state by -1 to swap to opposite perspective
    return

  def traverse(self, state, model):
    self.initialize_node(state)

    v =  self.node.find_leaf(model)

    self.N += 1
    self.W += v
    self.Q = self.W / self.N

    return v


class MCTS():
  def __init__(self, model, root_state, num_simulations=1600):
    self.model = model
    self.root = Node(root_state)
    self.root.expand(self.model, is_root=True)
    self.num_simulations = num_simulations # represents "thinking time"

  def search(self): # builds the search tree from the root node
    self.root.expand(self.model, is_root=True)

    for i in range(self.num_simulations):
      self.root.find_leaf(self.model)

    return

  def select_move(self, tau=0.0, external_move=None):

    if external_move is None:
      probas = []
      for i, child in enumerate(self.root.children):
        probas.append(child.N)
      
      # probas = probas / np.sum(probas)
      probas = np_softmax(np.array(probas), dim=1)

      if tau == 0.0:
        max_ind = np.argmax(probas)
      else:
        max_ind = int(np.random.choice(len(probas), 1, p=probas))

    else:
      max_ind = 0
      for i, child in enumerate(self.root.children):
        if child.action == external_move:
          max_ind = i
          break

    # need to update the search tree 
    chosen = self.root.children[max_ind]
    if chosen.node == None:
      chosen.initialize_node(self.root.state)
    self.root = chosen.node
    return chosen.action

  def get_pi(self, tau=1.0, as_prob=True):
    move_dist = np.zeros((len(self.root.state), len(self.root.state)))

    for child in self.root.children:
      move_dist[child.action[0]][child.action[1]] = child.N

    if as_prob is True:
      move_dist = np_softmax(move_dist)

    return move_dist

In [9]:
def softXEnt (input, target): # temporary
    logprobs = torch.nn.functional.log_softmax (input, dim = 1)
    return  -(target * logprobs).sum() / input.shape[0]


def append_state(states, labels, state, label):
  # Augmentation
  for i in range(2):
    for j in range(4):
      states.append(np.rot90(state, j))
      labels.append(np.rot90(label, j))
    
    state = state.T
    label = label.T
  
  state = state.T
  label = label.T
  return

class Brain(nn.Module):
  def __init__(self, input_shape=(2, 30, 30)):
    super().__init__()

    self.input_shape = input_shape

    use_bias = True
    self.conv1 = nn.Conv2d(input_shape[0], 24, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)
    self.conv2 = nn.Conv2d(24, 36, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)
    self.conv3 = nn.Conv2d(36, 48, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)
    self.conv4 = nn.Conv2d(48, 24, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)

    self.pol_conv1 = nn.Conv2d(24, 1, padding=(2,2), kernel_size=5, stride=1, bias=use_bias)

    self.val_conv1 = nn.Conv2d(24, 1, kernel_size=3, stride=1, bias=use_bias)
    self.val_linear1 = nn.Linear(64, 50)
    self.val_linear2 = nn.Linear(50, 1)

    self.flatten = nn.Flatten()

  def forward(self, x):
    # Core:
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = F.relu(x)
    x = self.conv4(x)
    x = F.relu(x)

    # Policy Head:
    p = self.pol_conv1(x)
    x = F.relu(x)

    p = p.view(-1, self.input_shape[1]*self.input_shape[2])
    p = F.softmax(p, dim=1)
    p = p.view(-1, self.input_shape[1], self.input_shape[2])

    # Value Head:
    v = self.val_conv1(x)
    x = F.relu(x)
    v = self.flatten(v)
    v = self.val_linear1(v)
    x = F.relu(x)
    v = self.val_linear2(v)
    v = torch.tanh(v)

    return p, v


class ZeroTTT():
  def __init__(self, brain_path=None, opt_path=None, board_len=10, lr=3e-4, weight_decay=0.0):
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.brain = Brain(input_shape=(2, board_len, board_len)).to(self.device)
    self.board_len = board_len

    self.optimizer = AdamW(self.brain.parameters(), lr=lr, weight_decay=weight_decay)
    self.value_loss = nn.MSELoss()
    self.policy_loss = nn.CrossEntropyLoss()

    if brain_path is not None:
      self.load_brain(brain_path, opt_path)

  def get_parameter_count(self):
    return sum(p.numel() for p in self.brain.parameters() if p.requires_grad)

  def save_brain(self, model_name, opt_state_name):
    print("Saving brain...")
    torch.save(self.brain.state_dict(), os.path.join('models', model_name))
    if opt_state_name is not None:
        torch.save(self.optimizer.state_dict(), os.path.join('models', opt_state_name))

  def load_brain(self, model_name, opt_state_name):
    print("Loading brain...")
    self.brain.load_state_dict(torch.load(os.path.join('models', model_name)))
    if opt_state_name is not None:
        self.optimizer.load_state_dict(torch.load(os.path.join('models', opt_state_name)))
    return

  def predict(self, x):

    if len(x.shape) < 4:
      x = np.expand_dims(x, axis=0)

    x = torch.from_numpy(x).float().to(self.device)

    policy, value = self.brain(x)
    return policy, value

  def self_play(self, n_games=1, num_simulations=100, positions_per_learn=100, batch_size=20 ,render=10,
            games_per_evaluation=-1, evaluation_game_count=20, evaluation_num_simulations=50):
    
    # Put model in training mode:
    self.brain.train()

    # TODO:
    # - Implement proper sampling every couple of games rather than learning after each game
    
    states = []
    policy_labels = []
    value_labels = []
    val_chunk = []


    best_model = ZeroTTT(brain_path='best_model', opt_path='best_opt_state', board_len=self.board_len) # best model always generates data
    env = Environment(board_len=self.board_len)

    for game_nr in range(n_games):
      
      mcts = MCTS(best_model, env.board, num_simulations=num_simulations)
      tau = 1.0

      print(f"Game {game_nr+1}...")

      while env.game_over() == 10:

        if len(env.move_hist) > 30: # after 30 moves no randomness
          tau = 0.0

        if np.any(env.board == 0) is False: # tie
          break

        mcts.search()

        if env.turn == env.x_token:
          append_state(states, policy_labels, env.board, mcts.get_pi())
        elif env.turn == env.o_token: # swap persepctive so O tokens are positive and X tokens are negative
          append_state(states, policy_labels, (-1)*env.board, mcts.get_pi())
        
        val_chunk += [env.turn]*8 # accounting for augmentation

        move = mcts.select_move(tau=tau)
        env.step(move)

        if (game_nr+1) % render == 0:
          env.render()
        
      if (game_nr+1) % render == 0:
        print(f"Player with token: {env.game_over()} won the game")

      if env.game_over() == env.x_token: # pass because the turns correctly specify the return from the proper perspectives
        pass
      elif env.game_over() == env.o_token:
        val_chunk = [lab * (-1.0) for lab in val_chunk] # invert the turns because that will represent -1 return for x turns and 1 for o turns
      else: # tie
        val_chunk = [0 for lab in val_chunk]

      value_labels += val_chunk
      val_chunk = []


      if len(states) >= positions_per_learn: # learn

        print(f"Training on {len(states)} positions...")

        states = [split_state(state) for state in states]

        states = np.array(states)
        policy_labels = np.array(policy_labels)
        value_labels = np.array(value_labels)

        p = np.random.permutation(len(states))

        states = states[p]
        policy_labels = policy_labels[p]
        value_labels = value_labels[p]

        batch_count = int(len(states)/batch_size)
        if len(states) / batch_size > batch_count:
          batch_count += 1

        for j in range(batch_count):

          self.optimizer.zero_grad()

          batch_st = states[j * batch_size: min((j+1) * batch_size, len(states))]
          batch_pl = policy_labels[j * batch_size: min((j+1) * batch_size, len(policy_labels))]
          batch_vl = value_labels[j * batch_size: min((j+1) * batch_size, len(value_labels))]

          batch_pl = torch.from_numpy(batch_pl).to(self.device)
          batch_vl = torch.from_numpy(batch_vl).float().to(self.device)
          prob, val = self.predict(batch_st)
          val = val.flatten()
  
          prob = torch.flatten(prob, 1, 2)
          batch_pl = torch.flatten(batch_pl, 1, 2)
  
          p_loss = softXEnt(prob, batch_pl)
          v_loss = self.value_loss(val, batch_vl)
  
          loss = p_loss + v_loss
          loss.backward()
  
          self.optimizer.step()
  
        states = []
        policy_labels = []
        value_labels = []
      
      if games_per_evaluation != -1 and (game_nr + 1) % games_per_evaluation == 0:

        print("Evaluating trained model...")

        win_count = 0
        token = 1
        for i in range(evaluation_game_count):
          result = self.evaluate('best_model', 'best_opt_state', token, board_len=self.board_len, num_simulations=evaluation_num_simulations, render=False)
          if result == 1:
            win_count += 1
          token *= -1

        win_pct = win_count/evaluation_game_count

        print(f"Model won {win_pct} of games.")

        if win_pct >= 0.55: 
          print("Overwriting best model...")
          self.save_brain('best_model', 'best_opt_state')
          best_model = ZeroTTT(brain_path='best_model', opt_path='best_opt_state', board_len=self.board_len) # overwrite the best_model to be the current weights

      env.reset()
    
  def evaluate(self, opp_name, opp_opt_state, model_token, board_len=10, num_simulations=100, render=False):

    # Put model in evaluation mode
    self.brain.eval()

    env = Environment(board_len=10)
    opponent = ZeroTTT(brain_path=opp_name, opt_path=opp_opt_state, board_len=board_len)

    model_MCTS = MCTS(self, env.board, num_simulations=num_simulations)
    opponent_MCTS = MCTS(opponent, env.board, num_simulations=num_simulations)

    while env.game_over() == 10:

      if np.any(env.board == 0) is False:
        break
      

      if env.turn == model_token:
        model_MCTS.search()
        move = model_MCTS.select_move(tau=0.0)
        env.step(move)
        opponent_MCTS.select_move(external_move=move)
      else:
        opponent_MCTS.search()
        move = opponent_MCTS.select_move(tau=0.0)
        env.step(move)
        model_MCTS.select_move(external_move=move)

      if render is True:
        env.render() 

    # Return 1 for win of model, -1 for loss and 0 for tie
    if env.game_over() == model_token:
      return 1
    elif env.game_over() == 0:
      return 0
    else:
      return -1




In [None]:
model = ZeroTTT(brain_path='best_model', opt_path='best_opt_state', lr=3e-4, board_len=10)
# model = ZeroTTT(brain_path=None, opt_path=None, lr=3e-4, board_len=10)

# test = torch.randn((2, 10, 10))
# p, v = model.predict(test)
# print(v.shape)

model.self_play(n_games=1000, num_simulations=200, render=20, positions_per_learn=1600, batch_size=40,
            games_per_evaluation=50, evaluation_game_count=20, evaluation_num_simulations=200)
# model.self_play(n_games=1000, num_simulations=10, render=1, positions_per_learn=800, batch_size=40,
            # games_per_evaluation=50, evaluation_game_count=20, evaluation_num_simulations=10)

# model.evaluate(opp_name='best_model', opp_opt_state='best_opt_state', board_len=10, num_simulations=100, render=True, model_token=1)



# TODO:

# Mcts Def makes it worse
# MCTS makes it go crazy hmmmm
# maybe it's just not enough training


Loading brain...
Loading brain...
Game 1...
Game 2...
Game 3...
Game 4...
Training on 1680 positions...
Game 5...
Game 6...
Game 7...
Game 8...
Training on 1984 positions...
Game 9...
Game 10...
Game 11...
Game 12...
Training on 2040 positions...
Game 13...
Game 14...
Game 15...
Game 16...
Training on 2048 positions...
Game 17...
Game 18...
Game 19...
Training on 1656 positions...
Game 20...
[[' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' 'X' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']]
[[' ' 'O' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' ']
 [' ' ' ' ' ' ' ' 'X' ' ' ' ' ' ' ' ' ' ']
 [