From 8d728d64b03a3580afea5a420d06ab81d9ea5101 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Mon, 9 Dec 2024 16:26:21 +0100 Subject: [PATCH] fix: coupling flows and CMs not working on GPU due to int type Apparently, int32 variables are not transferred to the GPU, leading to problems with XLA. Changing the type declarations to int seems to fix the problem --- bayesflow/networks/consistency_models/consistency_model.py | 4 ++-- bayesflow/networks/coupling_flow/permutations/random.py | 4 ++-- bayesflow/networks/coupling_flow/permutations/swap.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index d30fab9a3..3ba275162 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -83,7 +83,7 @@ def __init__( self.s0 = float(s0) self.s1 = float(s1) # create variable that works with JIT compilation - self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int32") + self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int") self.current_step.assign(0) self.seed_generator = keras.random.SeedGenerator() @@ -258,7 +258,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1)) discretization_index = ops.take( - self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int32") + self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int") ) discretized_time = ops.take(self.discretized_times, discretization_index, axis=0) diff --git a/bayesflow/networks/coupling_flow/permutations/random.py b/bayesflow/networks/coupling_flow/permutations/random.py index b0de99838..ee4646bca 100644 --- a/bayesflow/networks/coupling_flow/permutations/random.py +++ b/bayesflow/networks/coupling_flow/permutations/random.py @@ -16,12 +16,12 @@ def build(self, xz_shape: Shape, **kwargs) -> None: shape=(xz_shape[-1],), initializer=keras.initializers.Constant(forward_indices), trainable=False, - dtype="int32", + dtype="int", ) self.inverse_indices = self.add_weight( shape=(xz_shape[-1],), initializer=keras.initializers.Constant(inverse_indices), trainable=False, - dtype="int32", + dtype="int", ) diff --git a/bayesflow/networks/coupling_flow/permutations/swap.py b/bayesflow/networks/coupling_flow/permutations/swap.py index e699da069..566753fdf 100644 --- a/bayesflow/networks/coupling_flow/permutations/swap.py +++ b/bayesflow/networks/coupling_flow/permutations/swap.py @@ -16,12 +16,12 @@ def build(self, xz_shape: Shape, **kwargs) -> None: shape=(xz_shape[-1],), initializer=keras.initializers.Constant(forward_indices), trainable=False, - dtype="int32", + dtype="int", ) self.inverse_indices = self.add_variable( shape=(xz_shape[-1],), initializer=keras.initializers.Constant(inverse_indices), trainable=False, - dtype="int32", + dtype="int", )