diff --git a/fdtd/backend.py b/fdtd/backend.py index 6bcea1e..9abc0d0 100644 --- a/fdtd/backend.py +++ b/fdtd/backend.py @@ -322,21 +322,21 @@ def numpy(self, arr): class TorchCudaBackend(TorchBackend): """Torch Cuda Backend""" - def ones(self, shape): + def ones(self, shape, **kwargs): """create an array filled with ones""" - return torch.ones(shape, device="cuda") + return torch.ones(shape, device="cuda", **kwargs) - def zeros(self, shape): + def zeros(self, shape, **kwargs): """create an array filled with zeros""" - return torch.zeros(shape, device="cuda") + return torch.zeros(shape, device="cuda", **kwargs) - def array(self, arr, dtype=None): + def array(self, arr, dtype=None, **kwargs): """create an array from an array-like sequence""" if dtype is None: dtype = torch.get_default_dtype() if torch.is_tensor(arr): - return arr.clone().to(device="cuda", dtype=dtype) - return torch.tensor(arr, device="cuda", dtype=dtype) + return arr.clone().to(device="cuda", dtype=dtype, **kwargs) + return torch.tensor(arr, device="cuda", dtype=dtype, **kwargs) # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # The same warning applies here.