In [1]:
import time
import pickle
import numpy as np

import gym
from griddy_env import GriddyEnv

In [2]:
def calculate_qs(episode_mem, discount_factor=0.95):
    for i, mem in reversed(list(enumerate(episode_mem))):
        if i==len(episode_mem)-1:
            episode_mem[i]['q']= mem['reward']
        else:
            _, next_obs_q = greedy_policy(mem['new_observation'], return_action_val=True)
            calculated_q = mem['reward']+discount_factor*next_obs_q
            episode_mem[i]['q'] = calculated_q
    return episode_mem

In [3]:
def update_q_table(q_table, episode_mem, alpha=0.5):
    all_diffs=[]
    for mem in episode_mem:
        key = pickle.dumps(np.array((*mem['observation'].flatten(), mem['action'])))
        if key not in q_table:
            q_table[key]=0 #initialize
        new_val = q_table[key] + alpha*(mem['q']-q_table[key])
        diff = abs(q_table[key]-new_val)
        all_diffs.append(diff)
        q_table[key] = new_val
    return q_table, np.mean(all_diffs)

In [4]:
def greedy_policy(state, return_action_val=False):
    action_values=[]
    for test_action in range(4): #for each action
        key = pickle.dumps(np.array((*np.copy(state).flatten(), test_action)))
        if key not in q_table: q_table[key] = 0
        action_values.append(q_table[key])
    policy_action = np.argmax(action_values)
    if return_action_val: return policy_action, action_values[policy_action]
    return policy_action

In [5]:
def epsilon_greedy_policy(state):
    action = env.action_space.sample() if np.random.rand()<epsilon else greedy_policy(state)
    return action

In [6]:
def q_table_viz(q_table):
    qs = np.zeros((4, 4, 4))
    base_st = np.zeros((3, 4, 4), dtype=np.int64)
    base_st[0, 3, 3]=1
    for i in range(4):
        for j in range(4):
            for action in range(4):
                test_st = np.copy(base_st)
                test_st[2, i, j] = 1
                key = pickle.dumps(np.array([*test_st.flatten(), action], dtype=np.int64))
                if key in q_table:
                    val = q_table[key]
                else:
                    val=0
                qs[action, i, j] = val
    return qs

In [7]:
def visualise_agent(policy_function, n=5):
    try:
        for trial_i in range(n):
            observation = env.reset()
            done=False
            t=0
            while not done:
                env.render()
                policy_action = policy_function(observation)
                observation, reward, done, info = env.step(policy_action)
                time.sleep(0.5)
                t+=1
            env.render()
            time.sleep(1.5)
            print("Episode {} finished after {} timesteps".format(trial_i, t))
        env.close()
    except KeyboardInterrupt:
        env.close()

In [11]:
env = GriddyEnv()
epsilon = 1
i_episode=0
q_table = {}

In [12]:
def train(n_episodes=100):
    global epsilon
    global q_table
    global i_episode
    try:
        for _ in range(500):
            observation = env.reset()
            episode_mem = []
            converged_hist = []
            done=False
            t=0
            while not done:            
                action = epsilon_greedy_policy(observation)
                new_observation, reward, done, info = env.step(action)
                agent_pos = list(zip(*np.where(observation[2] == 1)))[0]
                if agent_pos==(3, 3):
                    print('errrpr')
                    gg
                episode_mem.append({'observation':observation,
                                    'action':action,
                                    'reward':reward,
                                    'new_observation':new_observation,
                                    'done':done})
                observation=new_observation
                t+=1
                epsilon*=0.9999
            episode_mem = calculate_qs(episode_mem)
            q_table, q_delta = update_q_table(q_table, episode_mem)
            i_episode+=1
            print("Episode {} finished after {} timesteps. Eplislon={}. q_Delta={}".format(i_episode, t, epsilon, q_delta))#, end='\r')
            print(q_table_viz(q_table))
            print()
        env.close()
    except KeyboardInterrupt:
        env.close()

In [13]:
train()

Episode 1 finished after 8 timesteps. Eplislon=0.9992002799440072. q_Delta=0.0625
[[[0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0. ]]

 [[0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0. ]]

 [[0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0. ]]

 [[0.  0.  0.  0. ]
  [0.  0.  0.  0. ]
  [0.  0.  0.  0.5]
  [0.  0.  0.  0. ]]]

