Skip to content

Commit

Permalink
cast abs(JaxDataArray) to jnp.array
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex authored and momchil-flex committed Aug 18, 2023
1 parent 821d78b commit 9161dd2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ the difference that can be observed when slightly modifying the grid resolution.
- Numerically stable sigmoid function in radius of curvature constraint.
- Fixed 2d checking in `Geometry.intersections_2dbox()`.
- Spatial monitor downsampling when the monitor is crossing a symmetry plane or Bloch boundary conditions.
- Cast `JaxDataArray.__abs__` output to `jnp.array`, reducing conversions needed in objective functions.

## [2.4.0rc1] - 2023-7-27

Expand Down
4 changes: 2 additions & 2 deletions tidy3d/plugins/adjoint/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def conj(self) -> JaxDataArray:

def __abs__(self) -> JaxDataArray:
"""Absolute value of self's values."""
new_values = jnp.abs(self.values)
new_values = jnp.abs(self.as_jnp_array)
return self.updated_copy(values=new_values)

def __pow__(self, power: int) -> JaxDataArray:
Expand Down Expand Up @@ -213,7 +213,7 @@ def sum(self, dim: str = None):
"""Sum (optionally along a single or multiple dimensions)."""

if dim is None:
return jnp.sum(self.values)
return jnp.sum(self.as_jnp_array)

# dim is supplied
if isinstance(dim, str):
Expand Down

0 comments on commit 9161dd2

Please sign in to comment.