In [3]:
import random
import numpy as np

import gym
import tensorflow

from mcts import Node, ucb



In [4]:
class TestEnv(object):
    def __init__(self, name):
        self.env = gym.make(name)
        self.state = self.env.reset()
        self.trajectory = []
        self.history = []
        self.is_terminal = False
        
    def reset(self):
        self.state = self.env.reset()
        self.history = []
        
    @property
    def valid_actions(self):
        raise NotImplementedError
    
    def simulate(self):
        '''Simulate path starting from current state
        '''
        path = []
        while not self.is_terminal:
            action = self.sample_action()
            next_state, reward, terminal, info = self.env.step(action)
            path.append((self.state, action, reward))
            self.state = next_state
            self.is_terminal = terminal
            
        return path
    
    def rollout(self, num_plays=1):
        '''Simulate num_play trajectories
        '''
        if self.is_terminal:
            self.clear()
            
        for i in range(num_plays):
            trajectory = self.simulate()
            self.history.append(trajectory)
            self.clear()

    def _unpack(self, trajectory):
        s, a, r = zip(*trajectory)
        return s, a, r
    
    def get_trajectories(self):
        '''Returns list of trajectories, where each trajectory is a list of states, actions, and rewards
        '''
        if self.history:
            return [self._unpack(trajectory) for trajectory in self.history]
        else:
            print "No trajectories"
            return None
        
    def sample_action(self, state=None, policy=None, n=1):
        if not policy:
            return np.random.choice(self.valid_actions, size=n)
            
    def clear(self):
        self.state = self.env.reset()
        self.is_terminal = False
        
    def clear_all(self):
        self.clear()
        self.history = []
        
class PongEnv(TestEnv):
    def __init__(self):
        super(PongEnv, self).__init__("Pong-v0")
        
    @property
    def valid_actions(self):
        '''Valid actions limited to 2 (up) and 3 (down)
        '''
        return [2,3]
    
    


In [6]:
pong = PongEnv()

[2016-09-13 01:58:19,522] Making new env: Pong-v0


In [9]:
pong.rollout()

In [10]:
len(pong.history)

2

In [11]:
ts = pong.get_trajectories()

In [12]:
s, a, r = ts[0]

In [58]:
import itertools


def take(it, n):
    return list(itertools.islice(it, n))

def partition_points(r):
    '''
    Partition episode (single Pong game) into sequences of points
    
    First player to reach 21 points wins game.  Partitioning episodes based on point sequences necessary
    for proper discounting of rewards (+1 for self point, -1 for opponent point)
    '''
    
    #Get indices of non-zero elements, add 1 for proper slicing
    idx = [0] + (np.nonzero(r)[0] + 1).tolist()
    slice_sizes = np.diff(idx)
    r_iter = iter(r)
    seqs = []

    for sz in slice_sizes:
        seqs.append(take(r_iter,sz))  
    
    return seqs
        

In [59]:
seqs = partition_points(r)


In [61]:
[sum(s) for s in seqs]

[-1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0]

In [52]:
len(seqs)


21

In [55]:
[sum(seq) for seq in seqs]

[-1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0,
 -1.0]

In [29]:
np.allclose(r[0], 0.)

True

In [30]:
list(itertools.takewhile(lambda x: x == 0., iter(r)))

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]