In [None]:
import gym
import numpy as np
import tensorflow as tf

In [None]:
def one_hot(length, idx):
    encode = np.zeros(shape=[length])
    encode[idx] = 1.
    return encode

In [None]:
env = gym.make('FrozenLake-v0')

In [None]:
n_states = env.env.nS
n_actions = env.env.nA
print(f'{n_states:,} states & {n_actions:,} actions')

### Network

In [None]:
tf.reset_default_graph()

# inputs & targets (states & actions)
inputs = tf.placeholder(tf.float32, shape=[n_states])
target = tf.placeholder(tf.float32, shape=[1, n_actions])

# reshape
X_reshape = tf.reshape(inputs, shape=[1, n_states])

# weights
weight = tf.Variable(tf.random_normal(shape=[n_states, n_actions], mean=0, stddev=0.4))

# Q value prediction
Q_value = tf.matmul(X_reshape, weight)
predict = tf.argmax(Q_value, axis=1)
print(Q_value, predict)

### Loss & Optimizer

In [None]:
loss = tf.squared_difference(target, Q_value)

global_step = tf.Variable(0, trainable=False, name='global_step')
optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-1)
train = optimizer.minimize(loss, global_step=global_step)

### Training

In [None]:
gamma = 0.9
epsilon = 0.1
episodes = 10000
max_trans_per_episode = 100

In [None]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

In [None]:
for episode in range(episodes):
    state, done = env.reset(), False
    total_reward = 0
    max_trans = 0
    while max_trans < max_trans_per_episode:
        max_trans += 1
        action, Q = sess.run([predict, Q_value], 
                             feed_dict={inputs: one_hot(n_states, state)})
        # Epsilon Greedy Exploration
        if np.random.randn(1) < epsilon:
            action[0] = env.action_space.sample()
        # Take the action
        new_state, reward, done, _ = env.step(action[0])
        # Get Q´ values for the next_state
        new_Q = sess.run(Q_value, feed_dict={inputs: one_hot(n_states, new_state)})
        Q[0, action[0]] = reward + gamma * np.max(new_Q)
        # Train network
        sess.run(train, feed_dict={inputs: one_hot(n_states, state), target: Q})
        state = new_state
        total_reward += reward
    if episode % 100 == 0:
        print(f'Episode: {episode}\tTotal reward: {total_reward}')