In [1]:
import torch
from torch import distributions, nn
import numpy as np
import random
import math
import sys
import copy

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
path = '/content/gdrive/My Drive/12th grade/Joshua Zhu Project Studio/Save'

In [4]:
class HexEnvironment():
  def __init__(self, boardsize):
    self.size = boardsize
    self.grid = [[], []] #the first is where p1 puts its pieces, the second is where p2 puts its pieces
    for i in range(boardsize):
      p1_row = []
      p2_row = []
      for j in range(boardsize):
        p1_row.append(0)
        p2_row.append(0)
      self.grid[0].append(p1_row)
      self.grid[1].append(p2_row)
    self.current_player = 1
    self.done = False

  def neighbors(self, x, y):
    neighbors = [[x-1, y+1], [x-1, y], [x, y-1], [x+1, y-1], [x+1, y], [x, y+1]]
    real_neighbors = []
    for i in range(len(neighbors)):
      [a, b] = neighbors[i]
      if (a >= 0) and (a < self.size) and (b >= 0) and (b < self.size):
        real_neighbors.append([a, b])
    return real_neighbors

  def obs(self):
    return copy.deepcopy(self.grid)

  def is_done(self):
    return self.done

  def get_player(self):
    return self.current_player

  def is_occupied(self, tile):
    row = int(tile/self.size)
    col = int(tile%self.size)
    if self.grid[0][row][col] == 1 or self.grid[1][row][col] == 1:
      return True
    return False

  def set_correct_player(self):
    p1_tokens = 0
    p2_tokens = 0
    for i in range(self.size):
      for j in range(self.size):
        if self.grid[0][i][j] == 1:
          p1_tokens += 1
        if self.grid[1][i][j] == 1:
          p2_tokens += 1
    if p1_tokens == p2_tokens:
      self.current_player = 1
    elif p1_tokens == (p2_tokens + 1):
      self.current_player = 2
    else:
      raise Exception("This board is invalid.")

  def swap_players(self):
    if self.current_player == 1:
      self.current_player = 2
    else:
      self.current_player = 1
  
  def set_tile(self, tile):
    self.set_correct_player()
    row = int(tile/self.size)
    col = int(tile%self.size)
    self.grid[self.current_player-1][row][col] = 1
    self.swap_players()
  
  def set_board(self, board):
    self.grid = board
    self.search_for_win()
    self.set_correct_player()
    
  def reset(self):
    self.current_player = 1
    self.done = False
    for i in range(self.size):
      for j in range(self.size):
        self.grid[0][i][j] = 0
        self.grid[1][i][j] = 0

  def print_self(self):
    for i in range(self.size):
      next_line = ""
      next_line += " "*i
      for j in range(self.size):
        if self.grid[0][i][j] == 1:
          next_line += "1"
        elif self.grid[1][i][j] == 1:
          next_line += "2"
        else:
          next_line += "_"
        next_line += " "
      print(next_line)

  def search_for_win(self):
    #Player 1 wants to connect bottom left ((10, 0), (10, 1), ..., (10, 10)) to top right ((0, 0), (0, 1), ... (0, 10))
    #Player 2 wants to connect top left ((0, 0), (1, 0), ... (10, 0)) to bottom right ((0, 10), (1, 10), ... (10, 10))
    #I'll do a dfs to find a winning path

    #Returns the winning player, 0 if no winner
    #Also updates self.done
        
    for player in range(2): #0 is p1, 1 is p2
      stack = []
      visited = [[False for i in range(self.size)] for j in range(self.size)]
      #Initialize search with the side corresponding to the player
      if player == 0:
        for i in range(self.size):
          if self.grid[0][self.size-1][i] == 1:
            stack.append([self.size-1, i])
      elif player == 1:
        for i in range(self.size):
          if self.grid[1][i][0] == 1:
            stack.append([i, 0])

      #Do the search
      while len(stack)>0:
        [x, y] = stack.pop()
        visited[x][y] = True
        neighbors = self.neighbors(x, y)
        for [newx, newy] in neighbors:
          if (self.grid[player][newx][newy] == 1) and (not visited[newx][newy]):
            stack.append([newx, newy])

      #End the search by checking the other side corresponding to the player
      if player == 0:
        for i in range(self.size):
          if visited[0][i]:
            self.done = True
            return 1
      elif player == 1:
        for i in range(self.size):
          if visited[i][self.size-1]:
            self.done = True
            return 2

    #0 means no one has won
    self.done = False
    return 0

  def filter_illegal(self, qvals, transpose = False):
    valid_probs = {}

    for i in range(self.size):
      for j in range(self.size):
        if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
          valid_probs[(i, j)] = qvals[i][j] if not transpose else qvals[j][i]

    #sorted_vals contains the true (i, j) coordinates, deals with the transpose already
    sorted_vals = sorted(valid_probs.items(), key=lambda x: x[1], reverse=True)
    return sorted_vals

  def greedy(self, qvals, transpose = False):
    #if transopose is true, that means the qvals are of the transposed board
    #however, the output is always in the true board's coordinates
    
    sorted_vals = self.filter_illegal(qvals, transpose)

    #greedy takes the maximum probability
    (best_x, best_y) = sorted_vals[0][0]
    
    #sorted_vals already dealth with the transpose
    return (best_x, best_y)
  
  def epsilon_greedy(self, qvals, epsilon, transpose = False):
    #if transopose is true, that means the qvals are of the transposed board
    #however, the output is always in the true board's coordinates

    sorted_vals = self.filter_illegal(qvals, transpose)

    (best_x, best_y) = sorted_vals[0][0]

    choice = random.random()
    if choice < epsilon:
      (best_x, best_y) = sorted_vals[int(len(sorted_vals)*(random.random()))][0]
    
    #sorted_vals already dealth with the transpose
    return (best_x, best_y)
    
  def random(self):
    valid_moves = []

    for i in range(self.size):
      for j in range(self.size):
        if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
          valid_moves.append((i, j))

    index = random.randrange(0, len(valid_moves))
    (random_x, random_y) = valid_moves[index]
    return (random_x, random_y)

  def stochastic(self, preds):
    filtered_preds = preds.detach().clone()

    for i in range(self.size):
      for j in range(self.size):
        if self.grid[0][i][j] != 0 or self.grid[1][i][j] != 0:
          filtered_preds[(i*self.size) + j] = 0

    m = distributions.Categorical(filtered_preds)
    action = m.sample().item()
    (chosen_x, chosen_y) = (math.floor(action/self.size), action%self.size)
    return (chosen_x, chosen_y)      

  def step(self, policy, move = None, epsilon = 0, transpose = False):
    #move will be a 2d array with q values for each hexagon of the grid

    if policy == 'epsilon_greedy':
      (x, y) = self.epsilon_greedy(move, epsilon, transpose)
    elif policy == 'greedy':
      (x, y) = self.greedy(move, transpose)
    elif policy == 'random':
      (x, y) = self.random()

    if self.grid[0][x][y] == 1 or self.grid[1][x][y] == 1:
      raise Exception("Tried to make an illegal move")
    self.grid[self.current_player - 1][x][y] = 1
    
    self.swap_players()

    return (x, y)

