diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py index 88374738b..873cf6699 100644 --- a/bayesflow/adapters/transforms/filter_transform.py +++ b/bayesflow/adapters/transforms/filter_transform.py @@ -19,6 +19,11 @@ def __call__(self, key: str, value: np.ndarray, inverse: bool) -> bool: @serializable(package="bayesflow.adapters") class FilterTransform(Transform): + """ + Implements a transform that applies a different transform on a subset of the data. Used by other transforms and + base adapter class. + """ + def __init__( self, *, diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py index 996331676..b8b1a8359 100644 --- a/bayesflow/adapters/transforms/rename.py +++ b/bayesflow/adapters/transforms/rename.py @@ -7,6 +7,24 @@ @serializable(package="bayesflow.adapters") class Rename(Transform): + """ + Transform to rename keys in data dictionary. Useful to rename variables to match those required by + approximator. This transform can only rename one variable at a time. + + Parameters: + - from_key: str of variable name that should be renamed + - to_key: str representing new name + + Example: + adapter = ( + bf.adapters.Adapter() + + # rename the variables to match the required approximator inputs + .rename("theta", "inference_variables") + .rename("x", "inference_conditions") + ) + """ + def __init__(self, from_key: str, to_key: str): super().__init__() self.from_key = from_key