In [3]:
import copy

class TicTacToe:
    height,width = 3,3
    def __init__(self, board=None, player='X'):
        self.board = [[' ' for _ in range(TicTacToe.width)] for _ in range(TicTacToe.height)] if board == None else copy.deepcopy(board)
        self.player = player
    
    def __str__(self):
        return '\n-----\n'.join('|'.join(row) for row in self.board)
    
    __repr__ = __str__
    
    def __hash__(self):
        return hash(''.join(c for r in self.board for c in r) + self.player)
    
    def __eq__(self, other):
        if other == None:
            return False
        s = ''.join(c for r in self.board for c in r) + self.player
        o = ''.join(c for r in other.board for c in r) + other.player
        return s == o
    
    def can_move(self, row, col):
        return self.board[row][col] == ' '
    
    def move(self, row, col):
        assert self.can_move(row, col)
        self.board[row][col] = self.player
        next_state = TicTacToe(board=self.board, player=('O' if self.player == 'X' else 'X'))            
        self.board[row][col] = ' '
        return next_state
    
    def board_full(self):
        return all(self.board[r][c] != ' ' for r in range(TicTacToe.height) for c in range(TicTacToe.width))

    def available_moves(self):
        return [(row,col) for row in range(TicTacToe.height) for col in range(TicTacToe.width) if self.can_move(row,col)]
    
    def next_states(self):
        return set(self.move(r,c) for r,c in self.available_moves()) if not self.is_terminal() else set()
    
    def check_win(self):
        lines = [
            [(0,0),(0,1),(0,2)],
            [(1,0),(1,1),(1,2)],
            [(2,0),(2,1),(2,2)],
            [(0,0),(1,0),(2,0)],
            [(0,1),(1,1),(2,1)],
            [(0,2),(1,2),(2,2)],
            [(0,0),(1,1),(2,2)],
            [(2,0),(1,1),(0,2)]
        ]
        
        for line in lines:
            all_same = True
            cmp_r,cmp_c = line[0]
            for r,c in line:
                if self.board[r][c] != self.board[cmp_r][cmp_c]:
                    all_same = False
            if all_same and self.board[cmp_r][cmp_c] != ' ':
                return True, self.board[cmp_r][cmp_c]

        return False, ("tie" if self.board_full() else None)
    
    def is_terminal(self):
        return self.check_win()[0] or self.board_full()
    
    def reward(self):
        if not self.is_terminal():
            raise Exception("Cannot reward non-terminal state")
        is_win, winner = self.check_win()
        if is_win:
            return 1.0 if winner == ('X' if self.player == 'O' else 'O') else 0
        else:
            return 0.5


In [4]:
import random
import numpy as np

class MCTS:
    def __init__(self, initial_state):
        self.initial_state = initial_state
        self.table = dict()
        self.expanded = dict()
    
    def __str__(self):
        return str(self.table)

    def UCT(self, state, c=1):
        # state should always be in the table when calling UCT. Otherwise there's an error
        assert all(n in self.expanded for n in self.expanded[state])
        N, V = self.table[state]
        def uct(s):
            n,v = self.table.get(s,(0,0))
            return v/n + c * np.sqrt(np.log(N)/n)
        return max(self.expanded[state], key=uct)
    
    def select(self, state):
        path = []
        while True:
            path.append(state)
            if state not in self.expanded or not self.expanded[state]:
                return path
            
            for child in self.expanded[state]:
                if child not in self.expanded:
                    path.append(child)
                    return path
            
            state = self.UCT(state,c=1) # UCT
            
    def expand(self, state):
        if state in self.expanded: return
        self.expanded[state] = state.next_states()

    def simulate(self, state):
        while True:
            if state.is_terminal():
                return state.reward()
            action = random.choice(state.available_moves())
            state = state.move(*action)

    def backup(self, path, reward):
        while path != []:
            state = path.pop()
            N, V = self.table.get(state, (0,0))
            self.table[state] = (N+1, V+reward)
            reward = 1-reward

    def SESB(self, num_rollout=1):
        # select
        path = self.select(self.initial_state)      
        leaf = path[-1]
        # expand
        self.expand(leaf)
        # simulate
        reward = 0
        for _ in range(num_rollout):
            reward += self.simulate(leaf)
        # backup/backpropagate
        self.backup(path, reward)

    def next(self, state):
        if state.is_terminal():
            raise Exception("Can't call next on terminal state")
        
        if state not in self.expanded:
            action = random.choice(state.available_moves())
            return state.move(*action)
        
        def score(s):
            n,v = self.table.get(s,(0,0))
            return float("-inf") if n == 0 else v/n
    
        return max(self.expanded[state], key=score)
        
    def runMCTS(self, num_iter=1000):
        for _ in range(num_iter):
            self.SESB()
