In [204]:
%pylab notebook
import numpy as np

Populating the interactive namespace from numpy and matplotlib


In [271]:
def random_choice(a, p=None):
    return list(a)[np.random.choice(len(a), p=p)]

def off_policy_mc(S, initialize, A, R, gamma, Q, target_policy, behavior_policy, niter):
    """
    S: set of states
    initialize: function () -> initial state
    A: function s -> A(s): set of actions in state s. Returns [] if the state is terminal
    R: function (s, a) -> (s2, r) : sample transition
    gamma: discount factor
    Q: dict (s, a) -> value
    target_policy: dict s -> a
    behavior_policy: function s -> { a: p } where p is prob of action a in state s. Returns {} if s is terminal.
    """
    C = { (s, a): 0. for s in S for a in list(A(s)) + [None] }
    
    for _ in xrange(niter):
        
        # Generate episode
        episode = [ ]
        s = initialize() #random_choice(S)
        r = 0.
        while 1:
            trans = behavior_policy(s)
            try:
                a = None if not trans else random_choice(trans.keys(), p=trans.values())
            except:
                print s
                print trans
                raise
            episode.append((s, a, r))
            if not trans:
                if len(episode) < 100:
                    print episode
                break
            s, r = R(s, a)
        
        print len(episode)
        
        G = 0
        W = 1
        for (s, a, r) in episode[::-1]:
            G = gamma * G + r
            C[(s, a)] += W
            Q[(s, a)] += W/C[(s, a)] * (G - Q[(s, a)])
            As = A(s)
            target_policy[s] = None if not As else max(As, key = lambda a2: (Q[(s, a2)], a2))
            if a != target_policy[s]:
                break
            W = W / behavior_policy(s).get(a, 1)

def epsilon_greedy(A, greedy, epsilon):
    def policy(s):
        As = A(s)
        n = len(As)
        res = { a: epsilon / (n-1) for a in As if a != greedy[s]}
        if greedy[s] is not None:
            res[greedy[s]] = 1 - epsilon
        return res
    return policy

In [281]:
MAX_SPEED = 2
GAMMA = 1
EPSILON = 0.2

TRACK_OUT = 0
TRACK_IN = 1
TRACK_START = 2
TRACK_END = 3

trackA = np.array([
    [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 ],
    [0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0 ],
][::-1])

