In [1]:
import gym
import gym_sokoban
import time
import datetime as dt
import matplotlib.pyplot as plt
import math
from collections import deque
import random
import numpy as np
import pickle
import sys
import tqdm
import os
from collections import namedtuple,defaultdict,deque 

In [2]:
chapter = 15
level = 1


env_name = 'Sokoban-v1'
env = gym.make(env_name)
ACTION_LOOKUP = env.unwrapped.get_action_lookup()
env.unwrapped.set_level(chapter,level)
env.seed(0)
env.reset()
print("Created environment: {}".format(env_name))


#create Results/Chapter [chapter]/Level [level] folder
if not os.path.exists('Results/Chapter '+str(chapter)+'/Level '+str(level)):
    os.makedirs('Results/Chapter '+str(chapter)+'/Level '+str(level))

Created environment: Sokoban-v1


In [3]:
env.render()

True

In [4]:
def map( x,  in_min,  in_max,  out_min,  out_max) :
  return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min;

def greedy_policy(V,s,eps_comp=1e-8):
    #.9 prob of greedy action
    #.1 prob of random action

    if s not in V:
        V[s] = np.zeros(env.action_space.n)
    r_choice = 1 -map(eps_comp,0,1,0.2,.7)
    if np.random.random() < r_choice:
        return np.random.choice(np.arange(env.action_space.n))
    else:
        max_val = np.max(V[s])
         #find all actions that have the max value and choose one at random
        max_actions = np.argwhere(V[s] == max_val).flatten()
        return np.random.choice(max_actions)

In [5]:
EVERY_VISIT_MC = False
EPISODES = 150000

V = {}
total_returns = {}
N = {}
for episode in tqdm.tqdm(range(EPISODES)):
    
    visited = []
    env.reset()
    state = env.unwrapped.serialize_state()
    done = False
    for t in range(50):
        action_time = time.time()
        if done:
            break
        action = greedy_policy(V,state,eps_comp=episode/EPISODES)
        next_state, reward, done, info = env.step(action)
        if EVERY_VISIT_MC or (not EVERY_VISIT_MC and state not in visited):
            if not EVERY_VISIT_MC:
                visited.append((state,action))
            if state not in total_returns:
                total_returns[state] = np.zeros(env.action_space.n)
            if state not in N:
                N[state] = np.zeros(env.action_space.n)
            for _state,_action in visited:
                total_returns[_state][_action] += reward
                N[_state][_action] += 1
                V[_state][_action] = (total_returns[_state][_action] / N[_state][_action])
            
            
            #V[state][action] = (total_returns[state][action] / N[state][action])
        state = env.unwrapped.serialize_state()
    
    if episode % 1000 == 0:
        fname = 'Results/Chapter '+str(chapter)+'/Level '+str(level)+'/MC_'+('every' if EVERY_VISIT_MC else 'first')+'_'+str(EPISODES)+'_episodes_temp.bin'
        with open(fname, 'wb') as handle:
            pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)


            
#save the value function
fname = 'Results/Chapter '+str(chapter)+'/Level '+str(level)+'/MC_'+('every' if EVERY_VISIT_MC else 'first')+'_'+str(EPISODES)+'_episodes.bin'
with open(fname, 'wb') as handle:
    pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

  0%|          | 0/150000 [00:00<?, ?it/s]

  2%|▏         | 2660/150000 [01:38<1:31:03, 26.97it/s]


KeyboardInterrupt: 

In [7]:
fname = 'Results/Chapter '+str(chapter)+'/Level '+str(level)+'/MC_'+('every' if EVERY_VISIT_MC else 'first')+'_'+str(EPISODES)+'_episodes.bin'
with open(fname, 'wb') as handle:
    pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [6]:
N.values()

