In [65]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib inline

In [66]:

GAP = 0.1 # gap between two hexagons
bsize = 3

# Basechange between xy and ij coordinates / base.
ij2xy = np.array([
    [ 1, 0.5],
    [ 0, np.sqrt(3/4)],
])
xy2ij = np.linalg.inv(ij2xy)

# Shape of our hexagon
thirty_degree = 2*np.pi/12
sixty_degree = 2*np.pi/6
hex_factor = (1-GAP)/np.cos(thirty_degree)/2 # the factor in the hexagon formula (to get a gap)
hexagon = hex_factor * np.asarray([(np.cos(alpha), np.sin(alpha)) for alpha in
                               np.arange(thirty_degree, 2*np.pi+thirty_degree, sixty_degree)])

# For Detection if a point is on a hexagon
middlepoints = (hexagon[1:] + hexagon[:-1]) / 2
dual_xy = middlepoints / np.linalg.norm(middlepoints[0])

def show_board(board, numbers=True):
    assert board.shape == (bsize, bsize)
    board = board.astype(int)
    
    c = [(.2,.2,.2), (.6,.0,.0), (.0,.05,.85)]
    
    plt.figure(figsize=(10,8))
    
    patch = patches.Polygon(([-100,-100], [-100,100], [100,100], [100,-100]), facecolor='black', zorder=-100)
    plt.gca().add_patch(patch)
    
    for i in range(bsize):
        for j in range(bsize):
            ij = np.asarray((i,j))
            xy = ij2xy @ ij
            x, y = xy
            ##plt.scatter(x, y, color=c[board[i,j]])
            patch = patches.Polygon(hexagon + xy, facecolor=c[board[i,j]%3], zorder=-100)
            ##patch = patches.Polygon([xy2ij @ (hexagon[k] + xy) for k in range(6)], facecolor=c[board[i,j]%3], zorder=-100)
            plt.gca().add_patch(patch)
            
            if numbers:
                plt.text(x, y, '%d, %d' %(i,j), color='white', horizontalalignment='center', verticalalignment='center')
            
    plt.axis('equal')
    plt.xlim([-1, bsize*1.5-0.5])
    plt.ylim([-1, bsize*np.sqrt(3/4)])
    plt.axis('off')
    
    
def print_board(board):
    char = [' ', 'r', 'b']
    print((bsize+1)*' ' + 2*bsize*'_')
    for i in range(bsize-1, -1, -1):
        line = ' '
        line += i*' ' + '/'
        for j in range(bsize):
            line += char[board[j,i]%3]
            if j<bsize-1: line += ' '
        line += '/'
        print(line)
    print((2*bsize+1)*'^')
    
    
def pos_on_board(pos):
    pos_ij = xy2ij @ pos
    
    for i in range(bsize):
        if (i-pos_ij[0]) > 2: continue
        for j in range(bsize):
            if (j-pos_ij[1]) > 2: continue
            ij = 1.*np.array((i, j))
            xy = ij2xy @ ij
            if np.sum((pos-xy)**2) <= hex_factor**2:
                if np.all(np.abs(dual_xy[:3] @ (pos-xy)) < 0.5 - GAP/2):
                    return i, j
                else:
                    return None, None
    return None, None

In [67]:

def winner_single(stones):
    '''Gives info if one player is connected left to right.
    stones: nxn-array of bool type, where the stones of the player lie.'''
    connected = np.zeros_like(stones) # welche Steine sind nach unten verbunden?
    con_new   = np.zeros_like(stones) # beim letzten update dazugekommen
    connected[0] = stones[0] # welche Steine sind nach unten verbunden?
    con_new[0] = stones[0] # Steine am unteren Rand sind automatisch verbunden.
    changed = True
    while changed:
        changed = False
        for i in range(bsize):
            for j in range(bsize):
                if con_new[i,j]:
                    # activated others around (i,j)
                    for (di, dj) in [(1,0), (0,1), (-1,1), (-1,0), (0,-1), (1,-1)]:
                        if 0<=i+di<bsize and 0<=j+dj<bsize:
                            if not connected[i+di,j+dj] and stones[i+di,j+dj]:
                                # nur wenn noch nicht als connected markieren weiter fortfahren
                                connected[i+di,j+dj] = True
                                con_new[i+di,j+dj] = True
                                changed = True
                    con_new[i,j] = False
    #show_board(connected) # for debugging
    if np.any(connected[-1]):
        return True
    else:
        return False
    
