Skip to content

Commit

Permalink
Improved efficiency of overlap_sort
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed Feb 15, 2024
1 parent 3f5be54 commit a6a90c1
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,8 @@ def overlap_sort(
for freq_id in range(f0_ind + step, last_ind, step):
# Get next frequency to sort
data_to_sort = self._isel(f=[freq_id])
# Assign to the base frequency so that outer_dot will compare them
data_to_sort = data_to_sort._assign_coords(f=[self.monitor.freqs[f0_ind]])

# Compute "sorting w.r.t. to neighbor" and overlap values

Expand Down Expand Up @@ -1191,6 +1193,18 @@ def _isel(self, **isel_kwargs):
update_dict = {key: field.isel(**isel_kwargs) for key, field in update_dict.items()}
return self._updated(update=update_dict)

def _assign_coords(self, **assign_coords_kwargs):
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
mode index. Used in ``overlap_sort`` but not officially supported since for example
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
newly created data."""

update_dict = dict(self._grid_correction_dict, **self.field_components)
update_dict = {
key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items()
}
return self._updated(update=update_dict)

def _find_ordering_one_freq(
self,
data_to_sort: ModeData,
Expand All @@ -1212,21 +1226,15 @@ def _find_ordering_one_freq(
if num_modes_to_sort <= 1:
return pairs, complex_amps

# Compute an overlap matrix for modes chosen for sorting
amps_reduced = np.zeros((num_modes_to_sort, num_modes_to_sort), dtype=np.complex128)

# Extract all modes of interest from template data
data_template_reduced = self._isel(mode_index=modes_to_sort)

for i, mode_index in enumerate(modes_to_sort):
# Get one mode from data_to_sort

one_mode = data_to_sort._isel(mode_index=[mode_index])
amps_reduced = data_template_reduced.outer_dot(
data_to_sort._isel(mode_index=modes_to_sort)
).to_numpy()[0, :, :]

# Project to all modes of interest from data_template
amps_reduced[:, i] = data_template_reduced.dot(one_mode).data.ravel()
if self.monitor.store_fields_direction == "-":
amps_reduced[:, i] *= -1
if self.monitor.store_fields_direction == "-":
amps_reduced *= -1

# Find the most similar modes and corresponding overlap values
pairs_reduced, amps_reduced = self._find_closest_pairs(amps_reduced)
Expand Down

0 comments on commit a6a90c1

Please sign in to comment.