## Value Iteration - Table Representation

#### Imports

In [2]:
import numpy as np
np.random.seed(1337)

#### Parameters

In [3]:
n_states = 10 # Number of states
n_actions = 4 # Number of actions
gamma = 0.9 # Discount Factor
tolerance = 0.00001 # Convergence criteria

#### Set state rewards

In [4]:
rewards = np.zeros([n_states, n_actions])
rewards[-1] = 1 # Goal state
rewards[-2] = -1 # Penalty state

#### Define transition probabilities

In [5]:
transition_prob = np.random.random([n_states,n_actions,n_states])
s = transition_prob.sum(axis=-1)
transition_prob = transition_prob/np.repeat(s, n_states).reshape([n_states, n_actions, n_states]) # Normalization
transition_prob[-1] = 0 # Make goal state terminal
transition_prob[-1,:,-1] = 1 # Make goal state terminal
transition_prob[-2] = 0 # Make penalty state terminal
transition_prob[-2,:,-2] = 1 # Make penalty state terminal

#### Initialize state values

In [16]:
state_values = np.zeros(n_states)
estimated_state_values = np.zeros(n_states)

#### Value Iteration through Bellman updates until convergence

In [17]:
while True:
    for s in range(n_states):
        estimated_state_values[s] = max(rewards[s,:] + gamma*np.dot(transition_prob[s,:], state_values)) # Bellman Update
    if np.abs((state_values - estimated_state_values)).mean() < tolerance:
        break
    state_values = estimated_state_values.copy()
    print state_values
        
print(state_values)

[ 0.  0.  0.  0.  0.  0.  0.  0. -1.  1.]
[ 0.05933014  0.12598526  0.00884204  0.17474599  0.03293927  0.02026472
  0.03456583  0.08449444 -1.9         1.9       ]
[ 0.16000972  0.30168432  0.05832781  0.37230267  0.114991    0.08541099
  0.11302557  0.22058477 -2.71        2.71      ]
[ 0.28075231  0.49728548  0.13612103  0.58739251  0.22540939  0.18133922
  0.2151522   0.37824012 -3.439       3.439     ]
[ 0.41102307  0.69870454  0.22792472  0.80576786  0.3501015   0.30307085
  0.32992447  0.54481216 -4.0951      4.0951    ]
[ 0.54354261  0.89786641  0.32646592  1.01967146  0.47987561  0.43453352
  0.44964395  0.71233353 -4.68559     4.68559   ]
[ 0.67349944  1.08964401  0.42615435  1.22448114  0.60911285  0.56796356
  0.57286729  0.87536701 -5.217031    5.217031  ]
[ 0.79828302  1.27148274  0.52358352  1.41808411  0.73429767  0.69882526
  0.70779979  1.03119185 -5.6953279   5.6953279 ]
[ 0.91733983  1.44334989  0.6169941   1.60119995  0.85398334  0.82504494
  0.83768496  1.1797089 