In [9]:
import random
import itertools
import numpy as np
from collections import namedtuple

import gym
import tensorflow


In [159]:
class TestEnv(object):
    def __init__(self, name):
        self.env = gym.make(name)
        self.state = self.env.reset()
        self.trajectory = []
        self.history = []
        self.is_terminal = False
        
    def reset(self):
        self.state = self.env.reset()
        self.history = []
    
    def step(self, a):
        next_state, reward, terminal, info = self.env.step(a)
        self.state = next_state
        self.is_terminal
        return next_state, reward, terminal, info
        
    def simulate(self):
        '''Simulate path starting from current state
        '''
        path = []
        while not self.is_terminal:
            action = self.sample_action()
            next_state, reward, terminal, info = self.env.step(action)
            path.append((self.state, action, reward))
            self.state = next_state
            self.is_terminal = terminal
            
        return self._unpack(path)
    
    def rollout(self, num_plays=1):
        '''Simulate num_play trajectories
        '''
        if self.is_terminal:
            self.clear()
            
        for i in range(num_plays):
            trajectory = self.simulate()
            self.history.append(trajectory)
            self.clear()

    def _unpack(self, trajectory):
        s, a, r = map(list, zip(*trajectory))
        return s, a, r
    
    def get_trajectories(self):
        '''Returns list of trajectories, where each trajectory is a list of states, actions, and rewards
        '''
        if self.history:
            return [self._unpack(trajectory) for trajectory in self.history]
        else:
            print "No trajectories"
            return None
        
    def sample_action(self, state=None, policy=None, n=1):
        if not policy:
            return np.random.choice(self.valid_actions, size=n)
            
    def clear(self):
        self.state = self.env.reset()
        self.is_terminal = False
        
    def clear_all(self):
        self.clear()
        self.history = []
    
    @property
    def valid_actions(self):
        raise NotImplementedError

        
class PongEnv(TestEnv):
    def __init__(self):
        super(PongEnv, self).__init__("Pong-v0")
        
    @property
    def valid_actions(self):
        '''Valid actions limited to 2 (up) and 3 (down)
        '''
        return [2,3]

def take(it, n):
    return list(itertools.islice(it, n))

def partition_points(r):
    '''
    Partition episode (single Pong game) into sequences of points
    
    First player to reach 21 points wins game.  Partitioning episodes based on point sequences necessary
    for proper discounting of rewards (+1 for self point, -1 for opponent point)
    '''
    
    #Get indices of non-zero elements, add 1 for proper slicing
    idx = [0] + (np.nonzero(r)[0] + 1).tolist()
    slice_sizes = np.diff(idx)
    r_iter = iter(r)
    seqs = []

    for sz in slice_sizes:
        seqs.append(take(r_iter,sz))  
    
    return seqs
    
class Node(object):
    def __init__(self, parent=None, action=None, state=None, terminal=False):
        self.parent = parent
        self.action = action
        self.state = state
        self.is_terminal = terminal
        self.children = []
        self.explored_children = []
        self.value = 0.
        self.visits = 0.
        
        
    def expand(self, env):
        for action in env.valid_actions:
            self.children.append(Node(parent=self, action=action))
        self.child_iter = iter(self.children)
        
    @property
    def num_children(self):
        return len(self.children)
    
    @property
    def num_explored(self):
        return len(self.explored_children)
    
    def get_unvisited(self):
        child = next(self.child_iter, None)
        self.explored_children.append(child)
        return child
        
    @property
    def has_unvisited(self):
        return (self.num_explored < self.num_children) and self.num_children > 0
    
fields = ["state", "action", "parent", "children", "explored_children", "visits", "value"]


__Selection Phase__

Cases:
- Not terminal and no children --> Expand
- Not terminal and unexplored children --> Select from unexplored children
- Not terminal and all children explored --> Apply bandit algo
- Terminal --> Break

In [174]:
import pdb

def simulate(node, env):
    print "simulating"
    a = node.action
    s, r, terminal, info = env.step(a)
    node.state = s
    
    states, actions, rewards = env.simulate()
    states = [s] + states
    actions = [a] + actions
    rewards = [r] + rewards
    
    return states, actions, rewards

def bandit(node):
    pass

def printDx(result):
    s, a, r = map(np.array, result)
    print "Num steps: ", len(s)
    print "Final Score: {} to {}".format(np.sum(r > 0), np.sum(r < 0))
    print "Num Up moves: {}".format(np.sum(a == 2))
    print "Num Down moves: {}".format(np.sum(a == 3))
    print
    
def backprop(result):
    print "backprop'ing"
    print 
    
def select(node, env):
    #Expand if new state (and not terminal)
    if node.num_children == 0:
        print "expanding"
        node.expand(env)
        print 
        
    elif node.has_unvisited:
        child = node.get_unvisited()
        print "exploring child {}".format(node.num_explored)
        
        #Simulate
        action = child.action
        result = simulate(child, env)
        printDx(result)
        
        #Backprop
        backprop(result)
        env.clear()
    else:
        #Run bandit selection algorithm if expanded and all children visited at least once
        bandit(node)
        print "all children visited, running bandit"
        return
    select(node, env)

In [175]:
#Create env
pong = PongEnv()

#Initial State
init_state = pong.state
root = Node(state=init_state)

#Run MCTS
select(root, pong)

[2016-09-14 01:38:50,403] Making new env: Pong-v0


expanding

exploring child 1
simulating
Num steps:  1254
Final Score: 0 to 21
Num Up moves: 640
Num Down moves: 614

backprop'ing

exploring child 2
simulating
Num steps:  1550
Final Score: 1 to 21
Num Up moves: 727
Num Down moves: 823

backprop'ing

all children visited, running bandit


expanding
exploring child 0
simulating
Num steps:  1346
Final Score: 0 to 21
Num Up moves: 656
Num Down moves: 690


In [None]:
#Selection Phase
policy = TreePolicy()
def TreePolicy(node):
    #3 cases: node has unexplored children
def UCB(Policy):