Skip to content

Commit

Permalink
Abstracted numpy manipulation from outer_dot
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed Feb 15, 2024
1 parent 02b1e71 commit e5297f7
Showing 1 changed file with 107 additions and 42 deletions.
149 changes: 107 additions & 42 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,55 +688,55 @@ def outer_dot(

# Mode indices, if available
modes_in_self = "mode_index" in coords[0]
mode_index_0 = coords[0]["mode_index"].values if modes_in_self else np.zeros(1, dtype=int)
coords[0]["mode_index"].values if modes_in_self else np.zeros(1, dtype=int)
modes_in_other = "mode_index" in coords[1]
mode_index_1 = coords[1]["mode_index"].values if modes_in_other else np.zeros(1, dtype=int)

dtype = np.promote_types(arrays[0].dtype, arrays[1].dtype)
dot = np.empty((f.size, mode_index_0.size, mode_index_1.size), dtype=dtype)
coords[1]["mode_index"].values if modes_in_other else np.zeros(1, dtype=int)

keys = (e_1, e_2, h_1, h_2)
sel_fields_self = {}
sel_fields_other = {}
for key in keys:
sel_fields_self[key] = fields_self[key].isel(f=isel1).to_numpy()
# if mode_index not present, insert it for unified handling
if not modes_in_self:
sel_fields_self[key] = np.expand_dims(
sel_fields_self[key], axis=len(sel_fields_self[key].shape)
fields_self[key] = fields_self[key].isel(f=isel1)
if modes_in_self:
fields_self[key] = fields_self[key].rename(mode_index="mode_index_0")
else:
fields_self[key] = fields_self[key].expand_dims(
dim={"mode_index_0": [0]}, axis=len(fields_self[key].shape)
)
sel_fields_other[key] = fields_other[key].isel(f=isel2).to_numpy()
if not modes_in_other:
sel_fields_other[key] = np.expand_dims(
sel_fields_other[key], axis=len(sel_fields_other[key].shape)
fields_other[key] = fields_other[key].isel(f=isel2)
if modes_in_other:
fields_other[key] = fields_other[key].rename(mode_index="mode_index_1")
else:
fields_other[key] = fields_other[key].expand_dims(
dim={"mode_index_1": [0]}, axis=len(fields_other[key].shape)
)

d_area = self._diff_area.to_numpy()

# Calculate overlap for each common frequency and each mode pair
for i in range(len(f)):
for j in range(len(mode_index_0)):
e_self_1 = sel_fields_self[e_1][:, :, i, j]
e_self_2 = sel_fields_self[e_2][:, :, i, j]
h_self_1 = sel_fields_self[h_1][:, :, i, j]
h_self_2 = sel_fields_self[h_2][:, :, i, j]

for k in range(len(mode_index_1)):
e_other_1 = sel_fields_other[e_1][:, :, i, k]
e_other_2 = sel_fields_other[e_2][:, :, i, k]
h_other_1 = sel_fields_other[h_1][:, :, i, k]
h_other_2 = sel_fields_other[h_2][:, :, i, k]

# Cross products of fields
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1

# Integrate over plane
integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
dot[i, j, k] = 0.25 * integrand.sum()

coords = {"f": f, "mode_index_0": mode_index_0, "mode_index_1": mode_index_1}
result = xr.DataArray(dot, coords=coords)
d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()

# function to apply at each pair of mode indices before integrating
def fn(fields_1, fields_2):
e_self_1 = fields_1[e_1]
e_self_2 = fields_1[e_2]
h_self_1 = fields_1[h_1]
h_self_2 = fields_1[h_2]
e_other_1 = fields_2[e_1]
e_other_2 = fields_2[e_2]
h_other_1 = fields_2[h_1]
h_other_2 = fields_2[h_2]

# Cross products of fields
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1

summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) * d_area
return summand

result = self._outer_fn_summation(
fields_1=fields_self,
fields_2=fields_other,
outer_dim_1="mode_index_0",
outer_dim_2="mode_index_1",
sum_dims=tan_dims,
fn=fn,
)

# Remove mode index coordinate if the input did not have it
if not modes_in_self:
Expand All @@ -746,6 +746,71 @@ def outer_dot(

return result

@staticmethod
def _outer_fn_summation(
fields_1: Dict[str, xr.DataArray],
fields_2: Dict[str, xr.DataArray],
outer_dim_1: str,
outer_dim_2: str,
sum_dims: List[str],
fn: Callable,
) -> xr.DataArray:
"""
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
The resulting ``xr.DataArray`` has has dimensions any dimensions in the fields which are not contained in sum_dims.
This can be more memory efficient than vectorizing over the ``outer_dims``, which can involve broadcasting and reshaping data.
It also converts to numpy arrays outside the loops to minimize xarray overhead.
"""
# first, convert to numpy outside the loop to reduce xarray overhead
fields_1_numpy = {key: val.to_numpy() for key, val in fields_1.items()}
fields_2_numpy = {key: val.to_numpy() for key, val in fields_2.items()}

# get one of the data arrays to look at for indexing
# assuming all data arrays have the same structure
data_array_temp_1 = list(fields_1.values())[0]
data_array_temp_2 = list(fields_2.values())[0]
numpy_temp_1 = data_array_temp_1.to_numpy()
numpy_temp_2 = data_array_temp_2.to_numpy()

# find the numpy axes associated with the provided dimensions
outer_axis_1 = data_array_temp_1.get_axis_num(outer_dim_1)
outer_axis_2 = data_array_temp_2.get_axis_num(outer_dim_2)
sum_axes = [data_array_temp_1.get_axis_num(dim) for dim in sum_dims]

# coords and array for result of calculation
coords = {key: val.to_numpy() for key, val in data_array_temp_1.coords.items()}
for dim in sum_dims:
coords.pop(dim)
# last two inds are the outer_dims
coords.pop(outer_dim_1)
coords[outer_dim_1] = data_array_temp_1.coords[outer_dim_1].to_numpy()
coords[outer_dim_2] = data_array_temp_2.coords[outer_dim_2].to_numpy()
# drop scalar non-indexing dimensions
coords = {key: val for key, val in coords.items() if len(val.shape) != 0}
shape = [len(val) for val in coords.values()]
dtype = np.promote_types(numpy_temp_1.dtype, numpy_temp_2.dtype)
data = np.zeros(shape, dtype=dtype)

# indexing tuples
idx_1 = [slice(None)] * numpy_temp_1.ndim
idx_2 = [slice(None)] * numpy_temp_2.ndim
idx_data = [slice(None)] * data.ndim

# calculate the sums of products
for outer_1 in range(numpy_temp_1.shape[outer_axis_1]):
for outer_2 in range(numpy_temp_2.shape[outer_axis_2]):
idx_1[outer_axis_1] = outer_1
idx_2[outer_axis_2] = outer_2
idx_data[-2] = outer_1
idx_data[-1] = outer_2
fields_1_curr = {key: val[tuple(idx_1)] for key, val in fields_1_numpy.items()}
fields_2_curr = {key: val[tuple(idx_2)] for key, val in fields_2_numpy.items()}
summand_curr = fn(fields_1_curr, fields_2_curr)
data_curr = np.sum(summand_curr, axis=tuple(sum_axes))
data[tuple(idx_data)] = data_curr

return xr.DataArray(data, coords=coords)

@property
def time_reversed_copy(self) -> FieldData:
"""Make a copy of the data with time-reversed fields."""
Expand Down

0 comments on commit e5297f7

Please sign in to comment.