In [9]:
import random
import itertools
import numpy as np
from collections import namedtuple

import gym
import tensorflow

from mcts import Node, ucb


In [51]:
class TestEnv(object):
    def __init__(self, name):
        self.env = gym.make(name)
        self.state = self.env.reset()
        self.trajectory = []
        self.history = []
        self.is_terminal = False
        
    def reset(self):
        self.state = self.env.reset()
        self.history = []
           
    def simulate(self):
        '''Simulate path starting from current state
        '''
        path = []
        while not self.is_terminal:
            action = self.sample_action()
            next_state, reward, terminal, info = self.env.step(action)
            path.append((self.state, action, reward))
            self.state = next_state
            self.is_terminal = terminal
            
        return path
    
    def rollout(self, num_plays=1):
        '''Simulate num_play trajectories
        '''
        if self.is_terminal:
            self.clear()
            
        for i in range(num_plays):
            trajectory = self.simulate()
            self.history.append(trajectory)
            self.clear()

    def _unpack(self, trajectory):
        s, a, r = zip(*trajectory)
        return s, a, r
    
    def get_trajectories(self):
        '''Returns list of trajectories, where each trajectory is a list of states, actions, and rewards
        '''
        if self.history:
            return [self._unpack(trajectory) for trajectory in self.history]
        else:
            print "No trajectories"
            return None
        
    def sample_action(self, state=None, policy=None, n=1):
        if not policy:
            return np.random.choice(self.valid_actions, size=n)
            
    def clear(self):
        self.state = self.env.reset()
        self.is_terminal = False
        
    def clear_all(self):
        self.clear()
        self.history = []
    
    @property
    def valid_actions(self):
        raise NotImplementedError

        
class PongEnv(TestEnv):
    def __init__(self):
        super(PongEnv, self).__init__("Pong-v0")
        
    @property
    def valid_actions(self):
        '''Valid actions limited to 2 (up) and 3 (down)
        '''
        return [2,3]

def take(it, n):
    return list(itertools.islice(it, n))

def partition_points(r):
    '''
    Partition episode (single Pong game) into sequences of points
    
    First player to reach 21 points wins game.  Partitioning episodes based on point sequences necessary
    for proper discounting of rewards (+1 for self point, -1 for opponent point)
    '''
    
    #Get indices of non-zero elements, add 1 for proper slicing
    idx = [0] + (np.nonzero(r)[0] + 1).tolist()
    slice_sizes = np.diff(idx)
    r_iter = iter(r)
    seqs = []

    for sz in slice_sizes:
        seqs.append(take(r_iter,sz))  
    
    return seqs
    
class Node(object):
    def __init__(self, env, parent=None, action=None, state=None):
        self.env = env
        self.parent = parent
        self.action = action
        self.state = state
        self.children = []
        self.explored_children = []
        self.value = 0.
        self.visits = 0.

    def expand(self):
        for action in self.env.valid_actions:
            self.children.append(Node(env=self.env, parent=self, action=action))

    @property
    def has_unvisited(self):
        return len(self.children) > 0 and len(self.explored_children) != len(self.children)
    
fields = ["state", "action", "parent", "children", "explored_children", "visits", "value"]


In [52]:
pong = PongEnv()

[2016-09-13 19:30:49,011] Making new env: Pong-v0


In [53]:
node = Node(pong)
node.has_unvisited

False

In [54]:
node.expand()

In [55]:
node.has_unvisited

True

In [36]:
n2.action

3