In [7]:
import gym
from griddy_env import GriddyEnvOneHot

import numpy as np
import pickle
from copy import deepcopy
import time

In [8]:
def calculate_Gs(episode_mem, discount_factor=0.95):
    for i, mem in reversed(list(enumerate(episode_mem))):
        if i==len(episode_mem)-1:
            episode_mem[i]['G']= mem['reward']
        else:
            G = mem['reward']+discount_factor*episode_mem[i+1]['G']
            episode_mem[i]['G'] = G   
    return episode_mem

In [27]:
def update_value_table(value_table, episode_mem, convergence_delta=0.001):
    all_diffs=[]
    for mem in episode_mem:
        key = pickle.dumps(mem['new_observation'])
        if key not in value_table:
            value_table[key]=0 #initialize
        #value_table[key] = max(value_table[key], mem['G'])
        new_val = 0.9*value_table[key] + 0.1*mem['G']
        diff = abs(value_table[key]-new_val)
        all_diffs.append(diff)
        value_table[key] = new_val
    return value_table, np.mean(all_diffs)<=convergence_delta

In [28]:
def transition(state, action):
    state = deepcopy(state)
    agent_pos = list(zip(*np.where(state[2] == 1)))[0]
    new_agent_pos = np.array(agent_pos)
    if action==0:
        new_agent_pos[1]-=1
    elif action==1:
        new_agent_pos[1]+=1
    elif action==2:
        new_agent_pos[0]-=1
    elif action==3:
        new_agent_pos[0]+=1    
    new_agent_pos = np.clip(new_agent_pos, 0, 3)

    state[2, agent_pos[0], agent_pos[1]] = 0 #moved from this position so it is empty
    state[2, new_agent_pos[0], new_agent_pos[1]] = 1 #moved to this position
    return state

In [29]:
def pick_best_action(state):
    action_values=[]
    for test_action in range(4): #for each action
        new_state = transition(state, test_action)
        key = pickle.dumps(new_state)
        if key not in value_table: value_table[key] = 0
        action_values.append(value_table[key])
    policy_action = np.argmax(action_values)
    return policy_action

In [30]:
def value_table_viz(value_table):
    values = np.zeros((4, 4))
    base_st = np.zeros((3, 4, 4), dtype=np.int64)
    base_st[0, 3, 3]=1
    for i in range(4):
        for j in range(4):
            test_st = deepcopy(base_st)
            test_st[2, i, j] = 1
            #print(test_st)
            key = pickle.dumps(test_st)
            if key in value_table:
                val = value_table[key]
            else:
                val=0
            values[i, j] = val
    return values

In [31]:
env = GriddyEnvOneHot()
epsilon = 1
value_table = {}

In [None]:
try:
    for i_episode in range(100):
        print('Episode', i_episode)
        old_observation = env.reset()
        done=False
        episode_mem = []
        t=0
        while not done:
            env.render()
            
            policy_action = pick_best_action(old_observation)
            action = env.action_space.sample() if np.random.rand()<epsilon else policy_action
            #action = env.action_space.sample()
            #print(action)
            new_observation, reward, done, info = env.step(action)
            episode_mem.append({'old_observation':deepcopy(old_observation),
                                'action':action,
                                'reward':reward,
                                'new_observation':deepcopy(new_observation),
                                'done':done})
            old_observation=deepcopy(new_observation)
            t+=1
            epsilon*=0.999
            #time.sleep(0.5)
        env.render()
        #time.sleep(0.5)
        episode_mem = calculate_Gs(episode_mem)
        value_table, converged = update_value_table(value_table, episode_mem)
        print("Episode finished after {} timesteps. Eplislon={}. Converged={}".format(t+1, epsilon, converged))
    env.close()
except KeyboardInterrupt:
    env.close()

Episode 0
Episode finished after 83 timesteps. Eplislon=0.9212341621210596. Converged=False
Episode 1
Episode finished after 85 timesteps. Eplislon=0.8469758853683546. Converged=False
Episode 2


In [24]:
value_table_viz(value_table)

array([[0.43653329, 0.41237238, 0.58660672, 0.66545277],
       [0.40484385, 0.39460289, 0.57862029, 0.70391521],
       [0.42031157, 0.50507498, 0.79875324, 0.87779845],
       [0.45579892, 0.57541861, 0.74092617, 0.9835768 ]])