-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiable n-dimensional nearest
and linear
interpolation for DataArray
#1769
base: develop
Are you sure you want to change the base?
Conversation
0ced3b6
to
0162bcc
Compare
61f33e4
to
02fd006
Compare
e4fcb00
to
adcdf69
Compare
dfa68bb
to
511e3be
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yaugenst-flex , let's discuss these last few things but overall looks great
…not defined for ArrayBox
…e old 0-size monitor dim behavior
d9ea683
to
c96525a
Compare
c96525a
to
ff75702
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! 👍 feel like this is ready once the notebook issue is figured out. thanks Yannick!
return anp.sum(data[monitor.name].flux.values) | ||
elif objtype == "intensity": | ||
return anp.sum(data.get_intensity(monitor.name).values) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in view of the notebook failing, maybe we need to add a parameterize
for the size of the monitor (planar vs point?)
|
||
@pytest.mark.parametrize("dim", [1, 2, 3, 4]) | ||
@pytest.mark.parametrize("method", ["linear", "nearest"]) | ||
class TestInterpn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do these test classes do? want to learn 😆
@@ -1014,8 +1014,8 @@ def insert_value(x, path: tuple[str, ...], sub_dict: dict): | |||
sub_element = current_dict[final_key] | |||
if isinstance(sub_element, DataArray): | |||
current_dict[final_key] = sub_element.copy(deep=False, data=x) | |||
if "AUTOGRAD" in sub_element.attrs: | |||
current_dict[final_key].attrs["AUTOGRAD"] = x | |||
if AUTOGRAD_KEY in sub_element.attrs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yikes, good catch
vals = interpn(points, obj.tracers, xi, method=method) | ||
|
||
da = DataArray(vals, dict(obj.coords) | coords) # tracers go into .attrs | ||
if isbox(self.values.flat[0]): # if tracing .values instead of .attrs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or da.contains_tracers
? although this seems to no longer exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's fine for now until we standardize all of this.
Implements differentiable interpolation methods for
DataArray
.plugins/autograd/functions.py
.td.DataArray
.DataArray
implementation.