In [1]:
import numpy as np
import torch
import random

In [2]:
import torch.nn as nn
import torch.nn.functional as F

NUM_MOVES = 81 # Max number of valid moves in a game of NestedTTT

class NestedTTTNet(nn.Module):
    def __init__(self, inner_board_units = 32, outer_board_units = 64, final_units = 256, n_res_layers = 2):
        super(NestedTTTNet, self).__init__()
        self.n_res_layers = n_res_layers
        self.outer_board_units = outer_board_units
        self.final_units = final_units
        
        self.conv1 = nn.Conv2d(3, inner_board_units, 3, padding = 1)
        self.bn_inner = nn.BatchNorm2d(inner_board_units)
        self.res1 = nn.Conv2d(inner_board_units, inner_board_units, 3, padding = 1)
        
        self.conv2_collapse = nn.Conv2d(inner_board_units, outer_board_units, 3)
        self.bn_outer = nn.BatchNorm2d(outer_board_units)
        self.res2 = nn.Conv2d(outer_board_units, outer_board_units, 3, padding = 1)
        
        self.conv3_collapse = nn.Conv2d(outer_board_units, final_units, 3)
        self.bn_linear = nn.BatchNorm1d(final_units)
        
        self.policy_head = nn.Sequential(
            nn.Linear(final_units, NUM_MOVES),
            nn.Softmax(dim = 1)
        )
        
        self.value_head = nn.Sequential(
            nn.Linear(final_units, final_units // 2),
            nn.ReLU(),
            nn.Linear(final_units // 2, 1),
            nn.Tanh()
        )
        
    def forward(self, board_states):
        '''
        board_states.size() == (-1, 3, 3, 3, 3, 3)
                            == (batch_size, [x-owned, o-owned, turn-state], outer_row, outer_col, inner_row, inner_col)
        '''
        # Move the outer board states to the left so their intra-square can be resolved before introducing inter-square
        board_states = board_states.permute(0, 2, 3, 1, 4, 5).contiguous().view(-1, 3, 3, 3)
        
        x = F.relu(self.bn_inner(self.conv1(board_states)))
        for _ in range(self.n_res_layers):
            x_conv = F.relu(self.bn_inner(self.res1(x)))
            x = F.relu(x + self.bn_inner(self.res1(x_conv)))
            
        x = F.relu(self.bn_outer(self.conv2_collapse(x)))
        x = x.view(-1, 3, 3, self.outer_board_units).permute(0, 3, 1, 2).contiguous() # Move them back to the right
        
        for _ in range(self.n_res_layers):
            x_conv = F.relu(self.bn_outer(self.res2(x)))
            x = F.relu(x + self.bn_outer(self.res2(x_conv)))
        
        x = F.relu(self.bn_linear(self.conv3_collapse(x).view(-1, self.final_units)))
        
        policy = self.policy_head(x)
        value = self.value_head(x)
        
        return policy, value

In [3]:
import torch

class NestedTTT:

    anti_diag_tensor = torch.ByteTensor([[1 if r + c == 2 else 0 for c in range(3)] for r in range(3)])

    def __init__(self):
        #Board State Tensor
        #Outermost dimension:
        #First value represents Player 0's (X's) holdings, second represents O's
        #Third is filled with the number of the current player to move (0 = X, 1 = O)
        #Rest of the dimensions are (Outer Row, Outer Column, Inner Row, Inner Column)
        self.state = torch.full((3, 3, 3, 3, 3), 0, dtype = torch.float32, requires_grad = False)

        #Valid Move Tensor
        #(Outer Row, Outer Column, Inner Row, Inner Column)
        self.valid_moves = torch.full((3, 3, 3, 3), 1, dtype = torch.uint8, requires_grad = False)

        #Summary Board Tensor
        #(Player, Row, Column)
        self.summary_boards = torch.full((2, 3, 3), 0, dtype = torch.uint8, requires_grad = False)

        #Current Player
        self.current_player = 0
        
    def check_win(self, board):
        """
        Short-circuits if current player has less than 3 squares on the current board
        If they do, checks all rows, columns, and diagonals for a win

        Returns 0 if the current player did not win the board in question, or 1 if they did
        """
        if board.sum().item() < 3:
            return 0
        else:
            return (board.sum(0) == 3).any().item() or \
                   (board.sum(1) == 3).any().item() or \
                   (board.diag().sum() == 3).item()   or \
                   (board.masked_select(self.anti_diag_tensor).sum() == 3).item()

    def make_move(self, move):
        """
        Takes a move:
          ({0 for X, 1 for O}, outer row, outer col, inner row, inner col)
        Returns 1 if the move won the game, 0 otherwise
        """
        if not self.valid_moves[move[1:]].item():
            raise ValueError("Invalid move placement")
        if move[0] != self.state[2,0,0,0,0].item():
            raise ValueError("Only the current player can make a move")
        self.state[move] = 1
        self.valid_moves[move[1:]] = 0
        self.state[2] = 1 - self.state[2]

        if self.check_win(self.state[move[:3]]):
            self.summary_boards[move[:3]] = 1
            self.valid_moves[move[1:3]] = 0
            if self.check_win(self.summary_boards[move[0]]):
                self.valid_moves.fill_(0)
                return 1

        return 0

    def undo_move(self, move):
        if not self.state[move].item():
            raise ValueError("Can only undo where a move has been played")
        if move[0] == self.current_player:
            raise ValueError("Cannot undo a move for the current player")
        if self.state[(2,0,0,0,0)].item() != self.current_player:
            raise ValueError("Desynchronization between current player and state tensor")
        self.state[move] = 0
        self.valid_moves[move[1:3]] = 1 - self.state[:2, move[1], move[2]].sum(0)

        self.summary_boards[move[:3]] = 0
        self.state[2] = 1 - self.state[2]

    def switch_player(self):
        self.current_player = (self.current_player + 1) % 2
        
    def copy(self):
        c = self.__new__(NestedTTT)
        c.state = self.state.clone().detach().detach()
        c.valid_moves = self.valid_moves.clone().detach()
        c.summary_boards = self.summary_boards.clone().detach()
        c.current_player = self.current_player
        return c
    
    def reset(self):
        self.__init__()
        
    def get_valid_moves(self):
        return self.valid_moves.nonzero().tolist()

In [4]:
from math import sqrt
import random

EXPLORATION_PARAM = sqrt(2)

class Node:
    """
    Nodes represent a state in the game.  They can have edges which
    indicate legal actions to take from the state
    """
    def __init__(self, incoming_edge):
        self.parent_edge = incoming_edge
        self.leaf = True
        self.edges = dict()

    def choose_action(self, temperature = 1):
        if not self.edges:
            raise ValueError("Cannot choose action from an unexplored node")

        visits_exp = sum([edge.N ** (1 / temperature) for edge in self.edges.values()])
        thresh = random.random()
        cumsum = 0
        for action, edge in self.edges.items():
            cumsum = cumsum + edge.N ** (1 / temperature)
            if thresh < cumsum / visits_exp:
                return action

    def selection_step(self):
        if self.leaf:
            raise ValueError("Cannot perform selection step at a leaf")
        max_result = None
        max_action = None
        
        total_trials = sum([edge.N for edge in self.edges.values()])

        for action, edge in self.edges.items():
            current_estimate = edge.value_estimate(total_trials)
            if max_result is None or max_result < current_estimate:
                max_result = current_estimate
                max_action = action

        return max_action
    
    def expansion(self, actions, ps):
        if actions:
            for action, p in zip(actions, ps):
                self.edges[action] = Edge(self, p)
            self.leaf = False
        
    def backpropagate(self, value):
        if self.parent_edge is None:
            raise ValueError("Can't start backpropagating at the root")
        else:
            #Current state for player A is a result of the action (edge) chosen by player B
            #If the current state has a good value for player A, then it has a bad value for player B
            #Value of current state (node) is inverted to update value of action taken (incoming edge)
            self.parent_edge.backpropagate(-value)

class Edge:
    """
    Edges represent transitions between game states, initiated by an action
    Edges keep track of the number of visits, the estimated value of taking
    their associated action, and the prior probability of taking that action
    """
    def __init__(self, origin, p):
        self.origin = origin
        self.destination = Node(self)
        self.N = 0 # Visits
        self.Q = 0 # Value estimate
        self.P = p # NN-generated prior

    def backpropagate(self, value):
        self.N += 1
        self.Q = self.Q * ((self.N - 1) / self.N) + value / self.N
        
        #Alternating edges indicate alternating player actions, so the value of states will be inverted at each layer
        if self.origin.parent_edge is not None:
            self.origin.parent_edge.backpropagate(-value)

    def value_estimate(self, parent_trials):
        return self.Q + EXPLORATION_PARAM * self.P * parent_trials ** .5 / (1 + self.N)
    
    def __str__(self):
        return "N={}, Q={}, P={}".format(self.N, self.Q, self.P)
    
    def __repr__(self):
        return str(self)

In [15]:
from time import time
import logging

class AlphaMCTSAgent:
    def __init__(self, control_net = None):
        self.root = Node(None)
        self.playout_total = 0
        self.control_net = control_net
        
    def update_control_net(self, control_net):
        self.control_net = control_net
        
    def search(self, game, turn_num, allotted_playouts = 800):
        if self.control_net is None:
            raise ValueError("Control net must be set before starting searches")
        playouts = 0

        start_time = time()
        while playouts < allotted_playouts:
            current_game = game.copy()
            current_game, node = self._selection(current_game)

            actions = current_game.get_valid_moves()
            ps, v = self.control_net(current_game.state.unsqueeze(0))
            ps.squeeze_()
            v = v.squeeze().item()

            indices = [self.ttt_position_to_index(action) for action in actions]
            
            actions = [(round(current_game.state[2,0,0,0,0].item()),) + tuple(action) for action in actions]

            node.expansion(actions, ps.detach().numpy()[indices])
            
            if id(self.root) != id(node):
                node.backpropagate(v)

            playouts += 1

        temp = 1 if turn_num < 30 else .1
        max_action = self.root.choose_action(temp)
        
        # debug info
        self.playout_total = playouts
        self.action_node = max_action
        
        actions = game.get_valid_moves()
        indices = [self.ttt_position_to_index(action) for action in actions]
        actions = [(round(game.state[2,0,0,0,0].item()),) + tuple(action) for action in actions]
        
        mcts_probs = {indices[i]: self.root.edges[actions[i]].N for i in range(len(indices))}
        total_trials = sum(mcts_probs.values())
        mcts_probs = [mcts_probs.get(x, 0) / total_trials for x in range(81)]
        
        return max_action, mcts_probs
    
    def take_action(self, action):
        try:
            self.root = self.root.edges[action].destination
        except KeyError:
            print("Action not found, throwing away tree")
            self.reset()
            
    def reset(self):
        self.root = Node(None)

    def _selection(self, state):
        """
        Progress through the tree of Nodes, starting at the root, until a
        leaf is found or there are unexplored actions at the level we are
        exploring.

        Uses the UCT algorithm to determine which nodes to progress to.
        """
        node = self.root
        while not node.leaf:
            move = node.selection_step()
            node = node.edges[move].destination
            state.make_move(move)
        return state, node
    
    @staticmethod
    def index_to_ttt_position(idx):
        return idx // 27 % 3, idx // 9 % 3, idx // 3 % 3, idx % 3

    @staticmethod
    def ttt_position_to_index(position):
        return position[0] * 27 + position[1] * 9 + position[2] * 3 + position[3]

In [16]:
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity = 100000):
        self.buffer = deque(maxlen = capacity)
    
    def push(self, data):
        self.buffer.append(data)
        
    def extend(self, data):
        self.buffer.extend(data)
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

In [17]:
# One class that has access to:
  # Game simulation
  # Control network
  # Current network
  # Replay buffer
  # MCTS agent
  # Training code  -- Another process?

In [22]:
import torch
import random

def zero_gen():
    while True:
        yield 0

def one_neg_one_gen():
    while True:
        yield 1
        yield -1

class SelfPlayTrainer:
    def __init__(self, agent, game, buffer_file = None, weights_file = None, n_batches = 0):
        self.agent = agent
        self.game = game
        self.replay_buffer = ReplayBuffer()
        
        if buffer_file is not None:
            self.replay_buffer.buffer = pickle.load(open(buffer_file, "rb"))
        
        self.current_network = NestedTTTNet()
        self.control_network = NestedTTTNet()
        
        if weights_file is not None:
            self.control_network.load_state_dict(torch.load(weights_file))
        
        self.current_network.load_state_dict(self.control_network.state_dict())
        self.control_network.eval()
        self.current_network.train()
        
        self.agent.update_control_net(self.control_network)
        
        self.n_batches = n_batches
        
        self.optim = torch.optim.Adam(self.current_network.parameters(), lr = .01, weight_decay = 10e-4)
    
    def generate_self_play_data(self, n_games = 100):
        for _ in range(n_games):
            turn_num = 0
            self.game.reset()
            self.agent.reset()
            result = 0
            player_num = 0

            states = []
            move_vectors = []

            while len(self.game.get_valid_moves()) > 0:
                move, move_probs = self.agent.search(self.game.copy(), turn_num, allotted_playouts = 400)

                states.append(self.game.state.tolist())
                move_vectors.append(move_probs)

                result = self.game.make_move(move)
                if not result:
                    self.game.switch_player()
                    self.agent.take_action(move)
                    turn_num += 1
                    player_num = (player_num + 1) % 2

            if not result:
                self.replay_buffer.extend(list(zip(states, move_vectors, zero_gen())))
            else:
                self.replay_buffer.extend(list(zip(states[::-1], move_vectors[::-1], one_neg_one_gen()))[::-1])

    def compare_control_to_train(self):
        self.current_net.eval()
        old_agent = AlphaMCTSAgent(control_net = self.control_net)
        new_agent = AlphaMCTSAgent(control_net = self.current_net)
        
        agents = [old_agent, new_agent]
        
        wins = 0
        ties = 0
        
        game = self.game.copy()
        
        for game_num in range(100):
            game.reset()
            agents[0].reset()
            agents[1].reset()
            result = 0
            player_num = game_num // 50 #Both take first turn 50 times
            turn_num = 100 #Turn down the temperature
            
            while len(game.get_valid_moves()) > 0:
                move, _ = agents[player_num].search(game.copy(), turn_num, allotted_playouts = 800)
                _, _ = agents[1 - player_num].search(game.copy(), turn_num, allotted_playouts = 800)

                result = self.game.make_move(move)
                if not result:
                    game.switch_player()
                    agents[0].take_action(move)
                    agents[1].take_action(move)
                    player_num = (player_num + 1) % 2

            if not result:
                ties += 1
            elif result and current_player == 1:
                wins += 1
        
        if wins + .5 * ties >= 55:
            print("Challenger network won {} games and tied {} games; it becomes new control network".format(wins, ties))
            torch.save(self.current_network.state_dict(), "control_weights_{}.pth".format(self.n_batches))
            self.control_network.load_state_dict(self.current_network.state_dict())
        else:
            print("Challenger network not sufficiently better; {} wins and {} ties".format(wins, ties))
        
        self.control_network.eval()
        self.current_network.train()            
    
    def train_on_batch(self, batch_size = 32):
        if len(self.replay_buffer) < batch_size:
            return
        
        self.current_network.train()
        
        sample = self.replay_buffer.sample(batch_size)
        states, probs, rewards = zip(*sample)
        states = torch.FloatTensor(states).requires_grad_(True)
        probs = torch.FloatTensor(probs).requires_grad_(True)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).requires_grad_(True)
        self.optim.zero_grad()
        
        ps, vs = trainer.current_network(states)
        
        loss = torch.nn.functional.mse_loss(vs, rewards) - (ps.log() * probs).sum()
        loss.backward()
        
        self.optim.step()
        
        self.n_batches += 1
        
        return loss.item()
    
    def run(self, total_runs = 100, self_play_games = 20, training_batches = 30, batch_size = 32):
        losses = []
        for run_num in range(1, total_runs+1):
            print("Run {} of {}".format(run_num, total_runs))
            for selfplay_num in range(1, self_play_games + 1):
                self.generate_self_play_data(self_play_games)
                print("\tFinished self-play game {} of {} (Buffer size {})".format(selfplay_num, self_play_games, len(self.replay_buffer)))
            print("Finished {} self-play games".format(self_play_games))
            for _ in range(training_batches):
                losses.append(self.train_on_batch(batch_size))
                if len(losses) == 5:
                    print("\tLoss for last 5 batches: {}".format(sum(losses)))
                    losses = []

In [23]:
trainer = SelfPlayTrainer(AlphaMCTSAgent(), NestedTTT())

In [None]:
trainer.run()

Run 1 of 100


In [None]:
batch_size = 32
sample = trainer.replay_buffer.sample(batch_size)
states, probs, rewards = zip(*sample)
states = torch.FloatTensor(states).requires_grad_(True)
probs = torch.FloatTensor(probs).requires_grad_(True)
rewards = torch.FloatTensor(rewards).unsqueeze(1).requires_grad_(True)

In [None]:
ps, vs = trainer.current_network(states)

In [None]:
vs.sum()

In [None]:
trainer.train_on_batch()

In [None]:
ps, vs = trainer.current_network(states)

In [None]:
vs.sum()

In [None]:
torch.save(trainer.current_network.state_dict(), "latest_weights.pth")