Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow predictions on new groups #693

Merged
merged 11 commits into from
Jun 29, 2023

Conversation

tomicapretto
Copy link
Collaborator

@tomicapretto tomicapretto commented Jun 27, 2023

Model.predict() gains a new argument sample_new_groups which determines if we want to allow for predictions on new groups. This is useful in the context of hierarchical modeling, where groups are assumed to be a sample from a larger group. With this new feature, we can get predictions for a new, unseen, group.

Small example

import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

data = bmb.load_data("sleepstudy")
model = bmb.Model("Reaction ~ 1 + Days + (1 + Days | Subject)", data)
idata = model.fit()

df_new = data.head(10).reset_index(drop=True)
df_new["Subject"] = "xxx"
df_new = pd.concat([df_new, data.head(10)])
df_new = df_new.reset_index(drop=True)

p = model.predict(idata, data=df_new, inplace=False, sample_new_groups=True)
reaction_draws = p.posterior["Reaction_mean"]
mean = reaction_draws.mean(("chain", "draw")).to_numpy()
bounds = reaction_draws.quantile((0.025, 0.975), ("chain", "draw")).to_numpy()

fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)

axes[0].scatter(df_new.iloc[10:]["Days"], df_new.iloc[10:]["Reaction"])
axes[1].scatter(df_new.iloc[:10]["Days"], df_new.iloc[:10]["Reaction"])

axes[0].fill_between(np.arange(10), bounds[0, 10:], bounds[1, 10:], alpha=0.5, color="C0")
axes[1].fill_between(np.arange(10), bounds[0, :10], bounds[1, :10], alpha=0.5, color="C0")

axes[0].set_title("Original participant")
axes[1].set_title("New participant");

image

Notice how the variability increases for the new group. Only a single resampling method is supported for now and it's equivalent to sample_new_levels="uncertainty" in brms.

To Do

  • Implement tests

Closes #417

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Jun 28, 2023

Codecov Report

Merging #693 (024e275) into main (259c474) will increase coverage by 0.44%.
The diff coverage is 89.80%.

@@            Coverage Diff             @@
##             main     #693      +/-   ##
==========================================
+ Coverage   88.32%   88.77%   +0.44%     
==========================================
  Files          40       40              
  Lines        2862     2976     +114     
==========================================
+ Hits         2528     2642     +114     
  Misses        334      334              
Impacted Files Coverage Δ
bambi/model_components.py 92.97% <89.40%> (+4.08%) ⬆️
bambi/models.py 79.16% <100.00%> (+0.51%) ⬆️

... and 1 file with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@tomicapretto tomicapretto marked this pull request as ready for review June 29, 2023 01:21
@tomicapretto tomicapretto merged commit c753266 into bambinos:main Jun 29, 2023
4 checks passed
@tomicapretto tomicapretto deleted the predict_new_groups branch June 29, 2023 02:17
GStechschulte pushed a commit to GStechschulte/bambi that referenced this pull request Jun 29, 2023
This was referenced Sep 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow new levels in group specific effects when using Model.predict()
2 participants