Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for numpyro and blackjax PyMC samplers #526

Merged
merged 18 commits into from
Jun 10, 2022

Conversation

markgoodhead
Copy link
Contributor

@markgoodhead markgoodhead commented Jun 8, 2022

This is to address #522 and #525 inspired by @zwelitunyiswa's example

I decided to add a single new value to the fit() method which allows switching in of numpyro/blackjax samplers instead of the pymc default. I decided against some cpu/gpu flags because it's mostly decided by whatever Jax can find and the methods I saw to disable GPUs are quite hacky involving playing with your environment variables which I felt is out of scope for a library to be fiddling with so I've just noted this in the documentation instead.

I've tested the samplers locally and they work on one of my personal projects, but I'll try and knock up a simple example shortly which demonstrates them all.

One note: The PyMC 4 release blog post says:

These samplers live in a different submodule sampling_jax but the plan is to integrate them into pymc.sample(backend="JAX").

So we should expect the implementation here to change pretty soon, so I think it's worth keeping the implementation in bambi simple so it's easy to port-over when this happens.

@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 8, 2022

OK so I've hit a snag I don't quite understand. My personal project works fine but my test example fails:

import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
import time

az.style.use("arviz-darkgrid")
rng = np.random.default_rng(0)

size = 1000
x = rng.normal(size=size)
print(x)
data = pd.DataFrame(
    {
        "x": x,
        "y": rng.normal(loc=x, size=size)
    }
)
print(data)

bmb_model = bmb.Model("y ~ x", data)
bmb_model_numpyro = bmb.Model("y ~ x", data)
bmb_model_blackjax = bmb.Model("y ~ x", data)
t0 = time.time()
idata = bmb_model.fit()
t1 = time.time()
idata_numpyro = bmb_model_numpyro.fit(chains=4, tune=1000, draws=1000, sampler_backend="numpyro", chain_method="vectorized")
t2 = time.time()
idata_blackjax = bmb_model_blackjax.fit(chains=4, tune=1000, draws=1000, sampler_backend="blackjax", chain_method="vectorized")
t3 = time.time()

print(f"Default: {t1-t0} Numpyro: {t2-t1} Blackjax: {t3-t2}")

It fails on line

idata.posterior[intercept_name] -= np.dot(X.mean(0), coefs).reshape(shape)
during the numpyro run because it appears that one of the idata attributes is read-only... must admit this is quite deep into the weeds of bambi/pymc internals and I'm a little stuck as to how to proceed. It seems like the InferenceData object being returned by numpyro's fit() is inconsistent with the standard one, which is most likely a PyMC bug?

Copy link
Collaborator

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markgoodhead Thanks for your contribution. I have a few comments.

bambi/backend/pymc.py Outdated Show resolved Hide resolved
bambi/models.py Outdated Show resolved Hide resolved
bambi/models.py Outdated Show resolved Hide resolved
bambi/models.py Outdated Show resolved Hide resolved
Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
@aloctavodia
Copy link
Collaborator

aloctavodia commented Jun 8, 2022

OK so I've hit a snag I don't quite understand. My personal project works fine but my test example fails:

import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
import time

az.style.use("arviz-darkgrid")
rng = np.random.default_rng(0)

size = 1000
x = rng.normal(size=size)
print(x)
data = pd.DataFrame(
    {
        "x": x,
        "y": rng.normal(loc=x, size=size)
    }
)
print(data)

bmb_model = bmb.Model("y ~ x", data)
bmb_model_numpyro = bmb.Model("y ~ x", data)
bmb_model_blackjax = bmb.Model("y ~ x", data)
t0 = time.time()
idata = bmb_model.fit()
t1 = time.time()
idata_numpyro = bmb_model_numpyro.fit(chains=4, tune=1000, draws=1000, sampler_backend="numpyro", chain_method="vectorized")
t2 = time.time()
idata_blackjax = bmb_model_blackjax.fit(chains=4, tune=1000, draws=1000, sampler_backend="blackjax", chain_method="vectorized")
t3 = time.time()

print(f"Default: {t1-t0} Numpyro: {t2-t1} Blackjax: {t3-t2}")

It fails on line

idata.posterior[intercept_name] -= np.dot(X.mean(0), coefs).reshape(shape)

during the numpyro run because it appears that one of the idata attributes is read-only... must admit this is quite deep into the weeds of bambi/pymc internals and I'm a little stuck as to how to proceed. It seems like the InferenceData object being returned by numpyro's fit() is inconsistent with the standard one, which is most likely a PyMC bug?

This is weird, this runs for me with both numpyro and blackjax

@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 8, 2022

How odd! Perhaps my environment isn't setup correctly and I'm behind on the latest versions. What versions of pymc/jax/arviz/xarray etc are you using?

I have modified my example script to work now with the new method arg approach (and fixed a bug in chains handling I spotted). Note: I'm not actually sure if I need to construct 3 models - call it paranoia at ruling out bugs in case bmb.Model was stateful between fits 😂

