Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): multi-dimensional integrand #160

Merged
merged 42 commits into from
Jan 24, 2023
Merged

Conversation

ilan-gold
Copy link
Collaborator

@ilan-gold ilan-gold commented Jan 3, 2023

Description

There might be a better word to describe this, but I figured this was descriptive to distinguish from integration dimensions.

Summary of changes

  • Allow for an integrand which is multi-dimensional for all integration methods (except VEGAS for now) i.e a matrix/vector of values

Resolved Issues

How Has This Been Tested?

  • Tests have been added to ensure that the feature works over simple 2x1 ("in both senses" i.e (2,1) and (2,)), 2x2, and 3x3 dimensional integrands

Marking this as draft while we resolve some questions, namely how badly we want VEGAS (I would need to read the paper, which I am fine doing if y'all do not want to accept the PR without a full suite of features)

torchquad/integration/newton_cotes.py Outdated Show resolved Hide resolved
torchquad/integration/newton_cotes.py Outdated Show resolved Hide resolved
torchquad/tests/integration_test_functions.py Outdated Show resolved Hide resolved
torchquad/tests/helper_functions.py Show resolved Hide resolved
@@ -47,8 +48,16 @@ def calculate_result(self, function_values, dim, n_per_dim, hs):
backend tensor: Quadrature result
"""
# Reshape the output to be [N,N,...] points instead of [dim*N] points
function_values = function_values.reshape([n_per_dim] * dim)

self.integrand_shape = function_values.shape[1:]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are arguments for the constructor of the class for both integration domain and shape, so we could follow that and add integrand_shape as a constructor argument although that seems like it could be inferred as here (I understand the convenience of having it for integration dimension since you can just write the integration domain once and then the number of times/dimensions you use that domain)

@gomezzz
Copy link
Collaborator

gomezzz commented Jan 3, 2023

@ilan-gold Should I have a look now? :)

Ideally, also merge current main into this as I just finished fixing the CI to run the tests a lot quicker and properly on all supported frameworks (torch, jax, TF, np) ✌️

@ilan-gold
Copy link
Collaborator Author

@gomezzz You can definitely have a look. I asked the questions I asked just because they were things I could see coming up in a review; between that and the lack of VEGAS, I left this as draft so it's up to you how to proceed! thanks!

@ilan-gold
Copy link
Collaborator Author

@gomezzz have a look at most recent commit for handling 1d arrays same as nd arrays. going to look into the mechanism that causes that monte-carlo test to fail - for the other ones i'm not sure what we should do since it basically seems lke a rounding error (the monte-carlo one seems obviously not that).

@ilan-gold
Copy link
Collaborator Author

ok @gomezzz down to just the very strange monte-carlo test now. you seemed to think you might have an idea of what was up? if so, could you post your thought? otherwise i'll finish it tomorrow! thanks!

@ilan-gold
Copy link
Collaborator Author

ilan-gold commented Jan 12, 2023

@gomezzz I would say I'm at a near total loss of what is happening here and why. Maybe I am missing something - the issue appears to be the anp.sum call in the monte-carlo integral. It's returning a decimal in the last entry for this test when it should return a whole number, 336. I have tried the following which slice the integrand first and they yield the correct result (336):

anp.sum(function_values[:, 1, 1, 1], axis=0) * volume / N

and

anp.sum(function_values[:, 1, 1, 1]) * volume / N

Even weirder is that when I do the following:

anp.sum(function_values) * volume / N

i.e summing the whole integrand, even though function_values is shaped Nx2x2x2, it gives a whole number! But

volume * anp.sum(function_values, axis=0) / N

(which is what I'd expect to work normally) gives ~332.78 in the last entry i.e integral[1, 1, 1]! I will come back to this later - it must be something insanely trivial.

@gomezzz
Copy link
Collaborator

gomezzz commented Jan 12, 2023

@ilan-gold Sounds like some weird implicit conversion or broadcasts happening? 🤔 It's a bit of a busy week for me, I'll let you experiment a bit more. Should I look at the rest in the mean time?

Btw for autoblack test it would be nice if you could just run black once on the repo 🙏 :) (can also do in the end if you prefer)

@ilan-gold
Copy link
Collaborator Author

@gomezzz Yes I think the rest of the PR is "done" at this point. I've addressed the internal consistency issue with a decorator for expanding/squeezing so we should be good: 4724216

@ilan-gold
Copy link
Collaborator Author

This would be a fairly large change over what I have here so I'd understand wanting to merge this and the Guassian branch before considering that.

@ilan-gold
Copy link
Collaborator Author

ilan-gold commented Jan 17, 2023

unrelated @gomezzz have you ever seen anything like this in tests? I would google this but there is no information at all on what happened:

Fatal Python error: Aborted

Thread 0x00007f3a50d06740 (most recent call first):
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/dispatch.py", line 1014 in backend_compile
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/profiler.py", line 314 in wrapper
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/dispatch.py", line 1079 in compile_or_get_cached
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 3439 in from_hlo
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 3170 in _compile_unloaded
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 3202 in compile
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359 in _xla_callable_uncached
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/dispatch.py", line 202 in xla_primitive_callable
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/util.py", line 247 in cached
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/util.py", line 254 in wrapper
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/dispatch.py", line 118 in apply_primitive
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/core.py", line 712 in process_primitive
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/core.py", line 332 in bind_with_trace
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/core.py", line 329 in bind
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 284 in gather
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3855 in _gather
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3828 in _rewriting_take
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/jax/_src/array.py", line 261 in __getitem__
  File "/home/ig62/torchquad/torchquad/tests/../integration/utils.py", line 203 in _check_integration_domain
  File "/home/ig62/torchquad/torchquad/tests/../integration/base_integrator.py", line 94 in _check_inputs
  File "/home/ig62/torchquad/torchquad/tests/../integration/newton_cotes.py", line 27 in integrate
  File "/home/ig62/torchquad/torchquad/tests/../integration/boole.py", line 28 in integrate
  File "/home/ig62/torchquad/torchquad/tests/integration_test_functions.py", line 83 in evaluate
  File "/home/ig62/torchquad/torchquad/tests/helper_functions.py", line 243 in compute_integration_test_errors
  File "/home/ig62/torchquad/torchquad/tests/boole_test.py", line 22 in _run_boole_tests
  File "/home/ig62/torchquad/torchquad/tests/helper_functions.py", line 275 in func
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/python.py", line 1789 in runtest
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 339 in from_call
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/home/ig62/software/miniconda3/envs/torchquad/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/home/ig62/software/miniconda3/envs/torchquad/bin/pytest", line 10 in <module>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, jaxlib.cpu_feature_guard, numpy.linalg.lapack_lite (total: 22)
Aborted

@gomezzz
Copy link
Collaborator

gomezzz commented Jan 17, 2023

@gomezzz Would you be interested in allowing for different domains for each integrand in the multi-dimensional integrand case? I don't see a reason why this wouldn't work. And it seems doable with the current API since we specify integration domains and number of dimensions in two different places.

Sure, sounds like a sensible addition. Maybe we make it a dedicated PR though? Then we can review this, merge it and build on that? Just to keep track which changes relate to which desired feature

@gomezzz
Copy link
Collaborator

gomezzz commented Jan 17, 2023

unrelated @gomezzz have you ever seen anything like this in tests? I would google this but there is no information at all on what happened:

Hm, not precisely this but a few ideas what could be happening. My first suspicion would be that conda has a numpy version that is incompatible with jax? I remember seeing something related to this before. Since we pip install jax, conda dependencies may not be accounted for. (So basically conda may have installed a numpy version based on the environment file and then pip install another version on jax* needs, however the conda env uses its version instead of the pip one?)

EDIT:typos

@ilan-gold
Copy link
Collaborator Author

@gomezzz Would you be interested in allowing for different domains for each integrand in the multi-dimensional integrand case? I don't see a reason why this wouldn't work. And it seems doable with the current API since we specify integration domains and number of dimensions in two different places.

Sure, sounds like a sensible addition. Maybe we make it a dedicated PR though? Then we can review this, merge it and build on that? Just to keep track which changes relate to which desired feature

Yes definitely. I will wait until after the Gaussian PR since it will require an overhaul of the IntegrationGrid, which that PR edits.

@ilan-gold
Copy link
Collaborator Author

unrelated @gomezzz have you ever seen anything like this in tests? I would google this but there is no information at all on what happened:

Hm, not precisely this but a few ideas what could be happening. My first suspicion would be that conda has a numpy version that is incompatible with jax? I remember seeing something related to this before. Since we pip install jax, conda dependencies may not be accounted for. (So basically conda may have installed a numpy version based on the environment file and then pip install another version on jax* needs, however the conda env uses its version instead of the pip one?)

EDIT:typos

Hmmm, ok. I'm not sure how to handle this. I only have access to GPU's I can get, and I did try to create the environment from scratch. I will try again tomorrow.

@gomezzz
Copy link
Collaborator

gomezzz commented Jan 17, 2023

unrelated @gomezzz have you ever seen anything like this in tests? I would google this but there is no information at all on what happened:

Hm, not precisely this but a few ideas what could be happening. My first suspicion would be that conda has a numpy version that is incompatible with jax? I remember seeing something related to this before. Since we pip install jax, conda dependencies may not be accounted for. (So basically conda may have installed a numpy version based on the environment file and then pip install another version on jax* needs, however the conda env uses its version instead of the pip one?)
EDIT:typos

Hmmm, ok. I'm not sure how to handle this. I only have access to GPU's I can get, and I did try to create the environment from scratch. I will try again tomorrow.

Sorry, I should have elaborated more. You can see the pip installed numpy version with pip freeze. If I am correct changing the numpy version in the environment_all_backends.yml to match that one should fix it. (btw. I recommend using mamba to set up the environments, otherwise it takes forever)

# jaxlib with CUDA support is not available for conda
- pip:
- --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilan-gold Ah, this is likely the problem btw! these change break the CI. You need the same env as currently on main
https://github.com/esa/torchquad/blob/main/environment_all_backends.yml

Same for #141

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Sorry was a little busy yesterday finding small bugs: #141 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dealing with this now!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl-py                       1.4.0
aiohttp                       3.8.3
aiosignal                     1.3.1
alabaster                     0.7.13
appdirs                       1.4.4
astunparse                    1.6.3
async-timeout                 4.0.2
attrs                         22.2.0
autoray                       0.0.0
Babel                         2.11.0
blinker                       1.5
brotlipy                      0.7.0
cached-property               1.5.2
cachetools                    5.2.1
certifi                       2022.12.7
cffi                          1.15.1
charset-normalizer            2.1.1
click                         8.1.3
colorama                      0.4.6
contourpy                     1.0.7
cryptography                  39.0.0
cycler                        0.11.0
docutils                      0.19
exceptiongroup                1.1.0
flatbuffers                   23.1.4
fonttools                     4.38.0
frozenlist                    1.3.3
gast                          0.4.0
google-auth                   2.16.0
google-auth-oauthlib          0.4.6
google-pasta                  0.2.0
grpcio                        1.51.1
h5py                          3.7.0
idna                          3.4
imagesize                     1.4.1
importlib-metadata            6.0.0
iniconfig                     2.0.0
jax                           0.4.1
jaxlib                        0.4.1
Jinja2                        3.1.2
keras                         2.11.0
Keras-Preprocessing           1.1.2
kiwisolver                    1.4.4
loguru                        0.6.0
Markdown                      3.4.1
MarkupSafe                    2.1.1
matplotlib                    3.6.3
multidict                     6.0.4
munkres                       1.1.4
numpy                         1.24.1
oauthlib                      3.2.2
opt-einsum                    3.3.0
packaging                     23.0
Pillow                        9.4.0
pip                           22.3.1
pluggy                        1.0.0
ply                           3.11
pooch                         1.6.0
protobuf                      4.21.12
pyasn1                        0.4.8
pyasn1-modules                0.2.7
pycparser                     2.21
Pygments                      2.14.0
PyJWT                         2.6.0
pyOpenSSL                     23.0.0
pyparsing                     3.0.9
PyQt5                         5.15.7
PyQt5-sip                     12.11.0
PySocks                       1.7.1
pytest                        7.2.1
python-dateutil               2.8.2
pytz                          2022.7.1
pyu2f                         0.1.5
requests                      2.28.2
requests-oauthlib             1.3.1
rsa                           4.9
scipy                         1.10.0
setuptools                    66.0.0
sip                           6.7.5
six                           1.16.0
snowballstemmer               2.2.0
Sphinx                        6.1.3
sphinx-rtd-theme              0.5.2
sphinxcontrib-applehelp       1.0.2
sphinxcontrib-devhelp         1.0.2
sphinxcontrib-htmlhelp        2.0.0
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.3
sphinxcontrib-serializinghtml 1.1.5
tensorboard                   2.11.2
tensorboard-data-server       0.6.1
tensorboard-plugin-wit        1.8.1
tensorflow                    2.11.0
tensorflow-estimator          2.11.0
termcolor                     2.2.0
toml                          0.10.2
tomli                         2.0.1
torch                         1.13.1.post200
tornado                       6.2
tqdm                          4.64.1
typing_extensions             4.4.0
unicodedata2                  15.0.0
urllib3                       1.26.14
Werkzeug                      2.2.2
wheel                         0.38.4
wrapt                         1.14.1
yarl                          1.8.2
zipp                          3.11.0

This is my pip list. The only thing that appears to have been broken is one of the MC tests. I am using python3.10 which could be a problem but there is no claim about version in the repo. Can you reproduce what is happening on CI locally? It looks like an issue with the tensorflow import. In any case, we should then probably pin a version in the conda file maybe?

Copy link
Collaborator Author

@ilan-gold ilan-gold Jan 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And furthermore, that MC test did not break in python3.9 (at least on the CI). I think it's probably best to expand testing to 3.8-10 (or 11) and make sure everything works, but I also do not know enough about the project to make that call. Happy to take your order here.

Copy link
Collaborator Author

@ilan-gold ilan-gold Jan 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conda also does indeed seem to find "conflicts" with 3.9 which maybe is why things were taking a while for you. none were found with 3.10, from what I could tell.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl-py                       1.4.0
aiohttp                       3.8.3
aiosignal                     1.3.1
alabaster                     0.7.13
appdirs                       1.4.4
astunparse                    1.6.3
async-timeout                 4.0.2
attrs                         22.2.0
autoray                       0.0.0
Babel                         2.11.0
blinker                       1.5
brotlipy                      0.7.0
cached-property               1.5.2
cachetools                    5.2.1
certifi                       2022.12.7
cffi                          1.15.1
charset-normalizer            2.1.1
click                         8.1.3
colorama                      0.4.6
contourpy                     1.0.7
cryptography                  39.0.0
cycler                        0.11.0
docutils                      0.19
exceptiongroup                1.1.0
flatbuffers                   23.1.4
fonttools                     4.38.0
frozenlist                    1.3.3
gast                          0.4.0
google-auth                   2.16.0
google-auth-oauthlib          0.4.6
google-pasta                  0.2.0
grpcio                        1.51.1
h5py                          3.7.0
idna                          3.4
imagesize                     1.4.1
importlib-metadata            6.0.0
iniconfig                     2.0.0
jax                           0.4.1
jaxlib                        0.4.1
Jinja2                        3.1.2
keras                         2.11.0
Keras-Preprocessing           1.1.2
kiwisolver                    1.4.4
loguru                        0.6.0
Markdown                      3.4.1
MarkupSafe                    2.1.1
matplotlib                    3.6.3
multidict                     6.0.4
munkres                       1.1.4
numpy                         1.24.1
oauthlib                      3.2.2
opt-einsum                    3.3.0
packaging                     23.0
Pillow                        9.4.0
pip                           22.3.1
pluggy                        1.0.0
ply                           3.11
pooch                         1.6.0
protobuf                      4.21.12
pyasn1                        0.4.8
pyasn1-modules                0.2.7
pycparser                     2.21
Pygments                      2.14.0
PyJWT                         2.6.0
pyOpenSSL                     23.0.0
pyparsing                     3.0.9
PyQt5                         5.15.7
PyQt5-sip                     12.11.0
PySocks                       1.7.1
pytest                        7.2.1
python-dateutil               2.8.2
pytz                          2022.7.1
pyu2f                         0.1.5
requests                      2.28.2
requests-oauthlib             1.3.1
rsa                           4.9
scipy                         1.10.0
setuptools                    66.0.0
sip                           6.7.5
six                           1.16.0
snowballstemmer               2.2.0
Sphinx                        6.1.3
sphinx-rtd-theme              0.5.2
sphinxcontrib-applehelp       1.0.2
sphinxcontrib-devhelp         1.0.2
sphinxcontrib-htmlhelp        2.0.0
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.3
sphinxcontrib-serializinghtml 1.1.5
tensorboard                   2.11.2
tensorboard-data-server       0.6.1
tensorboard-plugin-wit        1.8.1
tensorflow                    2.11.0
tensorflow-estimator          2.11.0
termcolor                     2.2.0
toml                          0.10.2
tomli                         2.0.1
torch                         1.13.1.post200
tornado                       6.2
tqdm                          4.64.1
typing_extensions             4.4.0
unicodedata2                  15.0.0
urllib3                       1.26.14
Werkzeug                      2.2.2
wheel                         0.38.4
wrapt                         1.14.1
yarl                          1.8.2
zipp                          3.11.0

This is my pip list.

For a full list you need to use conda list I think.

Can you reproduce what is happening on CI locally?

Yes, should be possible with https://github.com/nektos/act I think?

I might be wrong but my suspicion would be it is the same problem as here #158 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And furthermore, that MC test did not break in python3.9 (at least on the CI). I think it's probably best to expand testing to 3.8-10 (or 11) and make sure everything works, but I also do not know enough about the project to make that call. Happy to take your order here.

You mean setting up the CI to run with different python version? We can do that sure, it's basically just modifying this to run it with several python version. We can create a small separate PR to quickly merge into main and develop for that if you want?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conda also does indeed seem to find "conflicts" with 3.9 which maybe is why things were taking a while for you. none were found with 3.10, from what I could tell.

Frankly I have given up a bit on conda. I have had the problem of the CI running forever in 4 different repos with completely different dependencies and python versions. I think conda is just too slow to handle a medium number of dependencies anymore. With mamba it is super quick though :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think I need to install mamba and then proceed from there then.

@ilan-gold
Copy link
Collaborator Author

@gomezzz so sorry, i appear to have temporarily had a brain fart...too much integration in my life. In any case I think this is basically ready to go if you want to give it a proper review. If you approve, I can add some docs and merge/ask for re-review.

Copy link
Collaborator

@gomezzz gomezzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Could you format with black, get flake to pass and add the docs? :) Then we'll be good to go, I think! Thank you 🙏

@@ -53,6 +53,7 @@ def integrate(
function_values, self._nr_of_fevals = self.evaluate_integrand(fn, sample_points)
return self.calculate_result(function_values, integration_domain)

@expand_func_values_and_squeeze_intergal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@expand_func_values_and_squeeze_intergal
@expand_func_values_and_squeeze_integral

in a few other places too I think :)

# i.e we only have one dimension, or the second dimension (that of the integrand) is 1
is_1d = len(args[1].shape) == 1 or (len(args[1].shape) == 2 and args[1].shape[1] == 1)
if is_1d:
warnings.warn("DEPRECATION WARNING: In future versions of torchquad, an array-like object will always be returned.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So on the warning either is fine for me as long as it is consistent with the rest of the codebase. :)

On the return value, I think, a solution that keeps compatibility with existing code for now is best, which this does as far as I understand? If not, it's okay but we'll need to put it in the changelog of the next version (I'd plan to make a new release once your PRs are merged :) )

)
print(f"3D Boole Test passed. N: {N}, backend: {backend}, Errors: {errors}")
for err, test_function in zip(errors, funcs):
assert test_function.get_order() > 5 or err < 2e-13
assert test_function.get_order() > 5 or (err < 2e-13 if test_function.is_integrand_1d else err < 2e-11)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment on why we have the if statement here now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! it's because the errors are additive (so a bunch of small errors make a slightly larger error)

@ilan-gold
Copy link
Collaborator Author

@gomezzz I edited the warning to make it clear that thing would change in the future. Other than that, should be good (fingers crossed). Thanks so much for sticking with this!

Copy link
Collaborator

@gomezzz gomezzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor change to docs, ready to merge after! ✌️ Thanks! :)

Multidimensional/Vectorized Integrands
--------------------------------------

If you wish to evaluate many different integrands over the same domain, it may be faster to pass in a vectorized formulation if possible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could link here to the scipy thing you used? So people that are familiar with it realize that is the same way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I'll add the link.

docs/source/tutorial.rst Outdated Show resolved Hide resolved
result = torch.stack([torch.Tensor([simp.integrate(lambda x: parametrized_integrand(x, a, b), dim=1, N=101, integration_domain=integration_domain) for a in a_params]) for b in b_params])


Now let's see how to do this a bit more simply, and in a way that provides signficant speedup as the size of the integrand's `grid` grows:
Copy link
Collaborator

@gomezzz gomezzz Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually running this gives little speedup because the example is super small and simple (both compute virtually instantly). Maybe we can either evaluate more integrands in parallel or make the integrand more complex?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. running like this

def parametrized_integrand(x, a, b):
    return torch.sqrt(torch.cos(torch.sin((a + b) * x)))

a_params = torch.arange(40)
b_params = torch.arange(10, 20)
integration_domain = torch.Tensor([[0, 1]])
simp = tq.Simpson()
result = torch.stack([torch.Tensor([simp.integrate(lambda x: parametrized_integrand(x, a, b), dim=1, N=101, integration_domain=integration_domain) for a in a_params]) for b in b_params])

and

grid = torch.stack([torch.Tensor([a + b for a in a_params]) for b in b_params])

def integrand(x):
    return torch.sqrt(torch.cos(torch.sin(torch.einsum("i,jk->ijk", x.flatten(), grid))))

result_vectorized = simp.integrate(integrand, dim=1, N=101, integration_domain=integration_domain)

torch.all(torch.isclose(result_vectorized, result)) # True!

it goes from 188ms to 5ms on my machine 🙃

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a suggestion here or are you happy it worked haha?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to increase the params and complexity of the function with the code I posted so the effect is more noticeable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I see. Will do.

docs/source/tutorial.rst Outdated Show resolved Hide resolved
ilan-gold and others added 4 commits January 23, 2023 12:14
@gomezzz
Copy link
Collaborator

gomezzz commented Jan 24, 2023

@ilan-gold Should we merge then? :)

@ilan-gold
Copy link
Collaborator Author

i think so :) on to the next one! i'll need to do a bit of clean up on that branch so i'll ping you when it's ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Elementwise numerical integration
2 participants