__Selection Phase__

Starting from root node, run *Tree Policy*

*Tree Policy*

Cases:
1. Not terminal and no children --> run __Expansion__; proceed to case 2
2. Not terminal and unexplored children --> Select from unexplored children, proceed to __Simulation__
3. Not terminal and all children explored --> Apply bandit algo to select action / child.  Step forward env then select on child.
4. Terminal --> Break

__Expansion__
Initialize child nodes, one for each action

__Simulation__
Run trajectory according to *rollout policy* starting from node state until terminal state is reached, then __backprop__

Pong-specific: end simulation when a point is awarded not according to Pong environment terminal signal, which returns true only at end of game (when one player has reached 21 points).  

__Backprop__

Starting from terminal node, update value and visits recursively, applying discounts as needed, until root.  

In [1]:
import random
import itertools
import numpy as np
from collections import namedtuple

import gym
import tensorflow

from MCTS import MCTS, PongEnv, Node
from utils import partition_points

%load_ext autoreload
%autoreload 2

#Create env
pong = PongEnv()

#Initial State
init_state = pong.state
root = Node(state=init_state, root=True)

#Run MCTS
mcts = MCTS(gamma=1)

[2016-09-15 10:50:44,174] Making new env: Pong-v0


In [2]:
num_rollouts = 4

for i in range(num_rollouts):
    mcts.select(root, pong)


point scored

****** Point Dx ******
Num steps:  89
Final Score: 0 to 1
Num Up moves: 47
Num Down moves: 42

point scored

****** Point Dx ******
Num steps:  48
Final Score: 0 to 1
Num Up moves: 28
Num Down moves: 20


point scored

****** Point Dx ******
Num steps:  48
Final Score: 0 to 1
Num Up moves: 20
Num Down moves: 28


point scored

****** Point Dx ******
Num steps:  48
Final Score: 0 to 1
Num Up moves: 25
Num Down moves: 23



In [11]:
root.__class__.__name__

'Node'

In [5]:
%run MCTS.py

[2016-09-15 10:53:16,618] Making new env: Pong-v0



point scored

****** Point Dx ******
Num steps:  85
Final Score: 0 to 1
Num Up moves: 40
Num Down moves: 45

point scored

****** Point Dx ******
Num steps:  48
Final Score: 0 to 1
Num Up moves: 26
Num Down moves: 22


point scored

****** Point Dx ******
Num steps:  46
Final Score: 0 to 1
Num Up moves: 21
Num Down moves: 25


point scored

****** Point Dx ******
Num steps:  45
Final Score: 0 to 1
Num Up moves: 20
Num Down moves: 25



In [6]:
!cat MCTS.log

2016-09-15 10:14:52,525 INFO expanding root
2016-09-15 10:14:52,675 INFO expanding root/child1
2016-09-15 10:14:52,724 INFO expanding root/child2
2016-09-15 10:17:19,500 INFO expanding root
2016-09-15 10:17:19 INFO expanding root
2016-09-15 10:17:19,734 INFO expanding root/child1
2016-09-15 10:17:19 INFO expanding root/child1
2016-09-15 10:17:19,786 INFO expanding root/child2
2016-09-15 10:17:19 INFO expanding root/child2
2016-09-15 10:31:30,048 INFO expanding root
2016-09-15 10:31:30 INFO expanding root
2016-09-15 10:31:30,050 INFO exploring child root/child1
2016-09-15 10:31:30 INFO exploring child root/child1
2016-09-15 10:31:30,148 INFO exploring child root/child2
2016-09-15 10:31:30 INFO exploring child root/child2
2016-09-15 10:31:30,202 INFO expanding root/child1
2016-09-15 10:31:30 INFO expanding root/child1
2016-09-15 10:31:30,204 INFO exploring child root/child1/child1
2016-09-15 10:31:30 INFO exploring child root/child1/child1
2016-09-15 10:31:30,260 INFO 

c1, c2 = root.explored_children
c11, c12 = c1.children
c21, c22 = c2.children

In [6]:
node = root

def print_stats(node):
    print node.name, node.value, node.visits
    
    for child in node.children:
        print_stats(child)

print_stats(root)

root -4.0 4.0
root/child1 -2.0 2.0
root/child1/child1 -1.0 1.0
root/child1/child2 0.0 0.0
root/child2 -2.0 2.0
root/child2/child1 -1.0 1.0
root/child2/child2 0.0 0.0


In [169]:
visits = c1.visits
value = c1.value
print visits, value

1.0 -1.0


In [171]:
exploit = value / visits
print exploit

explore = np.sqrt((2 * np.log(c1.parent.visits)) / c1.visits)
print explore

print exploit + explore


-1.0
1.17741002252
0.177410022515


In [175]:
print np.argmax(map(mcts.ucb, root.children))
print np.argmax([1,2])

0
1


test_env = gym.make("Pong-v0")
test_env.reset()

s, r, t, i = test_env.step(2)




In [94]:
np.allclose(s, child.state)

True

In [67]:
pong.points_played

1

In [60]:
print pong.games_played
print pong.points_played

0
0


In [62]:
while pong.points_played < 10:
    pong.simulate()

point scored
point scored
point scored
point scored
point scored
point scored
point scored
point scored
point scored
point scored


In [53]:
print pong.games_played
print pong.points_played

2
43


In [48]:
h = []
sum(len(p) for p in h)

0

In [46]:
from operator import add
reduce(add, range(5))

10

In [90]:
print len(pong.trajectory), len(pong.history)

21 1


In [48]:
pong.is_point

True

In [49]:
res = pong.trajectory

In [50]:
res

[([array([[[  0,   0,   0],
           [  0,   0,   0],
           [  0,   0,   0],
           ..., 
           [109, 118,  43],
           [109, 118,  43],
           [109, 118,  43]],
   
          [[109, 118,  43],
           [109, 118,  43],
           [109, 118,  43],
           ..., 
           [109, 118,  43],
           [109, 118,  43],
           [109, 118,  43]],
   
          [[109, 118,  43],
           [109, 118,  43],
           [109, 118,  43],
           ..., 
           [109, 118,  43],
           [109, 118,  43],
           [109, 118,  43]],
   
          ..., 
          [[ 53,  95,  24],
           [ 53,  95,  24],
           [ 53,  95,  24],
           ..., 
           [ 53,  95,  24],
           [ 53,  95,  24],
           [ 53,  95,  24]],
   
          [[ 53,  95,  24],
           [ 53,  95,  24],
           [ 53,  95,  24],
           ..., 
           [ 53,  95,  24],
           [ 53,  95,  24],
           [ 53,  95,  24]],
   
          [[ 53,  95,  24],
      

In [34]:
len(r)

1901

In [29]:
len(pong.trajectory[0][2])


1175

In [17]:
mcts.select(root, pong)

expanding

exploring child 1
simulating
Num steps:  1407
Final Score: 1 to 21
Num Up moves: 687
Num Down moves: 720

backprop'ing

Point scored
Gameover


In [6]:
pong.is_poi

False

In [14]:
results = penv.simulate()

In [15]:
s, a, r = results

In [17]:
seqs = partition_points(s, a, r)

In [20]:
s0 = seqs[0]
s1 = seqs[1]

In [22]:
init_state0 = s0[0][0]

In [24]:
init_state0.shape

(210, 160, 3)

In [25]:
init_state_1 = s1[0][0]

In [26]:
init_state_1.shape

(210, 160, 3)

In [27]:
np.allclose(init_state0, init_state_1)

False