-
-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type-hint and type-overload a few common Model methods #790
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# pylint: disable=no-name-in-module | ||
# pylint: disable=too-many-lines | ||
import logging | ||
from typing import Callable, Literal, Optional, Union, overload | ||
import warnings | ||
|
||
from copy import deepcopy | ||
|
@@ -11,6 +12,7 @@ | |
import pandas as pd | ||
|
||
from arviz.plots import plot_posterior | ||
from arviz import InferenceData | ||
|
||
from bambi.backend import PyMCModel | ||
from bambi.defaults import get_builtin_family | ||
|
@@ -102,18 +104,18 @@ class Model: | |
# pylint: disable=too-many-instance-attributes | ||
def __init__( | ||
self, | ||
formula, | ||
data, | ||
family="gaussian", | ||
priors=None, | ||
link=None, | ||
categorical=None, | ||
potentials=None, | ||
dropna=False, | ||
auto_scale=True, | ||
noncentered=True, | ||
center_predictors=True, | ||
extra_namespace=None, | ||
formula: Union[str, Formula], | ||
data: pd.DataFrame, | ||
family: Union[str, Family] = "gaussian", | ||
priors: Optional[dict[str, Prior]] = None, | ||
link: Optional[Union[str, dict[str, str]]] = None, | ||
categorical: Optional[Union[str, list[str]]] = None, | ||
potentials: Optional[list[tuple[str, Callable]]] = None, | ||
dropna: bool = False, | ||
auto_scale: bool = True, | ||
noncentered: bool = True, | ||
center_predictors: bool = True, | ||
extra_namespace: Optional[dict] = None, | ||
): | ||
# attributes that are set later | ||
self.components = {} # Constant and Distributional components | ||
|
@@ -225,21 +227,57 @@ def __init__( | |
# Build priors | ||
self._build_priors() | ||
|
||
@overload | ||
def fit( | ||
self, | ||
draws=1000, | ||
tune=1000, | ||
discard_tuned_samples=True, | ||
omit_offsets=True, | ||
include_mean=False, | ||
inference_method="mcmc", | ||
init="auto", | ||
n_init=50000, | ||
chains=None, | ||
cores=None, | ||
random_seed=None, | ||
draws: int = 1000, | ||
tune: int = 1000, | ||
discard_tuned_samples: bool = True, | ||
omit_offsets: bool = True, | ||
include_mean: bool = False, | ||
inference_method: Literal["mcmc", "blackjax_nuts", "numpyro_nuts", "laplace"] = "mcmc", | ||
init: str = "auto", | ||
n_init: int = 50000, | ||
chains: Optional[int] = None, | ||
cores: Optional[int] = None, | ||
random_seed: Optional[int] = None, | ||
**kwargs, | ||
): | ||
) -> InferenceData: ... | ||
|
||
@overload | ||
def fit( | ||
self, | ||
draws: int = 1000, | ||
tune: int = 1000, | ||
discard_tuned_samples: bool = True, | ||
omit_offsets: bool = True, | ||
include_mean: bool = False, | ||
inference_method: Literal["vi"] = ..., | ||
init: str = "auto", | ||
n_init: int = 50000, | ||
chains: Optional[int] = None, | ||
cores: Optional[int] = None, | ||
random_seed: Optional[int] = None, | ||
**kwargs, | ||
) -> pm.MeanField: ... | ||
|
||
def fit( | ||
self, | ||
draws: int = 1000, | ||
tune: int = 1000, | ||
discard_tuned_samples: bool = True, | ||
omit_offsets: bool = True, | ||
include_mean: bool = False, | ||
inference_method: Literal[ | ||
"mcmc", "blackjax_nuts", "numpyro_nuts", "vi", "laplace" | ||
] = "mcmc", | ||
init: str = "auto", | ||
n_init: int = 50000, | ||
chains: Optional[int] = None, | ||
cores: Optional[int] = None, | ||
random_seed: Optional[int] = None, | ||
**kwargs, | ||
) -> InferenceData | pm.MeanField: | ||
"""Fit the model using PyMC. | ||
|
||
Parameters | ||
|
@@ -722,7 +760,13 @@ def plot_priors( | |
) | ||
return axes | ||
|
||
def prior_predictive(self, draws=500, var_names=None, omit_offsets=True, random_seed=None): | ||
def prior_predictive( | ||
self, | ||
draws: int = 500, | ||
var_names: Optional[Union[str, list[str]]] = None, | ||
omit_offsets: bool = True, | ||
random_seed: Optional[int] = None, | ||
) -> InferenceData: | ||
"""Generate samples from the prior predictive distribution. | ||
|
||
Parameters | ||
|
@@ -763,15 +807,37 @@ def prior_predictive(self, draws=500, var_names=None, omit_offsets=True, random_ | |
|
||
return idata | ||
|
||
@overload | ||
def predict( | ||
self, | ||
idata, | ||
kind="mean", | ||
data=None, | ||
inplace=True, | ||
include_group_specific=True, | ||
sample_new_groups=False, | ||
): | ||
idata: InferenceData, | ||
kind: Literal["mean", "pps"] = "mean", | ||
data: Optional[pd.DataFrame] = None, | ||
inplace: Literal[True] = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be |
||
include_group_specific: bool = True, | ||
sample_new_groups: bool = False, | ||
) -> None: ... | ||
|
||
@overload | ||
def predict( | ||
self, | ||
idata: InferenceData, | ||
kind: Literal["mean", "pps"] = "mean", | ||
data: Optional[pd.DataFrame] = None, | ||
inplace: Literal[False] = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
include_group_specific: bool = True, | ||
sample_new_groups: bool = False, | ||
) -> InferenceData: ... | ||
|
||
def predict( | ||
self, | ||
idata: InferenceData, | ||
kind: Literal["mean", "pps"] = "mean", | ||
data: Optional[pd.DataFrame] = None, | ||
inplace: bool = True, | ||
include_group_specific: bool = True, | ||
sample_new_groups: bool = False, | ||
) -> Optional[InferenceData]: | ||
"""Predict method for Bambi models | ||
|
||
Obtains in-sample and out-of-sample predictions from a fitted Bambi model. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem I see here is that after #775, users have access to a huge variety of samplers through bayeux. I'm not sure how this is handled with literals. But, is it possible that what we need is to create a Literal on the fly (at import time) where we check whether users have access to bayeux, and depending on that, the Literal contains different types?
Also, I don't think there's a convenient way to access the list of samplers without first creating a model. I'm going to create an issue now.