In [33]:
class Node():
  def __init__(self, board_state, tile, parent, whose_move, terminal, is_root):
    self.board_state = board_state
    self.tile = tile #the move that was just made
    self.is_root = is_root #only true for the blank board

    self.parent = parent
    #I’ll just pass the parent node itself as parent

    self.children = []
    #This will be a list of nodes

    self.whose_move = whose_move
    self.terminal = terminal
    self.num_visits = 0
    self.num_wins_of_parent = 0
    #self.num_wins counts the number of wins for the player who had just gone, so NOT self.whose_move
    #This is because the parent node goes to this one, and the parent node’s self.whose_move is opposite this one
    #so self.num_wins is the number of losses for the current self.whose_move

  def obs(self):
    return copy.deepcopy(self.board_state)

  #note the “opposite”-ness of these two functions
  def this_player_lost(self):
    self.num_visits += 1
    self.num_wins_of_parent += 1

  def this_player_won(self):
    self.num_visits += 1

  def get_value(self, C):
    N = self.parent.num_visits
    #UCB = (num_wins / num_visits) + C * sqrt{ ln(N) / num_visits}
    #remember that the parent is looking at this value, so num_wins is the win count for the parent’s current player
    if self.num_visits == 0:
      return sys.maxsize

    #I added a +1 to N
    return ((1.0*self.num_wins_of_parent) / (1.0*self.num_visits)) + C * ((math.log(N+1) / self.num_visits)**(0.5))

  def best_UCB(self, C = 2**(0.5)):
    if not self.has_children():
      raise Exception("This node has no children, cannot call best_UCB().")

    #C is just a constant, wikipedia says the default is sqrt{2}
    best_child = []
    best_UCB = -1

    for child in self.children:
      if child.get_value(C) > best_UCB:
        best_child = [child]
        best_UCB = child.get_value(C)
      elif child.get_value(C) == best_UCB:
        best_child.append(child)
    choice = random.randint(0, len(best_child)-1)
    return best_child[choice]

  def random_child(self):
    if not self.has_children():
      raise Exception("This node has no children, cannot call random_child().")
    choice = random.randint(0, len(self.children)-1)
    return self.children[choice]

  def has_children(self):
    if len(self.children) == 0:
      return False
    return True

In [80]:
class MonteCarloTreeSearchAgent():
  def __init__(self, boardsize):
    self.boardsize = boardsize
    self.env = HexEnvironment(boardsize) #to keep track of the actual game that's going on
    self.test_env = HexEnvironment(boardsize) #to do tests for the branches of the tree
    self.root = Node(self.env.obs(), None, None, 1, False, True)
    self.current_root = self.root

  def monte_carlo_tree_search(self, steps):
    #start at self.current_root, expand the tree a certain number of times (or use a time limit)
    for i in range(steps):
      next_node = self.traverse()
      if next_node.terminal:
        self.test_env.set_board(next_node.obs())
        winner = self.test_env.search_for_win()
        self.backpropogate(next_node, winner)
      else:
        if not next_node.has_children():
          self.generate_children(next_node)
        next_node = next_node.best_UCB()
        winner = self.random_rollout(next_node)
        self.backpropogate(next_node, winner)

  def move(self):
    #from the current root, make a move
    next_node = self.current_root.best_UCB() #CHANGE THIS
    move = next_node.tile
    self.generate_children(next_node)

    self.current_root = next_node
    self.env.set_tile(move)

    return move

  def traverse(self):
    #start from the root and keep picking best UCB until no more children or terminal
    next_node = self.current_root
    while (not next_node.terminal) and (next_node.has_children()):
      next_node = next_node.best_UCB()
    return next_node

  def update(self, move):
    #after the opponent makes a move, update self.env and self.current_root
    #this is assuming that self.current_root has already been updated after self has made its move
    for node in self.current_root.children:
      if node.tile == move:
        self.current_root = node
        self.generate_children(node)
        break
    self.env.set_tile(move)

  def generate_children(self, node):
    #generate children for a node that has no children
    if node.has_children():
      return
    for i in range(self.boardsize**2):
      start_board_state = node.obs()
      self.test_env.set_board(start_board_state)
      terminal = False
      nextp = 1
      if node.whose_move == 1:
        nextp = 2
      if not self.test_env.is_occupied(i):
        self.test_env.set_tile(i)
        winner = self.test_env.search_for_win()
        if winner != 0:
          terminal = True
        next_child = Node(self.test_env.obs(), i, node, nextp, terminal, False)
        node.children.append(next_child)

  def backpropogate(self, node, winner):
    #after a rollout is finished, update all the nodes
    current_node = node
    while current_node.is_root == False:
      if current_node.whose_move == winner:
        current_node.this_player_won()
      else:
        current_node.this_player_lost()
      current_node = current_node.parent
    self.root.num_visits += 1 #we don't care about num_wins for the root

  def random_rollout(self, node):
    #play randomly until terminal, then call backpropogate
    self.test_env.set_board(node.obs())
    winner = self.test_env.search_for_win()
    while not self.test_env.is_done():
      self.test_env.step('random')
      winner = self.test_env.search_for_win()
    return winner
  

