Skip to content

Commit

Permalink
Update categorical.py
Browse files Browse the repository at this point in the history
Setting dtype in jax.nn.one_hot calls to avoid a return dtype different from the parameters' dtypes.
  • Loading branch information
cyprienc committed May 17, 2023
1 parent b4d78b1 commit bcc6d9a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _sample_n(self, key: PRNGKey, n: int) -> Array:

def log_prob(self, value: EventT) -> Array:
"""See `Distribution.log_prob`."""
value_one_hot = jax.nn.one_hot(value, self.num_categories)
value_one_hot = jax.nn.one_hot(value, self.num_categories, dtype=self.logits.dtype)
mask_outside_domain = jnp.logical_or(
value < 0, value > self.num_categories - 1)
return jnp.where(
Expand All @@ -109,7 +109,7 @@ def log_prob(self, value: EventT) -> Array:

def prob(self, value: EventT) -> Array:
"""See `Distribution.prob`."""
value_one_hot = jax.nn.one_hot(value, self.num_categories)
value_one_hot = jax.nn.one_hot(value, self.num_categories, dtype=self.probs.dtype)
return jnp.sum(math.multiply_no_nan(self.probs, value_one_hot), axis=-1)

def entropy(self) -> Array:
Expand All @@ -135,7 +135,7 @@ def cdf(self, value: EventT) -> Array:
should_be_one = value >= self.num_categories - 1
# Will use value as an index below, so clip it to {0, ..., K-1}.
value = jnp.clip(value, 0, self.num_categories - 1)
value_one_hot = jax.nn.one_hot(value, self.num_categories)
value_one_hot = jax.nn.one_hot(value, self.num_categories, dtype=self.probs.dtype)
cdf = jnp.sum(math.multiply_no_nan(
jnp.cumsum(self.probs, axis=-1), value_one_hot), axis=-1)
return jnp.where(should_be_zero, 0., jnp.where(should_be_one, 1., cdf))
Expand Down

0 comments on commit bcc6d9a

Please sign in to comment.