Skip to content

Commit

Permalink
Improved efficiency of overlap_sort
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyflex committed Feb 13, 2024
1 parent 86f02ab commit 9a1eee1
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,7 @@ def overlap_sort(
for freq_id in range(f0_ind + step, last_ind, step):
# Get next frequency to sort
data_to_sort = self._isel(f=[freq_id])
data_to_sort = data_to_sort._assign_coords(f=[self.monitor.freqs[f0_ind]])

# Compute "sorting w.r.t. to neighbor" and overlap values

Expand Down Expand Up @@ -1126,6 +1127,18 @@ def _isel(self, **isel_kwargs):
update_dict = {key: field.isel(**isel_kwargs) for key, field in update_dict.items()}
return self._updated(update=update_dict)

def _assign_coords(self, **isel_kwargs):
"""Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and
mode index. Used in ``overlap_sort`` but not officially supported since for example
``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the
newly created data."""

update_dict = dict(self._grid_correction_dict, **self.field_components)
update_dict = {
key: field.assign_coords(**isel_kwargs) for key, field in update_dict.items()
}
return self._updated(update=update_dict)

def _find_ordering_one_freq(
self,
data_to_sort: ModeData,
Expand All @@ -1147,19 +1160,20 @@ def _find_ordering_one_freq(
if num_modes_to_sort <= 1:
return pairs, complex_amps

# Compute an overlap matrix for modes chosen for sorting
amps_reduced = np.zeros((num_modes_to_sort, num_modes_to_sort), dtype=np.complex128)

# Extract all modes of interest from template data
data_template_reduced = self._isel(mode_index=modes_to_sort)

for i, mode_index in enumerate(modes_to_sort):
amps_reduced = data_template_reduced.outer_dot(
data_to_sort._isel(mode_index=modes_to_sort)
).to_numpy()[0, :, :]

for i, _mode_index in enumerate(modes_to_sort):
# Get one mode from data_to_sort

one_mode = data_to_sort._isel(mode_index=[mode_index])
# one_mode = data_to_sort._isel(mode_index=[mode_index])

# Project to all modes of interest from data_template
amps_reduced[:, i] = data_template_reduced.dot(one_mode).data.ravel()
# amps_reduced[:, i] = data_template_reduced.dot(one_mode).data.ravel()
if self.monitor.store_fields_direction == "-":
amps_reduced[:, i] *= -1

Expand Down

0 comments on commit 9a1eee1

Please sign in to comment.