Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Alternative) Improved efficiency of outer_dot #1464

Merged
merged 2 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions tests/test_data/test_monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,33 @@ def test_mode_solver_numerical_grid_data():
def test_outer_dot():
mode_data = make_mode_solver_data()
field_data = make_field_data_2d()
_ = mode_data.outer_dot(mode_data)
_ = field_data.outer_dot(mode_data)
_ = mode_data.outer_dot(field_data)
_ = field_data.outer_dot(field_data)
dot = mode_data.outer_dot(mode_data)
assert "mode_index_0" in dot.coords and "mode_index_1" in dot.coords
dot = field_data.outer_dot(mode_data)
assert not "mode_index_0" in dot.coords and "mode_index_1" in dot.coords
dot = mode_data.outer_dot(field_data)
assert "mode_index_0" in dot.coords and not "mode_index_1" in dot.coords
dot = field_data.outer_dot(field_data)
assert not "mode_index_0" in dot.coords and not "mode_index_1" in dot.coords

# test that only common freqs are kept
inds1 = [0, 1, 3]
inds2 = [1, 2, 3, 4]

def isel(data, freqs):
data = data.updated_copy(
Ex=data.Ex.isel(f=freqs),
)
if isinstance(data, td.ModeSolverData):
data = data.updated_copy(n_complex=data.n_complex.isel(f=freqs))
return data

mode_data = isel(mode_data, inds1)
field_data = isel(field_data, inds2)

dot = mode_data.outer_dot(field_data)

assert len(dot.f) == 2


@pytest.mark.parametrize("phase_shift", np.linspace(0, 2 * np.pi, 10))
Expand Down
152 changes: 115 additions & 37 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,47 +683,60 @@ def outer_dot(

# Common frequencies to both data arrays
f = np.array(sorted(set(coords[0]["f"].values).intersection(coords[1]["f"].values)))
isel1 = [list(coords[0]["f"].values).index(freq) for freq in f]
isel2 = [list(coords[1]["f"].values).index(freq) for freq in f]

# Mode indices, if available
modes_in_self = "mode_index" in coords[0]
mode_index_0 = coords[0]["mode_index"].values if modes_in_self else np.zeros(1, dtype=int)
coords[0]["mode_index"].values if modes_in_self else np.zeros(1, dtype=int)
modes_in_other = "mode_index" in coords[1]
mode_index_1 = coords[1]["mode_index"].values if modes_in_other else np.zeros(1, dtype=int)

dtype = np.promote_types(arrays[0].dtype, arrays[1].dtype)
dot = np.empty((f.size, mode_index_0.size, mode_index_1.size), dtype=dtype)

# Calculate overlap for each common frequency and each mode pair
for i, freq in enumerate(f):
indexer_self = {"f": freq}
indexer_other = {"f": freq}
for mi0 in mode_index_0:
if modes_in_self:
indexer_self["mode_index"] = mi0
e_self_1 = fields_self[e_1].sel(indexer_self, drop=True)
e_self_2 = fields_self[e_2].sel(indexer_self, drop=True)
h_self_1 = fields_self[h_1].sel(indexer_self, drop=True)
h_self_2 = fields_self[h_2].sel(indexer_self, drop=True)

for mi1 in mode_index_1:
if modes_in_other:
indexer_other["mode_index"] = mi1
e_other_1 = fields_other[e_1].sel(indexer_other, drop=True)
e_other_2 = fields_other[e_2].sel(indexer_other, drop=True)
h_other_1 = fields_other[h_1].sel(indexer_other, drop=True)
h_other_2 = fields_other[h_2].sel(indexer_other, drop=True)

# Cross products of fields
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1

# Integrate over plane
d_area = self._diff_area
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in essence what this code is doing is something like this?

  1. array0 has dims (x, y, f, mode_index)
  2. array1 has dims (x, y, f, mode_index)
  3. outer product array0 and array1 to give array3 with dims (x, y, f, mode_index0, mode_index1)
  4. apply some function on array3 [technically a bit more complicated as it involves other array3-like objects]
  5. sum the result over dims (x, y)

?
And is the reason it's written out like this because step 3 can take too much memory?

so instead of step 3, we are just looping over the non-summed indices (f, mode_index0, mode_index1) , constructing array3 evaluated at a specific (f, mode_index0, mode_index1), and then summing this over (x,y)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's the difference between this and the previous PR. This avoids constructing the full outer product at once and only constructs one entry at a time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the approach you describe here is also what is currently implemented e.g. in pre/2.6. The difference in this PR is removing xarray overhead by converting to numpy outside the loops. I found this to be pretty important in my profiling.

integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
dot[i, mi0, mi1] = 0.25 * integrand.sum(dim=d_area.dims)

coords = {"f": f, "mode_index_0": mode_index_0, "mode_index_1": mode_index_1}
result = xr.DataArray(dot, coords=coords)
coords[1]["mode_index"].values if modes_in_other else np.zeros(1, dtype=int)

keys = (e_1, e_2, h_1, h_2)
for key in keys:
fields_self[key] = fields_self[key].isel(f=isel1)
if modes_in_self:
fields_self[key] = fields_self[key].rename(mode_index="mode_index_0")
else:
fields_self[key] = fields_self[key].expand_dims(
dim={"mode_index_0": [0]}, axis=len(fields_self[key].shape)
)
fields_other[key] = fields_other[key].isel(f=isel2)
if modes_in_other:
fields_other[key] = fields_other[key].rename(mode_index="mode_index_1")
else:
fields_other[key] = fields_other[key].expand_dims(
dim={"mode_index_1": [0]}, axis=len(fields_other[key].shape)
)

d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()

# function to apply at each pair of mode indices before integrating
def fn(fields_1, fields_2):
e_self_1 = fields_1[e_1]
e_self_2 = fields_1[e_2]
h_self_1 = fields_1[h_1]
h_self_2 = fields_1[h_2]
e_other_1 = fields_2[e_1]
e_other_2 = fields_2[e_2]
h_other_1 = fields_2[h_1]
h_other_2 = fields_2[h_2]

# Cross products of fields
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1

summand = 0.25 * (e_self_x_h_other - h_self_x_e_other) * d_area
return summand

result = self._outer_fn_summation(
fields_1=fields_self,
fields_2=fields_other,
outer_dim_1="mode_index_0",
outer_dim_2="mode_index_1",
sum_dims=tan_dims,
fn=fn,
)

# Remove mode index coordinate if the input did not have it
if not modes_in_self:
Expand All @@ -733,6 +746,71 @@ def outer_dot(

return result

@staticmethod
def _outer_fn_summation(
fields_1: Dict[str, xr.DataArray],
fields_2: Dict[str, xr.DataArray],
outer_dim_1: str,
outer_dim_2: str,
sum_dims: List[str],
fn: Callable,
) -> xr.DataArray:
"""
Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``.
The resulting ``xr.DataArray`` has has dimensions any dimensions in the fields which are not contained in sum_dims.
This can be more memory efficient than vectorizing over the ``outer_dims``, which can involve broadcasting and reshaping data.
It also converts to numpy arrays outside the loops to minimize xarray overhead.
"""
# first, convert to numpy outside the loop to reduce xarray overhead
fields_1_numpy = {key: val.to_numpy() for key, val in fields_1.items()}
fields_2_numpy = {key: val.to_numpy() for key, val in fields_2.items()}

# get one of the data arrays to look at for indexing
# assuming all data arrays have the same structure
data_array_temp_1 = list(fields_1.values())[0]
data_array_temp_2 = list(fields_2.values())[0]
numpy_temp_1 = data_array_temp_1.to_numpy()
numpy_temp_2 = data_array_temp_2.to_numpy()

# find the numpy axes associated with the provided dimensions
outer_axis_1 = data_array_temp_1.get_axis_num(outer_dim_1)
outer_axis_2 = data_array_temp_2.get_axis_num(outer_dim_2)
sum_axes = [data_array_temp_1.get_axis_num(dim) for dim in sum_dims]

# coords and array for result of calculation
coords = {key: val.to_numpy() for key, val in data_array_temp_1.coords.items()}
for dim in sum_dims:
coords.pop(dim)
# last two inds are the outer_dims
coords.pop(outer_dim_1)
coords[outer_dim_1] = data_array_temp_1.coords[outer_dim_1].to_numpy()
coords[outer_dim_2] = data_array_temp_2.coords[outer_dim_2].to_numpy()
# drop scalar non-indexing dimensions
coords = {key: val for key, val in coords.items() if len(val.shape) != 0}
shape = [len(val) for val in coords.values()]
dtype = np.promote_types(numpy_temp_1.dtype, numpy_temp_2.dtype)
data = np.zeros(shape, dtype=dtype)

# indexing tuples
idx_1 = [slice(None)] * numpy_temp_1.ndim
idx_2 = [slice(None)] * numpy_temp_2.ndim
idx_data = [slice(None)] * data.ndim

# calculate the sums of products
for outer_1 in range(numpy_temp_1.shape[outer_axis_1]):
for outer_2 in range(numpy_temp_2.shape[outer_axis_2]):
idx_1[outer_axis_1] = outer_1
idx_2[outer_axis_2] = outer_2
idx_data[-2] = outer_1
idx_data[-1] = outer_2
fields_1_curr = {key: val[tuple(idx_1)] for key, val in fields_1_numpy.items()}
fields_2_curr = {key: val[tuple(idx_2)] for key, val in fields_2_numpy.items()}
summand_curr = fn(fields_1_curr, fields_2_curr)
data_curr = np.sum(summand_curr, axis=tuple(sum_axes))
data[tuple(idx_data)] = data_curr

return xr.DataArray(data, coords=coords)

@property
def time_reversed_copy(self) -> FieldData:
"""Make a copy of the data with time-reversed fields."""
Expand Down
Loading