-
Notifications
You must be signed in to change notification settings - Fork 78
Closed
Description
Running the flow matching workflow in the Lotka-Volterra notebook leads to the error below. Seems to be an edge case related to formatting a warning message.
history = flow_matching_workflow.fit_offline(
training_data,
epochs=epochs,
batch_size=batch_size,
validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
WARNING:bayesflow:Log-Sinkhorn-Knopp produced NaNs.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[18], line 1
----> 1 history = flow_matching_workflow.fit_offline(
2 training_data,
3 epochs=epochs,
4 batch_size=batch_size,
5 validation_data=validation_data
6 )
File [~/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py:714](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py#line=713), in BasicWorkflow.fit_offline(self, data, epochs, batch_size, keep_optimizer, validation_data, **kwargs)
679 """
680 Train the approximator offline using a fixed dataset. This approach will be faster than online training,
681 since no computation time is spent in generating new data for each batch, but it assumes that simulations
(...)
709 metric evolution over epochs.
710 """
712 dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter)
--> 714 return self._fit(
715 dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs
716 )
File [~/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py:913](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py#line=912), in BasicWorkflow._fit(self, dataset, epochs, strategy, keep_optimizer, validation_data, **kwargs)
910 self.approximator.compile(optimizer=self.optimizer, metrics=kwargs.pop("metrics", None))
912 try:
--> 913 self.history = self.approximator.fit(
914 dataset=dataset, epochs=epochs, validation_data=validation_data, **kwargs
915 )
916 self._on_training_finished()
917 return self.history
File [~/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py:200](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py#line=199), in ContinuousApproximator.fit(self, *args, **kwargs)
148 def fit(self, *args, **kwargs):
149 """
150 Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
151 If `dataset` is not provided, a dataset is built from the `simulator`.
(...)
198 If both `dataset` and `simulator` are provided or neither is provided.
199 """
--> 200 return super().fit(*args, **kwargs, adapter=self.adapter)
File [~/Programming/IWR/bf2/bayesflow/approximators/approximator.py:139](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/approximator.py#line=138), in Approximator.fit(self, dataset, simulator, **kwargs)
136 mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
137 self.build_from_data(mock_data)
--> 139 return super().fit(dataset=dataset, **kwargs)
File [~/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/backend_approximator.py:22](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/backend_approximator.py#line=21), in BackendApproximator.fit(self, dataset, **kwargs)
21 def fit(self, *, dataset: keras.utils.PyDataset, **kwargs):
---> 22 return super().fit(x=dataset, y=None, **filter_kwargs(kwargs, super().fit))
File [/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122](http://localhost:8892/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File [~/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/tensorflow_approximator.py:20](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/tensorflow_approximator.py#line=19), in TensorFlowApproximator.train_step(self, data)
18 with tf.GradientTape() as tape:
19 kwargs = filter_kwargs(data | {"stage": "training"}, self.compute_metrics)
---> 20 metrics = self.compute_metrics(**kwargs)
22 loss = metrics["loss"]
24 grads = tape.gradient(loss, self.trainable_variables)
File [~/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py:135](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py#line=134), in ContinuousApproximator.compute_metrics(self, inference_variables, inference_conditions, summary_variables, sample_weight, stage)
133 # Force a conversion to Tensor
134 inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
--> 135 inference_metrics = self.inference_network.compute_metrics(
136 inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
137 )
139 loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
141 inference_metrics = {f"{key}[/inference_](http://localhost:8892/inference_){key}": value for key, value in inference_metrics.items()}
File [~/Programming/IWR/bf2/bayesflow/networks/flow_matching/flow_matching.py:263](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/networks/flow_matching/flow_matching.py#line=262), in FlowMatching.compute_metrics(self, x, conditions, sample_weight, stage)
256 x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])
258 if self.use_optimal_transport:
259 # we must choose between resampling x0 or x1
260 # since the data is possibly noisy and may contain outliers, it is better
261 # to possibly drop some samples from x1 than from x0
262 # in the marginal over multiple batches, this is not a problem
--> 263 x0, x1, assignments = optimal_transport(
264 x0,
265 x1,
266 seed=self.seed_generator,
267 **self.optimal_transport_kwargs,
268 return_assignments=True,
269 )
270 if conditions is not None:
271 # conditions must be resampled along with x1
272 conditions = keras.ops.take(conditions, assignments, axis=0)
File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/optimal_transport.py:41](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/optimal_transport.py#line=40), in optimal_transport(x1, x2, method, return_assignments, **kwargs)
14 def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, **kwargs):
15 """Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method
16 and cost matrix used.
17
(...)
39 x1 and x2 in optimal transport permutation order.
40 """
---> 41 assignments = methods[method.lower()](x1, x2, **kwargs)
42 x2 = keras.ops.take(x2, assignments, axis=0)
44 if return_assignments:
File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py:13](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py#line=12), in log_sinkhorn(x1, x2, seed, **kwargs)
8 def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
9 """
10 Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
11 Significantly slower than the unstabilized version, so use only when you need numerical stability.
12 """
---> 13 log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
14 assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
15 assignments = keras.ops.squeeze(assignments, axis=1)
File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py:74](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py#line=73), in log_sinkhorn_plan(x1, x2, regularization, rtol, atol, max_steps)
71 logging.warning(msg)
73 keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
---> 74 keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)
76 return log_plan
File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py:58](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py#line=57), in log_sinkhorn_plan.<locals>.log_steps()
55 def log_steps():
56 msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
---> 58 logging.debug(msg, steps)
File [~/Programming/IWR/bf2/bayesflow/utils/logging.py:26](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/logging.py#line=25), in debug(msg, *args, **kwargs)
25 def debug(msg, *args, **kwargs):
---> 26 _log(msg, *args, callback_fn=logger.debug, **kwargs)
File [~/Programming/IWR/bf2/bayesflow/utils/logging.py:18](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/logging.py#line=17), in _log(msg, callback_fn, *args, **kwargs)
16 jax.debug.callback(__log, *args, **kwargs)
17 else:
---> 18 callback_fn(msg.format(*args, **kwargs))
TypeError: Exception encountered when calling Cond.call().
unsupported format string passed to SymbolicTensor.__format__
Arguments received by Cond.call():
• args=('tf.Tensor(shape=(), dtype=bool)', '<function log_sinkhorn_plan.<locals>.log_steps at 0x7faf94299b20>', '<function log_sinkhorn_plan.<locals>.warn_convergence at 0x7faf94299940>')
• kwargs=<class 'inspect._empty'>