<a href="https://colab.research.google.com/github/joshuazhu17/Project_Studio/blob/main/HexGame.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np
import random
import math

In [2]:
class HexEnvironment():
    def __init__(self, boardsize):
        self.size = boardsize
        self.grid = [[], []] #the first is where p1 puts its pieces, the second is where p2 puts its pieces
        for i in range(boardsize):
            p1_row = []
            p2_row = []
            for j in range(boardsize):
                p1_row.append(0)
                p2_row.append(0)
            self.grid[0].append(p1_row)
            self.grid[1].append(p2_row)
        self.current_player = 1
        self.done = False

    def neighbors(self, x, y):
        neighbors = [[x-1, y+1], [x-1, y], [x, y-1], [x+1, y-1], [x+1, y], [x, y+1]]
        real_neighbors = []
        for i in range(len(neighbors)):
            [a, b] = neighbors[i]
            if (a >= 0) and (a < self.size) and (b >= 0) and (b < self.size):
                real_neighbors.append([a, b])
        return real_neighbors

    def obs(self):
        return self.grid

    def is_done(self):
        return self.done

    def get_player(self):
        return self.current_player
    
    def reset(self):
        self.current_player = 1
        self.done = False
        for i in range(self.size):
            for j in range(self.size):
                self.grid[0][i][j] = 0
                self.grid[1][i][j] = 0

    def print_self(self):
        for i in range(self.size):
            next_line = ""
            next_line += " "*i
            for j in range(self.size):
                if self.grid[0][i][j] == 1:
                    next_line += "1"
                elif self.grid[1][i][j] == 1:
                    next_line += "2"
                else:
                    next_line += "_"
                next_line += " "
            print(next_line)

    def search_for_win(self):
        #Player 1 wants to connect bottom left ((10, 0), (10, 1), ..., (10, 10)) to top right ((0, 0), (0, 1), ... (0, 10))
        #Player 2 wants to connect top left ((0, 0), (1, 0), ... (10, 0)) to bottom right ((0, 10), (1, 10), ... (10, 10))
        #I'll do a dfs to find a winning path
        
        for player in range(2): #0 is p1, 1 is p2
            stack = []
            visited = [[False for i in range(self.size)] for j in range(self.size)]
            #Initialize search with the side corresponding to the player
            if player == 0:
                for i in range(self.size):
                    if self.grid[0][self.size-1][i] == 1:
                        stack.append([self.size-1, i])
            elif player == 1:
                for i in range(self.size):
                    if self.grid[1][i][0] == 1:
                        stack.append([i, 0])

            #Do the search
            while len(stack)>0:
                [x, y] = stack.pop()
                visited[x][y] = True
                neighbors = self.neighbors(x, y)
                for [newx, newy] in neighbors:
                    if (self.grid[player][newx][newy] == 1) and (not visited[newx][newy]):
                        stack.append([newx, newy])

            #End the search by checking the other side corresponding to the player
            if player == 0:
                for i in range(self.size):
                    if visited[0][i]:
                        self.done = True
                        return 1
            elif player == 1:
                for i in range(self.size):
                    if visited[i][self.size-1]:
                        self.done = True
                        return 2

        #0 means no one has won
        return 0

    def greedy(self, probs):
        #I need to filter out the illegal moves
        valid_probs = {}

        for i in range(self.size):
            for j in range(self.size):
                if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
                    valid_probs[(i, j)] = probs[i][j]
        #for now I'll just take the maximum probability
        sorted_vals = sorted(valid_probs.items(), key=lambda x: x[1], reverse=True)
        (best_x, best_y) = sorted_vals[0][0]
        
        return (best_x, best_y)
      
    def epsilon_greedy(self, probs, epsilon):
        #I need to filter out the illegal moves
        valid_probs = {}

        for i in range(self.size):
            for j in range(self.size):
                if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
                    valid_probs[(i, j)] = probs[i][j]
        
        sorted_vals = sorted(valid_probs.items(), key=lambda x: x[1], reverse=True)
        (best_x, best_y) = sorted_vals[0][0]

        choice = random.random()
        if choice < epsilon:
            (best_x, best_y) = sorted_vals[random.randint(0, len(sorted_vals)-1)][0]
        
        return (best_x, best_y)
    
    def random(self):
        valid_moves = []

        for i in range(self.size):
            for j in range(self.size):
                if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
                    valid_moves.append((i, j))

        index = random.randint(0, len(valid_moves)-1)
        (bext_x, best_y) = valid_moves[index]
        return (bext_x, best_y)

    def step(self, policy, probs = None, epsilon = 0):
        #probs will be a 2d array with probabilities for each hexagon of the grid

        if policy == 'epsilon_greedy':
            (best_x, best_y) = self.epsilon_greedy(probs, epsilon)
        elif policy == 'greedy':
            (best_x, best_y) = self.greedy(probs)
        elif policy == 'random':
            (best_x, best_y) = self.random()
        self.grid[self.current_player - 1][best_x][best_y] = 1
        
        if self.current_player == 1:
            self.current_player = 2
        else:
            self.current_player = 1

        return (best_x, best_y)