#         print(self.table)
        print(len(self.expanded), "visits")

ttt_mcts1k = MCTS(TicTacToe())
ttt_mcts1k.runMCTS(1000)
ttt_mcts2k = MCTS(TicTacToe())
ttt_mcts2k.runMCTS(2000)
ttt_mcts3k = MCTS(TicTacToe())
ttt_mcts3k.runMCTS(3000)
ttt_mcts4k = MCTS(TicTacToe())
ttt_mcts4k.runMCTS(4000)
ttt_mcts5k = MCTS(TicTacToe())
ttt_mcts5k.runMCTS(5000)


1000 visits
1815 visits
2629 visits
2526 visits
3719 visits


In [5]:
ttt_mcts10k = MCTS(TicTacToe())
ttt_mcts10k.runMCTS(10_000)

4383 visits


In [6]:
def play_ttt(agent1,agent2=None):
    ttt = TicTacToe()
    print(ttt)
    while not ttt.is_terminal():
        if ttt.player == 'X':
            print("mcts: X")
            ttt = agent1.next(ttt)
            print(ttt)
        elif agent2:
            print("mcts: O")
            ttt = agent2.next(ttt)
            print(ttt)
        else:
            print("you: O")
            r,c = int(input()), int(input())
            ttt = ttt.move(r,c)
            print(ttt)
    win,winner=ttt.check_win()
    print(winner + (" wins" if win else ""))

play_ttt(ttt_mcts5k,ttt_mcts1k)

 | | 
-----
 | | 
-----
 | | 
mcts: X
 | | 
-----
X| | 
-----
 | | 
mcts: O
 | | 
-----
X| | 
-----
 | |O
mcts: X
 | | 
-----
X| | 
-----
X| |O
mcts: O
O| | 
-----
X| | 
-----
X| |O
mcts: X
O|X| 
-----
X| | 
-----
X| |O
mcts: O
O|X|O
-----
X| | 
-----
X| |O
mcts: X
O|X|O
-----
X|X| 
-----
X| |O
mcts: O
O|X|O
-----
X|X| 
-----
X|O|O
mcts: X
O|X|O
-----
X|X|X
-----
X|O|O
X wins


In [7]:
def test_mcts(agent, iters=10000):
    wins = 0
    draws = 0
    losses = 0
    for _ in range(iters):
        ttt = TicTacToe()
        while not ttt.is_terminal():
            if ttt.player == 'X':
#                 print("mcts: X")
                ttt = agent.next(ttt)
#                 print(ttt)
            else:
#                 print("random: O")
                ttt = random.choice(list(ttt.next_states()))
#                 print(ttt)
        win,winner=ttt.check_win()
        if winner == 'X':
            wins += 1
        elif winner == 'O':
            losses += 1
        else:
            draws += 1
#         print(winner + (" wins" if win else ""))
    
    print(f"wins: {wins}/{iters}")
    print(f"draws: {draws}/{iters}")
    print(f"losses: {losses}/{iters}")

# test_mcts(ttt_mcts1k)
# test_mcts(ttt_mcts2k)
# test_mcts(ttt_mcts3k)
# test_mcts(ttt_mcts4k)
# test_mcts(ttt_mcts5k)
test_mcts(ttt_mcts10k)



wins: 9512/10000
draws: 330/10000
losses: 158/10000


In [9]:
import copy

class Connect4:
    height,width = 6,7
    def __init__(self, board=None, player='@', terminal=False, winner=None):
        self.board = [[' ' for _ in range(Connect4.width)] for _ in range(Connect4.height)] if board == None else copy.deepcopy(board)
        self.player = player
        self.terminal = terminal
        self.winner = winner
    
    def __str__(self):
        return f"\n{'-'*(2*Connect4.width-1)}\n".join('|'.join(row) for row in self.board)
    
    __repr__ = __str__
    
    def __hash__(self):
        return hash(''.join(c for r in self.board for c in r) + self.player)
    
    def __eq__(self, other):
        if other == None:
            return False
        s = ''.join(c for r in self.board for c in r) + self.player
        o = ''.join(c for r in other.board for c in r) + other.player
        return s == o
    
    def can_move(self, col):
        return self.board[0][col] == ' '
    
    def move(self, col):
        assert self.can_move(col)
        row = -1
        while (row+1) in range(Connect4.height) and self.board[row+1][col] == ' ':
            row += 1
        self.board[row][col] = self.player
        win, winner = self.check_win(row, col)
        terminal = win or self.board_full()
        next_state = Connect4(board=self.board, player=('O' if self.player == '@' else '@'), terminal=terminal, winner=winner)            
        self.board[row][col] = ' '
        return next_state
    
    def board_full(self):
        return all(self.board[r][c] != ' ' for r in range(Connect4.height) for c in range(Connect4.width))

    def available_moves(self):
        return [(col,) for col in range(Connect4.width) if self.can_move(col)]
    
    def next_states(self):
        return set(self.move(*c) for c in self.available_moves()) if not self.is_terminal() else set()
    
    def check_win(self, row, col):
        in_a_row = 4
        player = self.board[row][col]
        def count(offset_row, offset_column):
            for i in range(1, in_a_row):
                r = row + offset_row * i
                c = col + offset_column * i
                if (r not in range(Connect4.height) or c not in range(Connect4.width) or self.board[r][c] != player):
                    return i - 1
            return in_a_row - 1

        if (count(1, 0) >= in_a_row - 1
            or (count(0, 1) + count(0, -1)) >= in_a_row - 1
            or (count(1, 1) + count(-1, -1)) >= in_a_row - 1
            or (count(1, -1) + count(-1, 1)) >= in_a_row - 1) :
            return True, player
        else:
            return False, ("tie" if self.board_full() else None)
    
    def is_terminal(self):
        return self.terminal
    
    def reward(self):
        if not self.is_terminal():
            raise Exception("Cannot reward non-terminal state")
        if self.winner:
            return 1.0 if self.winner == ('O' if self.player == '@' else '@') else 0
        else:
            return 0.5


In [10]:
c4 = Connect4()
c4.move(3)

 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | |@| | | 

In [11]:
connect4 = MCTS(Connect4())

In [12]:
connect4.runMCTS(50000)

45117 visits


In [15]:
def play_c4(agent1,agent2=None):
    c4 = Connect4()
    print(c4)
    while not c4.is_terminal():
        if c4.player == '@':
            print("mcts: @")
            c4 = agent1.next(c4)
            print(c4)
        elif agent2:
            print("mcts: O")
            c4 = agent2.next(c4)
            print(c4)
        else:
            print("you: O")
            while True:
                try:
                    c = int(input())
                    break
                except ValueError:
                    continue
            c4 = c4.move(c)
            print(c4)
    print(c4.winner + " wins" if c4.winner else "")

In [17]:
play_c4(connect4)

 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
mcts: @
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
@| | | | | | 
you: O













a

1
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
@|O| | | | | 
mcts: @
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
@|O|@| | | | 
you: O
2
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | |O| | | | 
-------------
@|O|@| | | | 
mcts: @
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | | | | | | 
-------------
 | |O| | | | 
-------------
@|O|@| | |@| 
you: O
3
 | | | | 