In [26]:
import random
import numpy as np

import gym
import tensorflow

from mcts import Node, ucb



In [12]:
node = Node(None, None)

dir(node)

['__doc__',
 '__init__',
 '__module__',
 'action',
 'children',
 'explored_children',
 'parent',
 'value',
 'visits']

In [77]:
class TestEnv(object):
    def __init__(self, name):
        self.env = gym.make(name)
        self.state = self.env.reset()
        self.trajectory = []
        self.history = []
        self.is_terminal = False
        
    def reset(self):
        self.state = self.env.reset()
        self.history = []
        
    @property
    def valid_actions(self):
        raise NotImplementedError
    
    def step(self, action):
        
        if not self.is_terminal:
            next_state, reward, terminal, info = self.env.step(action)
            self.trajectory.append([self.state, action, reward])
            self.state = next_state
            self.is_terminal = terminal
            return next_state, reward
        else:
            return None
        
    def rollout(self, num_plays=1):
        
        if self.is_terminal:
            self.clear()
            
        for i in range(num_plays):
            self.trajectory = []
            while not self.is_terminal:
                self.step(self.sample_action())
        
            self.history.append(self.trajectory)

    def _unpack(self, trajectory):
        s, a, r = zip(*trajectory)
        return s, a, r
    
    def get_trajectories(self):
        '''Returns list of trajectories, where each trajectory is a list of states, actions, and rewards
        '''
        if self.history:
            return [self._unpack(trajectory) for trajectory in self.history]
        else:
            print "No trajectories"
            return None
        
    def sample_action(self, n=1):
        return np.random.choice(self.valid_actions, size=n)
            
    def clear(self):
        self.state = self.env.reset()
        self.is_terminal = False
        
    def clear_all(self):
        self.clear()
        self.history = []
        
class PongEnv(TestEnv):
    def __init__(self):
        super(PongEnv, self).__init__("Pong-v0")
        
    @property
    def valid_actions(self):
        '''Valid actions limited to 2 (up) and 3 (down)
        '''
        return [2,3]
    
    


In [78]:
pong = PongEnv()

[2016-09-12 19:34:33,110] Making new env: Pong-v0


In [80]:
pong.rollout()

In [81]:
trajs = pong.get_trajectories()

In [83]:
s, a, r = trajs[0]

(0.0, 0.0, 0.0, 0.0, -1.0)