import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
import time

az.style.use("arviz-darkgrid")
rng = np.random.default_rng(0)

size = 1000
x = rng.normal(size=size)
data = pd.DataFrame(
    {
        "x": x,
        "y": rng.normal(loc=x, size=size)
    }
)

bmb_model = bmb.Model("y ~ x", data)
bmb_model_numpyro = bmb.Model("y ~ x", data)
bmb_model_blackjax = bmb.Model("y ~ x", data)
t0 = time.time()
idata = bmb_model.fit()
t1 = time.time()
idata_numpyro = bmb_model_numpyro.fit(method="nuts_numpyro", chain_method="vectorized")
t2 = time.time()
idata_blackjax = bmb_model_blackjax.fit(method="nuts_blackjax", chain_method="vectorized")
t3 = time.time()

print(f"Default: {t1-t0} Numpyro: {t2-t1} Blackjax: {t3-t2}")

bambi/backend/pymc.py Outdated Show resolved Hide resolved
bambi/backend/pymc.py Show resolved Hide resolved
bambi/backend/pymc.py Show resolved Hide resolved
bambi/backend/pymc.py Outdated Show resolved Hide resolved
bambi/backend/pymc.py Outdated Show resolved Hide resolved
bambi/models.py Outdated Show resolved Hide resolved
@zwelitunyiswa
Copy link

zwelitunyiswa commented Jun 8, 2022 via email

bambi/__init__.py Outdated Show resolved Hide resolved
@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 8, 2022

