Skip to content
This repository has been archived by the owner on Jan 9, 2024. It is now read-only.

Commit

Permalink
Add back override feature lost in previous merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang-gp committed Dec 3, 2019
1 parent e845600 commit eeede46
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
15 changes: 9 additions & 6 deletions foreshadow/smart/intent_resolving/intentresolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
IntentResolver as AutoIntentResolver,
)
from foreshadow.smart.smart import SmartTransformer
from foreshadow.utils import get_transformer
from foreshadow.utils import Override, get_transformer


_temporary_naming_conversion = {
Expand Down Expand Up @@ -52,7 +52,7 @@ def _resolve_intent(self, X, y=None):
# TODO Add sampling on X to reduce run time if the dataset is big
auto_intent_resolver = AutoIntentResolver(X)
intent_pd_series = auto_intent_resolver.predict()
return intent_pd_series[[0]].values[0]
return _temporary_naming_convert(intent_pd_series[[0]].values[0])

def resolve(self, X, *args, **kwargs):
"""Pick the appropriate transformer if necessary.
Expand Down Expand Up @@ -94,9 +94,12 @@ def pick_transformer(self, X, y=None, **fit_params):
Best intent transformer.
"""
intent_class_name = self._resolve_intent(X, y=y)
intent_class = get_transformer(
_temporary_naming_convert(intent_class_name)
)
column = X.columns[0]
override_key = "_".join([Override.INTENT, column])
if override_key in self.cache_manager["override"]:
intent_override = self.cache_manager["override"][override_key]
intent_class = get_transformer(intent_override)
else:
intent_class = get_transformer(self._resolve_intent(X, y=y))

return intent_class()
8 changes: 4 additions & 4 deletions foreshadow/tests/test_foreshadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,12 +818,12 @@ def test_foreshadow_serialization_adults_small_classification_override():

from foreshadow.intents import IntentType

shadow.override_intent("age", IntentType.NUMERIC)
shadow.override_intent("age", IntentType.CATEGORICAL)
shadow.override_intent("workclass", IntentType.CATEGORICAL)
shadow.fit(X_train, y_train)
shadow.to_json("foreshadow_adults_small_logistic_regression_2.json")

assert shadow.get_intent("age") == IntentType.NUMERIC
assert shadow.get_intent("age") == IntentType.CATEGORICAL
assert shadow.get_intent("workclass") == IntentType.CATEGORICAL
score2 = shadow.score(X_test, y_test)

Expand Down Expand Up @@ -854,10 +854,10 @@ def test_foreshadow_adults_small_classification_override_upfront():

from foreshadow.intents import IntentType

shadow.override_intent("age", IntentType.NUMERIC)
shadow.override_intent("age", IntentType.CATEGORICAL)
shadow.override_intent("workclass", IntentType.CATEGORICAL)
shadow.fit(X_train, y_train)
assert shadow.get_intent("age") == IntentType.NUMERIC
assert shadow.get_intent("age") == IntentType.CATEGORICAL
assert shadow.get_intent("workclass") == IntentType.CATEGORICAL
shadow.to_json(
"foreshadow_adults_small_logistic_regression_override_upfront.json"
Expand Down

0 comments on commit eeede46

Please sign in to comment.