Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix postprocessing on binary feature columns with number dtype #2189

Merged
merged 16 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions ludwig/features/binary_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ class _BinaryPostprocessing(torch.nn.Module):
def __init__(self, metadata: Dict[str, Any]):
super().__init__()
bool2str = metadata.get("bool2str")
# If the values in column could have been inferred as boolean dtype, do not convert preds to strings.
# This preserves the behavior of this feature before #2058.
if strings_utils.values_are_pandas_bools(bool2str):
bool2str = None
self.bool2str = {i: v for i, v in enumerate(bool2str)} if bool2str is not None else None
self.predictions_key = PREDICTIONS
self.probabilities_key = PROBABILITIES
Expand Down Expand Up @@ -159,17 +155,37 @@ def preprocessing_schema() -> Dict[str, Any]:

@staticmethod
def cast_column(column, backend):
# Binary features are always read as strings. column.astype(bool) for all non-empty cells returns True.
return column.astype(str)
"""Cast column of dtype object to bool.

Unchecked casting to boolean when given a column of dtype object converts all non-empty cells to True. We check
the values of the column directly and manually determine the best dtype to use.
"""
values = backend.df_engine.compute(column.drop_duplicates())

if strings_utils.values_are_pandas_numbers(values):
# If numbers, convert to float so it can be converted to bool
column = column.astype(float).astype(bool)
elif strings_utils.values_are_pandas_bools(values):
# If booleans, manually assign boolean values
column = backend.df_engine.map_objects(
column, lambda x: x.lower() in strings_utils.PANDAS_TRUE_STRS
).astype(bool)
else:
# If neither numbers or booleans, they are strings (objects)
column = column.astype(object)
return column

@staticmethod
def get_feature_meta(column: DataFrame, preprocessing_parameters: Dict[str, Any], backend) -> Dict[str, Any]:
distinct_values = backend.df_engine.compute(column.drop_duplicates()).tolist()
if column.dtype != object:
return {}

distinct_values = backend.df_engine.compute(column.drop_duplicates())
if len(distinct_values) > 2:
raise ValueError(
f"Binary feature column {column.name} expects 2 distinct values, " f"found: {distinct_values}"
f"Binary feature column {column.name} expects 2 distinct values, "
f"found: {distinct_values.values.tolist()}"
)

if "fallback_true_label" in preprocessing_parameters:
fallback_true_label = preprocessing_parameters["fallback_true_label"]
else:
Expand Down Expand Up @@ -201,9 +217,7 @@ def add_feature_data(
) -> None:
column = input_df[feature_config[COLUMN]]