In [4]:
class Agent():
    def __init__(self, model, size):
        self.model = model
        self.size = size

    def action(self, grid):
        nn_input = torch.tensor(grid).float()
        nn_input = torch.flatten(nn_input)
        preds = self.model(nn_input)
        out = preds.detach().clone().numpy()
        out = out.reshape((self.size, self.size))
        out = out.tolist()
        return out, preds

In [30]:
model1 = torch.nn.Sequential(
    torch.nn.Linear(50, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 25),
    torch.nn.LogSoftmax()
)
model2 = torch.nn.Sequential(
    torch.nn.Linear(50, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 25),
    torch.nn.LogSoftmax()
)
agent1 = Agent(model1, 5)
agent2 = Agent(model2, 5)

In [22]:
class Controller():
  def __init__(self, size, player1, player2):
    self.size = size
    self.env = HexEnvironment(size)
    self.player1 = player1
    self.player2 = player2

    #These are going to be 3 dimensional. Dim 1 (size: games recorded so far): game playthroughs. Dim 2 (size: number of moves made in the game): moves. Dim 3 (size: 2) the move made and probs
    self.p1preds = []
    self.p2preds = []

    self.winner_history = []

    self.p1_optimizer = torch.optim.Adam(self.player1.model.parameters())
    self.p2_optimizer = torch.optim.Adam(self.player2.model.parameters())
  
  def run_game(self, epsilon=0, printgame = False):
    current_p1preds = []
    current_p2preds = []
    self.env.reset()
    while not self.env.is_done():
      obs = self.env.obs()
      if self.env.current_player == 1:
        move, pred = self.player1.action(obs)
        (movex, movey) = self.env.step('epsilon_greedy', move, epsilon)
        tile_num = movex*self.size + movey
        current_p1preds.append((tile_num, pred))
      else:
        move, pred = self.player2.action(obs)
        (movex, movey) = self.env.step('epsilon_greedy', move, epsilon)
        tile_num = movex*self.size + movey
        current_p2preds.append((tile_num, pred))
      winner = self.env.search_for_win()
      if winner != 0:
        self.winner_history.append(winner)
    if printgame:
      self.env.print_self()

    self.p1preds.append(current_p1preds)
    self.p2preds.append(current_p2preds)

  def test_game(self):
    with torch.no_grad():
      current_p1preds = []
      current_p2preds = []
      self.env.reset()
      self.winner = 0
      while not self.env.is_done():
        obs = self.env.obs()
        if self.env.current_player == 1:
          move, pred = self.player1.action(obs)
          (movex, movey) = self.env.step('greedy', probs=move)
          tile_num = movex*self.size + movey
          current_p1preds.append((tile_num, pred))
        else:
          move, pred = self.player2.action(obs)
          (movex, movey) = self.env.step('greedy', probs=move)
          tile_num = movex*self.size + movey
          current_p2preds.append((tile_num, pred))
        winner = self.env.search_for_win()
        if winner != 0:
          self.winner = winner
        #Since this is in the while loop, it will print the whole history of the game
        self.env.print_self()
      print(current_p1preds)
      print(current_p2preds)
  
  def backprop(self, clip=1):
    #player1 backprop

    p1total_loss = 0
    for i in range(len(self.p1preds)):
      game = self.p1preds[i]
      p1batch_loss = 0
      for (tile_num, pred) in game:
        p1batch_loss += pred[tile_num]
      p1batch_loss = p1batch_loss*1.0/len(game)
      if self.winner_history[i] == 2:
        p1batch_loss *= -1
      p1total_loss += p1batch_loss

    print("Player1 loss: ", p1total_loss)
    p1total_loss.backward()

    #gradient clipping
    torch.nn.utils.clip_grad_norm_(self.player1.model.parameters(), clip)

    self.p1_optimizer.step()

    self.player1.model.zero_grad()

    #player2 backprop

    p2total_loss = 0
    for i in range(len(self.p2preds)):
      game = self.p2preds[i]
      p2batch_loss = 0
      for (tile_num, pred) in game:
        p2batch_loss += pred[tile_num]
      p2batch_loss = p2batch_loss*1.0/len(game)
      if self.winner_history[i] == 1:
        p2batch_loss *= -1
      p2total_loss += p2batch_loss

    print("Player2 loss: ", p2total_loss)
    p2total_loss.backward()

    #gradient clipping
    torch.nn.utils.clip_grad_norm_(self.player2.model.parameters(), clip)

    self.p2_optimizer.step()

    self.player2.model.zero_grad()

    #reset histories
    self.p1preds = []
    self.p2preds = []

  def train(self, games, batch_size):
    for i in range(games):
      epsilon = max(math.e**(-games/1000), 0.1)
      self.run_game(epsilon)
      if i%batch_size == 0 and i != 0:
        self.backprop()

  def p1_against_random(self, n):
    with torch.no_grad():
      wins = 0.0
      total = 0.0
      for i in range(n):
        self.env.reset()
        self.winner = 0
        while not self.env.is_done():
          obs = self.env.obs()
          if self.env.current_player == 1:
            move, pred = self.player1.action(obs)
            (movex, movey) = self.env.step('greedy', probs=move)
          else:
            pred = 'random'
            (movex, movey) = self.env.step('random')
            tile_num = movex*self.size + movey
          winner = self.env.search_for_win()
          if winner == 1:
            self.winner = winner
            wins += 1
            total += 1
          elif winner == 2:
            self.winner = winner
            total += 1
        self.env.print_self()
      print(wins/total)

In [31]:
test = Controller(5, agent1, agent2)

In [19]:
test.test_game()

_ _ _ _ _ 
 _ _ _ _ _ 
  _ _ _ _ _ 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ _ _ _ 2 
  _ _ _ _ _ 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ _ _ _ 2 
  _ _ _ 1 _ 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 _ _ _ _ 2 
  _ _ _ 1 _ 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 _ _ _ _ 2 
  _ _ _ 1 1 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 _ _ _ _ 2 
  _ 2 _ 1 1 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 _ _ _ _ 2 
  _ 2 _ 1 1 
   1 1 _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 _ _ _ _ 2 
  _ 2 2 1 1 
   1 1 _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 1 _ _ _ 2 
  _ 2 2 1 1 
   1 1 _ _ _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 1 _ _ _ 2 
  _ 2 2 1 1 
   1 1 _ 2 _ 
    _ _ _ _ _ 
_ _ 2 _ _ 
 1 _ _ _ 2 
  _ 2 2 1 1 
   1 1 1 2 _ 
    _ _ _ _ _ 
_ 2 2 _ _ 
 1 _ _ _ 2 
  _ 2 2 1 1 
   1 1 1 2 _ 
    _ _ _ _ _ 
_ 2 2 _ _ 
 1 _ _ 1 2 
  _ 2 2 1 1 
   1 1 1 2 _ 
    _ _ _ _ _ 
_ 2 2 2 _ 
 1 _ _ 1 2 
  _ 2 2 1 1 
   1 1 1 2 _ 
    _ _ _ _ _ 
_ 2 2 2 _ 
 1 _ _ 1 2 
  _ 2 2 1 1 
   1 1 1 2 _ 
    _ _ 1 _ _ 
_ 2 2 2 _ 
 1 2 _ 1 2 
  

  input = module(input)


In [32]:
test.p1_against_random(10000)

  input = module(input)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 _ 1 1 2 _ 
  2 1 2 1 _ 
   1 2 2 1 1 
    2 1 _ _ _ 
1 1 1 2 2 
 2 1 1 2 2 
  _ _ 1 1 _ 
   2 1 1 2 1 
    2 _ 1 2 2 
_ 1 _ _ 2 
 1 2 2 2 _ 
  2 1 _ 1 _ 
   _ 1 _ 2 1 
    _ 2 1 1 2 
2 1 1 1 _ 
 1 1 2 2 2 
  2 1 _ 2 _ 
   _ 1 2 2 1 
    2 1 2 1 1 
_ 1 1 _ _ 
 _ 1 1 _ 2 
  2 2 1 1 _ 
   2 1 2 2 1 
    _ 1 2 2 _ 
2 1 1 _ 2 
 2 1 2 _ 1 
  1 2 _ 1 2 
   1 2 2 _ 1 
    1 1 2 2 1 
2 1 1 1 1 
 _ 2 1 1 2 
  2 2 2 2 _ 
   _ 1 2 2 1 
    _ 1 2 _ 1 
_ 1 2 2 2 
 1 1 1 1 2 
  2 2 2 1 _ 
   1 1 1 2 1 
    2 2 1 1 2 
2 1 1 1 2 
 1 1 _ 2 2 
  2 2 2 2 _ 
   2 1 1 2 1 
    _ 1 1 1 2 
2 1 2 1 _ 
 2 1 _ 2 2 
  2 1 _ 2 _ 
   2 1 _ 1 1 
    2 1 1 1 _ 
1 2 2 _ 2 
 1 1 1 2 2 
  2 1 2 1 2 
   1 1 1 1 1 
    _ 2 2 1 2 
2 1 1 2 _ 
 2 1 2 1 2 
  2 1 1 1 _ 
   2 1 2 1 1 
    1 2 _ 2 _ 
_ 1 2 1 2 
 1 1 2 2 2 
  2 1 2 2 2 
   1 1 2 1 1 
    2 1 1 1 _ 
2 1 2 2 2 
 2 1 1 2 _ 
  1 1 2 1 2 
   1 2 1 1 1 
    2 1 2 1 2 
1 2 1 2 _ 
 2 1 2 1 _ 
  2 1 1 1 2 


In [24]:
test.train(10000, 100)

  input = module(input)


Player1 loss:  tensor(58.7746, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-60.1457, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.3504, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-56.8562, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.1238, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-57.1430, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.5165, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-57.2345, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.2945, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-57.3443, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.1539, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-57.3599, grad_fn=<AddBackward0>)
Player1 loss:  tensor(55.7453, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-57.8684, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.2711, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-57.8059, grad_fn=<AddBackward0>)
Player1 loss:  tensor(56.5392, grad_fn=<AddBackward0>)
Player2 loss:  tensor(-58.2216, grad_fn=<AddBackward0>)
P

In [25]:
test.test_game()

_ _ _ _ _ 
 _ _ _ _ _ 
  _ _ _ _ _ 
   1 _ _ _ _ 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ _ _ _ _ 
  _ _ _ _ _ 
   1 2 _ _ _ 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ 1 _ _ _ 
  _ _ _ _ _ 
   1 2 _ _ _ 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ 1 _ _ _ 
  _ _ _ _ _ 
   1 2 2 _ _ 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ 1 _ _ _ 
  _ _ _ _ _ 
   1 2 2 _ 1 
    _ _ _ _ _ 
_ _ _ _ _ 
 _ 1 _ _ _ 
  2 _ _ _ _ 
   1 2 2 _ 1 
    _ _ _ _ _ 
_ _ _ 1 _ 
 _ 1 _ _ _ 
  2 _ _ _ _ 
   1 2 2 _ 1 
    _ _ _ _ _ 
_ _ _ 1 _ 
 _ 1 _ _ _ 
  2 _ _ _ _ 
   1 2 2 _ 1 
    2 _ _ _ _ 
_ _ _ 1 _ 
 _ 1 _ _ _ 
  2 _ 1 _ _ 
   1 2 2 _ 1 
    2 _ _ _ _ 
_ _ 2 1 _ 
 _ 1 _ _ _ 
  2 _ 1 _ _ 
   1 2 2 _ 1 
    2 _ _ _ _ 
_ 1 2 1 _ 
 _ 1 _ _ _ 
  2 _ 1 _ _ 
   1 2 2 _ 1 
    2 _ _ _ _ 
_ 1 2 1 _ 
 _ 1 _ _ _ 
  2 _ 1 _ _ 
   1 2 2 2 1 
    2 _ _ _ _ 
_ 1 2 1 _ 
 _ 1 _ _ _ 
  2 _ 1 _ _ 
   1 2 2 2 1 
    2 _ 1 _ _ 
_ 1 2 1 _ 
 _ 1 _ _ _ 
  2 2 1 _ _ 
   1 2 2 2 1 
    2 _ 1 _ _ 
_ 1 2 1 _ 
 _ 1 1 _ _ 
  2 2 1 _ _ 
   1 2 2 2 1 
    2 _ 1 _ _ 
_ 1 2 1 _ 
 2 1 1 _ _ 
  

  input = module(input)


In [29]:
test.p1_against_random(10000)

  input = module(input)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 _ 1 _ _ 2 
  1 _ 1 _ 2 
   1 _ 2 _ 1 
    1 2 _ _ _ 
_ 1 1 1 2 
 2 1 _ 2 _ 
  1 2 1 2 _ 
   1 _ 2 2 2 
    1 _ 1 _ _ 
2 1 2 1 2 
 _ 1 _ 2 _ 
  1 _ 1 _ _ 
   1 2 _ 2 1 
    1 2 1 _ 2 
_ 1 1 1 2 
 2 1 1 2 2 
  1 2 2 _ 1 
   1 2 2 2 1 
    2 1 1 _ 2 
_ _ 1 _ 2 
 _ 1 _ _ 2 
  1 _ 1 2 _ 
   1 2 2 _ 1 
    2 _ 1 2 _ 
_ _ 1 2 2 
 _ 1 _ _ 2 
  1 _ 1 2 2 
   1 2 _ _ 1 
    1 _ 1 _ 2 
_ 2 1 2 _ 
 2 1 _ _ _ 
  1 2 1 _ 2 
   1 _ _ _ 1 
    1 _ _ 2 _ 
2 1 2 1 1 
 _ 1 1 2 2 
  1 1 1 2 1 
   1 2 _ 2 1 
    2 2 2 _ 2 
_ 2 1 2 1 
 _ 1 1 2 _ 
  1 2 1 2 2 
   1 _ 1 2 1 
    2 1 1 2 2 
_ 1 1 1 1 
 2 1 2 1 _ 
  1 2 1 2 2 
   1 _ 1 2 2 
    2 1 2 2 _ 
_ 2 1 1 _ 
 _ 1 2 2 _ 
  1 2 2 _ _ 
   1 _ _ _ 1 
    1 _ 2 _ _ 
_ _ 1 1 _ 
 _ 1 2 _ 2 
  1 _ 2 2 _ 
   1 _ 2 _ 1 
    2 2 1 _ _ 
_ 2 1 1 _ 
 _ 1 _ 2 2 
  1 _ 1 2 _ 
   1 _ 2 _ 2 
    1 _ _ _ _ 
2 2 1 1 _ 
 _ 1 _ _ _ 
  1 2 2 _ _ 
   1 _ _ _ 1 
    1 _ 2 _ 2 
_ 2 1 1 _ 
 2 1 2 _ _ 
  1 _ 1 2 2 
