In [1]:
import sys
sys.setrecursionlimit(10000)

In [27]:
class TransportationMDP(object):
  def __init__(self,N):
    # N is the number of blocks
    self.N = N

  def startState(self):
    return 1

  def isEnd(self, state):
    return state == self.N

  def actions(self, state):
    # return list of valid actions that can be taken for the current state
    result = []
    if state+1 <= self.N:
      result.append('walk')
    if state*2 <= self.N:
      result.append('tram')
    return result

  def succProbReward(self, state, action):
    # return list of (newState, prob, reward) triplet
    # state = s, action = a, newState = s`
    # prob = T(s,a,s`), reward = Reward(s,a,s`)
    result = []
    if action == 'walk':
      result.append((state+1, 1, -1))
    elif action == 'tram':
      failProb = 0.9
      result.append((state*2, failProb, -2))
      result.append((state, failProb, -2))
    return result

  def discount(self):
    return 1
  
  def states(self):
    return range(1, self.N+1)


In [28]:
mdp = TransportationMDP(N=10)

In [29]:
mdp.actions(3)

['walk', 'tram']

In [30]:
mdp.succProbReward(3,'walk')

[(4, 1, -1)]

In [31]:
mdp.succProbReward(3,'tram')

[(6, 0.9, -2), (3, 0.9, -2)]

In [32]:
def valueIteration(mdp):
  #initialize
  V = {} #state -> Vopt[state]
  for state in mdp.states():
    V[state] = 0
  
  def Q(state, action):
    return sum( prob * (reward + mdp.discount() * V[newState])
        for newState, prob, reward in mdp.succProbReward(state, action)
    )
  
  while True:
    #compute new values of newV given the old value V
    newV = {}
    for state in mdp.states():
      if mdp.isEnd(state):
        newV[state] = 0
      else:
        newV[state] = max( Q(state, action) for action in mdp.actions(state) )
    
    #check for convergence
    if max( abs(V[state] -newV[state] ) for state in mdp.states() ) < 1e-10:
      break

    V = newV

    #read out policy
    pi = {}
    for state in mdp.states():
      if mdp.isEnd(state):
        pi[state] = 'none'
      else:
        pi[state] =max(
            ( Q(state, action), action) for action in mdp.actions(state)
        ) [1]
    
    # print stuff out
    print("{:15} {:15} {:15}".format('s', 'V(s)', 'pi(s)'))
    for state in mdp.states():
      print("{:15} {:15} {:15}".format(state, V[state], pi[state]))
    

In [33]:
valueIteration(mdp)

s               V(s)            pi(s)          
              1              -1 walk           
              2              -1 walk           
              3              -1 walk           
              4              -1 walk           
              5              -1 walk           
              6              -1 walk           
              7              -1 walk           
              8              -1 walk           
              9              -1 walk           
             10               0 none           
s               V(s)            pi(s)          
              1              -2 walk           
              2              -2 walk           
              3              -2 walk           
              4              -2 walk           
              5              -2 walk           
              6              -2 walk           
              7              -2 walk           
              8              -2 walk           
              9              -1 walk    