Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Alternative) Improved efficiency of outer_dot #1464

Merged
merged 2 commits into from
Feb 15, 2024

Conversation

caseyflex
Copy link
Contributor

@caseyflex caseyflex commented Feb 13, 2024

This is an alternative version of the previous PR: #1463

It still improves the efficiency by reducing the use of expensive xarray functions (it may be even faster than the previous version). However, it is no longer vectorized over frequency and mode_index, which removes the memory overhead from the previous PR.

Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a couple minor comments, main thing is that I think we should avoid directly indexing in DataArrays if we know the coordinate names.

Also, it's a bit unfortunate that we need to do this instead of the vectorized approach. Is there some package (eg Dask) that we can use to handle this more adaptively? https://examples.dask.org/xarray.html

tidy3d/components/data/monitor_data.py Outdated Show resolved Hide resolved

# Cross products of fields
e_self_x_h_other = e_self_1 * h_other_2 - e_self_2 * h_other_1
h_self_x_e_other = h_self_1 * e_other_2 - h_self_2 * e_other_1

# Integrate over plane
d_area = self._diff_area
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in essence what this code is doing is something like this?

  1. array0 has dims (x, y, f, mode_index)
  2. array1 has dims (x, y, f, mode_index)
  3. outer product array0 and array1 to give array3 with dims (x, y, f, mode_index0, mode_index1)
  4. apply some function on array3 [technically a bit more complicated as it involves other array3-like objects]
  5. sum the result over dims (x, y)

?
And is the reason it's written out like this because step 3 can take too much memory?

so instead of step 3, we are just looping over the non-summed indices (f, mode_index0, mode_index1) , constructing array3 evaluated at a specific (f, mode_index0, mode_index1), and then summing this over (x,y)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's the difference between this and the previous PR. This avoids constructing the full outer product at once and only constructs one entry at a time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the approach you describe here is also what is currently implemented e.g. in pre/2.6. The difference in this PR is removing xarray overhead by converting to numpy outside the loops. I found this to be pretty important in my profiling.

@caseyflex
Copy link
Contributor Author

just a couple minor comments, main thing is that I think we should avoid directly indexing in DataArrays if we know the coordinate names.

Also, it's a bit unfortunate that we need to do this instead of the vectorized approach. Is there some package (eg Dask) that we can use to handle this more adaptively? https://examples.dask.org/xarray.html

When I profile the original code, most of the time is spent in xarray sel and in xarray align. The latter is performed when doing arithmetic on xarrays, and is eliminated by converting to numpy before doing arithmetic. The former is addressed here by converting the xarray to numpy outside of the loops, rather than running xarray sel inside the loops. I still found performance issues using any variant of sel or isel inside the loops. Actually, I think this is a known shortcoming of xarray, and they recommend a solution like this one -- using numpy inside performance-crticial loops and only converting to and from xarray at the beginning and the end. https://docs.xarray.dev/en/stable/user-guide/computation.html#wrapping-custom-computation

I don't think we need dask or anything fancy -- I don't see an issue with numpy? The reason for not vectorizing is memory based, and this approach solves that. Actually, I don't think vectorizing is needed since the number of modes isn't that large -- as long as there isn't too much overhead inside the loop.

@caseyflex
Copy link
Contributor Author

just a couple minor comments, main thing is that I think we should avoid directly indexing in DataArrays if we know the coordinate names.
Also, it's a bit unfortunate that we need to do this instead of the vectorized approach. Is there some package (eg Dask) that we can use to handle this more adaptively? https://examples.dask.org/xarray.html

When I profile the original code, most of the time is spent in xarray sel and in xarray align. The latter is performed when doing arithmetic on xarrays, and is eliminated by converting to numpy before doing arithmetic. The former is addressed here by converting the xarray to numpy outside of the loops, rather than running xarray sel inside the loops. I still found performance issues using any variant of sel or isel inside the loops. Actually, I think this is a known shortcoming of xarray, and they recommend a solution like this one -- using numpy inside performance-crticial loops and only converting to and from xarray at the beginning and the end. https://docs.xarray.dev/en/stable/user-guide/computation.html#wrapping-custom-computation

I don't think we need dask or anything fancy -- I don't see an issue with numpy? The reason for not vectorizing is memory based, and this approach solves that. Actually, I don't think vectorizing is needed since the number of modes isn't that large -- as long as there isn't too much overhead inside the loop.

I did a bit more benchmarking to compare with the vectorized version. I used about 1000 grid points, 100 modes, and 10 frequencies, and I ran outer_dot between two differerent sets of mode solver data. Under pre/2.6, it never finished (I cut it off after several minutes). Under casey/outerdot (the vectorized version), it took 8 seconds. Under this non-vectorized version, it took 1.5 seconds. So it looks like this version is better than the vectorized one in both memory and time. Reshaping the data to vectorize the operation is more costly than just looping in this instance.

@tylerflex
Copy link
Collaborator

I see, thanks. It's unfortunate but I also notice some similar overhead using xarray in the adjoint plugin. Often the sel and interp there are the limiting factor. However it's only a problem when they are called too many times. Calling with a large number of points is actually not too bad. Thanks for explaining.

@tylerflex
Copy link
Collaborator

tylerflex commented Feb 14, 2024

