-
Notifications
You must be signed in to change notification settings - Fork 0
/
warehouse_parallel.py
218 lines (187 loc) · 10 KB
/
warehouse_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import time
import random
import ray
import numpy as np
from pylru import lrucache
from more_itertools import split_into, powerset
from warehouse import Warehouse, q_table_to_action_list
import matplotlib.pyplot as plt
@ray.remote
class Main:
""" Main process. Contains global q-table. Receives local q-tables and sends max q values for a given state to
worker process."""
def __init__(self, env, n_proc, r_threshold):
self.env = env
self.n_proc = n_proc
self.r_threshold = r_threshold
self.q_table = np.zeros((self.env.n_states, self.env.n_actions))
self.stop = False
def receive_q_loc(self, qt_loc, states):
"""Receive local q table from worker and update global qt. Return reward that results from current q-table
and termination boolean if reward threshold is available."""
self.q_table[self.env.states.index(states[0]):self.env.states.index(states[-1])+1] = qt_loc # update q-table
rewards = q_table_to_action_list(self.q_table, self.env)[-1] # get reward resulting from current q-table
if self.r_threshold and rewards >= self.r_threshold: # current reward above threshold (and threshold available)
self.stop = True
return rewards, self.stop
def send_max_q(self, state):
"""Get max q value of state from global qt (send to worker)."""
return np.max(self.q_table[self.env.states.index(state), :])
def send_q_table(self):
""" Send final q-table."""
return self.q_table
@ray.remote
class Worker:
""" Worker process. Performs actions in its own part of the state space and communicates with Main periodically."""
def __init__(self, env, states, possible_actions, corridors, grid_size, start_global, pick_pts_global, main, id,
update_interval, cache_size, n_episodes, n_steps, l_rate, d_rate, max_e_rate, min_e_rate, e_d_rate):
self.env = env # complete environment
self.states = states # list of local states
self.grid_size = grid_size # grid size of complete environment
self.possible_actions = possible_actions # possible actions of local partition
self.corridors = corridors # local corridors
self.pick_pts_global = pick_pts_global # global pick points
self.qt_loc = np.zeros((len(self.states), 4)) # initialise local q-table
self.start_global = start_global # global starting position
self.start = self.corridors[0] # local starting position
self.position = self.start
self.state = (self.position, ())
self.main = main # reference to main process
self.id = id # id of worker process
self.update_interval = update_interval
self.cache = lrucache(cache_size) # initialise empty LRU cache
self.n_episodes = n_episodes
self.n_steps = n_steps
self.l_rate = l_rate
self.d_rate = d_rate
self.max_e_rate = max_e_rate
self.min_e_rate = min_e_rate
self.e_d_rate = e_d_rate
def reset(self):
"""Reset local state"""
self.state = random.choice(self.states) # pick random state to start the episode
def step(self, action):
"""Take step, update qt loc."""
done = False
# Determine new position
if action in self.possible_actions[np.where(self.corridors == self.state[0])[0][0]]: # if action is valid for
# current field
reward = -1 # general reward for motion
if action == 0: # up
new_position = self.state[0] - self.grid_size[1]
if action == 1: # right
new_position = self.state[0] + 1
if action == 2: # down
new_position = self.state[0] + self.grid_size[1]
if action == 3: # left
new_position = self.state[0] - 1
else: # action was invalid, so no movement
reward = -2
new_position = self.state[0]
self.position = new_position
# Check if new position is above or below unvisited point
if new_position + self.grid_size[1] in np.setdiff1d(self.pick_pts_global, self.state[1]): # if above
new_pick_state = tuple(sorted(self.state[1] + tuple([new_position + self.grid_size[1]])))
reward = 10
elif new_position - self.grid_size[1] in np.setdiff1d(self.pick_pts_global, self.state[1]): # if below
new_pick_state = tuple(sorted(self.state[1] + tuple([new_position - self.grid_size[1]])))
reward = 10
elif new_position == self.start_global and not np.setdiff1d(self.pick_pts_global, self.state[1]).any(): # if
# done
new_pick_state = self.state[1]
reward = 100
done = True
else: # don't change pick state
new_pick_state = self.state[1]
# Update state
old_state_idx = self.states.index(self.state)
new_state = (new_position, new_pick_state)
self.state = new_state
# Check if new position is local
if new_position in [state[0] for state in self.states]: # determine max next from local q table
max_next = np.max(self.qt_loc[self.states.index(new_state), :])
else: # finish episode and get max next from global q table from main
done = True
max_next = self.get_remote_max_q(new_state)
# Update local q-table
self.qt_loc[old_state_idx, action] = self.qt_loc[old_state_idx, action] * (1 - self.l_rate) \
+ self.l_rate * (reward + self.d_rate * max_next)
return done
def get_remote_max_q(self, state):
"""Function to get max q value of given state from main process. Includes caching."""
def get_from_main(state):
return ray.get(self.main.send_max_q.remote(state))
if state not in self.cache: # if state is not yet in cache, get the state from main and add to cache
self.cache[state] = get_from_main(state)
return self.cache[state]
def train(self):
"""Run local training process"""
e_rate = 1
rewards = []
# For each episode
for episode in range(self.n_episodes):
self.reset()
# For each step
for step in range(self.n_steps):
# Pick action
if random.uniform(0, 1) > e_rate: # exploit
action = np.argmax(self.qt_loc[self.states.index(self.state), :]) # pick best action from
# current state
else: # explore
action = random.choice(self.env.actions) # choose random action
# Take action
done = self.step(action=action)
# Break loop if done
if done:
break
# Update exploration rate
e_rate = self.min_e_rate + (self.max_e_rate - self.min_e_rate) * np.exp(-self.e_d_rate * episode)
# Update main q-table periodically (send local to main)
if not episode % self.update_interval and episode > 0:
reward_total, stop = ray.get(self.main.receive_q_loc.remote(qt_loc=self.qt_loc, states=self.states))
rewards.append(reward_total)
if stop:
break
return rewards
def train_parallel(env, n_proc, update_interval, cache_size, n_episodes, n_steps, l_rate, d_rate, max_e_rate,
min_e_rate, e_d_rate, r_threshold=None):
"""Perform parallel q-learning on env."""
# Split up environment grid according to number of processes
corridors_split = np.array_split(env.corridors, n_proc)
possible_actions_split = list(split_into(env.possible_actions, [len(c) for c in corridors_split]))
states_split = list(split_into(env.states, [len(c) * len(list(powerset(env.pick_pts))) for c in corridors_split]))
ray.init() # start up ray
main = Main.remote(env=env, n_proc=n_proc, r_threshold=r_threshold) # define main process
workers = [Worker.options(
name="worker" + str(i)).remote(
env=env, states=states_slice, possible_actions=actions_slice, corridors=corridors_slice,
id=i, grid_size=env.grid_size, start_global=env.start, pick_pts_global=env.pick_pts,
main=main, update_interval=update_interval, cache_size=cache_size, n_episodes=n_episodes,
n_steps=n_steps, l_rate=l_rate, d_rate=d_rate, max_e_rate=max_e_rate,
min_e_rate=min_e_rate, e_d_rate=e_d_rate)
for i, (states_slice, actions_slice, corridors_slice)
in enumerate(zip(states_split, possible_actions_split, corridors_split))] # define all workers
s = time.time()
rewards = ray.get([worker.train.remote() for worker in workers]) # start local training process for all workers
# Since the workers do not step through the episodes synchronously, the reward value that they received from the
# global q-table at given update intervals can differ. Therefore, take the maximum value of all workers, as this
# presents the actual best optimal path at that moment.
rewards_final = list(map(max, zip(*rewards)))
exec_time = time.time() - s
q_table_final = ray.get(main.send_q_table.remote()) # get final q-table from main process
ray.shutdown() # shutdown ray
actions_final, _ = q_table_to_action_list(q_table_final, env) # get final action list from q-table
return actions_final, rewards_final, exec_time
if __name__ == "__main__":
env = Warehouse(8, 4, 4)
env.render()
n_proc = 4
n_episodes = 10000
update_interval = 100
actions_final, rewards_final, exec_time = train_parallel(env, n_proc=n_proc, update_interval=update_interval,
n_episodes=n_episodes, cache_size=20, n_steps=100,
l_rate=1., d_rate=1., max_e_rate=1, min_e_rate=0.001,
e_d_rate=0.001)
plt.plot(range(0, len(rewards_final)*update_interval, update_interval), rewards_final)
plt.grid()
plt.show()