Vecrorized does work for both blackjack and numpyro( https://www.pymc.io/projects/docs/en/stable/api/samplers.html) but I found vectorized was slower on cpu (and my gpu won’t work since I have an M1).

On Wed, Jun 8, 2022 at 08:44 Osvaldo A Martin @.> wrote: chain_method="vectorized" This is weird, this runs for me with both numpyro and blackjax — Reply to this email directly, view it on GitHub <#526 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AH3QQV3MUOM6VTU6MFJSXATVOCIUBANCNFSM5YGBJLIA . You are receiving this because you were mentioned.Message ID: @.>

Yes vectorized is generally faster if you're on a single GPU, otherwise for multiple GPUs or multiple CPU cores I expect parallel would be better.

@aloctavodia
Copy link
Collaborator

Looks good, the only missing part is a test

@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 8, 2022

Looks good, the only missing part is a test

Hmm so I just tried modifying existing tests to also run the new fit methods, e.g.

def test_group_specific_categorical_interaction(crossed_data):
    crossed_data["fourcats"] = sum([[x] * 10 for x in ["a", "b", "c", "d"]], list()) * 3
    model = Model("Y ~ continuous + (threecats:fourcats|site)", crossed_data)
    model.fit(tune=10, draws=10)
    model.fit(tune=10, draws=10, method="nuts_numpyro")

However I again get an import error on from bambi import math at the top of the test file (which works if I comment it out) and my tests fail on the same error my example code gives above 🤦 do you get the same error modifying that test or does it work for you given the example also worked for you?

@aloctavodia
Copy link
Collaborator

Do you mind adding the test anyway?

@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 8, 2022

Do you mind adding the test anyway?

Tests added... Fingers crossed they actually work!

@aloctavodia
Copy link
Collaborator

@markgoodhead
Copy link
Contributor Author

Be sure to run black and pylint https://github.com/bambinos/bambi/blob/main/CONTRIBUTING.md#pull-request-checklist

Done 👍

@markgoodhead
Copy link
Contributor Author

One small issue is that pylint isn't happy with the import within the code itself - I assume you're happy to ignore the error here?

@tomicapretto
Copy link
Collaborator

One small issue is that pylint isn't happy with the import within the code itself - I assume you're happy to ignore the error here?

You could add # pylint: disable=import-outside-toplevel next to the import, I think that should work.

@markgoodhead
Copy link
Contributor Author

I just tried updating my version of xarray (which was 0.21.1 before) to the latest on pip (2022.3.0) and I still get the same xarray error... otherwise my versions are all compatible with the pymc 4.0.0 release on pip. Does anyone else get this error? If not, what versions of xarray etc are you using?

…ar imports and this is needed for tests to work
@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 8, 2022

Tests look to be failing due to Jax not being installed (ModuleNotFoundError: No module named 'jax') - @tomicapretto please can you assist? @aloctavodia requested that Jax be an optional install for the user so I guess we just need to add this to the github actions install only or something?

@aloctavodia
Copy link
Collaborator

We can add jax, numpyro, blackjax and any other necessary requirement for jax-based samplers to https://github.com/bambinos/bambi/blob/main/requirements-dev.txt

@tomicapretto
Copy link
Collaborator

I think we could have something like requirements-optional.txt like what you can find in ArviZ https://github.com/arviz-devs/arviz. I think requirements-dev.txt should be only for development dependencies.

Below, you will need to add another line saying pip install -r requirements-optional.txt.

pip install -r requirements.txt
pip install -r requirements-dev.txt

@aloctavodia
Copy link
Collaborator

Agreed, that's cleaner.

@canyon289
Copy link
Collaborator

Please also ensure the optional dependencies in setup.py is setup correctly for optional requirements

https://stackoverflow.com/a/43090648/414104

@markgoodhead
Copy link
Contributor Author

markgoodhead commented Jun 9, 2022

I've added the optional requirements files and hopefully done the setup.py changes @canyon289 requested correctly (all a bit new to me so I could well have done it wrong!). I wasn't sure what versions to specify in the file so I tried to find the equivalents in pymc to align with what they have... and was a bit surprised when I couldn't find any! Perhaps something similar should be added to pymc and then by depending on a specific pymc version this would flow naturally upstream to bambi?

Another thing to note here is that if a user installs jax via this version I believe they won't get CUDA support by default - further downstream libraries like numpyro look to sort of copy the Jax installation instructions in their setup optional structure. I think the best solution overall would be for each part of the library hierarchy to depend on the correct optional install in the sub-library they depend on, e.g. bambi[gpu] would end up calling pymc[gpu] which would call numpyro[gpu] etc... perhaps this is a bit out of scope for this PR though as it requires a lot of co-ordination with other repos and this current solution is a reasonable intermediate step?

@markgoodhead
Copy link
Contributor Author

OK it looks like the tests are failing for the same reason my local environment doesn't work which I've no idea how to fix! Anyone got any advice what I should try/do here?

2022-06-09T09:04:21.8572705Z bambi/models.py:265: in fit
2022-06-09T09:04:21.8572923Z     return self.backend.run(
2022-06-09T09:04:21.8573156Z bambi/backend/pymc.py:91: in run
2022-06-09T09:04:21.8573373Z     result = self._run_mcmc(
2022-06-09T09:04:21.8573593Z bambi/backend/pymc.py:288: in _run_mcmc
2022-06-09T09:04:21.8573880Z     idata = self._clean_mcmc_results(idata, omit_offsets, include_mean)
2022-06-09T09:04:21.8574228Z bambi/backend/pymc.py:363: in _clean_mcmc_results
2022-06-09T09:04:21.8574631Z     idata.posterior[intercept_name] -= np.dot(X.mean(0), coefs).reshape(shape)
2022-06-09T09:04:21.8575105Z /usr/share/miniconda/envs/test/lib/python3.8/site-packages/xarray/core/_typed_ops.py:290: in __isub__
2022-06-09T09:04:21.8575466Z     return self._inplace_binary_op(other, operator.isub)
2022-06-09T09:04:21.8575943Z /usr/share/miniconda/envs/test/lib/python3.8/site-packages/xarray/core/dataarray.py:3121: in _inplace_binary_op
2022-06-09T09:04:21.8576273Z     f(self.variable, other_variable)
2022-06-09T09:04:21.8576690Z /usr/share/miniconda/envs/test/lib/python3.8/site-packages/xarray/core/_typed_ops.py:480: in __isub__
2022-06-09T09:04:21.8577039Z     return self._inplace_binary_op(other, operator.isub)

@codecov-commenter
Copy link

Codecov Report

Merging #526 (478a5ed) into main (f9dc90d) will increase coverage by 0.14%.
The diff coverage is 95.45%.

@@            Coverage Diff             @@
##             main     #526      +/-   ##
==========================================
+ Coverage   86.69%   86.84%   +0.14%     
==========================================
  Files          32       32              
  Lines        2586     2622      +36     
==========================================
+ Hits         2242     2277      +35     
- Misses        344      345       +1     
Impacted Files Coverage Δ
bambi/models.py 88.65% <ø> (ø)
bambi/backend/pymc.py 80.88% <90.00%> (+0.60%) ⬆️
bambi/tests/test_built_models.py 99.03% <100.00%> (+0.08%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f9dc90d...478a5ed. Read the comment docs.

**kwargs,
)
else:
raise
Copy link
Collaborator

@canyon289 canyon289 Jun 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please raise a specific exception with a helpful message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah this was the code before I changed this function, it's just been moved around. To be honest I wondered about removing this whole error handling because I've seen pymc do the same thing internally anyway but I thought that might be out of scope for this PR - I'll do whatever is the consensus here 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If youre willing editing here would be helpful, but youre right if you just moved the code it can be out of scope! My ask is just open an issue ticket to track and reference this discussion :)

model.fit(method="nuts_blackjax", chain_method="vectorized")


def test_regression_blackjax():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: These two tests test_regression_blackjax and test_regression_nunpyro could be parameterized to reduce amount of code that needs to be read or maintained

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @canyon289 here, but if you want @markgoodhead you can open an issue fix this later.

@canyon289
Copy link
Collaborator

@markgoodhead thanks for doing this! this is a great capability add for bambi

Copy link
Collaborator

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you @markgoodhead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants