<a href="https://colab.research.google.com/github/joshuazhu17/Project_Studio/blob/main/HexGame.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np
import random
import math

In [2]:
class HexEnvironment():
    def __init__(self, boardsize):
        self.size = boardsize
        self.grid = [[], []] #the first is where p1 puts its pieces, the second is where p2 puts its pieces
        for i in range(boardsize):
            p1_row = []
            p2_row = []
            for j in range(boardsize):
                p1_row.append(0)
                p2_row.append(0)
            self.grid[0].append(p1_row)
            self.grid[1].append(p2_row)
        self.current_player = 1
        self.done = False

    def neighbors(self, x, y):
        neighbors = [[x-1, y+1], [x-1, y], [x, y-1], [x+1, y-1], [x+1, y], [x, y+1]]
        real_neighbors = []
        for i in range(len(neighbors)):
            [a, b] = neighbors[i]
            if (a >= 0) and (a < self.size) and (b >= 0) and (b < self.size):
                real_neighbors.append([a, b])
        return real_neighbors

    def obs(self):
        return self.grid

    def is_done(self):
        return self.done

    def get_player(self):
        return self.current_player
    
    def reset(self):
        self.current_player = 1
        self.done = False
        for i in range(self.size):
            for j in range(self.size):
                self.grid[0][i][j] = 0
                self.grid[1][i][j] = 0

    def print_self(self):
        for i in range(self.size):
            next_line = ""
            next_line += " "*i
            for j in range(self.size):
                if self.grid[0][i][j] == 1:
                    next_line += "1"
                elif self.grid[1][i][j] == 1:
                    next_line += "2"
                else:
                    next_line += "_"
                next_line += " "
            print(next_line)

    def search_for_win(self):
        #Player 1 wants to connect bottom left ((10, 0), (10, 1), ..., (10, 10)) to top right ((0, 0), (0, 1), ... (0, 10))
        #Player 2 wants to connect top left ((0, 0), (1, 0), ... (10, 0)) to bottom right ((0, 10), (1, 10), ... (10, 10))
        #I'll do a dfs to find a winning path
        
        for player in range(2): #0 is p1, 1 is p2
            stack = []
            visited = [[False for i in range(self.size)] for j in range(self.size)]
            #Initialize search with the side corresponding to the player
            if player == 0:
                for i in range(self.size):
                    if self.grid[0][self.size-1][i] == 1:
                        stack.append([self.size-1, i])
                        visited[self.size-1][i] = True
            elif player == 1:
                for i in range(self.size):
                    if self.grid[1][i][0] == 1:
                        stack.append([i, 0])
                        visited[i][0] = True

            #Do the search
            while len(stack)>0:
                [x, y] = stack.pop()
                visited[x][y] = True
                neighbors = self.neighbors(x, y)
                for [newx, newy] in neighbors:
                    if (self.grid[player][newx][newy] == 1) and (not visited[newx][newy]):
                        stack.append([newx, newy])

            #End the search by checking the other side corresponding to the player
            if player == 0:
                for i in range(self.size):
                    if visited[0][i]:
                        self.done = True
                        return 1
            elif player == 1:
                for i in range(self.size):
                    if visited[i][self.size-1]:
                        self.done = True
                        return 2

        #0 means no one has won
        return 0

    def greedy(self, probs):
        #I need to filter out the illegal moves
        valid_probs = {}

        for i in range(self.size):
            for j in range(self.size):
                if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
                    valid_probs[(i, j)] = probs[i][j]
        #for now I'll just take the maximum probability
        sorted_vals = sorted(valid_probs.items(), key=lambda x: x[1], reverse=True)
        (best_x, best_y) = sorted_vals[0][0]
        
        return (best_x, best_y)
      
    def epsilon_greedy(self, probs, epsilon):
        #I need to filter out the illegal moves
        valid_probs = {}

        for i in range(self.size):
            for j in range(self.size):
                if self.grid[0][i][j] == 0 and self.grid[1][i][j] == 0:
                    valid_probs[(i, j)] = probs[i][j]
        
        sorted_vals = sorted(valid_probs.items(), key=lambda x: x[1], reverse=True)
        (best_x, best_y) = sorted_vals[0][0]

        choice = random.random()
        if choice < epsilon:
            (best_x, best_y) = sorted_vals[random.randint(0, len(sorted_vals)-1)][0]
        
        return (best_x, best_y)

    def step(self, probs, policy, epsilon = 0):
        #probs will be a 2d array with probabilities for each hexagon of the grid

        if policy == 'epsilon_greedy':
            (best_x, best_y) = self.epsilon_greedy(probs, epsilon)
        elif policy == 'greedy':
            (best_x, best_y) = self.greedy(probs)
        self.grid[self.current_player - 1][best_x][best_y] = 1
        
        if self.current_player == 1:
            self.current_player = 2
        else:
            self.current_player = 1

        return (best_x, best_y)

In [3]:
class Agent():
    def __init__(self, model, size):
        self.model = model
        self.size = size

    def action(self, grid):
        nn_input = torch.tensor(grid).float()
        nn_input = torch.flatten(nn_input)
        preds = self.model(nn_input)
        out = preds.detach().clone().numpy()
        out = out.reshape((self.size, self.size))
        out = out.tolist()
        return out, preds

In [14]:
model1 = torch.nn.Sequential(
    torch.nn.Linear(50, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 25),
    torch.nn.LogSoftmax()
)
model2 = torch.nn.Sequential(
    torch.nn.Linear(50, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 25),
    torch.nn.LogSoftmax()
)
agent1 = Agent(model1, 5)
agent2 = Agent(model2, 5)

In [11]:
class Controller():
  def __init__(self, size, player1, player2):
    self.size = size
    self.env = HexEnvironment(size)
    self.player1 = player1
    self.player2 = player2
    self.p1preds = []
    self.p2preds = []
    self.winner = 0
    self.p1_optimizer = torch.optim.Adam(self.player1.model.parameters())
    self.p2_optimizer = torch.optim.Adam(self.player2.model.parameters())
  
  def run_game(self, epsilon=0):
    self.env.reset()
    self.winner = 0
    while not self.env.is_done():
      obs = self.env.obs()
      if self.env.current_player == 1:
        move, pred = self.player1.action(obs)
        (movex, movey) = self.env.step(move, 'epsilon_greedy', epsilon)
        tile_num = movex*self.size + movey
        self.p1preds.append((tile_num, pred))
      else:
        move, pred = self.player2.action(obs)
        (movex, movey) = self.env.step(move, 'epsilon_greedy', epsilon)
        tile_num = movex*self.size + movey
        self.p2preds.append((tile_num, pred))
      winner = self.env.search_for_win()
      if winner != 0:
        self.winner = winner
    self.env.print_self()
  
  def backprop(self, clip=1):
    p1advantage = 1
    p2advantage = -1
    if self.winner == 2:
      p1advantage = -1
      p2advantage = 1
    #player1 backprop

    p1preds_tensor = self.p1preds[0][1].view(1, -1)
    target_tensor1 = torch.tensor([self.p1preds[0][0]])
    for i in range(2, len(self.p1preds)):
      tile_num = torch.tensor([self.p1preds[i][0]])
      pred = self.p1preds[i][1].view(1, -1)

      p1preds_tensor = torch.cat((p1preds_tensor, pred))
      target_tensor1 = torch.cat((target_tensor1, tile_num))

    criterion = torch.nn.NLLLoss()

    loss = criterion(p1preds_tensor, target_tensor1)
    loss = loss*p1advantage/(len(self.p1preds)**2)
    print(loss)
    loss.backward()

    #gradient clipping
    torch.nn.utils.clip_grad_norm_(self.player1.model.parameters(), clip)

    self.p1_optimizer.step()

    self.player1.model.zero_grad()
    #player2 backprop

    p2preds_tensor = self.p2preds[0][1].view(1, -1)
    target_tensor2 = torch.tensor([self.p2preds[0][0]])
    for i in range(2, len(self.p2preds)):
      tile_num = torch.tensor([self.p2preds[i][0]])
      pred = self.p2preds[i][1].view(1, -1)

      p2preds_tensor = torch.cat((p2preds_tensor, pred))
      target_tensor2 = torch.cat((target_tensor2, tile_num))

    criterion = torch.nn.NLLLoss()

    loss = criterion(p2preds_tensor, target_tensor2)
    loss = loss*p2advantage/(len(self.p2preds)**2)
    print(loss)
    loss.backward()

    #gradient clipping
    torch.nn.utils.clip_grad_norm_(self.player2.model.parameters(), clip)

    self.p2_optimizer.step()

    self.player2.model.zero_grad()

    #reset histories
    self.p1preds = []
    self.p2preds = []

  def train(self, games):
    for i in range(games):
      epsilon = max(math.e**(-games/1000), 0.1)
      self.run_game(epsilon)
      self.backprop()

In [15]:
test = Controller(5, agent1, agent2)

In [16]:
test.train(10000)

  input = module(input)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
tensor(-2034.0027, grad_fn=<DivBackward0>)
_ _ _ 1 _ 
 _ _ 2 1 _ 
  2 _ 1 _ _ 
   _ 1 _ 2 _ 
    1 _ _ 2 _ 
tensor(0.0020, grad_fn=<DivBackward0>)
tensor(-2702.1467, grad_fn=<DivBackward0>)
_ _ _ 1 2 
 _ _ 2 1 _ 
  2 _ 1 _ _ 
   _ 1 _ _ _ 
    1 _ _ 2 _ 
tensor(0.0020, grad_fn=<DivBackward0>)
tensor(-2702.8721, grad_fn=<DivBackward0>)
_ _ _ 1 _ 
 _ _ 2 1 2 
  2 _ 1 _ _ 
   _ 1 _ _ _ 
    1 2 _ _ _ 
tensor(0.0027, grad_fn=<DivBackward0>)
tensor(-2990.5916, grad_fn=<DivBackward0>)
_ _ _ 1 _ 
 _ _ 2 1 2 
  2 _ 1 _ 1 
   2 2 1 _ _ 
    _ _ 1 _ _ 
tensor(0.0042, grad_fn=<DivBackward0>)
tensor(-1645.8611, grad_fn=<DivBackward0>)
_ _ 1 _ _ 
 _ _ 1 2 _ 
  2 _ 1 _ 2 
   _ 1 2 _ _ 
    1 _ _ _ _ 
tensor(0.0030, grad_fn=<DivBackward0>)
tensor(-2940.2488, grad_fn=<DivBackward0>)
_ _ 1 _ _ 
 _ _ 1 2 _ 
  2 _ 1 _ 2 
   2 1 _ _ _ 
    1 _ _ _ _ 
tensor(0.0018, grad_fn=<DivBackward0>)
tensor(-2598.6199, grad_fn=<DivBackward0>)
_ _ 1 _ _ 