Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix with float type casting in pytorch backend #58

Merged
merged 2 commits into from
Nov 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions fdtd/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ def __init__(

if bd.is_array(permittivity) and len(permittivity.shape) == 3:
permittivity = permittivity[:, :, :, None]
self.inverse_permittivity = bd.ones((self.Nx, self.Ny, self.Nz, 3)) / bd.float(
permittivity
self.inverse_permittivity = bd.ones((self.Nx, self.Ny, self.Nz, 3)) / bd.array(
permittivity, dtype=bd.float
)

if bd.is_array(permeability) and len(permeability.shape) == 3:
permeability = permeability[:, :, :, None]
self.inverse_permeability = bd.ones((self.Nx, self.Ny, self.Nz, 3)) / bd.float(
permeability
self.inverse_permeability = bd.ones((self.Nx, self.Ny, self.Nz, 3)) / bd.array(
permeability, dtype=bd.float
)

# save current time index
Expand All @@ -169,21 +169,21 @@ def __init__(
self.folder = None

def _handle_distance(self, distance: Number) -> int:
""" transform a distance to an integer number of gridpoints """
"""transform a distance to an integer number of gridpoints"""
if not isinstance(distance, int):
return int(float(distance) / self.grid_spacing + 0.5)
return distance

def _handle_time(self, time: Number) -> int:
""" transform a time value to an integer number of timesteps """
"""transform a time value to an integer number of timesteps"""
if not isinstance(time, int):
return int(float(time) / self.time_step + 0.5)
return time

def _handle_tuple(
self, shape: Tuple[Number, Number, Number]
) -> Tuple[int, int, int]:
""" validate the grid shape and transform to a length-3 tuple of ints """
"""validate the grid shape and transform to a length-3 tuple of ints"""
if len(shape) != 3:
raise ValueError(
f"invalid grid shape {shape}\n"
Expand All @@ -196,7 +196,7 @@ def _handle_tuple(
return x, y, z

def _handle_slice(self, s: slice) -> slice:
""" validate the slice and transform possibly float values to ints """
"""validate the slice and transform possibly float values to ints"""
start = (
s.start
if not isinstance(s.start, float)
Expand All @@ -211,7 +211,7 @@ def _handle_slice(self, s: slice) -> slice:
return slice(start, stop, step)

def _handle_single_key(self, key):
""" transform a single index key to a slice or list """
"""transform a single index key to a slice or list"""
try:
len(key)
return [self._handle_distance(k) for k in key]
Expand All @@ -224,27 +224,27 @@ def _handle_single_key(self, key):

@property
def x(self) -> int:
""" get the number of grid cells in the x-direction """
"""get the number of grid cells in the x-direction"""
return self.Nx * self.grid_spacing

@property
def y(self) -> int:
""" get the number of grid cells in the y-direction """
"""get the number of grid cells in the y-direction"""
return self.Ny * self.grid_spacing

@property
def z(self) -> int:
""" get the number of grid cells in the y-direction """
"""get the number of grid cells in the y-direction"""
return self.Nz * self.grid_spacing

@property
def shape(self) -> Tuple[int, int, int]:
""" get the shape of the FDTD grid """
"""get the shape of the FDTD grid"""
return (self.Nx, self.Ny, self.Nz)

@property
def time_passed(self) -> float:
""" get the total time passed """
"""get the total time passed"""
return self.time_steps_passed * self.time_step

def run(self, total_time: Number, progress_bar: bool = True):
Expand Down Expand Up @@ -273,7 +273,7 @@ def step(self):
self.time_steps_passed += 1

def update_E(self):
""" update the electric field by using the curl of the magnetic field """
"""update the electric field by using the curl of the magnetic field"""

# update boundaries: step 1
for boundary in self.boundaries:
Expand All @@ -299,7 +299,7 @@ def update_E(self):
det.detect_E()

def update_H(self):
""" update the magnetic field by using the curl of the electric field """
"""update the magnetic field by using the curl of the electric field"""

# update boundaries: step 1
for boundary in self.boundaries:
Expand All @@ -325,28 +325,28 @@ def update_H(self):
det.detect_H()

def reset(self):
""" reset the grid by setting all fields to zero """
"""reset the grid by setting all fields to zero"""
self.H *= 0.0
self.E *= 0.0
self.time_steps_passed *= 0

def add_source(self, name, source):
""" add a source to the grid """
"""add a source to the grid"""
source._register_grid(self)
self.sources[name] = source

def add_boundary(self, name, boundary):
""" add a boundary to the grid """
"""add a boundary to the grid"""
boundary._register_grid(self)
self.boundaries[name] = boundary

def add_detector(self, name, detector):
""" add a detector to the grid """
"""add a detector to the grid"""
detector._register_grid(self)
self.detectors[name] = detector

def add_object(self, name, obj):
""" add an object to the grid """
"""add an object to the grid"""
obj._register_grid(self)
self.objects[name] = obj

Expand Down