-
Notifications
You must be signed in to change notification settings - Fork 3
/
fit_carracing_ddpg.py
122 lines (81 loc) · 3.96 KB
/
fit_carracing_ddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from tensorflow import keras
import gym
from trickster.agent import DDPG
from trickster.rollout import Trajectory, MultiRolling, RolloutConfig
from trickster.experience import Experience
from trickster.utility import spaces
K = keras.backend
def _activate(_x, activation="leakyrelu", batch_normalize=True):
if batch_normalize:
_x = keras.layers.BatchNormalization()(_x)
if activation == "leakyrelu":
_x = keras.layers.LeakyReLU()(_x)
else:
_x = keras.layers.Activation(activation)(_x)
return _x
def _conv(_x, width, stride=1, activation="leakyrelu", batch_normalize=True):
_x = keras.layers.Conv2D(width, kernel_size=3, strides=stride, padding="same")(_x)
_x = _activate(_x, activation, batch_normalize)
return _x
def _dense(_x, width, activation="leakyrelu", batch_normalize=True):
_x = keras.layers.Dense(width)(_x)
_x = _activate(_x, activation, batch_normalize)
return _x
def _clip(_x, low, high):
_x = keras.layers.Lambda(lambda _xx: K.clip(_xx, low, high))(_x)
return _x
def make_backbone():
state_in = keras.Input(input_shape, name="state_in")
x = _conv(state_in, width=8, stride=2, batch_normalize=BATCH_NORMALIZE) # 48
x = _conv(x, width=16, stride=2, batch_normalize=BATCH_NORMALIZE) # 24
x = _conv(x, width=32, stride=2, batch_normalize=BATCH_NORMALIZE) # 12
features = keras.layers.GlobalAveragePooling2D()(x) # 32
return state_in, features
def make_actor(state_in, features):
actor_stream = _dense(features, width=64, batch_normalize=BATCH_NORMALIZE)
actor_stream = _dense(actor_stream, width=64, batch_normalize=BATCH_NORMALIZE)
actor_output0 = _clip(_dense(actor_stream, width=1, batch_normalize=False), -1, 1)
actor_output1 = _clip(_dense(actor_stream, width=1, batch_normalize=False), 0, 1)
actor_output2 = _clip(_dense(actor_stream, width=1, batch_normalize=False), 0, 1)
actor_output = keras.layers.concatenate([actor_output0, actor_output1, actor_output2])
actor_network = keras.Model(state_in, actor_output)
actor_network.compile(keras.optimizers.Adam(ACTOR_LR), keras.losses.mean_squared_error)
return actor_network
def make_critic(state_in, action_in, features):
critic_stream = keras.layers.concatenate([features, action_in])
critic_stream = _dense(critic_stream, width=64, batch_normalize=BATCH_NORMALIZE)
critic_stream = _dense(critic_stream, width=64, batch_normalize=BATCH_NORMALIZE)
critic_output = _dense(critic_stream, width=1, batch_normalize=False)
critic_network = keras.Model([state_in, action_in], critic_output)
critic_network.compile(keras.optimizers.Adam(CRITIC_LR), keras.losses.mean_squared_error)
return critic_network
class CarRacing(gym.ObservationWrapper):
def __init__(self):
super().__init__(env=gym.make("CarRacing-v0"))
def observation(self, observation):
return observation / 255.
BATCH_NORMALIZE = False
ACTOR_LR = 1e-4
CRITIC_LR = 1e-4
NUM_ENVS = 4
envs = [CarRacing() for _ in range(NUM_ENVS)]
test_env = CarRacing()
input_shape = envs[0].observation_space.shape # 96 96 3
num_actions = envs[0].action_space.shape[0]
state_inputs, backbone_features = make_backbone()
actor = make_actor(state_inputs, backbone_features)
action_inputs = keras.Input([num_actions], name="critic_action_in")
critic = make_critic(state_inputs, action_inputs, backbone_features)
agent = DDPG(actor, critic,
action_space=spaces.CONTINUOUS,
memory=Experience(max_length=int(1e4)),
discount_factor_gamma=1.,
action_noise_sigma=0.1,
action_noise_sigma_decay=1.,
action_minima=[-1, 0, 0],
action_maxima=[1, 1, 1])
rollout = MultiRolling(agent, envs)
test_rollout = Trajectory(agent, test_env, RolloutConfig(max_steps=128))
rollout.fit(episodes=10000, updates_per_episode=4, steps_per_update=4, update_batch_size=64,
testing_rollout=test_rollout, render_every=100)
test_rollout.render(repeats=100)