Skip to content

Commit

Permalink
Add docstrings to utility functions (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Jul 8, 2023
1 parent 67151b3 commit 6268ccf
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 57 deletions.
3 changes: 3 additions & 0 deletions bambi/defaults/hsgp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""This module contains the default priors for the parameters of different covariance kernels
that can be used in HSGP terms."""

# fmt: off
HSGP_COV_PARAMS_DEFAULT_PRIORS = {
"ExpQuad": {
Expand Down
8 changes: 8 additions & 0 deletions bambi/families/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ class ZeroInflatedPoisson(UnivariateFamily):

# pylint: disable = protected-access
def get_success_level(term):
"""Returns the success level of a categorical term.
Whenever the concept of "success level" does not apply, it returns None.
"""
if term.kind != "categoric":
return None

Expand All @@ -444,6 +448,10 @@ def get_success_level(term):

# pylint: disable = protected-access
def get_reference_level(term):
"""Returns the reference level of a categorical term.
Whenever the concept of "reference level" does not apply, it returns None.
"""
if term.kind != "categoric":
return None

Expand Down
11 changes: 10 additions & 1 deletion bambi/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def check_additionals(self, additionals: Sequence[str]):
return additionals

def check_additional(self, additional: str):
"""Check if an additional formula match the expected format
"""Check if an additional formula matches the expected format
Parameters
----------
Expand Down Expand Up @@ -91,11 +91,20 @@ def __repr__(self):


def formula_has_intercept(formula: str) -> bool:
"""Determines if a model formula results in a model with intercept."""
description = fm.model_description(formula)
return any(isinstance(term, fm.terms.Intercept) for term in description.terms)


def check_ordinal_formula(formula: Formula) -> Formula:
"""Check if a supplied formula can be used with an ordinal model.
Ordinal models have the following constrains (for the moment):
* A single formula must be passed. This is because Bambi does not support modeling the
thresholds as a function of predictors.
* The intercept is omitted. This is to avoid non-identifiability issues between the intercept
and the thresholds.
"""
if len(formula.additionals) > 0:
raise ValueError("Ordinal families don't accept multiple formulas")
if formula_has_intercept(formula.main):
Expand Down
6 changes: 0 additions & 6 deletions bambi/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,3 @@ def prepare_prior(prior, kind, auto_scale):
else:
raise ValueError("'prior' must be instance of Prior or None.")
return prior


def with_suffix(value, suffix):
if suffix:
return f"{value}_{suffix}"
return value
15 changes: 12 additions & 3 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def plot_priors(
omit_offsets=True,
omit_group_specific=True,
ax=None,
**kwargs,
):
"""
Samples from the prior distribution and plots its marginals.
Expand Down Expand Up @@ -707,6 +708,7 @@ def plot_priors(
kind=kind,
bins=bins,
ax=ax,
**kwargs,
)
return axes

Expand Down Expand Up @@ -1016,7 +1018,11 @@ def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}


def with_categorical_cols(data, columns):
def with_categorical_cols(data: pd.DataFrame, columns) -> pd.DataFrame:
"""Convert selected columns of a DataFrame to categorical type.
It converts all object columns plus columns specified in the `columns` argument.
"""
# Convert 'object' and explicitly asked columns to categorical.
object_columns = list(data.select_dtypes("object").columns)
to_convert = list(set(object_columns + listify(columns)))
Expand All @@ -1026,18 +1032,21 @@ def with_categorical_cols(data, columns):
return data


def prior_repr(term):
def prior_repr(term) -> str:
"""Get a string representation of a Bambi term."""
return f"{term.name} ~ {term.prior}"


def hsgp_repr(term):
def hsgp_repr(term) -> str:
"""Get a string representation of a Bambi HSGP term."""
output_list = [f"cov: {term.cov}", *[f"{key} ~ {value}" for key, value in term.prior.items()]]
output_list = [" " + element for element in output_list]
output_list.insert(0, term.name)
return "\n".join(output_list)