In [104]:
yahaha = MonteCarloTreeSearchAgent(5)

In [105]:
yahaha.monte_carlo_tree_search(10000)

In [106]:
yahaha.current_root.board_state

[[[0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0]],
 [[0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0]]]

In [100]:
yahaha.current_root = a

In [102]:
yahaha.monte_carlo_tree_search(1000)

In [96]:
a = yahaha.current_root.parent

In [98]:
a.board_state

[[[0, 0, 1], [0, 0, 1], [0, 0, 0]], [[0, 0, 0], [1, 1, 0], [0, 0, 0]]]

In [108]:
a = yahaha.current_root.children[7]

In [110]:
for node in a.children:
  print(node.tile, node.num_visits, node.get_value(2**(0.5)))

0 30 1.0289481381478942
1 21 1.0296730002416818
2 37 1.0287840071341303
3 24 1.0321198619019734
4 21 1.0296730002416818
5 27 1.0314393006695093
6 29 1.0184309286104571
8 21 1.0296730002416818
9 35 1.041724926575691
10 24 1.0321198619019734
11 16 1.031866253369957
12 39 1.0423977371395319
13 29 1.0184309286104571
14 43 1.0648111354639602
15 43 1.0415553215104718
16 34 1.0338698233480952
17 37 1.0287840071341303
18 39 1.0423977371395319
19 24 1.0321198619019734
20 27 1.0314393006695093
21 49 1.0284133692726285
22 14 1.0409093759250594
23 35 1.041724926575691
24 21 1.0296730002416818


In [107]:
for node in yahaha.current_root.children:
  print(node.tile, node.num_visits, node.get_value(2**(0.5)))

0 231 0.7802254276158551
1 273 0.7799074748803362
2 338 0.7807888729647854
3 371 0.7807787017176876
4 520 0.7805223852575717
5 377 0.7807286442325735
6 422 0.7800191649112641
7 719 0.7803689059761476
8 706 0.7819264710444213
9 231 0.7802254276158551
10 383 0.780666376014467
11 411 0.7810496105926623
12 454 0.7807265995139413
13 568 0.7804385587310947
14 140 0.7770221468758747
15 271 0.7810129231890383
16 591 0.7806083660631035
17 429 0.7806441304287746
18 457 0.7806379121398092
19 290 0.7796185032256959
20 641 0.7810665059317461
21 335 0.7807632274071737
22 362 0.780828783841819
23 300 0.781129491051413
24 180 0.7810145751432143


In [93]:
move = yahaha.move()

In [94]:
move

6

In [91]:
yahaha.update(3)

In [65]:
yahaha.monte_carlo_tree_search(1)

In [61]:
testnode = yahaha.traverse()

In [62]:
testnode.board_state

[[[0, 0, 0], [1, 0, 0], [1, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 1]]]

In [63]:
testnode.terminal

False

In [66]:
testnode.children

[<__main__.Node at 0x7f17964d1d10>,
 <__main__.Node at 0x7f17964d1b10>,
 <__main__.Node at 0x7f17964d1e10>,
 <__main__.Node at 0x7f17964d1b50>,
 <__main__.Node at 0x7f17964d18d0>]

In [31]:
for node in testnode.children:
  print(node.board_state)

[[[1, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]
[[[0, 1, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]
[[[0, 0, 1], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]
[[[0, 0, 0], [1, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]
[[[0, 0, 0], [0, 1, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]
[[[0, 0, 0], [0, 0, 1], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]
[[[0, 0, 0], [0, 0, 0], [1, 0, 1]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]]]


In [59]:
for node in yahaha.root.children:
  print(node.num_visits, node.get_value(2**(0.5)))

62 1.0365998688467801
82 1.0324463036845144
151 1.037600058109048
91 1.0380191532587864
134 1.0375343114223206
111 1.0375048486432794
212 1.038316599759992
75 1.029224254652203
82 1.0324463036845144


In [60]:
yahaha.root.num_visits

1000