trackAmini = np.array([
    [0, 0, 1, 1, 1, 1, 1, 1, 3 ],
    [0, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [1, 1, 1, 1, 1, 1, 1, 1, 3 ],
    [1, 1, 1, 1, 1, 1, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [1, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [0, 1, 1, 1, 1, 0, 0, 0, 0 ],
    [0, 0, 1, 1, 1, 0, 0, 0, 0 ],
    [0, 0, 2, 2, 2, 0, 0, 0, 0 ],
][::-1])


track = trackAmini
track_height, track_width = track.shape
starting_line = set( (i, j) for i in xrange(track_height) for j in xrange(track_width) if track[i, j] == TRACK_START )
finishing_line = set( (i, j) for i in xrange(track_height) for j in xrange(track_width) if track[i, j] == TRACK_END )

S = [ ((i, j), (vi, vj)) 
     for i in xrange(track_height) for j in xrange(track_width) 
     for vi in xrange(1+MAX_SPEED) for vj in xrange(1+MAX_SPEED) ]
A = lambda ((i, j), (vi, vj)): [] if (i, j) in finishing_line else [ (ai, aj) for ai in (-1, 0, 1) for aj in (-1, 0, 1) 
    if ((vi+ai) <= MAX_SPEED) and ((vi+ai) >= 0) and ((vj+aj) <= MAX_SPEED) and ((vj+aj) >= 0) 
    and ((vi+ai, vj+aj) != (0, 0)) ]

initialize = lambda: (random_choice(starting_line), (0, 0))

def R(s, a):
    (i, j), (vi, vj) = s
    (ai, aj) = a
    
    vi2, vj2 = vi + ai, vj + aj
    i2, j2 = i + vi, j + vj
        
    if (i2 < 0) or (i2 >= track_height) or (j2 < 0) or (j2 >= track_width) or track[i, j] == TRACK_OUT:
        s2 = (random_choice(starting_line), (0, 0))
    else:
        s2 = ((i2, j2), (vi2, vj2))

    return s2, -1



In [282]:
Q = { (s, a): 0 for s in S for a in list(A(s)) + [None] }
target_policy = { s: (random_choice(A(s))) if A(s) else None for s in S }
#target_policy = { (pos, v): (1,1) if pos in starting_line else (0, 0) if pos not in finishing_line else None for (pos, v) in S }
behavior_policy = epsilon_greedy(A, target_policy, EPSILON)

In [283]:
off_policy_mc(S, initialize, A, R, GAMMA, Q, target_policy, behavior_policy, niter=100)

3582
892
3062
2924
[(((0, 2), (0, 0)), (1, 0), 0.0), (((0, 2), (1, 0)), (0, 1), -1), (((1, 2), (1, 1)), (-1, 0), -1), (((2, 3), (0, 1)), (0, 1), -1), (((2, 4), (0, 2)), (0, 0), -1), (((2, 6), (0, 2)), (1, 0), -1), (((0, 3), (0, 0)), (1, 1), -1), (((0, 3), (1, 1)), (1, 0), -1), (((1, 4), (2, 1)), (0, -1), -1), (((3, 5), (2, 0)), (0, 1), -1), (((0, 3), (0, 0)), (1, 1), -1), (((0, 3), (1, 1)), (1, 0), -1), (((1, 4), (2, 1)), (0, -1), -1), (((3, 5), (2, 0)), (0, 0), -1), (((0, 2), (0, 0)), (1, 1), -1), (((0, 2), (1, 1)), (0, 1), -1), (((1, 3), (1, 2)), (1, 0), -1), (((2, 5), (2, 2)), (-1, -1), -1), (((0, 3), (0, 0)), (0, 1), -1), (((0, 3), (0, 1)), (1, 1), -1), (((0, 4), (1, 2)), (0, 0), -1), (((1, 6), (1, 2)), (0, -1), -1), (((0, 3), (0, 0)), (1, 1), -1), (((0, 3), (1, 1)), (1, 0), -1), (((1, 4), (2, 1)), (0, -1), -1), (((3, 5), (2, 0)), (0, 1), -1), (((0, 2), (0, 0)), (0, 1), -1), (((0, 2), (0, 1)), (1, -1), -1), (((0, 3), (1, 0)), (1, 0), -1), (((1, 3), (2, 0)), (0, 0), -1), (((3, 3), (

2951
319
4269
1800
569
710
1479
714
2552
955
2383
1289
102
110
2493
363
2566
616
1565
817
1679
1384
381
418
2997
1756
1728
2533
640
4094
[(((0, 4), (0, 0)), (1, 1), 0.0), (((0, 4), (1, 1)), (1, 0), -1), (((1, 5), (2, 1)), (-1, 0), -1), (((0, 3), (0, 0)), (1, 1), -1), (((0, 3), (1, 1)), (0, 1), -1), (((1, 4), (1, 2)), (-1, -1), -1), (((2, 6), (0, 1)), (0, 1), -1), (((0, 2), (0, 0)), (1, 1), -1), (((0, 2), (1, 1)), (0, 1), -1), (((1, 3), (1, 2)), (1, 0), -1), (((2, 5), (2, 2)), (-1, 0), -1), (((0, 3), (0, 0)), (0, 1), -1), (((0, 3), (0, 1)), (1, 1), -1), (((0, 4), (1, 2)), (1, 0), -1), (((1, 6), (2, 2)), (0, 0), -1), (((0, 2), (0, 0)), (1, 1), -1), (((0, 2), (1, 1)), (0, 1), -1), (((1, 3), (1, 2)), (1, 0), -1), (((2, 5), (2, 2)), (0, -1), -1), (((0, 2), (0, 0)), (1, 1), -1), (((0, 2), (1, 1)), (-1, 0), -1), (((1, 3), (0, 1)), (0, 0), -1), (((1, 4), (0, 1)), (1, 0), -1), (((1, 5), (1, 1)), (1, 0), -1), (((0, 3), (0, 0)), (0, 1), -1), (((0, 3), (0, 1)), (0, 0), -1), (((0, 4), (0, 1)), (0, 

3991


In [None]:
finishing_line

In [None]:
s0 = ((17, 15), (0, 0))
print s0
print target_policy[s0]
print A(s0)
print behavior_policy(s0)