Skip to content

Commit

Permalink
Added reward and Apply RL actions
Browse files Browse the repository at this point in the history
  • Loading branch information
ashkan-software committed Jul 17, 2019
1 parent d9fc018 commit 4a72e81
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 5 deletions.
5 changes: 3 additions & 2 deletions flow/envs/green_wave_env.py
Expand Up @@ -200,8 +200,9 @@ def _apply_rl_actions(self, rl_actions):
rl_mask = [int(x) for x in list('{0:0b}'.format(rl_actions))]
rl_mask = [0] * (self.num_traffic_lights - len(rl_mask)) + rl_mask
else:
# convert values less than 0.5 to zero and above to 1. 0's indicate
# that should not switch the direction
# convert values less than 0 to zero and above 0 to 1. 0 indicates
# that should not switch the direction, and 1 indicates that switch
# should happen
rl_mask = rl_actions > 0.0

for i, action in enumerate(rl_mask):
Expand Down
99 changes: 96 additions & 3 deletions tutorials/tutorial11_traffic_lights.ipynb
Expand Up @@ -442,7 +442,7 @@
" return Discrete(2 ** self.num_traffic_lights)\n",
" else:\n",
" return Box(\n",
" low=-1,\n",
" low=0,\n",
" high=1,\n",
" shape=(self.num_traffic_lights,),\n",
" dtype=np.float32)"
Expand All @@ -452,6 +452,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In the case that the action space is discrete, we need 1-bit (that can be 0 or 1) for the action of each traffic light node. Hence, we need `self.num_traffic_lights` bits to represent the action spacs. To make a `self.num_traffic_lights`-bit number, we use the pyhton's `Discrete(range)`, and since we have `self.num_traffic_lights` bits, the `range` will be 2^`self.num_traffic_lights`.\n",
"\n",
"In the case that the action space is continuous, we use a range (that is currently (0,1)) of numbers for each traffic light node. Hence, we will define `self.num_traffic_lights` \"Boxes\", each in the range (0,1). \n",
"\n",
"Note that the variable `num_traffic_lights` is actually the number of intersections in the grid system, not the number of traffic lights. Number of traffic lights in our example is 4 times the number of intersections"
]
},
Expand Down Expand Up @@ -562,7 +566,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"...."
"The agents in an RL scenario will learn to maximize a certain reward. This objective can be defined in terms of maximizing rewards or minimizing the penalty. In this example, we penalize the large delay and boolean actions that indicate a switch (with the negative sign)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_reward(self, rl_actions, **kwargs):\n",
" return - rewards.min_delay_unscaled(self) - rewards.boolean_action_penalty(rl_actions >= 0.5, gain=1.0)"
]
},
{
Expand All @@ -576,7 +590,86 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"..."
"In the `_apply_rl_actions` function, we specify what actions our agents should take in the environment. In this example, the agents (traffic light nodes) decide based on the action value how to change the traffic lights."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _apply_rl_actions(self, rl_actions):\n",
" \"\"\"See class definition.\"\"\"\n",
" # check if the action space is discrete\n",
" if self.discrete:\n",
" # convert single value to list of 0's and 1's\n",
" rl_mask = [int(x) for x in list('{0:0b}'.format(rl_actions))]\n",
" rl_mask = [0] * (self.num_traffic_lights - len(rl_mask)) + rl_mask\n",
" else:\n",
" # convert values less than 0.5 to zero and above 0.5 to 1. 0 \n",
" # indicates that should not switch the direction, and 1 indicates\n",
" # that switch should happen\n",
" rl_mask = rl_actions > 0.5\n",
"\n",
" # Loop through the traffic light nodes \n",
" for i, action in enumerate(rl_mask):\n",
" if self.currently_yellow[i] == 1: # currently yellow\n",
" # Code to change from yellow to red\n",
" ...\n",
" else:\n",
" # Code to change to yellow\n",
" ..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These are the portions of the code that are hidden from the above code for shortening the code:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
" # Code to change from yellow to red\n",
" self.last_change[i] += self.sim_step\n",
" # Check if our timer has exceeded the yellow phase, meaning it\n",
" # should switch to red\n",
" if self.last_change[i] >= self.min_switch_time:\n",
" if self.direction[i] == 0:\n",
" self.k.traffic_light.set_state(\n",
" node_id='center{}'.format(i),\n",
" state=\"GrGr\")\n",
" else:\n",
" self.k.traffic_light.set_state(\n",
" node_id='center{}'.format(i),\n",
" state='rGrG')\n",
" self.currently_yellow[i] = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
" # Code to change to yellow\n",
" if action:\n",
" if self.direction[i] == 0:\n",
" self.k.traffic_light.set_state(\n",
" node_id='center{}'.format(i),\n",
" state='yryr')\n",
" else:\n",
" self.k.traffic_light.set_state(\n",
" node_id='center{}'.format(i),\n",
" state='ryry')\n",
" self.last_change[i] = 0.0\n",
" self.direction[i] = not self.direction[i]\n",
" self.currently_yellow[i] = 1"
]
}
],
Expand Down

0 comments on commit 4a72e81

Please sign in to comment.