column_np_dtype = np.dtype(column.dtype)
# np.str_ is dtype of modin cols
if any(column_np_dtype == np_dtype for np_dtype in {np.object_, np.str_}):
if column.dtype == object:
metadata = metadata[feature_config[NAME]]
if "str2bool" in metadata:
column = backend.df_engine.map_objects(column, lambda x: metadata["str2bool"][str(x)])
Expand Down Expand Up @@ -373,13 +387,10 @@ def postprocess_predictions(
predictions_col = f"{self.feature_name}_{PREDICTIONS}"
if predictions_col in result:
if "bool2str" in metadata:
# If the values in column could have been inferred as boolean dtype, do not convert preds to strings.
# This preserves the behavior of this feature before #2058.
if not strings_utils.values_are_pandas_bools(class_names):
result[predictions_col] = backend.df_engine.map_objects(
result[predictions_col],
lambda pred: metadata["bool2str"][pred],
)
result[predictions_col] = backend.df_engine.map_objects(
result[predictions_col],
lambda pred: metadata["bool2str"][pred],
)

probabilities_col = f"{self.feature_name}_{PROBABILITIES}"
if probabilities_col in result:
Expand Down
15 changes: 10 additions & 5 deletions ludwig/utils/strings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,18 @@ def str2bool(v: str, fallback_true_label=None) -> bool:
return v == fallback_true_label


def column_is_bool(column: Series) -> bool:
"""Returns whether a column could have been cast by read_csv as boolean."""
distinct_values = column.drop_duplicates()
return values_are_pandas_bools(distinct_values)
def values_are_pandas_numbers(values: List[str]):
"""Returns True if values would be read by pandas as dtype float or int."""
for v in values:
try:
float(v)
except ValueError:
return False
return True


def values_are_pandas_bools(values: List[Union[str, bool]]):
def values_are_pandas_bools(values: List[str]):
"""Returns True if values would be read by pandas as dtype bool."""
lowercase_values_set = {str(v).lower() for v in values}
return lowercase_values_set.issubset(PANDAS_FALSE_STRS | PANDAS_TRUE_STRS)

Expand Down
17 changes: 6 additions & 11 deletions ludwig/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from ludwig.utils.dataframe_utils import to_numpy_dataset, unflatten_df
from ludwig.utils.misc_utils import get_from_registry
from ludwig.utils.print_utils import logging_level_registry
from ludwig.utils.strings_utils import column_is_bool

logger = logging.getLogger(__name__)

Expand All @@ -68,15 +67,8 @@ def _convert_ground_truth(ground_truth, feature_metadata, ground_truth_apply_idx
# non-standard boolean representation
ground_truth = _vectorize_ground_truth(ground_truth, feature_metadata["str2bool"], ground_truth_apply_idx)
else:
# If the values in column could have been inferred as boolean dtype, cast (strings) as booleans.
# This preserves the behavior of this feature before #2058.
if column_is_bool(ground_truth):
ground_truth = _vectorize_ground_truth(
ground_truth, {"false": False, "False": False, "true": True, "True": True}, ground_truth_apply_idx
)
else:
# standard boolean representation
ground_truth = ground_truth.values
# standard boolean representation
ground_truth = ground_truth.values

# ensure positive_label is 1 for binary feature
positive_label = 1
Expand Down Expand Up @@ -233,7 +225,10 @@ def _extract_ground_truth_values(
reader = get_from_registry(data_format, external_data_reader_registry)

# retrieve ground truth from source data set
gt_df = reader(ground_truth)
if data_format in {"csv", "tsv"}:
gt_df = reader(ground_truth, dtype=None) # allow type inference
else:
gt_df = reader(ground_truth)

# extract ground truth for visualization
if SPLIT in gt_df:
Expand Down
114 changes: 95 additions & 19 deletions tests/integration_tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

import os
from functools import partial
from unittest import mock

import numpy as np
Expand All @@ -23,11 +24,19 @@

from ludwig.api import LudwigModel
from ludwig.constants import NAME, TRAINER
from tests.integration_tests.utils import binary_feature, category_feature, generate_data
from tests.integration_tests.utils import (
binary_feature,
category_feature,
generate_data,
init_backend,
RAY_BACKEND_CONFIG,
)


@pytest.mark.distributed
@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.parametrize("distinct_values", [(False, True), ("No", "Yes")])
def test_binary_predictions(tmpdir, distinct_values):
def test_binary_predictions(tmpdir, backend, distinct_values):
input_features = [
category_feature(vocab_size=3),
]
Expand All @@ -41,31 +50,18 @@ def test_binary_predictions(tmpdir, distinct_values):
input_features,
output_features,
os.path.join(tmpdir, "dataset.csv"),
num_examples=100,
)
data_df = pd.read_csv(data_csv_path)

# Optionally convert bool values to strings, e.g., {'Yes', 'No'}
false_value, true_value = distinct_values
data_df[feature[NAME]] = data_df[feature[NAME]].map(lambda x: true_value if x else false_value)
data_df.to_csv(data_csv_path, index=False)

config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 1}}
ludwig_model = LudwigModel(config)
_, _, output_directory = ludwig_model.train(
dataset=data_df,
output_directory=os.path.join(tmpdir, "output"),
)

# Check that metadata JSON saves and loads correctly
ludwig_model = LudwigModel.load(os.path.join(output_directory, "model"))

# Produce an even mix of True and False predictions, as the model may be biased towards
# one direction without training
def random_logits(*args, **kwargs):
return torch.tensor(np.random.uniform(low=-1.0, high=1.0, size=(len(data_df),)))

with mock.patch("ludwig.features.binary_feature.BinaryOutputFeature.logits", random_logits):
preds_df, _ = ludwig_model.predict(dataset=data_csv_path)

preds_df = predict_with_backend(tmpdir, config, data_csv_path, backend, num_predict_samples=len(data_df))
cols = set(preds_df.columns)
assert f"{feature[NAME]}_predictions" in cols
assert f"{feature[NAME]}_probabilities_{str(false_value)}" in cols
Expand All @@ -83,4 +79,84 @@ def random_logits(*args, **kwargs):
assert prob_1 == prob
else:
assert prob_0 == prob
assert prob_0 == 1 - prob_1
assert np.allclose(prob_0, 1 - prob_1)


@pytest.mark.distributed
@pytest.mark.parametrize("backend", ["local", "ray"])
@pytest.mark.parametrize("distinct_values", [(0.0, 1.0), (0, 1)])
def test_binary_predictions_with_number_dtype(tmpdir, backend, distinct_values):
input_features = [
category_feature(vocab_size=3),
]

feature = binary_feature()
output_features = [
feature,
]

data_csv_path = generate_data(
input_features,
output_features,
os.path.join(tmpdir, "dataset.csv"),
num_examples=100,
)
data_df = pd.read_csv(data_csv_path)

# Optionally convert bool values to strings, e.g., {'Yes', 'No'}
false_value, true_value = distinct_values
data_df[feature[NAME]] = data_df[feature[NAME]].map(lambda x: true_value if x else false_value)
data_df.to_csv(data_csv_path, index=False)

config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 1}}

preds_df = predict_with_backend(tmpdir, config, data_csv_path, backend, num_predict_samples=len(data_df))
cols = set(preds_df.columns)
assert f"{feature[NAME]}_predictions" in cols
assert f"{feature[NAME]}_probabilities_False" in cols
assert f"{feature[NAME]}_probabilities_True" in cols
assert f"{feature[NAME]}_probability" in cols

for pred, prob_0, prob_1, prob in zip(
preds_df[f"{feature[NAME]}_predictions"],
preds_df[f"{feature[NAME]}_probabilities_False"],
preds_df[f"{feature[NAME]}_probabilities_True"],
preds_df[f"{feature[NAME]}_probability"],
):
assert isinstance(pred, bool)
if pred:
assert prob_1 == prob
else:
assert prob_0 == prob
assert np.allclose(prob_0, 1 - prob_1)


def predict_with_backend(tmpdir, config, data_csv_path, backend, num_predict_samples=None):
with init_backend(backend):
if backend == "ray":
backend = RAY_BACKEND_CONFIG
backend["processor"]["type"] = "dask"

ludwig_model = LudwigModel(config, backend=backend)
_, _, output_directory = ludwig_model.train(
dataset=data_csv_path,
output_directory=os.path.join(tmpdir, "output"),
)
# Check that metadata JSON saves and loads correctly
ludwig_model = LudwigModel.load(os.path.join(output_directory, "model"))

# Produce an even mix of True and False predictions, as the model may be biased towards
# one direction without training
def random_logits(*args, num_predict_samples=None, **kwargs):
return torch.tensor(np.random.uniform(low=-1.0, high=1.0, size=(num_predict_samples,)))

if num_predict_samples is not None:
with mock.patch(
"ludwig.features.binary_feature.BinaryOutputFeature.logits",
partial(random_logits, num_predict_samples=num_predict_samples),
):
preds_df, _ = ludwig_model.predict(dataset=data_csv_path)
else:
preds_df, _ = ludwig_model.predict(dataset=data_csv_path)

return preds_df
33 changes: 2 additions & 31 deletions tests/integration_tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import os
import tempfile

Expand All @@ -37,6 +36,8 @@
h3_feature,
image_feature,
number_feature,
RAY_BACKEND_CONFIG,
ray_start,
sequence_feature,
set_feature,
text_feature,
Expand Down Expand Up @@ -66,36 +67,6 @@ def predict_cpu(model_dir, dataset):
ray = None


RAY_BACKEND_CONFIG = {
"type": "ray",
"processor": {
"parallelism": 2,
},
"trainer": {
"use_gpu": False,
"num_workers": 2,
"resources_per_worker": {
"CPU": 0.1,
"GPU": 0,
},
},
}


@contextlib.contextmanager
def ray_start(num_cpus=2, num_gpus=None):
res = ray.init(
num_cpus=num_cpus,
num_gpus=num_gpus,
include_dashboard=False,
object_store_memory=150 * 1024 * 1024,
)
try:
yield res
finally:
ray.shutdown()


def run_api_experiment(config, dataset, backend_config, skip_save_processed_input=True):
# Sanity check that we get 4 slots over 1 host
kwargs = get_trainer_kwargs()
Expand Down
Loading