def make_priors_summary(component: DistributionalComponent) -> str:
"""Get a summary of terms and priors in a distributional component."""
# Common effects
priors_common = [
prior_repr(term) for term in component.common_terms.values() if term.kind != "offset"
Expand Down
56 changes: 14 additions & 42 deletions bambi/terms/utils.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,32 @@
from formulae.terms.call import Call
import formulae as fm
from formulae.terms.call_resolver import get_function_from_module


# pylint: disable = protected-access
def get_reference_level(term):
if term.kind != "categoric":
return None
def is_single_component(term) -> bool:
"""Determines if formulae term contains a single component"""
return hasattr(term, "components") and len(term.components) == 1

if term.levels is None:
return None

levels = term.levels
intermediate_data = term.components[0]._intermediate_data
if hasattr(intermediate_data, "_contrast"):
return intermediate_data._contrast.reference

return levels[0]


# pylint: disable = protected-access
def get_success_level(term):
if term.kind != "categoric":
return None

if term.levels is None:
return term.components[0].reference

levels = term.levels
intermediate_data = term.components[0]._intermediate_data
if hasattr(intermediate_data, "_contrast"):
return intermediate_data._contrast.reference

return levels[0]


def is_single_component(term):
return len(term.term.components) == 1


def extract_first_component(term):
return term.term.components[0]


def is_call_component(component):
return isinstance(component, Call)
def is_call_component(component) -> bool:
"""Determines if formulae component is the result of a function call"""
return isinstance(component, fm.terms.call.Call)


def is_call_of_kind(call, kind):
"""Determines if formulae call component is of certain kind
To do so, it checks whether the callee has metadata and whether the 'kind' slot matches the
kind passed to the function.
"""
function = get_function_from_module(call.call.callee, call.env)
return hasattr(function, "__metadata__") and function.__metadata__["kind"] == kind


def is_censored_response(term):
"""Determines if a formulae term represents a censored response"""
if not is_single_component(term):
return False
component = extract_first_component(term)
component = term.term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "censored")
26 changes: 21 additions & 5 deletions bambi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def multilinify(sequence: Sequence[str], sep: str = ",") -> str:
return "\n" + sep.join(sequence)


def wrapify(string, width=100, indentation=2):
def wrapify(string: str, width: int = 100, indentation: int = 2) -> str:
"""Wraps long strings into multiple lines.
This function is used to print the model summary.
"""
lines = string.splitlines(True)
wrapper = textwrap.TextWrapper(width=width)
for idx, line in enumerate(lines):
Expand Down Expand Up @@ -146,30 +150,42 @@ def get_aliased_name(term):
return term.name


def is_single_component(term):
def is_single_component(term) -> bool:
"""Determines if formulae term contains a single component"""
return hasattr(term, "components") and len(term.components) == 1


def is_call_component(component):
def is_call_component(component) -> bool:
"""Determines if formulae component is the result of a function call"""
return isinstance(component, fm.terms.call.Call)


def has_stateful_transform(component):
def is_stateful_transform(component):
"""Determines if formulae call component is a stateful transformation"""
return component.call.stateful_transform is not None


def is_hsgp_term(term):
"""Determines if formulae term represents a HSGP term
Bambi uses this function to detect HSGP terms and treat them in a different way.
"""
if not is_single_component(term):
return False
component = term.components[0]
if not is_call_component(component):
return False
if not has_stateful_transform(component):
if not is_stateful_transform(component):
return False
return isinstance(component.call.stateful_transform, HSGP)


def remove_common_intercept(dm: fm.matrices.DesignMatrices) -> fm.matrices.DesignMatrices:
"""Removes the intercept from the common design matrix
This is used in ordinal families, where the intercept is requested but not used because its
inclusion, together with the cutpoints, would create a non-identifiability problem.
"""
dm.common.terms.pop("Intercept")
intercept_slice = dm.common.slices.pop("Intercept")
dm.common.design_matrix = np.delete(dm.common.design_matrix, intercept_slice, axis=1)
Expand Down

0 comments on commit 6268ccf

Please sign in to comment.