-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support other backends than PyTorch using autoray #137
Conversation
Numpy as numerical backend is now supported in Trapezoid, Simpson, Boole and MonteCarlo. The numerical backend is determined by the type of the integration_domain argument.
JAX and Tensorflow now work with the Newton Cotes rules and MonteCarlo Since JAX has a special syntax for in-place changes, I replaced the zeros array initialisation with stacking. To not break the torch gradients I used stack instead of array for the h values. MonteCarlo now no longer initialises a tensor full of zeros; I haven't tested the runtime and memory impact of this change yet. For Tensorflow I replaced the way ravel is called. Tensorflow's ravel only works if numpy behaviour is enabled, otherwise it is missing in the latest tensorflow version.
This is used if integration_domain is, for example, a list.
I added a RNG helper class to maintain a state for random number generation if the backend supports it. This should reduce problems when an integrand function itself changes a global RNG state. For consistency between backends, if seed is None, the RNG is initialized with a random state where possible. For the torch backend this means that a previous call to torch.random.manual_seed no longer affects future random number generations in MonteCarlo where the seed argument is unsed or None.
* Add the default argument for the seed and adjust the parameter order * Add a uniform dummy method as a place to document this backend-specific function, which is defined in the constructor
I moved the imports into the functions so that set_precision can work if torch is not installed.
…level if it exists
…ntation mistake * Fix _linspace_with_grads doc: N is int * Cast the integration_domain to float in _setup_integration_domain so that tensorflow cannot use integers for the domain tensor This is required if integration_domain is a list and the backend argument is specified. * Change _linspace_with_grads so that it does not fail with tensorflow and requires_grad set to True
…DME and docs I used "conda" and a wildcard in build versions in environment.yml to install numerical backends with CUDA support jaxlib with CUDA support seems to be only available with pip
With the pip installation of tensorflow, the ~/miniconda3/envs/torchquad/lib/python3.9/site-packages/tests folder appears and breaks the imports in torchquads tests. I tried prepending "../" to sys.path instead of appending it but this did not fix the problem.
I also added a run_example_functions function which in comparison to compute_test_errors additionally returns the functions. Furthermore, the example functions are not generated on import but when calling run_example_functions. The tests runtime difference due to this change is negligible in comparison to the time required to import a numerical backend.
…s and remove compute_test_errors
* Check the integration domain with a _check_integration_domain utils function To calculate gradients with tensorflow, integration_domain needs to be a tf.Variable, and len() does not work on this type. I changed the input check code so that it uses shape instead of len for tensors. * Support JIT compilation of an integrator.integrate function over integration_domain with JAX, Torch and Tensorflow MonteCarlo does not yet work with all of them.
Now numpy and tensorflow support both float32 and float64 precision
Running the
Is this expected? |
Installing the all_env jax does not work and produces the following error:
I think this may be a windows problem. 🤔 I think jax is not natively available on win? Solved it by following https://github.com/cloudhan/jax-windows-builder Not a huge problem, but it may be worth pointing out that it was tested on linux? 🤔 |
Sending in an array with wrong backend currently leads to quite cryptic messages. E.g. import torchquad as tq
import jax.numpy as jnp
tq.set_up_backend("jax", data_type="float32")
def some_function(x):
return jnp.sin(x[:, 0]) + j.numpy.exp(x[:, 1])
trap = tq.Trapezoid()
# Set the backend argument to "tensorflow" instead of "torch"
integral_value = trap.integrate(
some_function,
dim=2,
N=10000,
integration_domain=[[0, 1], [-1, 1]],
backend="tensorflow",
) leads to
|
Similarly if I specify no backend I also get errors. Not sure if we leave that problem for the user to fix though 🤔 However, intuitively, if I call set_up_backend I would expect naively that that should be enough. import torchquad as tq
import jax.numpy as jnp
tq.set_up_backend("jax", data_type="float32")
def some_function(x):
return jnp.sin(x[:, 0]) + j.numpy.exp(x[:, 1])
trap = tq.Trapezoid()
# Set the backend argument to "tensorflow" instead of "torch"
integral_value = trap.integrate(
some_function,
dim=2,
N=10000,
integration_domain=[[0, 1], [-1, 1]],
) produces
|
I've tested the installations only on GNU/Linux operating systems and forgot to mention this. Here are some commands to test for CUDA support: python3 -c 'import tensorflow as tf; print(tf.test.is_built_with_cuda(), tf.config.list_physical_devices("GPU"))'
# Show if TensorFlow supports CUDA and a list of found GPU Devices
python3 -c 'import tensorflow as tf; tf.function(lambda x: x, jit_compile=True)(1.0)'
# Fails if TensorFlow does not support compilation with XLA
CUDA_VISIBLE_DEVICES= python3 -c 'import tensorflow as tf; tf.function(lambda x: x, jit_compile=True)(1.0)'
# Fails if TensorFlow does not support compilation with XLA on CPU
python3 -c 'import torch; print(torch.version.cuda, torch.cuda.is_available(), torch.backends.cudnn.enabled, torch.backends.cudnn.version(), torch.cuda.is_initialized())'
# Show the CUDA version supported by PyTorch, a bool if CUDA works,
# a bool if cudnn works, the cudnn version,
# and a bool if CUDA is already initialised (which is usually False)
python3 -c 'import jax; print(jax.devices(), jax.devices()[0].device_kind)'
# Fails if JAX cannot find a GPU; otherwise, it lists the available GPUs
# and shows the name of the first GPU
# Set the TF_CPP_MIN_LOG_LEVEL=0 environment variable for more information The deployment test uses N=101, which is very small, so VEGAS executes warmups without points. This leads to the The cryptic error message appears because TensorFlow's NumPy behaviour is not enabled. Instead of the |
python3 -c 'import tensorflow as tf; tf.function(lambda x: x, jit_compile=True)(1.0)' leads to
or directly in in terminal:
|
Then let's increase N! :) |
Then let's add a global / env variable to track selected backend :) |
If I run locally the Monte Carlo test now fails for me
I also get many, many warnings exactly like this one
and a few other ones after that
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments :) almost all minor things, I think
… on GNU/Linux Co-authored-by: Pablo Gómez <pablo.gomez@gmx.de>
This avoids the 'Cannot update the VEGASMap.' warning, which was shown because there were too few points for the warmup.
… be inferred from user-provided arguments
* Move the get_jit_compiled_integrate methods to the bottom * Move the calculate_result methods below the integrate methods
Done.
The I haven't seen the deprecated FileDescriptor() warning before. For the deprecation warnings about np.object, np.bool, np.int and the imp module, I have initially added a pyproject.toml file to hide them: The torch.meshgrid warning appears with torch 1.10.0 but not torch 1.9.1.post3 since the indexing argument was added recently. Using this argument hides the deprecation warning with a new PyTorch version but makes torchquad incompatible with older versions (unless I add a case distinction). torchquad works with both indexing orders. The test executions on GPU currently may require environment variables which change the memory allocation behaviour of the backends since all backends are imported one after another and some of them can reserve the whole GPU memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice work! 💪 :) I'll be looking to create a new release specifically for this asap (probably early May)
Description
The main change is a big code rewrite so that Trapezoid, Simpson, Boole and MonteCarlo can be used with NumPy, PyTorch, JAX and Tensorflow, and VEGAS with NumPy and PyTorch.
Other changes include, for example:
TORCHQUAD_LOG_LEVEL
environment variableHow Has This Been Tested?
Additional tests for all backends