diff --git a/notebooks/jax.ipynb b/notebooks/jax.ipynb index 100ac6c..579818f 100644 --- a/notebooks/jax.ipynb +++ b/notebooks/jax.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -54,15 +54,17 @@ "@struct.dataclass\n", "class BufferState:\n", " needs_reward: jnp.ndarray\n", + " populated: jnp.ndarray\n", " buffer: Experience\n", " next_index: int\n", " batch_size: int\n", - " max_len_per_batch: int" + " max_len_per_batch: int\n", + " sample_batch_size: int" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +72,8 @@ "def init_buffer_state(\n", " experience: Experience,\n", " batch_size: int,\n", - " max_len_per_batch: int\n", + " max_len_per_batch: int,\n", + " sample_batch_size: int\n", ") -> BufferState:\n", " \n", " experience = jax.tree_map(jnp.empty_like, experience)\n", @@ -85,9 +88,11 @@ " return BufferState(\n", " buffer=experience,\n", " next_index=0,\n", - " needs_reward=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),\n", + " populated=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),\n", + " needs_reward=jnp.ones((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),\n", " max_len_per_batch=max_len_per_batch,\n", - " batch_size=batch_size\n", + " batch_size=batch_size,\n", + " sample_batch_size=sample_batch_size\n", " )\n", "\n", "def add_experience(\n", @@ -102,12 +107,14 @@ "\n", " # Update the next index\n", " needs_reward = buffer_state.needs_reward.at[:, buffer_state.next_index, 0].set(True)\n", + " populated = buffer_state.populated.at[:, buffer_state.next_index, 0].set(True)\n", " updated_next_index = (buffer_state.next_index + 1) % buffer_state.max_len_per_batch\n", " \n", " return buffer_state.replace(\n", " buffer=updated_pytree,\n", " next_index=updated_next_index,\n", - " needs_reward=needs_reward\n", + " needs_reward=needs_reward,\n", + " populated=populated\n", " )\n", "\n", "def assign_rewards(\n", @@ -120,19 +127,41 @@ "\n", " return buffer_state.replace(\n", " needs_reward = buffer_state.needs_reward * (1 - select_batch)\n", - " )\n" + " )\n", + "\n", + "def sample(\n", + " buffer_state: BufferState,\n", + " rng: jax.random.PRNGKey\n", + ") -> Experience:\n", + " probs = ((~buffer_state.needs_reward).reshape(-1) * buffer_state.populated.reshape(-1)).astype(jnp.float32)\n", + " indices = jax.random.choice(\n", + " rng,\n", + " buffer_state.max_len_per_batch * buffer_state.batch_size,\n", + " shape=(buffer_state.sample_batch_size,),\n", + " replace=False,\n", + " p = probs / probs.sum()\n", + " )\n", + " batch_indices = indices // buffer_state.max_len_per_batch\n", + " item_indices = indices % buffer_state.max_len_per_batch\n", + "\n", + " return jax.tree_util.tree_map(\n", + " lambda x: x[batch_indices, item_indices],\n", + " buffer_state.buffer\n", + " ), batch_indices, item_indices\n", + " \n" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "buff_state = init_buffer_state(\n", " {\"obs\": jnp.array([0, 0]), \"reward\": jnp.array([0])},\n", " batch_size=4,\n", - " max_len_per_batch=100\n", + " max_len_per_batch=100,\n", + " sample_batch_size=10\n", ")\n", "\n", "for j in range(10):\n", @@ -157,7 +186,49 @@ " buff_state,\n", " jnp.array([5, 6, 7, 8]).reshape(-1, 1, 1),\n", " jnp.array([1, 0, 1, 1]).reshape(-1, 1, 1)\n", - ")" + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "({'obs': Array([[2, 3],\n", + " [1, 1],\n", + " [5, 2],\n", + " [1, 1],\n", + " [6, 2],\n", + " [1, 2],\n", + " [1, 2],\n", + " [2, 0],\n", + " [7, 2],\n", + " [9, 3]], dtype=int32),\n", + " 'reward': Array([[8],\n", + " [0],\n", + " [7],\n", + " [2],\n", + " [7],\n", + " [7],\n", + " [7],\n", + " [5],\n", + " [7],\n", + " [8]], dtype=int32)},\n", + " Array([3, 1, 2, 1, 2, 2, 2, 0, 2, 3], dtype=int32),\n", + " Array([12, 11, 15, 1, 6, 11, 1, 12, 7, 19], dtype=int32))" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample(buff_state, jax.random.PRNGKey(1))" ] }, {