Skip to content

Commit

Permalink
buffer sample fn
Browse files Browse the repository at this point in the history
  • Loading branch information
lowrollr committed Nov 9, 2023
1 parent 55c19ab commit 43e8112
Showing 1 changed file with 83 additions and 12 deletions.
95 changes: 83 additions & 12 deletions notebooks/jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,23 +54,26 @@
"@struct.dataclass\n",
"class BufferState:\n",
" needs_reward: jnp.ndarray\n",
" populated: jnp.ndarray\n",
" buffer: Experience\n",
" next_index: int\n",
" batch_size: int\n",
" max_len_per_batch: int"
" max_len_per_batch: int\n",
" sample_batch_size: int"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# stolen from flashbax\n",
"def init_buffer_state(\n",
" experience: Experience,\n",
" batch_size: int,\n",
" max_len_per_batch: int\n",
" max_len_per_batch: int,\n",
" sample_batch_size: int\n",
") -> BufferState:\n",
" \n",
" experience = jax.tree_map(jnp.empty_like, experience)\n",
Expand All @@ -85,9 +88,11 @@
" return BufferState(\n",
" buffer=experience,\n",
" next_index=0,\n",
" needs_reward=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),\n",
" populated=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),\n",
" needs_reward=jnp.ones((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),\n",
" max_len_per_batch=max_len_per_batch,\n",
" batch_size=batch_size\n",
" batch_size=batch_size,\n",
" sample_batch_size=sample_batch_size\n",
" )\n",
"\n",
"def add_experience(\n",
Expand All @@ -102,12 +107,14 @@
"\n",
" # Update the next index\n",
" needs_reward = buffer_state.needs_reward.at[:, buffer_state.next_index, 0].set(True)\n",
" populated = buffer_state.populated.at[:, buffer_state.next_index, 0].set(True)\n",
" updated_next_index = (buffer_state.next_index + 1) % buffer_state.max_len_per_batch\n",
" \n",
" return buffer_state.replace(\n",
" buffer=updated_pytree,\n",
" next_index=updated_next_index,\n",
" needs_reward=needs_reward\n",
" needs_reward=needs_reward,\n",
" populated=populated\n",
" )\n",
"\n",
"def assign_rewards(\n",
Expand All @@ -120,19 +127,41 @@
"\n",
" return buffer_state.replace(\n",
" needs_reward = buffer_state.needs_reward * (1 - select_batch)\n",
" )\n"
" )\n",
"\n",
"def sample(\n",
" buffer_state: BufferState,\n",
" rng: jax.random.PRNGKey\n",
") -> Experience:\n",
" probs = ((~buffer_state.needs_reward).reshape(-1) * buffer_state.populated.reshape(-1)).astype(jnp.float32)\n",
" indices = jax.random.choice(\n",
" rng,\n",
" buffer_state.max_len_per_batch * buffer_state.batch_size,\n",
" shape=(buffer_state.sample_batch_size,),\n",
" replace=False,\n",
" p = probs / probs.sum()\n",
" )\n",
" batch_indices = indices // buffer_state.max_len_per_batch\n",
" item_indices = indices % buffer_state.max_len_per_batch\n",
"\n",
" return jax.tree_util.tree_map(\n",
" lambda x: x[batch_indices, item_indices],\n",
" buffer_state.buffer\n",
" ), batch_indices, item_indices\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"buff_state = init_buffer_state(\n",
" {\"obs\": jnp.array([0, 0]), \"reward\": jnp.array([0])},\n",
" batch_size=4,\n",
" max_len_per_batch=100\n",
" max_len_per_batch=100,\n",
" sample_batch_size=10\n",
")\n",
"\n",
"for j in range(10):\n",
Expand All @@ -157,7 +186,49 @@
" buff_state,\n",
" jnp.array([5, 6, 7, 8]).reshape(-1, 1, 1),\n",
" jnp.array([1, 0, 1, 1]).reshape(-1, 1, 1)\n",
")"
")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"({'obs': Array([[2, 3],\n",
" [1, 1],\n",
" [5, 2],\n",
" [1, 1],\n",
" [6, 2],\n",
" [1, 2],\n",
" [1, 2],\n",
" [2, 0],\n",
" [7, 2],\n",
" [9, 3]], dtype=int32),\n",
" 'reward': Array([[8],\n",
" [0],\n",
" [7],\n",
" [2],\n",
" [7],\n",
" [7],\n",
" [7],\n",
" [5],\n",
" [7],\n",
" [8]], dtype=int32)},\n",
" Array([3, 1, 2, 1, 2, 2, 2, 0, 2, 3], dtype=int32),\n",
" Array([12, 11, 15, 1, 6, 11, 1, 12, 7, 19], dtype=int32))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample(buff_state, jax.random.PRNGKey(1))"
]
},
{
Expand Down

0 comments on commit 43e8112

Please sign in to comment.