Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add full support for contrasts to Formulaic #70

Merged
merged 5 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 55 additions & 29 deletions formulaic/materializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def _encode_evaled_factor(
drop_rows: set,
reduced_rank: bool = False,
) -> Dict[str, Any]:
if not isinstance(factor.values, dict) or not factor.metadata.encoded:
if not factor.metadata.encoded:
if factor.expr in self.encoded_cache:
encoded = self.encoded_cache[factor.expr]
elif (factor.expr, reduced_rank) in self.encoded_cache:
Expand Down Expand Up @@ -551,40 +551,64 @@ def wrapped(values, metadata, state, *args, **kwargs):
)
if nested_state:
state[k] = nested_state
return encoded
if isinstance(values, FactorValues):
return FactorValues(
encoded, metadata=values.__formulaic_metadata__
)
return encoded # pragma: no cover; nothing in formulaic uses this, but is here for generality.
return f(values, metadata, state, *args, **kwargs)

return wrapped

# If we need to unpack values into columns, we do this here.
# Otherwise, we pass through the original values.
factor_values = FactorValues(
self._extract_columns_for_encoding(factor),
metadata=factor.metadata,
)

encoder_state = spec.encoder_state.get(factor.expr, [None, {}])[1]
if factor.metadata.kind is Factor.Kind.CATEGORICAL:
encoded = map_dict(self._encode_categorical)(
factor_values,
factor.metadata,
encoder_state,
spec,
drop_rows,
reduced_rank=reduced_rank,
)
elif factor.metadata.kind is Factor.Kind.NUMERICAL:
encoded = map_dict(self._encode_numerical)(
factor_values, factor.metadata, encoder_state, spec, drop_rows
)
elif factor.metadata.kind is Factor.Kind.CONSTANT:
encoded = map_dict(self._encode_constant)(
factor_values, factor.metadata, encoder_state, spec, drop_rows

if factor.metadata.encoder is not None:
encoded = as_columns(
factor.metadata.encoder(
factor.values,
reduced_rank=reduced_rank,
drop_rows=drop_rows,
encoder_state=encoder_state,
model_spec=spec,
)
)
else:
raise FactorEncodingError(
factor
) # pragma: no cover; it is not currently possible to reach this sentinel
# If we need to unpack values into columns, we do this here.
# Otherwise, we pass through the original values.
factor_values = FactorValues(
self._extract_columns_for_encoding(factor),
metadata=factor.metadata,
)

if factor.metadata.kind is Factor.Kind.CATEGORICAL:
encoded = map_dict(self._encode_categorical)(
factor_values,
factor.metadata,
encoder_state,
spec,
drop_rows,
reduced_rank=reduced_rank,
)
elif factor.metadata.kind is Factor.Kind.NUMERICAL:
encoded = map_dict(self._encode_numerical)(
factor_values,
factor.metadata,
encoder_state,
spec,
drop_rows,
)
elif factor.metadata.kind is Factor.Kind.CONSTANT:
encoded = map_dict(self._encode_constant)(
factor_values,
factor.metadata,
encoder_state,
spec,
drop_rows,
)
else:
raise FactorEncodingError(
factor
) # pragma: no cover; it is not currently possible to reach this sentinel
spec.encoder_state[factor.expr] = (factor.metadata.kind, encoder_state)

# Only encode once for encodings where we can just drop a field
Expand All @@ -596,7 +620,9 @@ def wrapped(values, metadata, state, *args, **kwargs):

self.encoded_cache[cache_key] = encoded
else:
encoded = factor.values
encoded = as_columns(
factor.values
) # pragma: no cover; we don't use this in formulaic yet.

encoded = FactorValues(
encoded,
Expand Down
22 changes: 13 additions & 9 deletions formulaic/materializers/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas
import scipy.sparse as spsparse
from interface_meta import override
from formulaic.utils.cast import as_columns

from .base import FormulaMaterializer
from .types import NAAction
Expand All @@ -31,7 +32,9 @@ def _check_for_nulls(self, name, values, na_action, drop_rows):
if na_action is NAAction.IGNORE:
return

if isinstance(values, dict):
if isinstance(
values, dict
): # pragma: no cover; no formulaic transforms return dictionaries any more
for key, vs in values.items():
self._check_for_nulls(f"{name}[{key}]", vs, na_action, drop_rows)

Expand Down Expand Up @@ -74,16 +77,18 @@ def _encode_categorical(
# Even though we could reduce rank here, we do not, so that the same
# encoding can be cached for both reduced and unreduced rank. The
# rank will be reduced in the _encode_evaled_factor method.
from formulaic.transforms import encode_categorical
from formulaic.transforms import encode_contrasts

if drop_rows:
values = values.drop(index=values.index[drop_rows])
return encode_categorical(
values,
reduced_rank=False,
_metadata=metadata,
_state=encoder_state,
_spec=spec,
return as_columns(
encode_contrasts(
values,
reduced_rank=False,
_metadata=metadata,
_state=encoder_state,
_spec=spec,
)
)

@override
Expand Down Expand Up @@ -132,7 +137,6 @@ def _get_columns_for_term(self, factors, spec, scale=1):

@override
def _combine_columns(self, cols, spec, drop_rows):

# If we are outputing a pandas DataFrame, explicitly override index
# in case transforms/etc have lost track of it.
if spec.output == "pandas":
Expand Down
13 changes: 12 additions & 1 deletion formulaic/materializers/types/factor_values.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from typing import Generic, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import wrapt

Expand All @@ -28,6 +28,12 @@ class FactorValuesMetadata:
format: The format to use when exploding factors into multiple columns
(e.g. when encoding categories via dummy-encoding).
encoded: Whether the values should be treated as pre-encoded.
encoder: An optional callable with signature
`(values: Any, reduced_rank: bool, drop_rows: List[int], encoder_state: Dict[str, Any], spec: ModelSpec)`
that outputs properly encoded values suitable for the current
materializer. Note that this should only be used in cases where
direct evaluation would yield different results in reduced vs.
non-reduced rank scenarios.
"""

kind: Factor.Kind = Factor.Kind.UNKNOWN
Expand All @@ -36,6 +42,7 @@ class FactorValuesMetadata:
drop_field: Optional[str] = None
format: str = "{name}[{field}]"
encoded: bool = False
encoder: Optional[Callable[[Any, bool, List[int], Dict[str, Any]], Any]] = None

def replace(self, **kwargs) -> FactorValuesMetadata:
"""
Expand Down Expand Up @@ -65,6 +72,9 @@ def __init__(
drop_field: Optional[str] = MISSING,
format: str = MISSING,
encoded: bool = MISSING,
encoder: Optional[
Callable[[Any, bool, List[int], Dict[str, Any]], Any]
] = MISSING,
):
metadata_constructor = FactorValuesMetadata
metadata_kwargs = dict(
Expand All @@ -74,6 +84,7 @@ def __init__(
drop_field=drop_field,
format=format,
encoded=encoded,
encoder=encoder,
)
for key in set(metadata_kwargs):
if metadata_kwargs[key] is MISSING:
Expand Down
16 changes: 14 additions & 2 deletions formulaic/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@

from .basis_spline import basis_spline
from .identity import identity
from .encode_categorical import encode_categorical
from .contrasts import C, encode_contrasts, ContrastsRegistry
from .poly import poly
from .scale import center, scale

__all__ = [
"basis_spline",
"identity",
"C",
"encode_contrasts",
"ContrastsRegistry",
"poly",
"center",
"scale",
"TRANSFORMS",
]

TRANSFORMS = {
# Common transforms
Expand All @@ -21,6 +32,7 @@
"center": center,
"poly": poly,
"scale": scale,
"C": encode_categorical,
"C": C,
"contr": ContrastsRegistry,
"I": identity,
}
Loading