## Fitted Value Iteration - Function Approximation

#### Imports

In [61]:
import numpy as np
import tensorflow as tf
np.random.seed(1337)

#### Parameters

In [70]:
n_states = 100 # Number of states
gamma = 0.9 # Discount Factor
learning_rate = 0.8 # Learning Rate
tolerance = 0.001 # Convergence criteria
iterations = 100# Number of iterations
n_epoch = 500 # Number of model training epochs

#### Set state rewards

In [63]:
rewards = np.zeros(n_states)
rewards[-1] = 1 # Goal state
rewards[-2] = -1 # Penalty state

#### Define transition probabilities

In [64]:
transition_prob = np.random.random([n_states,n_states])
s = transition_prob.sum(axis=-1)
transition_prob = transition_prob/np.repeat(s,n_states).reshape([n_states,n_states]) # Normalization
transition_prob[-1] = 0 # Make goal state terminal
transition_prob[-2] = 0 # Make penalty state terminal

#### Initialize value network

In [65]:
inputs = tf.placeholder(tf.float32, [None, n_states])
weights = tf.Variable(tf.zeros([n_states,1]))
outputs = tf.matmul(inputs, weights)
targets = tf.placeholder(tf.float32, [None, 1])

#### Define loss and optimizer 

In [66]:
loss = tf.losses.mean_squared_error(targets, outputs)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

#### Fitted Value Iteration through Bellman updates until convergence

In [74]:
init = tf.global_variables_initializer()
state_one_hot = np.eye(n_states)

with tf.Session() as sess:
    sess.run(init)
    for itr in range(iterations):
        v_theta = sess.run([outputs], feed_dict={inputs: state_one_hot})[0].T[0]
        v_estimated = np.zeros(n_states)
        for s in range(n_states):
            v_estimated[s] = rewards[s] + gamma*np.dot(transition_prob[s], v_theta) # Bellman Update
        v_estimated = np.expand_dims(v_estimated, -1)
        for i in range(n_epoch):
            _, mse = sess.run([optimizer, loss], feed_dict={inputs: state_one_hot, targets: v_estimated})
        print('Iteration %d, Error %f' % (itr, mse))
                
print(v_theta)

Iteration 0, Error 0.000000
Iteration 1, Error 0.000000
Iteration 2, Error 0.000000
Iteration 3, Error 0.000000
Iteration 4, Error 0.000000
Iteration 5, Error 0.000000
Iteration 6, Error 0.000000
Iteration 7, Error 0.000005
Iteration 8, Error 0.000331
Iteration 9, Error 0.000548
Iteration 10, Error 0.000476
Iteration 11, Error 0.000558
Iteration 12, Error 0.000309
Iteration 13, Error 0.000455
Iteration 14, Error 0.000321
Iteration 15, Error 0.000422
Iteration 16, Error 0.000293
Iteration 17, Error 0.000757
Iteration 18, Error 0.000402
Iteration 19, Error 0.000448
Iteration 20, Error 0.000438
Iteration 21, Error 0.000329
Iteration 22, Error 0.000379
Iteration 23, Error 0.000219
Iteration 24, Error 0.000352
Iteration 25, Error 0.000672
Iteration 26, Error 0.000551
Iteration 27, Error 0.000523
Iteration 28, Error 0.000362
Iteration 29, Error 0.000524
Iteration 30, Error 0.000704
Iteration 31, Error 0.000528
Iteration 32, Error 0.000658
Iteration 33, Error 0.000951
Iteration 34, Error 0.00