## Amőbázó mesterséges intelligencia (Monte Carlo Tree Search)

Kocsis-Szepesvári-féle cikk: http://ggp.stanford.edu/readings/uct.pdf

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
class State:
    def __init__(self, field, turn=1):
        self.field = field
        self.turn  = turn
    
    def move(self,i,j,turn):
        field = np.copy(self.field)
        field[i,j] = turn
        return State(field,3-turn)
        
    def children(self):
        ch = []
        for i in range(3):
            for j in range(3):
                if self.field[i,j] == 0:
                    ch += [self.move(i,j,self.turn)]
        return ch
    
    def end(self):
        for i in range(3):
            if self.field[i,0] == self.field[i,1] and self.field[i,1] == self.field[i,2] and self.field[i,0] != 0:
                return self.field[i,0]
            
            if self.field[0,i] == self.field[1,i] and self.field[1,i] == self.field[2,i] and self.field[0,i] != 0:
                return self.field[0,i]
            
        if self.field[0,0] == self.field[1,1] and self.field[1,1] == self.field[2,2] and self.field[0,0] != 0:
            return self.field[0,0]

        if self.field[0,2] == self.field[1,1] and self.field[1,1] == self.field[2,0] and self.field[0,2] != 0:
            return self.field[0,2]
            
        if np.any(self.field == 0):
            return 0
        
        return 3
    
    def randomchild(self):
        return np.random.choice(self.children())
    
    # Dictionary-hez kell
    def __hash__(self):
        return hash(self.field.tostring())
    
    def __eq__(self,other):
        return np.all(self.field == other.field)

In [3]:
class MCTS:
    def __init__(self, state):
        self.state = state
        self.values = {}
        self.values[state] = [0,0]
    
    def ucb(self, states):
        melyik = 0
        mennyi = -np.inf # fix
        N = 0
        for s in states:
            N += self.values[s][1]
        for i in states:
            x,n = self.values[i]
            if n == 0:
                return i # különben nullával való osztás
            d = x/n + sqrt((2*np.log(N))/n)
            if d > mennyi:
                mennyi = d
                melyik = i
                
        return melyik
        
    def selection(self,state):
        if state.end() == 0:
            children = state.children()
            if any([c not in self.values for c in children]): # fix (nem fában vagyunk)
                return [state]
            else:
                q = self.ucb(children)
                return [state] + self.selection(q) # fordított sorrend!
        else: 
            return [state]
        
    def expansion(self,state):
        for s in state.children():
            if s not in self.values: # fix (nem fában vagyunk)
                self.values[s] = [0,0]
    
    def rollout(self,state):
        while state.end() == 0:
            state = state.randomchild()
        if state.end() == 1:
            return -1. #0.
        if state.end() == 3:
            return 0. #0.5
        else:
            return 1.
    
    def update(self,path,score):
        x = score
        for p in path:
            self.values[p][0] += x
            self.values[p][1] += 1
            x = -x #1-x

In [4]:
mcts  = MCTS(State(np.zeros((3,3))))

In [6]:
state = State(np.zeros((3,3)))

while True:
    for n in range(1000):
        path = mcts.selection(state)
        selected = path[-1]

        mcts.expansion(selected)
        score = mcts.rollout(selected)
        mcts.update(path,score)
    state = mcts.ucb(state.children())
    print(state.field)
    
    if state.end():
        break

    i,j = input("Hova teszel? ").split(',')
    i,j = int(i)-1,int(j)-1
    state = state.move(i,j,2)

[[0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]


Hova teszel?  1,1


[[2. 0. 0.]
 [0. 1. 1.]
 [0. 0. 0.]]


Hova teszel?  2,1


[[2. 0. 0.]
 [2. 1. 1.]
 [1. 0. 0.]]


Hova teszel?  1,3


[[2. 1. 2.]
 [2. 1. 1.]
 [1. 0. 0.]]


Hova teszel?  3,2


[[2. 1. 2.]
 [2. 1. 1.]
 [1. 2. 1.]]
