In [1]:
import gym
import gym_sokoban
import time
import datetime as dt
import matplotlib.pyplot as plt
import math
from collections import deque
import random
import numpy as np
import pickle
import sys
import tqdm
import os
from collections import namedtuple,defaultdict,deque 

In [2]:
env_name = 'Sokoban-v1'
env = gym.make(env_name)
ACTION_LOOKUP = env.unwrapped.get_action_lookup()
env.unwrapped.set_level(0,1)
env.seed(0)
env.reset()
print("Created environment: {}".format(env_name))

Created environment: Sokoban-v1


In [3]:
def greedy_policy(V,s):
    #.9 prob of greedy action
    #.1 prob of random action

    if s not in V:
        V[s] = np.zeros(env.action_space.n)
    r_choice = .4
    if np.random.random() < r_choice:
        return np.random.choice(np.arange(env.action_space.n))
    else:
        max_val = np.max(V[s])
         #find all actions that have the max value and choose one at random
        max_actions = np.argwhere(V[s] == max_val).flatten()
        return np.random.choice(max_actions)

In [11]:
EVERY_VISIT_MC = False
EPISODES = 1000

V = {}
total_returns = {}
N = {}
for episode in tqdm.tqdm(range(EPISODES)):
    
    visited = []
    env.reset()
    state = env.unwrapped.serialize_state()
    done = False
    for t in range(20):
        action_time = time.time()
        if done:
            break
        action = greedy_policy(V,state)
        next_state, reward, done, info = env.step(action)
        next_state = env.unwrapped.serialize_state()
        if EVERY_VISIT_MC or (not EVERY_VISIT_MC and state not in visited):
            if not EVERY_VISIT_MC:
                visited.append(state)
            if state not in total_returns:
                total_returns[state] = np.zeros(env.action_space.n)
            for _state in total_returns:
                total_returns[_state][action] += reward
            if state not in N:
                N[state] = np.zeros(env.action_space.n)
            
            N[state][action] += 1
            V[state][action] = (total_returns[state][action] / N[state][action])
        state = next_state

            
#save the value function
with open('value_function_1.pickle', 'wb') as handle:
    pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:09<00:00, 105.71it/s]


In [12]:
N.values()

dict_values([array([649., 115.,  84., 152.]), array([627., 111.,  88., 119.]), array([118.,  23.,  24.,  62.]), array([41., 14., 20., 31.]), array([600.,  96.,  87., 130.]), array([583., 114.,  89., 101.]), array([572.,  95.,  92., 103.]), array([571.,  82.,  89.,  90.]), array([77., 19., 24., 14.]), array([12.,  2.,  2.,  7.]), array([6., 0., 1., 1.]), array([77.,  6., 10., 12.]), array([9., 5., 1., 0.]), array([558.,  77.,  93.,  85.]), array([0., 0., 0., 1.]), array([0., 1., 0., 0.]), array([0., 1., 0., 0.]), array([67., 20.,  3., 16.]), array([11.,  3.,  5.,  2.]), array([0., 1., 0., 3.]), array([541.,  80.,  77.,  86.]), array([62., 11.,  4., 23.]), array([14.,  3.,  2.,  6.]), array([0., 1., 1., 0.]), array([5., 0., 0., 1.]), array([0., 1., 0., 0.]), array([0., 2., 0., 0.]), array([1., 0., 0., 0.])])

In [13]:

for s in V:
    argmax = np.argmax(V[s])
    max_val = np.max(V[s])
    s = s[26:40]
    print("State: {} Action: {} Value: {}".format(s,ACTION_LOOKUP[argmax],max_val))

State: 30010004000053 Action: right Value: 20.657673343604998
State: 30001004000053 Action: right Value: 21.382488038276943
State: 30100004000053 Action: right Value: 111.38644067796322
State: 31000004000053 Action: right Value: 306.04073170730965
State: 30000104000053 Action: right Value: 22.34798333333274
State: 30000014000053 Action: right Value: 23.000566037735233
State: 30000001400053 Action: right Value: 23.399842657342035
State: 30000000140053 Action: right Value: 23.485043782836502
State: 30000010400053 Action: right Value: 173.01870129869673
State: 30000100400053 Action: right Value: 917.8808333333131
State: 30001000400053 Action: left Value: 673.8000000000017
State: 30000001040053 Action: right Value: 172.1037662337617
State: 30000010040053 Action: right Value: 1446.189999999963
State: 30000000014053 Action: right Value: 23.976810035841655
State: 30000100040053 Action: right Value: 0.0
State: 30001000040053 Action: right Value: 0.0
State: 30010000040053 Action: right Value: 0

In [14]:
def greedy_policy_pi(V,s):
    #.9 prob of greedy action
    #.1 prob of random action
    return np.argmax(V[s])


In [21]:
_ = env.reset()
state = env.unwrapped.serialize_state()


In [22]:
st_time = time.time()
done = False
while not done:
    if time.time() - st_time < .2:
        env.render()
        continue
    st_time = time.time()
    action_time = time.time()
    action = greedy_policy_pi(V,state)
    next_state, reward, done, info = env.step(action)
    next_state = env.unwrapped.serialize_state()
    print("Action: {} Reward: {}".format(ACTION_LOOKUP[action],reward))
    env.render()

Action: right Reward: 1.99
Action: right Reward: 1.98
Action: right Reward: 1.97
Action: right Reward: 1.96
Action: right Reward: 1.95
Action: right Reward: 1.94
Action: right Reward: 1.93
Action: right Reward: 13.92


In [None]:

action = 0
next_state, reward, done, info = env.step(action)
env.render()
print("Action: {} Reward: {}".format(ACTION_LOOKUP[action],reward))

Action: push up Reward: -35.60000000000006
