In [1]:
import numpy as np
import matplotlib.pyplot as plt
import copy
import pickle

import tensorflow as tf
import tensorflow.keras as tfk

INF = float('inf')

In [2]:
class TicTocToeGame:
    # requirements:
    #    1) discrete            => both the states and actions are discrete
    #    2) deterministic       => taking an action leads to a set outcome
    #                              (no randomness as a result of an action)
    #    3) competitive         => players compete against each other
    #    4) perfect information => both players see everything
    def __init__(self):
        self.reset()
        self.action_space = np.arange(9)
    
    def step(self, action):
        self.state = self.state.next_state(action)
        self.player = self.player^1
    
    def reset(self):
        self.state = GameState()
        self.player = 0
    
    def start_state(self):
        return GameState()
    
class TicTocToeGameState:
    def __init__(self, board=None):
        # 0 | 1 | 2
        # 3 | 4 | 5
        # 6 | 7 | 8
        self.board = board
        if self.board is None:
            self.board = np.zeros((3, 3, 2), dtype=np.int32)
    
    def __str__(self):
        X = self.board[:,:,0] + 2*self.board[:,:,1]
        return '\n'.join([''.join(map(lambda x : '-XO'[x], X[i].flatten())) for i in range(3)])
        # return ''.join(map(lambda x : '-XO'[x], X.flatten()))
    
    def evaluate(self):
        # returns (is_end, reward)
        
        for p in range(2):
            # diagonals
            if self.board[1,1,p] == 1:
                if self.board[1,1,p] == self.board[0,0,p] and \
                   self.board[1,1,p] == self.board[2,2,p]:
                    return True, -2*p + 1
                if self.board[1,1,p] == self.board[0,2,p] and \
                   self.board[1,1,p] == self.board[2,0,p]:
                    return True, -2*p + 1
            # rows
            for i in range(3):
                if np.sum(self.board[i,:,p]) == 3:
                    return True, -2*p + 1
            # columns
            for j in range(3):
                if np.sum(self.board[:,j,p]) == 3:
                    return True, -2*p + 1
        if self.board.sum() == 9:
            return True, 0
        return False, 0
    
    def get_valid_actions(self, action_space):
        actions = []
        for ij in action_space:
            if self.board[ij//3, ij%3, 0] == 0 and self.board[ij//3, ij%3, 1] == 0:
                actions.append(ij)
        return actions
    
    def mirror(self):
        new_state = TicTocToeGameState(self.board.copy())
        new_state.board[:,:,[0, 1]] = new_state.board[:,:,[1, 0]]
        return new_state
    
    def next_state(self, action):
        new_state = TicTocToeGameState(self.board.copy())
        new_state.board[action//3, action%3, 0] = 1
        return new_state.mirror()

In [3]:
class Connect4Game:
    def __init__(self):
        self.reset()
        self.action_space = np.arange(7)
    
    def step(self, action):
        self.state = self.state.next_state(action)
        self.player = self.player^1
    
    def reset(self):
        self.state = Connect4GameState()
        self.player = 0
    
    def start_state(self):
        return Connect4GameState()
    
class Connect4GameState:
    def __init__(self, board=None, board_size=(6, 7, 2)):
        self.board_size = board_size
        self.board = board
        if self.board is None:
            self.board = np.zeros(self.board_size, dtype=np.int32)
    
    def __str__(self):
        X = self.board[:,:,0] + 2*self.board[:,:,1]
        return '\n'.join([''.join(map(lambda x : '-XO'[x], X[i].flatten())) for i in range(self.board_size[0])])
    
    def inside(self, i, j):
        return 0 <= i and i < self.board_size[0] and 0 <= j and j < self.board_size[1]
    
    def evaluate(self):
        dr = [1, 1, 0, -1]
        dc = [0, 1, 1, 1]
        for p in range(2):
            for i in range(self.board_size[0]):
                for j in range(self.board_size[1]):
                    for k in range(4):
                        li = i + 3*dr[k]
                        lj = j + 3*dc[k]
                        if not self.inside(li, lj):
                            continue
                        cnt = 0
                        for l in range(4):
                            cnt += self.board[i + l*dr[k], j + l*dc[k], p]
                        if cnt == 4:
                            return True, -2*p + 1
        if self.board.sum() == self.board_size[0] * self.board_size[1]:
            return True, 0
        return False, 0
    
    def get_valid_actions(self, action_space):
        actions = []
        for j in range(self.board_size[1]):
            if self.board[:,j,:].sum() != self.board_size[0]:
                actions.append(j)
        return actions
    
    def mirror(self):
        new_state = Connect4GameState(self.board.copy())
        new_state.board[:,:,[0, 1]] = new_state.board[:,:,[1, 0]]
        return new_state
    
    def next_state(self, action):
        new_state = Connect4GameState(self.board.copy())
        updated = False
        for i in range(self.board_size[0]-1, -1, -1):
            if new_state.board[i, action, 0] == 0 and new_state.board[i, action, 1] == 0:
                new_state.board[i, action, 0] = 1
                updated = True
                break
        if not updated:
            assert(False)
        return new_state.mirror()

In [4]:
class NeuralNet:
    def __init__(self, input_shape, num_conv=2, conv_channels=16,
                 dropout_rate=0.3, proba_len=9):
        self.input_boards = tfk.layers.Input(shape=input_shape)
        x = self.input_boards
        for _ in range(num_conv):
            x = tfk.layers.Conv2D(conv_channels, 3, padding='same')(x)
            x = tfk.layers.BatchNormalization(axis=3)(x)
            x = tfk.layers.Activation('elu')(x)
        x = tfk.layers.Flatten()(x)
        x = tfk.layers.Dense(64)(x)
        x = tfk.layers.BatchNormalization(axis=1)(x)
        x = tfk.layers.Activation('elu')(x)
        x = tfk.layers.Dropout(dropout_rate)(x)
        x = tfk.layers.Dense(32)(x)
        x = tfk.layers.BatchNormalization(axis=1)(x)
        x = tfk.layers.Activation('elu')(x)
        x = tfk.layers.Dropout(dropout_rate)(x)
        self.pi = tfk.layers.Dense(proba_len, activation='softmax', name='pi')(x)
        self.v  = tfk.layers.Dense(1, activation='tanh', name='v')(x)
        
        self.model = tfk.models.Model(inputs=self.input_boards, 
                                      outputs=[self.pi, self.v])
        self.model.compile(loss=['categorical_crossentropy', 'mean_squared_error'],
                           optimizer='adam')
    
    def train(self, states, probabilities, values, epochs=100, verbose=0):
        early_stop = tfk.callbacks.EarlyStopping(monitor='val_loss', patience=3)
        
        self.model.fit(states, [probabilities, values], verbose=verbose,
                       epochs=epochs, validation_split=0.20, batch_size=16,
                       callbacks=[early_stop])
    
    def predict(self, state):
        prob, val = self.model.predict(np.array([state.board]))
        return prob[0], val[0,0]

In [5]:
class MCSTNode:
    def __init__(self, state, num_actions):
        self.state = state
        self.Ns = 0
        self.N  = np.zeros(num_actions)
        self.Q  = np.zeros(num_actions)
        self.P  = np.zeros(num_actions)
        self.is_leaf = True

class MCST:
    def __init__(self, cpuct=0.5, epsilon=0.25, alpha=np.full(9, 1.0/9),
                 num_simulations=64, temp=1.0):
        self.state_to_node = {}
        self.root = None
        self.cpuct = cpuct
        self.epsilon = epsilon
        self.alpha = alpha
        self.num_simulations = num_simulations
        self.temp = temp
        
    def set_root(self, game):
        if str(game.state) not in self.state_to_node:
            self.root = MCSTNode(game.state, len(game.action_space))
        else:
            self.root = self.state_to_node[str(game.state)]
        self.root.P = (1-self.epsilon)*self.root.P + self.epsilon*np.random.dirichlet(self.alpha,)
        self.state_to_node[str(game.state)] = self.root
    
    def search(self, node, game, nnet):
        is_end, reward = node.state.evaluate()
        if is_end:    return -reward
        
        # if this is a leaf node
        if node.is_leaf:
            node.is_leaf = False
            node.P, v = nnet.predict(node.state)
            return -v
        
        CPT = node.P * np.sqrt(node.Ns) / (1 + node.N)
        U = node.Q + self.cpuct * CPT
        
        U_masked = np.full(U.shape, -INF)
        valid_actions = node.state.get_valid_actions(game.action_space)
        U_masked[valid_actions] = U[valid_actions]
        
        a = np.where(U_masked == np.max(U_masked))[0][0]
        
        new_state = node.state.next_state(a)
        if str(new_state) not in self.state_to_node:
            new_node = MCSTNode(new_state, len(game.action_space))
            self.state_to_node[str(new_state)] = new_node
        else:
            new_node = self.state_to_node[str(new_state)]
        v = self.search(new_node, game, nnet)
        
        node.Q[a] = (node.Q[a] * node.N[a] + v) / (1 + node.N[a])
        node.N[a] += 1
        node.Ns += 1
        return -v
    
    def self_play(self, node, game, nnet):
        for _ in range(self.num_simulations):
            self.search(node, game, nnet)
    
    def play(self, node, game, nnet, mode='training'):
        if mode == 'predictive' or node.Ns == 0:
            state_equiv = node.state.get_random_identity()
            action_proba = nnet.predict(state_equiv)[0]
        elif mode == 'training':
            action_proba = node.N**(1.0/self.temp) / node.Ns**(1/self.temp)
        elif mode == 'competitive':
            action_proba = np.zeros(len(game.action_space))
            action_proba[np.where(node.N == node.N.max())] = 1
        else:
            assert(False)
        
        valid_actions = node.state.get_valid_actions(game.action_space)
        
        action_proba_masked = np.full(action_proba.shape, 0.0)
        action_proba_masked[valid_actions] = action_proba[valid_actions]
        action_proba_masked /= action_proba_masked.sum()
        return action_proba_masked
    
    def simulate_episode(self, node, game, nnet, list_data):
        is_end, reward = node.state.evaluate()
        if is_end:    return -reward
        
        self.self_play(node, game, nnet)
        action_proba = self.play(node, game, nnet, mode='training')
        action = np.random.choice(len(action_proba), p=action_proba)
        
        new_state = node.state.next_state(action)
        if str(new_state) not in self.state_to_node:
            new_node = MCSTNode(new_state, len(game.action_space))
            self.state_to_node[str(new_state)] = new_node
        else:
            new_node = self.state_to_node[str(new_state)]
        v = self.simulate_episode(new_node, game, nnet, list_data)
        
        list_data[0].append(node.state.board)
        list_data[1].append(action_proba)
        list_data[2].append(v)
        
        return -v

In [6]:
class Agent:
    def __init__(self, nnet_input_shape=(3,3,2),
                 alpha=np.random.dirichlet(np.full(9,2)),
                 num_simulations=32, temp=1.0, proba_len=9):
        self.nnet = NeuralNet(nnet_input_shape, proba_len=proba_len)
        self.alpha = alpha
        self.num_simulations = num_simulations
        self.temp = temp
    
    def choose_action(self, game, mode='competitive', return_proba=False):
        self.mcst = MCST(cpuct=2, alpha=self.alpha,
                         num_simulations=self.num_simulations,
                         temp=self.temp)
        self.mcst.set_root(game)
        self.mcst.self_play(self.mcst.root, game, self.nnet)
        action_proba = self.mcst.play(self.mcst.root, game, self.nnet, mode=mode)
        action = np.random.choice(len(action_proba), p=action_proba)
        if return_proba:
            return action, action_proba, self.mcst.root.N
        return action
    
    def simulate_episode(self, game):
        self.mcst = MCST(cpuct=2, alpha=self.alpha,
                         num_simulations=self.num_simulations,
                         temp=self.temp)
        self.mcst.set_root(game)
        list_data = [[], [], []]
        self.mcst.simulate_episode(self.mcst.root, game, self.nnet, list_data)
        return list_data

In [17]:
class AlphaZero:
    def __init__(self, memory_limit=100000, nnet_input_shape=(3,3,2),
                 alpha=np.random.dirichlet(np.full(7,2)), proba_len=7):
        self.memory_limit = memory_limit
        self.nnet_input_shape = nnet_input_shape
        self.alpha = alpha
        self.proba_len = proba_len
        
        self.memory = [np.array([]), 
                       np.array([]),
                       np.array([])]
        self.best_agent = Agent(nnet_input_shape=self.nnet_input_shape,
                                alpha=self.alpha, proba_len=self.proba_len)
        self.curr_agent = None
    
    def train(self, num_episodes=128, duel_interval=10,
              num_duels=20, epochs=100, verbose=0):
        
        if self.curr_agent is None:
            self.curr_agent = Agent(nnet_input_shape=self.nnet_input_shape,
                                    alpha=self.alpha, proba_len=self.proba_len)
            self.curr_agent.nnet.model = tfk.models.clone_model(self.best_agent.nnet.model)
            self.curr_agent.nnet.model.compile(loss=['categorical_crossentropy', 'mean_squared_error'],
                                               optimizer='adam')
        
        for ep in range(num_episodes):
            game = Connect4Game()
            
            data = self.curr_agent.simulate_episode(game)
            for i in range(3):
                if len(self.memory[i]) == 0:
                    self.memory[i] = np.array(data[i])
                else:
                    self.memory[i] = np.concatenate((self.memory[i], np.array(data[i])), axis=0)
                if len(self.memory[i]) > self.memory_limit:
                    self.memory[i] = self.memory[i][-self.memory_limit:]
                
            if verbose:
                with open('training_data_{}.mem'.format(ep), 'wb') as f:
                    pickle.dump(self.memory, f)
                print(np.array(self.memory[0]).shape)
                print(np.array(self.memory[1]).shape)
                print(np.array(self.memory[2]).reshape(-1, 1).shape)
                
            self.curr_agent.nnet.train(self.memory[0], self.memory[1], self.memory[2].reshape(-1, 1),
                                       epochs=epochs, verbose=verbose)
            if verbose:
                print("done training")
            
            # TODO: modify this
            if (ep + 1) % duel_interval == 0:
                print("{} dueling...".format(ep + 1))
                res = self.pit((self.best_agent, self.curr_agent), num_duels)
                if res[0] <= res[1]:
                    self.best_agent = Agent(nnet_input_shape=self.nnet_input_shape,
                                            alpha=self.alpha, proba_len=self.proba_len)
                    self.best_agent = tfk.models.clone_model(self.curr_agent.nnet.model)
                    self.best_agent.nnet.model.compile(loss=['categorical_crossentropy', 'mean_squared_error'],
                                                       optimizer='adam')
                print("{} done dueling".format(ep + 1))
                print(res)
                print("========================")
        
    def pit(self, agents, num_duels):
        res = [0, 0]
        for i in range(num_duels):
            game = Connect4Game()
            terminated = False
            mover = i % 2
            while not terminated:
                action = agents[mover].choose_action(game, 'competitive')
                game.step(action)
                
                terminated, reward = game.state.evaluate()
                if terminated:
                    if reward == 1:
                        res[mover] += 1
                    elif reward == -1:
                        res[mover^1] += 1
                
                mover ^= 1
        return res

In [18]:
az = AlphaZero(memory_limit=5000,
               nnet_input_shape=(6, 7, 2),
               alpha=np.random.dirichlet(np.full(7,2)),
               proba_len=7)
print(az.best_agent.nnet.model.summary())

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 6, 7, 2)]    0                                            
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 6, 7, 16)     304         input_5[0][0]                    
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 6, 7, 16)     64          conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 6, 7, 16)     0           batch_normalization_8[0][0]      
____________________________________________________________________________________________

