Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiable n-dimensional nearest and linear interpolation for DataArray #1769

Open
wants to merge 37 commits into
base: develop
Choose a base branch
from

Conversation

yaugenst-flex
Copy link
Contributor

@yaugenst-flex yaugenst-flex commented Jun 17, 2024

Implements differentiable interpolation methods for DataArray.

  • Implement interpolation function in plugins/autograd/functions.py.
  • Write tests & documentation for above.
  • Integrate into td.DataArray.
  • Test DataArray implementation.

@tylerflex tylerflex mentioned this pull request Jun 19, 2024
52 tasks
@yaugenst-flex yaugenst-flex changed the base branch from develop to pre/2.8 June 28, 2024 14:29
@yaugenst-flex yaugenst-flex force-pushed the yaugenst-flex/autograd-interp branch 2 times, most recently from e4fcb00 to adcdf69 Compare July 1, 2024 13:21
@yaugenst-flex yaugenst-flex self-assigned this Jul 2, 2024
@yaugenst-flex yaugenst-flex force-pushed the yaugenst-flex/autograd-interp branch 2 times, most recently from dfa68bb to 511e3be Compare July 3, 2024 13:44
Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yaugenst-flex , let's discuss these last few things but overall looks great

tests/test_components/test_autograd.py Show resolved Hide resolved
tests/utils.py Outdated Show resolved Hide resolved
tidy3d/components/autograd/derivative_utils.py Outdated Show resolved Hide resolved
tidy3d/components/autograd/functions.py Show resolved Hide resolved
tidy3d/components/autograd/functions.py Show resolved Hide resolved
tidy3d/components/data/data_array.py Outdated Show resolved Hide resolved
tidy3d/components/data/data_array.py Outdated Show resolved Hide resolved
tidy3d/components/data/data_array.py Outdated Show resolved Hide resolved
tidy3d/components/data/data_array.py Show resolved Hide resolved
tidy3d/components/data/monitor_data.py Outdated Show resolved Hide resolved
@tylerflex tylerflex added the 2.7 will go into version 2.7.* label Jul 3, 2024
tests/utils.py Outdated Show resolved Hide resolved
@yaugenst-flex yaugenst-flex force-pushed the yaugenst-flex/autograd-interp branch from d9ea683 to c96525a Compare July 4, 2024 20:13
@yaugenst-flex yaugenst-flex changed the base branch from pre/2.8 to develop July 4, 2024 20:13
@yaugenst-flex yaugenst-flex force-pushed the yaugenst-flex/autograd-interp branch from c96525a to ff75702 Compare July 4, 2024 20:16
@yaugenst-flex yaugenst-flex marked this pull request as ready for review July 4, 2024 20:18
Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! 👍 feel like this is ready once the notebook issue is figured out. thanks Yannick!

return anp.sum(data[monitor.name].flux.values)
elif objtype == "intensity":
return anp.sum(data.get_intensity(monitor.name).values)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in view of the notebook failing, maybe we need to add a parameterize for the size of the monitor (planar vs point?)


@pytest.mark.parametrize("dim", [1, 2, 3, 4])
@pytest.mark.parametrize("method", ["linear", "nearest"])
class TestInterpn:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do these test classes do? want to learn 😆

@@ -1014,8 +1014,8 @@ def insert_value(x, path: tuple[str, ...], sub_dict: dict):
sub_element = current_dict[final_key]
if isinstance(sub_element, DataArray):
current_dict[final_key] = sub_element.copy(deep=False, data=x)
if "AUTOGRAD" in sub_element.attrs:
current_dict[final_key].attrs["AUTOGRAD"] = x
if AUTOGRAD_KEY in sub_element.attrs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yikes, good catch

vals = interpn(points, obj.tracers, xi, method=method)

da = DataArray(vals, dict(obj.coords) | coords) # tracers go into .attrs
if isbox(self.values.flat[0]): # if tracing .values instead of .attrs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or da.contains_tracers? although this seems to no longer exist?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's fine for now until we standardize all of this.

@tylerflex tylerflex added the .1 label Jul 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2.7 will go into version 2.7.* .1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants