In [1]:
import gym
import gym_sokoban
import time
import datetime as dt
import matplotlib.pyplot as plt
import math
from collections import deque
import random
import numpy as np
import pickle
import sys
import os
from collections import namedtuple,defaultdict,deque 

In [2]:
env_name = 'Sokoban-v1'
env = gym.make(env_name)
ACTION_LOOKUP = env.unwrapped.get_action_lookup()
env.unwrapped.set_level(0,1)
env.seed(0)
env.reset()
print("Created environment: {}".format(env_name))

Created environment: Sokoban-v1


In [3]:
def greedy_policy(V,s):
    #.9 prob of greedy action
    #.1 prob of random action

    if s not in V:
        V[s] = np.zeros(env.action_space.n)
    r_choice = .1
    if np.random.random() < r_choice:
        return np.random.choice(np.arange(env.action_space.n))
    else:
        max_val = np.max(V[s])
         #find all actions that have the max value and choose one at random
        max_actions = np.argwhere(V[s] == max_val).flatten()
        return np.random.choice(max_actions)
def print_time(ep_start,action_time):
    delta_1 = time.time() - ep_start
    delta_2 = time.time() - action_time
    st_del2 = str(dt.timedelta(seconds=delta_2))
    print("\rEpisode time: {} eps/s Action time: {}".format(int(1/(delta_1+.0001)),st_del2),end="")

In [4]:
EVERY_VISIT_MC = False
EPISODES = 300

V = {}
total_returns = {}
N = {}
for episode in range(EPISODES):
    ep_start = time.time()
    
    visited = []
    env.reset()
    state = env.unwrapped.serialize_state()
    done = False
    for t in range(1000):
        action_time = time.time()
        if done:
            break
        action = greedy_policy(V,state)
        next_state, reward, done, info = env.step(action)
        next_state = env.unwrapped.serialize_state()
        if EVERY_VISIT_MC or (not EVERY_VISIT_MC and state not in visited):
            if not EVERY_VISIT_MC:
                visited.append(state)
            if state not in total_returns:
                total_returns[state] = np.zeros(env.action_space.n)
            for _state in total_returns:
                total_returns[_state][action] += reward
            if state not in N:
                N[state] = np.zeros(env.action_space.n)
            
            N[state][action] += 1
            V[state][action] = total_returns[state][action] / N[state][action]
        state = next_state
        print_time(ep_start,action_time)
        print("  Episode: {}".format(episode),end="         ")

            
#save the value function
with open('value_function_1.pickle', 'wb') as handle:
    pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

Episode time: 9 eps/s Action time: 0:00:00  Episode: 299                  

In [9]:
V['3333333333333300000000000030010004000053000000000000333333333333330000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000']

array([-43244.35      , -38464.72      , -13114.3       ,   -575.6661597 ,
       -49014.63333333, -15471.4       ,  -8095.28333333, -13094.56      ])

In [5]:
N.values()

dict_values([array([  2.,   5.,   8., 263.,   6.,   5.,   6.,   5.]), array([  2., 205.,   5.,   3.,   4.,   5.,   7.,   3.]), array([  5.,   1.,   4.,   3., 117.,   2.,   4.,   7.]), array([171.,   5.,   3.,   2.,   5.,   4.,   3.,   3.]), array([  5.,   1.,   2.,   3., 116.,   1.,   3.,   5.]), array([  6.,   4., 111.,   2.,   5.,   3.,   5.,   2.]), array([ 5.,  4.,  3.,  3., 78.,  3.,  1.,  5.]), array([ 2.,  2.,  1., 32.,  2.,  2.,  1.,  1.]), array([ 2.,  3.,  3., 40.,  3.,  2.,  3.,  1.]), array([ 2.,  1.,  2.,  3.,  1., 14.,  3.,  2.]), array([1., 2., 1., 2., 1., 1., 1., 1.]), array([1., 1., 0., 1., 1., 1., 1., 0.]), array([0., 0., 2., 0., 0., 1., 1., 1.]), array([0., 0., 0., 0., 0., 1., 0., 1.]), array([1., 1., 0., 0., 0., 0., 0., 0.]), array([3., 2., 2., 3., 2., 4., 1., 2.]), array([1., 2., 2., 0., 0., 1., 1., 1.]), array([0., 1., 1., 1., 1., 1., 1., 1.]), array([0., 1., 0., 0., 0., 1., 0., 1.]), array([0., 1., 0., 0., 0., 0., 1., 1.]), array([1., 0., 0., 0., 0., 1., 1., 0.])

In [6]:

for s in V:
    argmax = np.argmax(V[s])
    max_val = np.max(V[s])
    s = s[26:40]
    print("State: {} Action: {} Value: {}".format(s,ACTION_LOOKUP[argmax],max_val))

State: 30010004000053 Action: push up Value: -17.929571984435764
State: 30001004000053 Action: push up Value: -26.467045454545403
State: 30000104000053 Action: push up Value: -42.15221238938045
State: 30000014000053 Action: move up Value: -85.7607142857142
State: 30100004000053 Action: move left Value: -27.856830601092888
State: 31000004000053 Action: push down Value: -54.89904761904758
State: 30000001400053 Action: push right Value: -135.73333333333318
State: 30000010400053 Action: push left Value: -388.4999999999996
State: 30000100400053 Action: move up Value: -261.19999999999965
State: 30001000400053 Action: push up Value: 0.0
State: 30010000400053 Action: push down Value: 0.0
State: 30100000400053 Action: push up Value: 0.0
State: 30000000140053 Action: move down Value: -221.69999999999948
State: 30000001040053 Action: push down Value: 0.0
State: 30000000014053 Action: push left Value: -410.3333333333328
State: 30000000104053 Action: push left Value: -58.3000000000001
State: 300000

In [6]:
env.reset()
s = env.unwrapped.serialize_state()

In [7]:
st_time = time.time()
done = False
while not done:
    if time.time() - st_time < .2:
        env.render()
        continue
    st_time = time.time()
    action_time = time.time()
    action = greedy_policy(V,state)
    next_state, reward, done, info = env.step(action)
    next_state = env.unwrapped.serialize_state()
    print("Action: {} Reward: {}".format(ACTION_LOOKUP[action],reward))

Action: push up Reward: 9.9
Action: move right Reward: 9.8
Action: push up Reward: -0.29999999999999927
Action: push up Reward: -10.399999999999999
Action: push up Reward: -20.5
Action: move up Reward: -20.6
Action: push up Reward: -30.700000000000003
Action: push up Reward: -40.800000000000004
Action: push up Reward: -50.900000000000006
Action: push up Reward: -61.00000000000001
Action: push up Reward: -71.1
Action: push up Reward: -81.19999999999999
Action: push up Reward: -91.29999999999998
Action: push up Reward: -101.39999999999998
Action: push up Reward: -111.49999999999997
Action: push up Reward: -121.59999999999997
Action: push up Reward: -131.69999999999996
Action: push up Reward: -141.79999999999995
Action: push up Reward: -151.89999999999995
Action: push up Reward: -161.99999999999994
Action: push up Reward: -172.09999999999994
Action: move right Reward: -172.19999999999993
Action: push up Reward: -182.29999999999993
Action: push up Reward: -192.39999999999992
Action: push u

KeyboardInterrupt: 