Actually, I think this is a known shortcoming of xarray, and they recommend a solution like this one -- using numpy inside performance-crticial loops and only converting to and from xarray at the beginning and the end. https://docs.xarray.dev/en/stable/user-guide/computation.html#wrapping-custom-computation

Yea this ufunc approach (and see dask="parallelized" is what I was getting at. but I'm not sure if it totally applies here or if it's worth the effort.

I don't think we need dask or anything fancy -- I don't see an issue with numpy? The reason for not vectorizing is memory based, and this approach solves that. Actually, I don't think vectorizing is needed since the number of modes isn't that large -- as long as there isn't too much overhead inside the loop.

My concern with moving to numpy is mainly for code clarity. The main advantage of xarray is that it allows us to abstract away the raw indices of the data and use labelled coordinates instead. This keeps the code much easier to read and reason about. If we are going to use raw numpy arrays in some instances (like this) for performance reasons, I would probably prefer that be abstracted away somehow. For example in its own function that simply performs a mathematical operation over some data arrays but uses numpy under the hood, kind of like how I described before, one could imagine a function that does something like:

  1. array0 has dims (x, y, f, mode_index)
  2. array1 has dims (x, y, f, mode_index)
  3. outer product array0 and array1 to give array3 with dims (x, y, f, mode_index0, mode_index1)
  4. apply some function on array3 [technically a bit more complicated as it involves other array3-like objects]
  5. sum the result over dims (x, y)

More generally it could look something like

def outer_product_integration(arr1: DataArray, arr2: DataArray, fn: typing.Callable, shared_dims: List[float], sum_dims: List[float) -> DataArray:
    """Take outer product of two arrays sharing some dimensions, apply a function, sum over some dimensions."""

    # note: replace this with the for loop approach
    arr3 = outer_product(arr1, arr2, shared_dims)
    arr_result = fn(arr3)

    return arr_result.sum(sum_dims)

Perhaps if this kind of operation is used in other places (eg the mode solver), we can generalize it into its own helper function?

@tylerflex
Copy link
Collaborator

After looking at this, maybe because this logic is already encapsulated in outer_dot, it's enough separation.

@caseyflex
Copy link
Contributor Author

After looking at this, maybe because this logic is already encapsulated in outer_dot, it's enough separation.

Yeah, it might be, but your comment makes sense, I might end up doing that extra abstraction anyway because I do use some variant of outer_dot elsewhere. And I see, it could be parallelized using dask but that's probably not necessary with the current speed.

@tylerflex
Copy link
Collaborator

Cool, thanks. it's nice that we can use this outer_dot in other parts of the code. As in your other PR.

I think I have some mild trauma from trying to write my BEM solver using pure numpy. Sometimes I was dealing with like 6 indices and had a real hard time keeping track of things internally. So whenever possible, I think it's nice to abstract some of these details away and xarray is great for that (when it's fast enough I guess).

@weiliangjin2021
Copy link
Collaborator

Actually, I don't think vectorizing is needed since the number of modes isn't that large -- as long as there isn't too much overhead inside the loop.

How many are considered as a large number of modes? Sometimes people solve for hundreds of modes.

@caseyflex
Copy link
Contributor Author

Actually, I don't think vectorizing is needed since the number of modes isn't that large -- as long as there isn't too much overhead inside the loop.

How many are considered as a large number of modes? Sometimes people solve for hundreds of modes.

I tried with 100 modes. Anyway, the existing version isn't vectorized, and the vectorized version takes too much memory in this case. This should be a strict improvement over the existing version in all cases.

Copy link
Collaborator

@momchil-flex momchil-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

I know @tylerflex's pain about not using conveniences for dimensions, but sometimes it seems like there's no way around it, performance-wise. It happens more often on the backend.

@tylerflex
Copy link
Collaborator

tylerflex commented Feb 14, 2024

I know @tylerflex's pain about not using conveniences for dimensions, but sometimes it seems like there's no way around it, performance-wise. It happens more often on the backend.

I think as long as it's abstracted away somehow and separated from the regular "logic" of the code, it's ok.

EDIT: I guess that's exactly what xarray does for us, but sometimes I guess we need to do this ourselves.

@caseyflex
Copy link
Contributor Author

caseyflex commented Feb 15, 2024

I know @tylerflex's pain about not using conveniences for dimensions, but sometimes it seems like there's no way around it, performance-wise. It happens more often on the backend.

I think as long as it's abstracted away somehow and separated from the regular "logic" of the code, it's ok.

EDIT: I guess that's exactly what xarray does for us, but sometimes I guess we need to do this ourselves.

In the new commit, I abstracted away the raw numpy manipulations to a separate function. I also was able to vectorize over frequency without adding memory overhead. I benchmarked and memory profiled and the performance and memory usage is similar to before.

I see some automated checks are failing because it can't import numpy. I'm not sure if that is related to this PR.

@daquinteroflex
Copy link
Collaborator

I've just fixed it @caseyflex btw! Apologies, I think you just need to rebase again

@caseyflex
Copy link
Contributor Author

I've just fixed it @caseyflex btw! Apologies, I think you just need to rebase again

thanks!

@momchil-flex
Copy link
Collaborator

@tylerflex can we merge?

@tylerflex
Copy link
Collaborator

sure

@momchil-flex momchil-flex merged commit 3f5be54 into pre/2.6 Feb 15, 2024
16 checks passed
@momchil-flex momchil-flex deleted the casey/alternativeouterdot branch February 15, 2024 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants