diff --git a/geomstats/_backend/pytorch/__init__.py b/geomstats/_backend/pytorch/__init__.py index 1b5aac2dba..5ff380fdf8 100644 --- a/geomstats/_backend/pytorch/__init__.py +++ b/geomstats/_backend/pytorch/__init__.py @@ -747,9 +747,9 @@ def _unnest_iterable(ls): return out -def pad(a, pad_width, constant_value=0.0): +def pad(a, pad_width, mode="constant", constant_values=0.0): return _torch.nn.functional.pad( - a, _unnest_iterable(reversed(pad_width)), value=constant_value + a, _unnest_iterable(reversed(pad_width)), mode=mode, value=constant_values )