Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow plot_cap() to show predictions at the observation level #668

Merged
merged 8 commits into from
Apr 27, 2023

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Apr 16, 2023

This draft PR addresses issue #623 to allow plot_cap() to show predictions at the observation level, i.e., to use the adjusted posterior predictive distribution.

This is achieved by adding an additional boolean argument pps (posterior predictive sample) to the plot_cap() function. If true, then the adjusted posterior predictive distribution is plotted for the response variable by using model.predict(..., kind="pps"), else the adjusted posterior is used for plotting, i.e., model.predict(..., kind="mean").

Below are some examples:

import arviz as az
import bambi as bmb
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from bambi.plots import plot_cap

data = pd.read_csv("data/mtcars.csv")
data["cyl"] = data["cyl"].replace({4: "low", 6: "medium", 8: "high"})
data["gear"] = data["gear"].replace({3: "A", 4: "B", 5: "C"})
data["cyl"] = pd.Categorical(data["cyl"], categories=["low", "medium", "high"], ordered=True)

model = bmb.Model("mpg ~ 0 + hp * wt + cyl + gear", data)
idata = model.fit(draws=1000, target_accept=0.95, random_seed=1234)
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "hp", pps=False, ax=ax);

output1

# posterior predictive distribution (forward sampling)
# i.e., plot_cap at observation level
fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, "hp", pps=True, ax=ax);

output2

fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, ["hp", "wt"], pps=False, ax=ax);

output3

fig, ax = plt.subplots(figsize=(7, 3), dpi=120)
plot_cap(model, idata, ["hp", "wt"], pps=True, ax=ax);

output4

This is a work in progress and I still need to test on other examples, add tests, and perform linting. However, if anyone has any initial comments or criticisms, let me know. Thanks!

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Collaborator

@tomicapretto tomicapretto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the nice PR! I added some suggestions for changes. Let me know if you have any questions :)

bambi/plots/plot_cap.py Outdated Show resolved Hide resolved
bambi/plots/plot_cap.py Outdated Show resolved Hide resolved
bambi/plots/plot_cap.py Outdated Show resolved Hide resolved
bambi/plots/plot_cap.py Outdated Show resolved Hide resolved
bambi/plots/plot_cap.py Outdated Show resolved Hide resolved
docs/notebooks/data/mtcars.csv Outdated Show resolved Hide resolved
@tomicapretto
Copy link
Collaborator

@GStechschulte I just updated the PR with the mtcars dataset, now you can do bmb.load_data("mtcars"). I also bumped PyMC to 5.3.0. Let me know if you want me to review the example :)

@codecov-commenter
Copy link

codecov-commenter commented Apr 18, 2023

Codecov Report

Merging #668 (39d6ccc) into main (8594fd3) will increase coverage by 0.18%.
The diff coverage is 93.18%.

❗ Current head 39d6ccc differs from pull request most recent head f06f09e. Consider uploading reports for the commit f06f09e to get more accurate results

@@            Coverage Diff             @@
##             main     #668      +/-   ##
==========================================
+ Coverage   87.62%   87.80%   +0.18%     
==========================================
  Files          40       40              
  Lines        2650     2665      +15     
==========================================
+ Hits         2322     2340      +18     
+ Misses        328      325       -3     
Impacted Files Coverage Δ
bambi/data/datasets.py 96.15% <ø> (ø)
bambi/families/likelihood.py 85.36% <ø> (ø)
bambi/plots/plot_cap.py 89.36% <63.63%> (-1.47%) ⬇️
bambi/families/multivariate.py 97.01% <96.66%> (-1.32%) ⬇️
bambi/families/univariate.py 95.38% <96.66%> (+5.01%) ⬆️
bambi/backend/terms.py 96.47% <100.00%> (+0.02%) ⬆️
bambi/defaults/families.py 87.50% <100.00%> (ø)
bambi/families/family.py 94.56% <100.00%> (+1.30%) ⬆️
bambi/model_components.py 89.71% <100.00%> (+0.09%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@GStechschulte
Copy link
Collaborator Author

@GStechschulte I just updated the PR with the mtcars dataset, now you can do bmb.load_data("mtcars"). I also bumped PyMC to 5.3.0. Let me know if you want me to review the example :)

Great, thanks! I am going to perform a few more examples (other than the mtcars dataset and model) first, and then you can review the example(s). I plan to be finished with this (and move from a draft PR to PR) by the end of this weekend 23.04.

As far as adding tests in the test_plots.py of the tests/ directory goes, what do you think? I don't have much experience with writing tests, but I think with the added functionality, it would be useful to add some. Looking at the code in test_plots.py, I have a pretty good idea of what the tests should be. Or...I can ask CoPilot ;)

