In [46]:
import gym
import collections
from tensorboardX import SummaryWriter

ENV_NAME = "FrozenLake-v0"
GAMMA = 0.9
TEST_EPISODES = 20


In [None]:
class Agent:
    def __init__(self):
        self.env = gym.make(ENV_NAME)
        self.state = self.env.reset()
        self.rewards = collections.defaultdict(float)
        self.transits = collections.defaultdict(collections.Counter)
        self.values = collections.defaultdict(float)

    def play_n_random_steps(self, count):
        for _ in range(count):
            action = self.env.action_space.sample()
            new_state, reward, is_done, _ = self.env.step(action)
            self.rewards[(self.state, action, new_state)] = reward
            self.transits[(self.state, action)][new_state] += 1
            self.state = self.env.reset() if is_done else new_state

    def select_action(self, state):
        best_action, best_value = None, None
        for action in range(self.env.action_space.n):
            action_value = self.values[(state, action)]
            if best_value is None or best_value < action_value:
                best_value = action_value
                best_action = action
        return best_action

    def play_episode(self, env):
        total_reward = 0.0
        state = env.reset()
        while True:
            action = self.select_action(state)
            new_state, reward, is_done, _ = env.step(action)
            self.rewards[(state, action, new_state)] = reward
            self.transits[(state, action)][new_state] += 1
            total_reward += reward
            if is_done:
                break
            state = new_state
        return total_reward


    def value_iteration(self):
        for state in range(self.env.observation_space.n):
            for action in range(self.env.action_space.n):
                target_counts = self.transits[(state, action)]
                total = sum(target_counts.values())
                action_value = 0.0
                for tgt_state, count in target_counts.items():
                    reward = self.rewards[(state, action, tgt_state)]
                    best_action = self.select_action(tgt_state)
                    action_value += (count / total) * (reward + GAMMA * self.values[(tgt_state,best_action)])
                self.values[(state,action)] = action_value
            
    def get_values(self):
        return self.values

In [None]:
if __name__ == "__main__":
    test_env = gym.make(ENV_NAME,map_name='4x4',is_slippery=False)
    agent = Agent()
    writer = SummaryWriter(comment="-v-iteration")

    iter_no = 0
    best_reward = 0.0
    while True:
        iter_no += 1
        agent.play_n_random_steps(100)
        agent.value_iteration()

        reward = 0.0
        for _ in range(TEST_EPISODES):
            reward += agent.play_episode(test_env)
        reward /= TEST_EPISODES
        writer.add_scalar("reward", reward, iter_no)
        if reward > best_reward:
            print("Best reward updated %.3f -> %.3f" % (best_reward, reward))
            best_reward = reward
        if reward > 0.80:
            print("Solved in %d iterations!" % iter_no)
            break
    writer.close()

In [45]:
value = agent.get_values()
for i,j in value.items():
    print(i,j)
#print([i,j for i,j in value.items()])


(0, 0) 0.004969171784127788
(0, 1) 0.00602724224433383
(0, 2) 0.006344197870726633
(0, 3) 0.006105066840452293
(4, 0) 0.009414270330448081
(4, 1) 0.014903238344406232
(4, 2) 0.013121932061195753
(4, 3) 0.016805738781422577
(1, 0) 0.007459657127826521
(1, 1) 0.0074573933234839225
(1, 2) 0.007397828629685215
(1, 3) 0.006863850100285057
(2, 0) 0.008204256861279234
(2, 1) 0.008191200359159668
(2, 2) 0.007893946517946929
(2, 3) 0.007550047051033333
(5, 0) 0.021309168814638343
(5, 1) 0.0
(5, 2) 0.0
(5, 3) 0.0
(6, 0) 0.021960023585062313
(6, 1) 0.02131133969717352
(6, 2) 0.03865692789033549
(6, 3) 0.02619758276961985
(3, 0) 0.008800096042479576
(3, 1) 0.010357975137587485
(3, 2) 0.010681371163511363
(3, 3) 0.010753208750998982
(8, 0) 0.005031661679604207
(8, 1) 0.004933668971203032
(8, 2) 0.004932255141804283
(8, 3) 0.0052054391429361
(10, 0) 0.004900846691596509
(10, 1) 0.0037734931282452826
(10, 2) 0.004665556406753116
(10, 3) 0.005611077936693346
(7, 0) 0.031698598611506645
(7, 1) 0.0
(7, 