In [19]:
%%time
az.train(num_episodes=1000, duel_interval=30, num_duels=7, epochs=100, verbose=0)

30 dueling...
30 done dueling
[2, 5]
60 dueling...
60 done dueling
[3, 4]
90 dueling...
90 done dueling
[5, 2]
120 dueling...
120 done dueling
[3, 4]
150 dueling...
150 done dueling
[4, 3]
180 dueling...
180 done dueling
[3, 4]
210 dueling...
210 done dueling
[3, 4]
240 dueling...
240 done dueling
[3, 4]
270 dueling...
270 done dueling
[6, 1]
300 dueling...
300 done dueling
[3, 4]
330 dueling...
330 done dueling
[3, 4]
360 dueling...
360 done dueling
[3, 4]


KeyboardInterrupt: 

In [21]:
memory = az.memory.copy()

In [23]:
def play_with_player(agent):
    game = Connect4Game()
    print(game.state)
    print("================")
    
    while not game.state.evaluate()[0]:
        print(game.state.get_valid_actions(game.action_space))
        while True:
            print("player: ")
            action = int(input())
            if action in game.state.get_valid_actions(game.action_space):
                game.step(action)
                print(game.state.mirror())
                break
            else:
                print("Invalid action. Please try again")
        print("================")
        
        if game.state.evaluate()[0]:
            break
        
        print("AI")
        action, action_proba, num_uses = agent.choose_action(game, 'competitive', True)
        print(action_proba, num_uses)
        game.step(action)
        print(game.state)
        print("================")

In [35]:
play_with_player(az.best_agent)

-------
-------
-------
-------
-------
-------
[0, 1, 2, 3, 4, 5, 6]
player: 


 2


-------
-------
-------
-------
-------
--X----
AI
[0. 0. 0. 1. 0. 0. 0.] [ 1.  0.  1. 28.  0.  1.  0.]
-------
-------
-------
-------
-------
--XO---
[0, 1, 2, 3, 4, 5, 6]
player: 


 3


-------
-------
-------
-------
---X---
--XO---
AI
[0. 0. 0. 1. 0. 0. 0.] [ 1.  0.  1. 27.  1.  1.  0.]
-------
-------
-------
---O---
---X---
--XO---
[0, 1, 2, 3, 4, 5, 6]
player: 


 3


-------
-------
---X---
---O---
---X---
--XO---
AI
[0. 0. 0. 0. 0. 1. 0.] [ 1.  0.  0.  0.  0. 30.  0.]
-------
-------
---X---
---O---
---X---
--XO-O-
[0, 1, 2, 3, 4, 5, 6]
player: 


 5


-------
-------
---X---
---O---
---X-X-
--XO-O-
AI
[0. 0. 0. 0. 0. 0. 1.] [ 3.  3.  5.  0.  0.  1. 19.]
-------
-------
---X---
---O---
---X-X-
--XO-OO
[0, 1, 2, 3, 4, 5, 6]
player: 


 4


-------
-------
---X---
---O---
---X-X-
--XOXOO
AI
[0. 0. 0. 0. 1. 0. 0.] [ 1.  0.  0.  0. 30.  0.  0.]
-------
-------
---X---
---O---
---XOX-
--XOXOO
[0, 1, 2, 3, 4, 5, 6]
player: 


 4


-------
-------
---X---
---OX--
---XOX-
--XOXOO
AI
[0. 0. 0. 0. 0. 0. 1.] [ 1.  0.  4.  0. 12.  0. 14.]
-------
-------
---X---
---OX--
---XOXO
--XOXOO
[0, 1, 2, 3, 4, 5, 6]
player: 


 2


-------
-------
---X---
---OX--
--XXOXO
--XOXOO
AI
[0.5 0.5 0.  0.  0.  0.  0. ] [7. 7. 5. 0. 6. 0. 6.]
-------
-------
---X---
---OX--
--XXOXO
O-XOXOO
[0, 1, 2, 3, 4, 5, 6]
player: 


 0


-------
-------
---X---
---OX--
X-XXOXO
O-XOXOO
AI
[0.5 0.  0.5 0.  0.  0.  0. ] [13.  0. 13.  0.  5.  0.  0.]
-------
-------
---X---
--OOX--
X-XXOXO
O-XOXOO
[0, 1, 2, 3, 4, 5, 6]
player: 


 2


-------
-------
--XX---
--OOX--
X-XXOXO
O-XOXOO
AI
[1. 0. 0. 0. 0. 0. 0.] [10.  1.  1.  7.  7.  1.  4.]
-------
-------
--XX---
O-OOX--
X-XXOXO
O-XOXOO
[0, 1, 2, 3, 4, 5, 6]
player: 


 4


-------
-------
--XXX--
O-OOX--
X-XXOXO
O-XOXOO
AI
[0. 0. 0. 0. 0. 0. 1.] [ 1.  1.  1. 10.  1.  1. 16.]
-------
-------
--XXX--
O-OOX-O
X-XXOXO
O-XOXOO
[0, 1, 2, 3, 4, 5, 6]
player: 


 2


-------
--X----
--XXX--
O-OOX-O
X-XXOXO
O-XOXOO


In [25]:
az.best_agent.nnet.model.save('connect4_model.h5')

In [26]:
with open('connect4_training_data.mem', 'wb') as f:
    pickle.dump(az.memory, f)