@tomicapretto
Copy link
Collaborator

Sounds great! Thanks for the extra examples.

It's not trivial to test plotting functions. That's why the code in plot_tests.py only checks things work, without checking the actual plot. If you can do something close to that, that would be awesome.

See the test code also uses mtcars, but it's loaded from the tests module. You could remove that and load it with bmb.load_data() as well :)

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Apr 23, 2023

Sounds great! Thanks for the extra examples.

It's not trivial to test plotting functions. That's why the code in plot_tests.py only checks things work, without checking the actual plot. If you can do something close to that, that would be awesome.

See the test code also uses mtcars, but it's loaded from the tests module. You could remove that and load it with bmb.load_data() as well :)

Hey @tomicapretto I have been testing this functionality with a variety of different models (linear, negative binomial, binomial, and categorical regression). The examples can be found in this Gist. Most of them look good besides categorical regression (which I opened an issue), and models with binary responses that are represented by the proportion alias.

The problem with p(y, n) ~ x is that when using pps=True, to create the cap_data, both n and x need to be passed in order to compute the proportion y / n of the outcome y. I provide an example of what a general solution to this problem looks like in the Regression for binary responses section.

@tomicapretto
Copy link
Collaborator

@GStechschulte thanks for the thorough notebook. It's a great piece of work! Here my answers

  • For binary responses the posterior predictive distribution is in a different scale than the mean parameter. The former is in the binary scale (set {0, 1}) and the latter is on the probability scale (interval (0, 1)). So a barchart or point and line plot would be more appropriate. However I think it's not needed to handle that case in this PR. I would leave it as it is, or add an error when the family is Bernoulli/Binomial.
  • For categorical responses:
    • We could plot the mean. I can help to fix the fill between.
    • Similar issues than before with the draws from the posterior predictive distribution. It doesn't need to be solved now.

What do you think?

* FAQ page first draft

* Update docs/faq.rst

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* Update docs/faq.rst

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* Remove general bayesian modelling questions + minor fixes

* Update faq.rst

* Update faq.rst

---------

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>
@GStechschulte GStechschulte marked this pull request as ready for review April 24, 2023 18:18
@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Apr 24, 2023

@tomicapretto Thanks, and indeed, you are correct regarding the scales of the binary responses. Agreed, a bar plot would be a better way to visualize the observations. For this PR, I left this particular example as is. Same for categorical regression; we can collaborate on this issue in #669

One last point regarding adding tests. I suppose one could just copy and paste, albeit with pps=True, for the majority of the functions in test_plots.py. Surely there is a more pragmatic approach to this with pytest? I have been looking into parameterizing tests.

@tomicapretto
Copy link
Collaborator

@GStechschulte Yes, you can do something like this:

@pytest.mark.parametrize("pps", [False, True])
def test_basic(mtcars, pps):
    model, idata = mtcars

    # Using dictionary
    # Horizontal variable is numeric
    plot_cap(model, idata, {"horizontal": "hp"}, pps=pps)

    # Horizontal variable is categorical
    plot_cap(model, idata, {"horizontal": "gear"}, pps=pps)

    # Using list
    plot_cap(model, idata, ["hp"], pps=pps)
    plot_cap(model, idata, ["gear"], pps=pps)

Havent' tested locally but I think that's it. Let me know if you need help :)

@tomicapretto
Copy link
Collaborator

@GStechschulte looks like it can be merged? :D

@GStechschulte
Copy link
Collaborator Author

@GStechschulte looks like it can be merged? :D

Indeed, it can be merged :) thanks for your code reviews and quick feedback!

@tomicapretto tomicapretto merged commit 743001f into bambinos:main Apr 27, 2023
4 checks passed
GStechschulte added a commit to GStechschulte/bambi that referenced this pull request May 9, 2023
…inos#668)

* plot_cap using post. pred. samples

* add mtcars dataset and test plot_cap using post. pred. samples

* plot_cap show predictions at obs. level

* remove unused code and formatting

* Add mtcars as a dataset. Bump PyMC to 5.3.0

* FAQ page first draft (bambinos#657)

* FAQ page first draft

* Update docs/faq.rst

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* Update docs/faq.rst

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* Remove general bayesian modelling questions + minor fixes

* Update faq.rst

* Update faq.rst

---------

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* black reformatting

* add tests for pps=bool using pytest parameterization

---------

Co-authored-by: Tomas Capretto <tomicapretto@gmail.com>
Co-authored-by: Abuzar Mahmood <abuzarmahmood@gmail.com>
Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>
@GStechschulte GStechschulte deleted the plot-cap-obs-level branch July 20, 2023 08:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants