-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmain.py
More file actions
97 lines (72 loc) · 2.16 KB
/
main.py
File metadata and controls
97 lines (72 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: utf-8 -*-
import signal, sys, json
from environment import Environment
from strategy import Strategy
EPISODE_COUNT=1000 * 1000
SAVE_INTERVAL=100
MAX_EPISODE_STEPS=100000
ENVIRONMENT_HEIGHT=10
ENVIRONMENT_WIDTH=10
SAVE_FILE='sarsa.json'
INIT_ENVIRONMENT="""
X . █ . █ . . . . .
. . █ . █ █ █ █ █ 웃
. . █ . . . . . . .
. █ █ . █ . █ █ █ .
M . . . █ . █ . M .
. █ █ . █ . █ █ █ .
. . █ . . . . . . .
█ █ █ █ █ █ █ █ █ .
. . . . . . . . . .
. . . . . . . . █ .
"""
def build_environment():
return Environment(INIT_ENVIRONMENT)
def build_strategy():
γ = 0.99
α = 0.1
λ = 0.1
ε = 0.1
ε_decay = 1
return Strategy(γ, α, λ, ε, ε_decay, Environment.actions)
def load_from_file(strategy):
try:
with open(SAVE_FILE) as f:
strategy.load(json.load(f))
print('Loaded', SAVE_FILE)
except:
pass
def save_to_file(strategy):
try:
with open(SAVE_FILE, 'w') as f:
json.dump(strategy.dump(), f)
# print('Saved', SAVE_FILE)
except:
pass
def run_episode(strategy):
environment = build_environment()
steps = 0
total_reward = 0
strategy.new_episode()
while not environment.actor_in_terminal_state and steps < MAX_EPISODE_STEPS:
state_before = environment.get_actor_state()
action = strategy.next_action(state_before)
reward = environment.perform_action(action)
state_after = environment.get_actor_state()
strategy.update(state_before, action, reward, state_after)
total_reward += reward
steps += 1
return steps, total_reward
def save_and_exit(_1,_2):
save_to_file(strategy)
sys.exit(0)
if __name__ == '__main__':
signal.signal(signal.SIGINT, save_and_exit) # handle ctrl-c
strategy = build_strategy()
load_from_file(strategy)
for episode_index in range(EPISODE_COUNT):
run_episode(strategy)
if episode_index > 0 and episode_index % SAVE_INTERVAL == 0:
save_to_file(strategy)
print(episode_index)
save_to_file(strategy)