-
Notifications
You must be signed in to change notification settings - Fork 59
/
jerk_agent.py
130 lines (116 loc) · 3.71 KB
/
jerk_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
"""
A scripted agent called "Just Enough Retained Knowledge".
"""
import random
import gym
import numpy as np
import gym_remote.client as grc
import gym_remote.exceptions as gre
EXPLOIT_BIAS = 0.25
TOTAL_TIMESTEPS = int(1e6)
def main():
"""Run JERK on the attached environment."""
env = grc.RemoteEnv('tmp/sock')
env = TrackedEnv(env)
new_ep = True
solutions = []
while True:
if new_ep:
if (solutions and
random.random() < EXPLOIT_BIAS + env.total_steps_ever / TOTAL_TIMESTEPS):
solutions = sorted(solutions, key=lambda x: np.mean(x[0]))
best_pair = solutions[-1]
new_rew = exploit(env, best_pair[1])
best_pair[0].append(new_rew)
print('replayed best with reward %f' % new_rew)
continue
else:
env.reset()
new_ep = False
rew, new_ep = move(env, 100)
if not new_ep and rew <= 0:
print('backtracking due to negative reward: %f' % rew)
_, new_ep = move(env, 70, left=True)
if new_ep:
solutions.append(([max(env.reward_history)], env.best_sequence()))
def move(env, num_steps, left=False, jump_prob=1.0 / 10.0, jump_repeat=4):
"""
Move right or left for a certain number of steps,
jumping periodically.
"""
total_rew = 0.0
done = False
steps_taken = 0
jumping_steps_left = 0
while not done and steps_taken < num_steps:
action = np.zeros((12,), dtype=np.bool)
action[6] = left
action[7] = not left
if jumping_steps_left > 0:
action[0] = True
jumping_steps_left -= 1
else:
if random.random() < jump_prob:
jumping_steps_left = jump_repeat - 1
action[0] = True
_, rew, done, _ = env.step(action)
total_rew += rew
steps_taken += 1
if done:
break
return total_rew, done
def exploit(env, sequence):
"""
Replay an action sequence; pad with NOPs if needed.
Returns the final cumulative reward.
"""
env.reset()
done = False
idx = 0
while not done:
if idx >= len(sequence):
_, _, done, _ = env.step(np.zeros((12,), dtype='bool'))
else:
_, _, done, _ = env.step(sequence[idx])
idx += 1
return env.total_reward
class TrackedEnv(gym.Wrapper):
"""
An environment that tracks the current trajectory and
the total number of timesteps ever taken.
"""
def __init__(self, env):
super(TrackedEnv, self).__init__(env)
self.action_history = []
self.reward_history = []
self.total_reward = 0
self.total_steps_ever = 0
def best_sequence(self):
"""
Get the prefix of the trajectory with the best
cumulative reward.
"""
max_cumulative = max(self.reward_history)
for i, rew in enumerate(self.reward_history):
if rew == max_cumulative:
return self.action_history[:i+1]
raise RuntimeError('unreachable')
# pylint: disable=E0202
def reset(self, **kwargs):
self.action_history = []
self.reward_history = []
self.total_reward = 0
return self.env.reset(**kwargs)
def step(self, action):
self.total_steps_ever += 1
self.action_history.append(action.copy())
obs, rew, done, info = self.env.step(action)
self.total_reward += rew
self.reward_history.append(self.total_reward)
return obs, rew, done, info
if __name__ == '__main__':
try:
main()
except gre.GymRemoteError as exc:
print('exception', exc)