dict_values([array([ 52384.,  52107.,  52980., 122343.,  52637.]), array([136292., 139675., 314538., 138631., 138755.]), array([148984., 238664., 117794., 125513., 122755.]), array([98890., 45995., 45971., 44610., 44141.]), array([ 58072.,  60011.,  58574., 131687.,  55708.]), array([ 55958.,  54904.,  58713., 122999.,  56324.]), array([30716., 15339., 14248., 17861., 16790.]), array([21492., 46726., 20333., 19445., 19245.]), array([16142.,  7158.,  6153.,  6945.,  7172.]), array([4742., 2012., 1936., 2342., 1814.]), array([8800., 3714., 5040., 4364., 4148.]), array([2279., 2250., 4867., 2562., 2212.]), array([4577., 4452., 2900., 2794., 2719.]), array([2538., 3227., 2673., 5360., 2303.]), array([ 911.,  832., 1505.,  621.,  721.]), array([335., 293., 410., 225., 327.]), array([1983., 1007., 1733., 2169., 1532.]), array([ 700.,  527.,  870., 1760.,  720.]), array([436., 577., 419., 702., 523.]), array([ 617.,  761., 1028.,  701.,  990.]), array([1491., 1181.,  981.,  833., 1063.]), arr

In [7]:
V

{'3333000000000000000000000030030000000000000000000000300300000000000000000000003003333300000000000000000035440413000000000000000000300505030000000000000000003003333300000000000000000033330000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000': array([7.27139474, 7.1562696 , 7.10016978, 7.37775278, 7.27694855]),
 '33330000000000000000000000300300000000000000000000003003000000000000000000000030033333000000000000000000354441030000000000000000003005050300000000000000000030033333000000000000000000333300000000000000000000000000000000000000000000000000000000000000

In [8]:

for s in V:
    argmax = np.argmax(V[s])
    max_val = np.max(V[s])
    s = s[26:40] +'\n'+ s[26+26:40+26]
    print("{} Action: {} Value: {}".format(s,ACTION_LOOKUP[argmax],max_val))

30030000000000
30030000000000 Action: left Value: 7.377752777028111
30030000000000
30030000000000 Action: down Value: 7.2671474638999225
30030000000000
30030000000000 Action: up Value: 7.246014195689408
30030000000000
30030000000000 Action: right Value: 7.220452796035666
30030000000000
30030000000000 Action: left Value: 7.229032121621111
30030000000000
30030000000000 Action: left Value: 7.237631728712619
30030000000000
30030000000000 Action: right Value: 7.1910839627555285
30030000000000
30030000000000 Action: up Value: 7.283958738175734
30030000000000
30030000000000 Action: right Value: 7.072367054887877
30030000000000
30030000000000 Action: right Value: 6.721695908899206
30030000000000
30030000000000 Action: right Value: 7.162724545454543
30030000000000
30030000000000 Action: down Value: 7.186666735155102
30030000000000
30030000000000 Action: up Value: 7.171038858939778
30030000000000
30030000000000 Action: left Value: 7.167425746268596
30030000000000
30030000000000 Action: down Valu

In [9]:
#load the value function

fname = "D:\\2023-2024\\RNN\\Sokoban\\Results\\Chapter 0\\Level 3\\MC_first_100000_episodes_temp.bin"
with open(fname, 'rb') as handle:
    V = pickle.load(handle)

In [10]:
def greedy_policy_pi(V,s):
    #.9 prob of greedy action
    #.1 prob of random action

    if s not in V:
        V[s] = np.zeros(env.action_space.n)
    r_choice = .3
    if np.random.random() < r_choice:
        return np.random.choice(np.arange(env.action_space.n))
    else:
        max_val = np.max(V[s])
         #find all actions that have the max value and choose one at random
        max_actions = np.argwhere(V[s] == max_val).flatten()
        return np.random.choice(max_actions)


In [11]:
_ = env.reset()
state = env.unwrapped.serialize_state()


In [12]:

for i in range(1000):
    time.sleep(1)
    _ = env.reset()
    state = env.unwrapped.serialize_state()
    st_time = time.time()
    done = False
    t= 0
    last_reward = 0
    
    last_state = state
    while True:
        if time.time() - st_time < .2:
            env.render()
            time.sleep(1/60)
            continue
        st_time = time.time()
        action_time = time.time()
        action = greedy_policy_pi(V,state)
        
        state, reward, done, info = env.step(action)
        state = env.unwrapped.serialize_state()
        print(f'{ACTION_LOOKUP[action]} state change: {last_state != state}',end=' ')
        last_state = state
        last_reward = reward
        t += 1

        if done or t > 100:
            env.render()
            break
        env.render()

down state change: True extinguish state change: False left state change: True left state change: True left state change: True right state change: True down state change: False up state change: True extinguish state change: False right state change: True left state change: True extinguish state change: False extinguish state change: False extinguish state change: False right state change: True down state change: True extinguish state change: False extinguish state change: False left state change: True up state change: True up state change: False down state change: True up state change: True up state change: False right state change: True right state change: False left state change: True extinguish state change: False down state change: True up state change: True down state change: True left state change: True left state change: True up state change: True extinguish state change: False extinguish state change: False left state change: True down state change: True extinguish state change

KeyboardInterrupt: 

In [32]:
ACTION_LOOKUP

{0: 'right', 1: 'up', 2: 'down', 3: 'left', 4: 'extinguish'}

In [42]:

action = 1
next_state, reward, done, info = env.step(action)
env.render()
print("Action: {} Reward: {}".format(ACTION_LOOKUP[action],reward))

Action: up Reward: 3.692000000000001


In [41]:
for v in V.values():
    print(np.round(v,2))

[1.01 1.08 1.19 0.83 1.1 ]
[1.23 1.32 1.04 1.16 1.22]
[0.97 0.99 1.1  1.16 1.09]
[0.99 1.13 1.05 1.08 1.06]
[1.06 1.16 1.09 1.12 1.13]
[0.71 0.69 0.92 1.08 0.82]
[0.86 1.14 0.13 0.9  0.82]
[0.12 0.16 0.81 0.26 0.3 ]
[ 0.23  0.59  0.95  0.33 -0.13]
[-2.41 -0.31 -0.89  1.   -2.1 ]
[-2.85 -6.28 -4.43 -0.37 -6.55]
[ -3.72  -5.09 -12.     0.48  -4.89]
[-20.74  -3.76  -0.66  -2.44  -3.65]
[-33.99 -10.96  -9.26 -18.23 -26.69]
[-38.35 -36.99  -6.6   -7.39 -16.01]
[-16.92  -1.98   2.04   1.98  -8.4 ]
[  1.96  -8.72 -10.61   1.93   2.  ]
[-8.15 -5.13 -8.87  2.05  1.97]
[1.93 1.86 1.92 1.99 1.9 ]
[2.02 2.02 1.96 1.98 2.04]
[1.88 1.98 2.08 2.01 2.25]
[2.04 1.98 1.91 1.84 1.89]
[1.83 1.74 1.75 1.77 1.81]
[1.76 1.74 1.79 1.67 1.72]
[1.53 1.78 1.74 1.78 1.83]
[1.62 1.74 1.71 1.73 1.71]
[ -5.4  -15.72  -6.1  -13.61 -15.67]
[  3.38 -24.19   2.47  -9.55   2.38]
[-44.78 -48.45 -30.56  -5.3  -53.99]
[-72.25 -64.68 -20.95 -15.71 -20.18]
[-89.36 -75.99 -84.99 -89.9  -86.49]
[ -91.1  -108.89 -110.14 -105.4  