From a14fe6891c942b0990f5a7c5f85306c9f1d67de5 Mon Sep 17 00:00:00 2001 From: flaport Date: Fri, 17 Nov 2023 10:03:08 -0800 Subject: [PATCH] bugfix #66: apply proper check for complex --- .pre-commit-config.yaml | 23 ++++++++++++++--------- fdtd/backend.py | 15 ++++++++++++++- fdtd/objects.py | 18 ++++++++---------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d20e871..b7d71af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,15 @@ repos: - - repo: https://github.com/psf/black - rev: 20.8b1 - hooks: - - id: black - language_version: python3 - - repo: https://github.com/kynan/nbstripout - rev: 0.3.9 - hooks: - - id: nbstripout +- repo: https://github.com/kynan/nbstripout + rev: 0.6.0 + hooks: + - id: nbstripout +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.9.1 + hooks: + - id: black + language_version: python3.11 +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.9.1 + hooks: + - id: black-jupyter + language_version: python3.11 diff --git a/fdtd/backend.py b/fdtd/backend.py index 483260e..497cda8 100644 --- a/fdtd/backend.py +++ b/fdtd/backend.py @@ -75,6 +75,17 @@ class Backend: def __repr__(self): return self.__class__.__name__ + def is_complex(x): + """check if an object is a `ComplexFloat`""" + return ( + isinstance(x, complex) + or (isinstance(x, np.ndarray) and x.dtype in (np.complex64, np.complex128)) + or ( + isinstance(x, torch.Tensor) + and x.dtype in (torch.complex64, torch.complex128) + ) + ) + def _replace_float(func): """replace the default dtype a function is called with""" @@ -99,7 +110,7 @@ class NumpyBackend(Backend): float = numpy.float64 """ floating type for array """ - + complex = numpy.complex128 """ complex type for array """ @@ -305,6 +316,8 @@ def numpy(self, arr): else: return numpy.asarray(arr) + is_complex = staticmethod(torch.is_complex) + # Torch Cuda Backend if TORCH_CUDA_AVAILABLE: diff --git a/fdtd/objects.py b/fdtd/objects.py index 7e20cdb..fc36422 100644 --- a/fdtd/objects.py +++ b/fdtd/objects.py @@ -22,7 +22,7 @@ ## Object class Object: - """ An object to place in the grid """ + """An object to place in the grid""" def __init__(self, permittivity: Tensorlike, name: str = None): """ @@ -48,7 +48,7 @@ def _register_grid( self.grid = grid self.grid.objects.append(self) - if self.permittivity.dtype is bd.complex().dtype: + if bd.is_complex(self.permittivity): self.grid.promote_dtypes_to_complex() if self.name is not None: @@ -70,7 +70,8 @@ def _register_grid( if bd.is_array(self.permittivity) and len(self.permittivity.shape) == 3: self.permittivity = self.permittivity[:, :, :, None] self.inverse_permittivity = ( - bd.ones((self.Nx, self.Ny, self.Nz, 3),dtype=self.permittivity.dtype) / self.permittivity + bd.ones((self.Nx, self.Ny, self.Nz, 3), dtype=self.permittivity.dtype) + / self.permittivity ) # set the permittivity values of the object at its border to be equal @@ -123,9 +124,9 @@ def update_E(self, curl_H): """ loc = (self.x, self.y, self.z) - self.grid.E[loc] = self.grid.E[loc] +( + self.grid.E[loc] = self.grid.E[loc] + ( self.grid.courant_number * self.inverse_permittivity * curl_H[loc] - ) + ) def update_H(self, curl_E): """custom update equations for inside the object @@ -134,9 +135,6 @@ def update_H(self, curl_E): curl_E: the curl of electric field in the grid. """ - # def promote_dtypes_to_complex(self): - # self.E = self.E.astype(bd.complex) - # self.H = self.H.astype(bd.complex) def __repr__(self): return f"{self.__class__.__name__}(name={repr(self.name)})" @@ -163,7 +161,7 @@ def _handle_slice(s): class AbsorbingObject(Object): - """ An absorbing object takes conductivity into account """ + """An absorbing object takes conductivity into account""" def __init__( self, permittivity: Tensorlike, conductivity: Tensorlike, name: str = None @@ -232,7 +230,7 @@ def update_H(self, curl_E): class AnisotropicObject(Object): - """ An object with anisotropic permittivity tensor """ + """An object with anisotropic permittivity tensor""" def _register_grid( self, grid: Grid, x: slice = None, y: slice = None, z: slice = None