Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Update select_topk to not use deprecated function #43

Merged
merged 1 commit into from Nov 23, 2021

Conversation

ptigwe
Copy link
Contributor

@ptigwe ptigwe commented Nov 22, 2021

Replacing the deprecated jax.ops.index_update with the suggested
alternative of arr.at[idx].set(val). Another alternative which is to
use masking tricks also yields the same effect is as follows:

idx_axis0 = jnp.arange(prob_tensor.shape[0])
jnp.sum(idx_axis0 == jnp.expand_dims(idx_axis1, -1), 1)

Replacing the deprecated `jax.ops.index_update` with the suggested
alternative of `arr.at[idx].set(val)`. Another alternative which is to
use masking tricks also yields the same effect is as follows:

``` python
idx_axis0 = jnp.arange(prob_tensor.shape[0])
jnp.sum(idx_axis0 == jnp.expand_dims(idx_axis1, -1), 1)
```
@cgarciae
Copy link
Owner

LGTM! Thanks a lot @ptigwe!

@cgarciae cgarciae merged commit 231d99b into cgarciae:master Nov 23, 2021
@ptigwe ptigwe deleted the deprecated_jax branch November 23, 2021 16:20
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants