Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type-hint and type-overload a few common Model methods #790

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
130 changes: 98 additions & 32 deletions bambi/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=no-name-in-module
# pylint: disable=too-many-lines
import logging
from typing import Callable, Literal, Optional, Union, overload
import warnings

from copy import deepcopy
Expand All @@ -11,6 +12,7 @@
import pandas as pd

from arviz.plots import plot_posterior
from arviz import InferenceData

from bambi.backend import PyMCModel
from bambi.defaults import get_builtin_family
Expand Down Expand Up @@ -102,18 +104,18 @@ class Model:
# pylint: disable=too-many-instance-attributes
def __init__(
self,
formula,
data,
family="gaussian",
priors=None,
link=None,
categorical=None,
potentials=None,
dropna=False,
auto_scale=True,
noncentered=True,
center_predictors=True,
extra_namespace=None,
formula: Union[str, Formula],
data: pd.DataFrame,
family: Union[str, Family] = "gaussian",
priors: Optional[dict[str, Prior]] = None,
link: Optional[Union[str, dict[str, str]]] = None,
categorical: Optional[Union[str, list[str]]] = None,
potentials: Optional[list[tuple[str, Callable]]] = None,
dropna: bool = False,
auto_scale: bool = True,
noncentered: bool = True,
center_predictors: bool = True,
extra_namespace: Optional[dict] = None,
):
# attributes that are set later
self.components = {} # Constant and Distributional components
Expand Down Expand Up @@ -225,21 +227,57 @@ def __init__(
# Build priors
self._build_priors()

@overload
def fit(
self,
draws=1000,
tune=1000,
discard_tuned_samples=True,
omit_offsets=True,
include_mean=False,
inference_method="mcmc",
init="auto",
n_init=50000,
chains=None,
cores=None,
random_seed=None,
draws: int = 1000,
tune: int = 1000,
discard_tuned_samples: bool = True,
omit_offsets: bool = True,
include_mean: bool = False,
inference_method: Literal["mcmc", "blackjax_nuts", "numpyro_nuts", "laplace"] = "mcmc",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem I see here is that after #775, users have access to a huge variety of samplers through bayeux. I'm not sure how this is handled with literals. But, is it possible that what we need is to create a Literal on the fly (at import time) where we check whether users have access to bayeux, and depending on that, the Literal contains different types?

Also, I don't think there's a convenient way to access the list of samplers without first creating a model. I'm going to create an issue now.

init: str = "auto",
n_init: int = 50000,
chains: Optional[int] = None,
cores: Optional[int] = None,
random_seed: Optional[int] = None,
**kwargs,
):
) -> InferenceData: ...

@overload
def fit(
self,
draws: int = 1000,
tune: int = 1000,
discard_tuned_samples: bool = True,
omit_offsets: bool = True,
include_mean: bool = False,
inference_method: Literal["vi"] = ...,
init: str = "auto",
n_init: int = 50000,
chains: Optional[int] = None,
cores: Optional[int] = None,
random_seed: Optional[int] = None,
**kwargs,
) -> pm.MeanField: ...

def fit(
self,
draws: int = 1000,
tune: int = 1000,
discard_tuned_samples: bool = True,
omit_offsets: bool = True,
include_mean: bool = False,
inference_method: Literal[
"mcmc", "blackjax_nuts", "numpyro_nuts", "vi", "laplace"
] = "mcmc",
init: str = "auto",
n_init: int = 50000,
chains: Optional[int] = None,
cores: Optional[int] = None,
random_seed: Optional[int] = None,
**kwargs,
) -> InferenceData | pm.MeanField:
"""Fit the model using PyMC.

Parameters
Expand Down Expand Up @@ -722,7 +760,13 @@ def plot_priors(
)
return axes

def prior_predictive(self, draws=500, var_names=None, omit_offsets=True, random_seed=None):
def prior_predictive(
self,
draws: int = 500,
var_names: Optional[Union[str, list[str]]] = None,
omit_offsets: bool = True,
random_seed: Optional[int] = None,
) -> InferenceData:
"""Generate samples from the prior predictive distribution.

Parameters
Expand Down Expand Up @@ -763,15 +807,37 @@ def prior_predictive(self, draws=500, var_names=None, omit_offsets=True, random_

return idata

@overload
def predict(
self,
idata,
kind="mean",
data=None,
inplace=True,
include_group_specific=True,
sample_new_groups=False,
):
idata: InferenceData,
kind: Literal["mean", "pps"] = "mean",
data: Optional[pd.DataFrame] = None,
inplace: Literal[True] = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be bool instead of Literal[True]? I'm not a typing expert so I may be missing some details, but since this can be either True or False, I think it should be bool

include_group_specific: bool = True,
sample_new_groups: bool = False,
) -> None: ...

@overload
def predict(
self,
idata: InferenceData,
kind: Literal["mean", "pps"] = "mean",
data: Optional[pd.DataFrame] = None,
inplace: Literal[False] = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

include_group_specific: bool = True,
sample_new_groups: bool = False,
) -> InferenceData: ...

def predict(
self,
idata: InferenceData,
kind: Literal["mean", "pps"] = "mean",
data: Optional[pd.DataFrame] = None,
inplace: bool = True,
include_group_specific: bool = True,
sample_new_groups: bool = False,
) -> Optional[InferenceData]:
"""Predict method for Bambi models

Obtains in-sample and out-of-sample predictions from a fitted Bambi model.
Expand Down