diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index d30fab9a3..3ba275162 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -83,7 +83,7 @@ def __init__( self.s0 = float(s0) self.s1 = float(s1) # create variable that works with JIT compilation - self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int32") + self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int") self.current_step.assign(0) self.seed_generator = keras.random.SeedGenerator() @@ -258,7 +258,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1)) discretization_index = ops.take( - self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int32") + self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int") ) discretized_time = ops.take(self.discretized_times, discretization_index, axis=0) diff --git a/bayesflow/networks/coupling_flow/permutations/random.py b/bayesflow/networks/coupling_flow/permutations/random.py index b0de99838..ee4646bca 100644 --- a/bayesflow/networks/coupling_flow/permutations/random.py +++ b/bayesflow/networks/coupling_flow/permutations/random.py @@ -16,12 +16,12 @@ def build(self, xz_shape: Shape, **kwargs) -> None: shape=(xz_shape[-1],), initializer=keras.initializers.Constant(forward_indices), trainable=False, - dtype="int32", + dtype="int", ) self.inverse_indices = self.add_weight( shape=(xz_shape[-1],), initializer=keras.initializers.Constant(inverse_indices), trainable=False, - dtype="int32", + dtype="int", ) diff --git a/bayesflow/networks/coupling_flow/permutations/swap.py b/bayesflow/networks/coupling_flow/permutations/swap.py index e699da069..566753fdf 100644 --- a/bayesflow/networks/coupling_flow/permutations/swap.py +++ b/bayesflow/networks/coupling_flow/permutations/swap.py @@ -16,12 +16,12 @@ def build(self, xz_shape: Shape, **kwargs) -> None: shape=(xz_shape[-1],), initializer=keras.initializers.Constant(forward_indices), trainable=False, - dtype="int32", + dtype="int", ) self.inverse_indices = self.add_variable( shape=(xz_shape[-1],), initializer=keras.initializers.Constant(inverse_indices), trainable=False, - dtype="int32", + dtype="int", )