def winner(board):
    if winner_single(board.T==1):
        return 1
    if winner_single(board==-1):
        return -1
    return 0

In [68]:
def ismovevalid(board, i, j):
    if type(i) != int: return False
    if type(j) != int: return False
    if not 0<=i<len(board): return False
    if not 0<=j<len(board): return False
    return board[i,j] == 0

In [69]:
def int_input(text):
    while True:
        s = input(text)
        try:
            i = int(s)
            return i
        except:
            print('Cannot convert %s to int!' % s)

In [None]:
board = np.zeros((bsize,bsize), int)

while True:
    show_board(board)
    plt.show()
    print('Reds turn')
    while True:
        x = int_input('Input row coordinate: ')
        y = int_input('Input column coordinates: ')
        if not ismovevalid(board, x, y):
            print('Move is not valid!')
            continue
        else:
            board[x,y] = +1
            break
    if winner(board) == +1:
        print('Red has won!')
        show_board(board)
        plt.show()
        break
        
    show_board(board)
    plt.show()
    print('Blues turn')
    while True:
        x = int_input('Input row coordinate: ')
        y = int_input('Input column coordinates: ')
        if not ismovevalid(board, x, y):
            print('Move is not valid!')
            continue
        else:
            board[x,y] = -1
            break
    if winner(board) == -1:
        print('Blue has won!')
        show_board(board)
        plt.show()
        break

## AlphaZero Ansätze

In [70]:
def gamefinished(board):
    'Returns +1 or -1 if red or blue has won, or returns 0 if game has not ended yet.'
    return winner(board) / (1 + abs(board).sum() / bsize**2) # extra denominator makes the enemy to delay victory as long as possible

def validmoves(board):
    'Returns an array with all valid positions for the next move'
    board = board.astype(int)
    valid = [ k for k in range(bsize**2) if board.flatten()[k] == 0 ]
    if abs(board).sum() == 1 and not np.any(abs(board)>1):   # switch sides is allowed if only once and on the second move
        valid += [-1]
        
    assert len(valid)!=0 or np.all(board!=0), 'Fehler bei validmoves, kein Zug möglich obwohl Brett nicht voll'
    
    return valid

def performaction(board, action):
    'Performs a move on the current board'
    b = board.copy()
    if action == -1:
        # Switch Sides
        assert not np.any(abs(b)>4)
        b = switchsides(b)
        b *= 4   # replace 1 by 4 to mark that sides have been switched
    else:
        # Regular Move
        b[int(action/bsize), action%bsize] = +1
    return b

def switchsides(board):
    'Exchanges red and blue player'
    return -board.T



def sym_representative(board):
    'Returns a representative of the states connected by symmetry'
    # Die roten Steine sollen in der unten linken Ecke gehäuft sein wenn möglich.
    # Bei Stechen zählen die blauen Steine.
    m = bsize//2
    lower  = (board[:m,:] > 0).sum() # rote Steine in oberer Spielfeldhälte etc.
    upper  = (board[-m:,:] > 0).sum()
    left   = (board[:,:m] > 0).sum()
    right  = (board[:,-m:] > 0).sum()
    lower += (board[:m,:] < 0).sum()/bsize**2 # blaue Steine in oberer Spielfeldhälte etc.
    upper += (board[-m:,:] < 0).sum()/bsize**2
    left  += (board[:,:m] < 0).sum()/bsize**2
    right += (board[:,-m:] < 0).sum()/bsize**2
    
    if lower < upper:
        board = board[:,::-1]
    if left < right:
        board = board[::-1]
    return board


def tostring(board):
    'Bijective function which creates a string representation of board'
    char = {2: 'b', 0: '0', 1: 'r'}
    s = ''
    for i in range(bsize):
        for j in range(bsize):
            s += char[board[i,j]%3]
        s += ';'
    if np.any(abs(board)>1): s = 's' + s  # sides were switched
    return s




SyntaxError: invalid syntax (<ipython-input-70-feeca967e8a1>, line 14)

In [81]:

def p_and_v_trivial(board):
    'Simulation of the Neural Network'
    p = np.zeros(bsize**2)
    valid = validmoves(board)
    for a in valid:
        p[a] = 1.
    p /= p.sum() + 1e-12
    v = 0.
    return p, v



verbose = False
verbose_start = np.Inf


def print_(*args):
    global verbose
    if verbose: print(*args)
        
def show_board_(board):
    global verbose
    if verbose: show_board(board)
    #if verbose: print_board(board)

def printmove(a):
    try:
        if a == -1:
            return 'sideswitch'
        else:
            return str(a//bsize)+','+str(a%bsize)
    except:
        print(a)
        assert False
        

from collections import defaultdict


class MCTS:
    def __init__(self, p_and_v, c_ut=1., niterations=10, maxdepth=np.Inf, alpha=0.3, noise=0.2):
        self.p_and_v = p_and_v
        self.visited = set()
        self.T = {}
        self.v = {}
        self.Ns = defaultdict(lambda: 0)
        self.Nsa = defaultdict(lambda: 0)
        self.Qsa = defaultdict(lambda: 0)
        self.vsa = defaultdict(lambda: [])
        self.p = {}
        self.Vs = {}
        
        self.c_ut = c_ut
        self.maxdepth = maxdepth
        self.niterations = niterations
        self.alpha = alpha # alpha parameter of dirichlet noise ()
        self.noise = noise # noise amount (1 means pure noise, 0 means deterministic)
        
        
    def run(self):
        'Recursive variant of the mcts algorithm, parent function of the recursion'
        s = np.zeros((bsize, bsize), dtype=int)
        for i in range(self.niterations):
            
            global verbose, verbose_start
            if i >= verbose_start: verbose = True
            
            print_(f'Starting iteration nr {i}!')
            self.search_recursive(s)
            print_('\n\n\n\n')
        
    
    def search_iterative(self, s0):
        'Recursive variant of the mcts algorithm, recursive subroutine'
        
        s_stack = []
        a_stack = []
        
        s = s0
        
        for d in range(self.maxdepth):
            
            s = sym_representative(s)
            s_ = tostring(s)
        
        
            print_(f'Reached state s, id is {s_},')
            print_(f'which was already visited {self.Ns[s_]} times:')
            show_board_(s)
            plt.show()
        
            # visited for first time?
            if s_ not in self.visited:
                self.visited.add(s_)
                self.p[s_], self.v[s_] = self.p_and_v(s)
                self.T[s_] = gamefinished(s)
                self.Vs[s_] = validmoves(s)
            
            # is game finished?
            if self.T[s_] != 0:
                print_(f'Game has ended, going back with v = {self.T[s_]}')
                v = -self.T[s_]
                break

            # maximal depth was reached -> terminate here and work with v instead of T
            if depth >= self.maxdepth:
                print_(f'Maximal deptch reached, going back with v = {self.v[s_]}')
                v =  -self.v[s_]
                break

            # otherwise leaf node?
            if self.Ns[s_] == 0:
                self.Ns[s_] = 1
                print_(f'Node visited for first time, going back with v = {self.v[s_]}')
                v = -self.v[s_]
                break
                
                ##################
        
            # make a decision
            valid = self.Vs[s_]
            assert len(valid) != 0
            (f'There are {len(valid)} valid moves')
            Umax = -np.Inf
            abest = None
            dirichlet = np.random.dirichlet(len(valid)*[self.alpha])
            for k, a in enumerate(valid):
                if self.Nsa[s_, a] > 0:
                    U = self.Qsa[s_, a] + self.c_ut * self.p[s_][a] * np.sqrt(self.Ns[s_])/(self.Nsa[s_, a]+1)
                    print_(f'Move {printmove(a)} chosen already {self.Nsa[s_, a]} times,')
                    print_(f'has Q value of {self.Qsa[s_,a]} and U value of {U}.')
                    U = (1-self.noise)*U + self.noise*dirichlet[k]
                else:
                    U = self.c_ut * self.p[s_][a] * np.sqrt(self.Ns[s_]+1e-6) # small epsilon in case Ns=0 -> well defined U guaranteed
                    print_(f'Move {printmove(a)} chosen 0 times,')
                    print_(f'has unknown Q value and U value of {U}.')
                    U = (1-self.noise)*U + self.noise*dirichlet[k]
                if U > Umax:
                    Umax = U
                    abest = a
            a = abest
            print_(f'Chose move {printmove(a)}.')
            print_(f'\n\n')
        
        # Perform the best action which was just found
        sprime = switchsides(performaction(s, a))
        v = self.search_recursive(sprime, depth+1)
        
        self.vsa[s_,a].append(v)
        self.Qsa[s_,a] = (self.Qsa[s_,a] * self.Nsa[s_,a] + v) / (self.Nsa[s_,a]+1)
        self.Ns[s_] += 1
        self.Nsa[s_,a] += 1
        
        return -v
        
    
    def search_recursive(self, s, depth=0):
        'Recursive variant of the mcts algorithm, recursive subroutine'
        
        s = sym_representative(s)
        s_ = tostring(s)
        
        
        print_(f'Reached state s, id is {s_},')
        print_(f'which was already visited {self.Ns[s_]} times:')
        show_board_(s)
        plt.show()
        
        # visited for first time?
        if s_ not in self.visited:
            self.visited.add(s_)
            self.p[s_], self.v[s_] = self.p_and_v(s)
            self.T[s_] = gamefinished(s)
            self.Vs[s_] = validmoves(s)
            
        # is game finished?
        if self.T[s_] != 0:
            print_(f'Game has ended, going back with v = {self.T[s_]}')
            return -self.T[s_]
        
        # maximal depth was reached -> terminate here and work with v instead of T
        if depth >= self.maxdepth:
            print_(f'Maximal deptch reached, going back with v = {self.v[s_]}')
            return -self.v[s_]
        
        # otherwise leaf node?
        if self.Ns[s_] == 0:
            self.Ns[s_] = 1
            print_(f'Node visited for first time, going back with v = {self.v[s_]}')
            return -self.v[s_]
        
        # Otherwise go deeper into tree
        valid = self.Vs[s_]
        assert len(valid) != 0, valid, board
        (f'There are {len(valid)} valid moves')
        Umax = -np.Inf
        abest = None
        dirichlet = np.random.dirichlet(len(valid)*[self.alpha])
        for k, a in enumerate(valid):
            if self.Nsa[s_, a] > 0:
                U = self.Qsa[s_, a] + self.c_ut * self.p[s_][a] * np.sqrt(self.Ns[s_])/(self.Nsa[s_, a]+1)
                print_(f'Move {printmove(a)} chosen already {self.Nsa[s_, a]} times,')
                print_(f'has Q value of {self.Qsa[s_,a]} and U value of {U}.')
                U = (1-self.noise)*U + self.noise*dirichlet[k]
            else:
                U = self.c_ut * self.p[s_][a] * np.sqrt(self.Ns[s_]+1e-6) # small epsilon in case Ns=0 -> well defined U guaranteed
                print_(f'Move {printmove(a)} chosen 0 times,')
                print_(f'has unknown Q value and U value of {U}.')
                U = (1-self.noise)*U + self.noise*dirichlet[k]
            if U > Umax:
                Umax = U
                abest = a
        a = abest
        print_(f'Chose move {printmove(a)}.')
        print_(f'\n\n')
        
        # Perform the best action which was just found
        sprime = switchsides(performaction(s, a))
        v = self.search_recursive(sprime, depth+1)
        
        self.vsa[s_,a].append(v)
        self.Qsa[s_,a] = (self.Qsa[s_,a] * self.Nsa[s_,a] + v) / (self.Nsa[s_,a]+1)
        self.Ns[s_] += 1
        self.Nsa[s_,a] += 1
        
        return -v

SyntaxError: invalid syntax (<ipython-input-81-47436fe60dbf>, line 195)

In [80]:
np.random.seed(0)
verbose = False
verbose_start = 6000
mcts = MCTS(p_and_v_trivial, niterations=6010)
mcts.run()

AssertionError: []

In [17]:

verbose = False
verbose_start = 6000000
mcts = MCTS(p_and_v_trivial, niterations=1000)

%prun mcts.search_recursive()

 

In [30]:
board = np.array([[-1, -1,  1],
 [ 1,  1,  0],
 [ 0,  0, -1]], dtype=int)

validmoves(board)

[5, 6, 7]

In [31]:
board

array([[-1, -1,  1],
       [ 1,  1,  0],
       [ 0,  0, -1]])