Episode 2 finished after 7 timesteps. Eplislon=0.9985010495451367. q_Delta=0.1392857142857143
[[[0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]]

 [[0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]
  [0.     0.     0.     0.2375]
  [0.     0.     0.5    0.    ]]

 [[0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]
  [0.     0.     0.     0.    ]]

 [[0.     0.     0.     0.    ]
  [0.     0.     0.     0.2375]
  [0.     0.     0.     0.5   ]
  [0.    

[[[0.06359819 0.5465212  0.67047157 0.78989986]
  [0.60912525 0.72284986 0.79617476 0.84067191]
  [0.73966255 0.76375496 0.82596391 0.88388153]
  [0.74643536 0.7843223  0.89299109 0.        ]]

 [[0.69055474 0.75008792 0.7762645  0.64517678]
  [0.78227721 0.84499666 0.89303778 0.81998383]
  [0.8346605  0.89392063 0.92750244 0.87021484]
  [0.89468665 0.94825325 0.99951172 0.        ]]

 [[0.52748291 0.67219407 0.79069238 0.75960617]
  [0.49079686 0.69006093 0.78497162 0.76961102]
  [0.7020797  0.79712856 0.77217519 0.62884155]
  [0.50508932 0.81925258 0.81568726 0.        ]]

 [[0.67105834 0.79606241 0.84296268 0.88575439]
  [0.77172699 0.83088921 0.89411343 0.94533234]
  [0.70386489 0.88747578 0.94588318 0.99951172]
  [0.72272496 0.89718351 0.94533234 0.        ]]]

Episode 23 finished after 7 timesteps. Eplislon=0.9146585430972585. q_Delta=0.012584855760846827
[[[0.06359819 0.5465212  0.67047157 0.78989986]
  [0.60912525 0.72284986 0.79617476 0.84067191]
  [0.73966255 0.76375496 0.825

Episode 46 finished after 11 timesteps. Eplislon=0.850519367471868. q_Delta=0.0005286982127271129
[[[0.72728931 0.73409337 0.77346099 0.81402974]
  [0.77170157 0.77356373 0.81440508 0.85725098]
  [0.81291382 0.81446019 0.85735181 0.90204634]
  [0.8572354  0.85733705 0.90248893 0.        ]]

 [[0.77362889 0.81436609 0.85718099 0.85009423]
  [0.81325386 0.85735704 0.90239071 0.90195994]
  [0.85737073 0.90248183 0.94998364 0.94991937]
  [0.90249922 0.94999984 0.99999997 0.        ]]

 [[0.72956212 0.77355631 0.81373417 0.85655279]
  [0.73069999 0.77357304 0.81441407 0.85695078]
  [0.75957459 0.81445065 0.85732202 0.90125735]
  [0.80353823 0.85731898 0.90246382 0.        ]]

 [[0.77158774 0.81446526 0.85735527 0.90245769]
  [0.81443913 0.85736674 0.90249219 0.94998548]
  [0.85699277 0.90249845 0.9499981  0.99999952]
  [0.856359   0.9024823  0.9499998  0.        ]]]

Episode 47 finished after 10 timesteps. Eplislon=0.8496692307360671. q_Delta=2.6061607691685306e-05
[[[0.72728931 0.73409337 

[[[0.73115431 0.73476986 0.77376706 0.81444043]
  [0.77371564 0.77376491 0.81450612 0.85737454]
  [0.81449375 0.81450622 0.85737498 0.90249978]
  [0.8573728  0.857375   0.9025     0.        ]]

 [[0.77374254 0.81450292 0.85735666 0.85735534]
  [0.81442768 0.85737474 0.90249977 0.90246268]
  [0.857375   0.90249999 0.94999997 0.94999492]
  [0.9025     0.95       1.         0.        ]]

 [[0.72956212 0.7737614  0.81450024 0.85734407]
  [0.73285965 0.7737732  0.81450142 0.85736075]
  [0.77372278 0.81450621 0.85737497 0.90246007]
  [0.8144848  0.857375   0.90249998 0.        ]]

 [[0.77349873 0.81450584 0.85737472 0.9024995 ]
  [0.8145062  0.857375   0.9025     0.94999998]
  [0.85736902 0.9025     0.95       1.        ]
  [0.85734316 0.9025     0.95       0.        ]]]

Episode 76 finished after 5 timesteps. Eplislon=0.8042775236246144. q_Delta=4.4035709034773165e-08
[[[0.73115431 0.73476986 0.77376706 0.81444043]
  [0.77371564 0.77376491 0.81450612 0.85737454]
  [0.81449375 0.81450622 0.8

[[[0.73484317 0.73499994 0.77377917 0.81448979]
  [0.77378081 0.77378044 0.81450623 0.85737499]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.85737493 0.857375   0.9025     0.        ]]

 [[0.77378031 0.81450619 0.85737442 0.85737236]
  [0.81448661 0.85737498 0.90249999 0.90249942]
  [0.857375   0.9025     0.95       0.94999992]
  [0.9025     0.95       1.         0.        ]]

 [[0.73439378 0.77377591 0.8145055  0.85737488]
  [0.73505549 0.77378041 0.81450625 0.85737407]
  [0.77378082 0.81450625 0.857375   0.90249875]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77377653 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.85737494 0.9025     0.95       0.        ]]]

Episode 111 finished after 16 timesteps. Eplislon=0.763292851177371. q_Delta=4.201111840351768e-09
[[[0.73484317 0.73499994 0.77377917 0.81448979]
  [0.77378081 0.77378044 0.81450623 0.85737499]
  [0.81450625 0.81450625 0.8

Episode 145 finished after 20 timesteps. Eplislon=0.7261374180074647. q_Delta=1.2318121132559589e-07
[[[0.7350763  0.73509041 0.77378092 0.81450574]
  [0.77378091 0.77378093 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.85737499 0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.85737493 0.85737368]
  [0.81450379 0.857375   0.9025     0.90249971]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.7350045  0.77378031 0.81450625 0.85737494]
  [0.73508271 0.77378093 0.81450625 0.85737454]
  [0.77378094 0.81450625 0.857375   0.90249992]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77377873 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.85737499 0.9025     0.95       0.        ]]]

Episode 146 finished after 5 timesteps. Eplislon=0.7257744219049418. q_Delta=7.293810000419399e-12
[[[0.7350763  0.7350904

Episode 181 finished after 23 timesteps. Eplislon=0.6959216991425771. q_Delta=5.380054903996082e-10
[[[0.73508409 0.7350918  0.77378094 0.81450599]
  [0.77378093 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.85737499 0.85737484]
  [0.81450617 0.857375   0.9025     0.90249996]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73507004 0.7737809  0.81450625 0.85737494]
  [0.73508271 0.77378094 0.81450625 0.85737477]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378066 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 182 finished after 7 timesteps. Eplislon=0.6954347000723794. q_Delta=3.2566013378040484e-13
[[[0.73508409 0.7350918

Episode 216 finished after 19 timesteps. Eplislon=0.6705077710917665. q_Delta=9.530388174545422e-15
[[[0.7350914  0.73509189 0.77378094 0.81450599]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.85737498]
  [0.81450621 0.857375   0.9025     0.90249999]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509121 0.77378093 0.81450625 0.857375  ]
  [0.7350916  0.77378094 0.81450625 0.85737499]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378093 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 217 finished after 8 timesteps. Eplislon=0.6699715525795251. q_Delta=0.0
[[[0.7350914  0.73509189 0.77378094 0.8145

[[[0.73509165 0.73509189 0.77378094 0.81450622]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.857375  ]
  [0.81450624 0.857375   0.9025     0.9025    ]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509155 0.77378094 0.81450625 0.857375  ]
  [0.73509185 0.77378094 0.81450625 0.857375  ]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378093 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 262 finished after 13 timesteps. Eplislon=0.6433141811224802. q_Delta=2.964741272994441e-11
[[[0.73509165 0.73509189 0.77378094 0.81450622]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.

Episode 307 finished after 5 timesteps. Eplislon=0.6217459180471736. q_Delta=1.1693979118376774e-12
[[[0.73509186 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.857375  ]
  [0.81450625 0.857375   0.9025     0.9025    ]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509155 0.77378094 0.81450625 0.857375  ]
  [0.73509187 0.77378094 0.81450625 0.857375  ]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 308 finished after 12 timesteps. Eplislon=0.6210002331610694. q_Delta=1.2261488121131e-13
[[[0.73509186 0.73509189 

Episode 351 finished after 16 timesteps. Eplislon=0.5997001092073375. q_Delta=2.2157969903346952e-13
[[[0.73509188 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.857375  ]
  [0.81450625 0.857375   0.9025     0.9025    ]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509185 0.77378094 0.81450625 0.857375  ]
  [0.73509187 0.77378094 0.81450625 0.857375  ]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 352 finished after 3 timesteps. Eplislon=0.5995202171649789. q_Delta=0.0
[[[0.73509188 0.73509189 0.77378094 0.814

[[[0.73509188 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.857375  ]
  [0.81450625 0.857375   0.9025     0.9025    ]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509187 0.77378094 0.81450625 0.857375  ]
  [0.73509189 0.77378094 0.81450625 0.857375  ]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 401 finished after 16 timesteps. Eplislon=0.5766455341370138. q_Delta=0.0
[[[0.73509188 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025   

[[[0.73509188 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.857375  ]
  [0.81450625 0.857375   0.9025     0.9025    ]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509188 0.77378094 0.81450625 0.857375  ]
  [0.73509189 0.77378094 0.81450625 0.857375  ]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 448 finished after 1 timesteps. Eplislon=0.5560876562921702. q_Delta=0.0
[[[0.73509188 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    

[[[0.73509189 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0.857375   0.9025    ]
  [0.857375   0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.857375  ]
  [0.81450625 0.857375   0.9025     0.9025    ]
  [0.857375   0.9025     0.95       0.95      ]
  [0.9025     0.95       1.         0.        ]]

 [[0.73509189 0.77378094 0.81450625 0.857375  ]
  [0.73509189 0.77378094 0.81450625 0.857375  ]
  [0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.        ]]

 [[0.77378094 0.81450625 0.857375   0.9025    ]
  [0.81450625 0.857375   0.9025     0.95      ]
  [0.857375   0.9025     0.95       1.        ]
  [0.857375   0.9025     0.95       0.        ]]]

Episode 492 finished after 22 timesteps. Eplislon=0.5350306349778062. q_Delta=2.1585586853299013e-11
[[[0.73509189 0.73509189 0.77378094 0.81450625]
  [0.77378094 0.77378094 0.81450625 0.857375  ]
  [0.81450625 0.81450625 0

In [29]:
q_table_viz(q_table)

array([[[0.73509189, 0.73509189, 0.77378094, 0.81450625],
        [0.77378094, 0.77378094, 0.81450625, 0.857375  ],
        [0.81450625, 0.81450625, 0.857375  , 0.9025    ],
        [0.857375  , 0.857375  , 0.9025    , 0.        ]],

       [[0.77378094, 0.81450625, 0.857375  , 0.857375  ],
        [0.81450625, 0.857375  , 0.9025    , 0.9025    ],
        [0.857375  , 0.9025    , 0.95      , 0.95      ],
        [0.9025    , 0.95      , 1.        , 0.        ]],

       [[0.73509189, 0.77378094, 0.81450625, 0.857375  ],
        [0.73509189, 0.77378094, 0.81450625, 0.857375  ],
        [0.77378094, 0.81450625, 0.857375  , 0.9025    ],
        [0.81450625, 0.857375  , 0.9025    , 0.        ]],

       [[0.77378094, 0.81450625, 0.857375  , 0.9025    ],
        [0.81450625, 0.857375  , 0.9025    , 0.95      ],
        [0.857375  , 0.9025    , 0.95      , 1.        ],
        [0.857375  , 0.9025    , 0.95      , 0.        ]]])

In [31]:
visualise_agent(greedy_policy)

Episode 0 finished after 5 timesteps
Episode 1 finished after 2 timesteps
Episode 2 finished after 1 timesteps
Episode 3 finished after 3 timesteps
Episode 4 finished after 1 timesteps
