Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: be explicit about squeeze dim in prioritised sampling to avoid flattening (1,1) arrays #27

Merged
merged 2 commits into from
Jul 4, 2024

Conversation

callumtilbury
Copy link
Contributor

Currently, the following snippet will fail:

from flashbax import make_prioritised_flat_buffer
import jax
import jax.numpy as jnp

buffer = make_prioritised_flat_buffer(
    max_length=100,
    min_length=1,
    sample_batch_size=1,  # NB
    add_sequences=False,
)

timestep = {"obs": jnp.zeros(shape=(3)),}

state = buffer.init(timestep)

for i in range(5):
    timestep = {
        "obs": jnp.ones(shape=(3)) * i,
    }
    state = buffer.add(state, timestep)

buffer.sample(state, jax.random.PRNGKey(0))

because of this line:

state, None, query_values.squeeze()

If the sample_batch_size is 1, query_values is shape (1,1), which squeezes to ().

Instead we must be explicit about the squeeze dim.

@EdanToledo EdanToledo merged commit 3c74aa8 into main Jul 4, 2024
3 checks passed
@EdanToledo EdanToledo deleted the fix/prioritised-sampling-with-batch-of-1 branch July 4, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants