Skip to content

Commit

Permalink
Replace all non-word characters in feature names to ensure no downstr…
Browse files Browse the repository at this point in the history
…eam issues with external libraries. (#3438)
  • Loading branch information
justinxzhao committed Jul 7, 2023
1 parent 231fd62 commit 60f1416
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 58 deletions.
28 changes: 28 additions & 0 deletions examples/synthetic/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Train a model from entirely synthetic data."""

import logging
import tempfile

import yaml

from ludwig.api import LudwigModel
from ludwig.data.dataset_synthesizer import build_synthetic_dataset_df

config = yaml.safe_load(
"""
input_features:
- name: Pclass (new)
type: category
output_features:
- name: Survived
type: binary
"""
)

df = build_synthetic_dataset_df(120, config)
model = LudwigModel(config, logging_level=logging.INFO)

with tempfile.TemporaryDirectory() as tmpdir:
model.train(dataset=df, output_directory=tmpdir)
5 changes: 5 additions & 0 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
read_spss,
read_stata,
read_tsv,
sanitize_column_names,
SAS_FORMATS,
SPSS_FORMATS,
STATA_FORMATS,
Expand Down Expand Up @@ -1194,6 +1195,10 @@ def build_dataset(

dataset_df = df_engine.parallelize(dataset_df)

# Ensure that column names with non-word characters won't cause problems for downstream operations.
# NOTE: Must be kept consistent with config sanitization in schema/model_types/base.py.
dataset_df = sanitize_column_names(dataset_df)

if mode == "training":
sample_ratio = global_preprocessing_parameters["sample_ratio"]
if sample_ratio < 1.0:
Expand Down
41 changes: 37 additions & 4 deletions ludwig/schema/model_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,27 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.config_validation.checks import get_config_check_registry
from ludwig.config_validation.validation import check_schema
from ludwig.constants import BACKEND, ENCODER, MODEL_ECD
from ludwig.constants import (
BACKEND,
COLUMN,
DEPENDENCIES,
ENCODER,
INPUT_FEATURES,
MODEL_ECD,
NAME,
OUTPUT_FEATURES,
TIED,
)
from ludwig.error import ConfigValidationError
from ludwig.globals import LUDWIG_VERSION
from ludwig.schema import utils as schema_utils
from ludwig.schema.defaults.base import BaseDefaultsConfig
from ludwig.schema.features.base import BaseInputFeatureConfig, BaseOutputFeatureConfig, FeatureCollection
from ludwig.schema.hyperopt import HyperoptConfig
from ludwig.schema.model_types.utils import (
filter_combiner_entities_,
merge_fixed_preprocessing_params,
merge_with_defaults,
sanitize_and_filter_combiner_entities_,
set_derived_feature_columns_,
set_hyperopt_defaults_,
set_llm_tokenizers,
Expand All @@ -30,7 +40,7 @@
from ludwig.schema.utils import ludwig_dataclass
from ludwig.types import ModelConfigDict
from ludwig.utils.backward_compatibility import upgrade_config_dict_to_latest_version
from ludwig.utils.data_utils import load_yaml
from ludwig.utils.data_utils import get_sanitized_feature_name, load_yaml
from ludwig.utils.registry import Registry

model_type_schema_registry = Registry()
Expand All @@ -57,7 +67,7 @@ def __post_init__(self):
set_validation_parameters(self)
set_hyperopt_defaults_(self)
set_tagger_decoder_parameters(self)
filter_combiner_entities_(self)
sanitize_and_filter_combiner_entities_(self)

# Set preprocessing parameters for text features for LLM model type
set_llm_tokenizers(self)
Expand All @@ -77,6 +87,29 @@ def from_dict(config: ModelConfigDict) -> "ModelConfig":
config = copy.deepcopy(config)
config = upgrade_config_dict_to_latest_version(config)

# Use sanitized feature names.
# NOTE: This must be kept consistent with build_dataset()
for input_feature in config[INPUT_FEATURES]:
input_feature[NAME] = get_sanitized_feature_name(input_feature[NAME])
if COLUMN in input_feature and input_feature[COLUMN]:
input_feature[COLUMN] = get_sanitized_feature_name(input_feature[COLUMN])
for output_feature in config[OUTPUT_FEATURES]:
output_feature[NAME] = get_sanitized_feature_name(output_feature[NAME])
if COLUMN in output_feature and output_feature[COLUMN]:
output_feature[COLUMN] = get_sanitized_feature_name(output_feature[COLUMN])

# Sanitize tied feature names.
for input_feature in config[INPUT_FEATURES]:
if TIED in input_feature and input_feature[TIED]:
input_feature[TIED] = get_sanitized_feature_name(input_feature[TIED])

# Sanitize dependent feature names.
for output_feature in config[OUTPUT_FEATURES]:
if DEPENDENCIES in output_feature and output_feature[DEPENDENCIES]:
output_feature[DEPENDENCIES] = [
get_sanitized_feature_name(feature_name) for feature_name in output_feature[DEPENDENCIES]
]

config["model_type"] = config.get("model_type", MODEL_ECD)
model_type = config["model_type"]
if model_type not in model_type_schema_registry:
Expand Down
7 changes: 6 additions & 1 deletion ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ludwig.schema.hyperopt.scheduler import BaseHyperbandSchedulerConfig
from ludwig.schema.trainer import ECDTrainerConfig
from ludwig.types import HyperoptConfigDict, ModelConfigDict
from ludwig.utils.data_utils import get_sanitized_feature_name

if TYPE_CHECKING:
from ludwig.schema.model_types.base import ModelConfig
Expand Down Expand Up @@ -173,12 +174,16 @@ def set_derived_feature_columns_(config_obj: "ModelConfig"):
feature.proc_column = compute_feature_hash(feature.to_dict())


def filter_combiner_entities_(config: "ModelConfig"):
def sanitize_and_filter_combiner_entities_(config: "ModelConfig"):
if config.model_type != MODEL_ECD or config.combiner.type != "comparator":
return

input_feature_names = {input_feature.name for input_feature in config.input_features}

# Sanitize feature names.
config.combiner.entity_1 = [get_sanitized_feature_name(fname) for fname in config.combiner.entity_1]
config.combiner.entity_2 = [get_sanitized_feature_name(fname) for fname in config.combiner.entity_2]

entity_1_excluded = {fname for fname in config.combiner.entity_1 if fname not in input_feature_names}
if entity_1_excluded:
logger.warning(
Expand Down
4 changes: 3 additions & 1 deletion ludwig/schema/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def _jsonschema_type_mapping(self):
if not isinstance(default, dict):
raise ValidationError(f"Invalid default: `{default}`")

load_default = lambda: GradientClippingConfig.Schema().load(default)
def load_default():
return GradientClippingConfig.Schema().load(default)

dump_default = GradientClippingConfig.Schema().dump(default)

return field(
Expand Down
54 changes: 6 additions & 48 deletions ludwig/trainers/trainer_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ def __init__(
# and before set_steps_to_1_or_quit returns
self.original_sigint_handler = None

# LGBM Datasets do not allow JSON special characters in feature names. We sanitize the feature names before
# dataset construction but keep references to the original names to be able to cast back if needed.
self.feature_names_map: Dict[str, str] = None

@staticmethod
def get_schema_cls() -> BaseTrainerConfig:
return GBMTrainerConfig
Expand Down Expand Up @@ -833,8 +829,8 @@ def _construct_lgb_datasets(
validation_set: Optional["Dataset"] = None, # noqa: F821
test_set: Optional["Dataset"] = None, # noqa: F821
) -> Tuple[lgb.Dataset, List[lgb.Dataset], List[str]]:
X_train = self.sanitize_feature_names(training_set.to_scalar_df(self.model.input_features.values()))
y_train = self.sanitize_feature_names(training_set.to_scalar_df(self.model.output_features.values()))
X_train = training_set.to_scalar_df(self.model.input_features.values())
y_train = training_set.to_scalar_df(self.model.output_features.values())

# create dataset for lightgbm
# keep raw data for continued training https://github.com/microsoft/LightGBM/issues/4965#issuecomment-1019344293
Expand All @@ -852,8 +848,8 @@ def _construct_lgb_datasets(
eval_sets = [lgb_train]
eval_names = [LightGBMTrainer.TRAIN_KEY]
if validation_set is not None:
X_val = self.sanitize_feature_names(validation_set.to_scalar_df(self.model.input_features.values()))
y_val = self.sanitize_feature_names(validation_set.to_scalar_df(self.model.output_features.values()))
X_val = validation_set.to_scalar_df(self.model.input_features.values())
y_val = validation_set.to_scalar_df(self.model.output_features.values())
try:
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train, free_raw_data=False).construct()
except lgb.basic.LightGBMError as e:
Expand All @@ -871,8 +867,8 @@ def _construct_lgb_datasets(
pass

if test_set is not None:
X_test = self.sanitize_feature_names(test_set.to_scalar_df(self.model.input_features.values()))
y_test = self.sanitize_feature_names(test_set.to_scalar_df(self.model.output_features.values()))
X_test = test_set.to_scalar_df(self.model.input_features.values())
y_test = test_set.to_scalar_df(self.model.output_features.values())
try:
lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train, free_raw_data=False).construct()
except lgb.basic.LightGBMError as e:
Expand All @@ -896,44 +892,6 @@ def callback(self, fn, coordinator_only=True):
for callback in self.callbacks:
fn(callback)

def sanitize_feature_names(self, df: "DataFrame") -> "DataFrame": # noqa: F821
"""Remove JSON special characters from feature names.
LightGBM Datasets raise an error when processing feature names with JSON special characters (e.g., ".", "{",
"}", "[", "]"). This method replaces non-word characters in DataFrame column names with "_" and creates a map
of `feature_name` -> `sanitized_feature_name`. Assumes that repeated calls will operate on dataframes with the
same columns.
Args:
df: The dataframe to sanitize.
Returns:
A copy of `df` with non-word characters removed from the column names.
"""
sanitizer = lambda k: re.sub(r"[\W]", "_", k)

if self.feature_names_map is None:
self.feature_names_map = {k: sanitizer(k) for k in df.columns}

return df.rename(columns=self.feature_names_map)

def desanitize_feature_names(self, df: "DataFrame") -> "DataFrame": # noqa: F821
"""Restore original feature names to a df.
The inverse of `LightGBMTrainer.sanitize_feature_names`, this method maps sanitized feature names back to their
original formats. Assumes that repeated calls will operate on dataframes with the same columns.
Args:
df: A dataframe previously run through `LightGBMTrainer.sanitize_feature_names`.
Returns:
A copy of `df` with the original feature names restored.
"""
if self.feature_names_map is None:
return df
else:
return df.rename(columns={v: k for k, v in self.feature_names_map.items()})


def _map_to_lgb_ray_params(params: Dict[str, Any]) -> "RayParams": # noqa
from lightgbm_ray import RayParams
Expand Down
14 changes: 14 additions & 0 deletions ludwig/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,3 +1089,17 @@ def use_credentials(creds):
json.dump(old_conf, f)
conf.clear()
set_conf_files(tmpdir, conf)


def get_sanitized_feature_name(feature_name: str) -> str:
"""Replaces non-word characters (anything other than alphanumeric or _) with _.
Used in model config initialization and sanitize_column_names(), which is called during dataset building.
"""
return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", feature_name)


def sanitize_column_names(df: DataFrame) -> DataFrame:
"""Renames df columns with non-word characters (anything other than alphanumeric or _) to _."""
safe_column_names = [get_sanitized_feature_name(col) for col in df.columns]
return df.rename(columns=dict(zip(df.columns, safe_column_names)))
38 changes: 38 additions & 0 deletions tests/integration_tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ludwig.callbacks import Callback
from ludwig.constants import BATCH_SIZE, COLUMN, ENCODER, H3, NAME, PREPROCESSING, TRAINER, TYPE
from ludwig.data.concatenate_datasets import concatenate_df
from ludwig.data.dataset_synthesizer import build_synthetic_dataset_df
from ludwig.data.preprocessing import preprocess_for_training
from ludwig.encoders.registry import get_encoder_classes
from ludwig.error import ConfigValidationError
Expand Down Expand Up @@ -1129,3 +1130,40 @@ def test_experiment_ordinal_category(csv_filename):

rel_path = generate_data(input_features, output_features, csv_filename)
run_experiment(input_features, output_features, dataset=rel_path)


def test_experiment_feature_names_with_non_word_chars(tmpdir):
config = yaml.safe_load(
"""
input_features:
- name: Pclass (new)
type: category
- name: review.text
type: category
- name: other_feature
type: category
tied: review.text
output_features:
- name: Survived (new)
type: binary
- name: Thrived
type: binary
dependencies:
- Survived (new)
combiner:
type: comparator
entity_1:
- Pclass (new)
- other_feature
entity_2:
- review.text
"""
)

df = build_synthetic_dataset_df(120, config)
model = LudwigModel(config, logging_level=logging.INFO)

model.train(dataset=df, output_directory=tmpdir)
8 changes: 4 additions & 4 deletions tests/integration_tests/test_explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from ludwig.api import LudwigModel
from ludwig.constants import BATCH_SIZE, BINARY, CATEGORY, MINIMUM_BATCH_SIZE, MODEL_ECD, MODEL_GBM
from ludwig.constants import BATCH_SIZE, BINARY, CATEGORY, MINIMUM_BATCH_SIZE, MODEL_ECD, MODEL_GBM, TYPE
from ludwig.explain.captum import IntegratedGradientsExplainer
from ludwig.explain.explainer import Explainer
from ludwig.explain.explanation import Explanation
Expand Down Expand Up @@ -173,10 +173,10 @@ def run_test_explainer_api(
input_features = [
# Include a non-canonical name that's not a valid key for a vanilla pytorch ModuleDict:
# https://github.com/pytorch/pytorch/issues/71203
{"name": "binary.1", "type": "binary"},
{"name": "type", "type": "binary"},
number_feature(),
category_feature(encoder={"type": "onehot", "reduce_output": "sum"}),
category_feature(encoder={"type": "passthrough", "reduce_output": "sum"}),
category_feature(encoder={TYPE: "onehot", "reduce_output": "sum"}),
category_feature(encoder={TYPE: "passthrough", "reduce_output": "sum"}),
]
if model_type == MODEL_ECD:
# TODO(travis): need unit tests to test the get_embedding_layer() of every encoder to ensure it is
Expand Down
3 changes: 3 additions & 0 deletions tests/integration_tests/test_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ludwig.api import LudwigModel
from ludwig.constants import INPUT_FEATURES, MODEL_TYPE, OUTPUT_FEATURES, TRAINER
from ludwig.error import ConfigValidationError
from ludwig.schema.model_types.base import ModelConfig
from tests.integration_tests import synthetic_test_data
from tests.integration_tests.utils import binary_feature
from tests.integration_tests.utils import category_feature as _category_feature
Expand Down Expand Up @@ -65,6 +66,8 @@ def _train_and_predict_gbm(input_features, output_features, tmpdir, backend_conf
if trainer_config:
config[TRAINER].update(trainer_config)

config = ModelConfig.from_dict(config).to_dict()

model = LudwigModel(config, backend=backend_config)
_, _, output_directory = model.train(
dataset=dataset_filename,
Expand Down
16 changes: 16 additions & 0 deletions tests/ludwig/utils/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
read_csv,
read_html,
read_parquet,
sanitize_column_names,
use_credentials,
)
from tests.integration_tests.utils import private_param
Expand Down Expand Up @@ -223,3 +224,18 @@ def test_read_html(df_lib, nrows):
kwargs["nrows"] = nrows

read_html(HTML_DOCUMENT, df_lib, **kwargs)


def test_sanitize_column_names():
df = pd.DataFrame(
{
"col.one": [1, 2, 3, 4],
"col(two)": [4, 5, 6, 7],
"col[]:three": [7, 8, 9, 10],
"col 'one' (new)": [1, 2, 3, 4],
}
)

df = sanitize_column_names(df)

assert list(df.columns) == ["col_one", "col_two_", "col___three", "col _one_ _new_"]

0 comments on commit 60f1416

Please sign in to comment.