Skip to content

Commit

Permalink
Add a clear error message about invalid column names in GBM datasets (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffkinnison committed Jan 4, 2023
1 parent 4498c87 commit 303f389
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
34 changes: 31 additions & 3 deletions ludwig/trainers/trainer_lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import re
import signal
import sys
import threading
Expand Down Expand Up @@ -806,14 +807,32 @@ def _construct_lgb_datasets(

# create dataset for lightgbm
# keep raw data for continued training https://github.com/microsoft/LightGBM/issues/4965#issuecomment-1019344293
lgb_train = lgb.Dataset(X_train, label=y_train, free_raw_data=False).construct()
try:
lgb_train = lgb.Dataset(X_train, label=y_train, free_raw_data=False).construct()
except lgb.basic.LightGBMError as e:
if re.search(r"special JSON characters", str(e)):
raise ValueError(
"Some column names in the training set contain invalid characters. Please ensure column names only "
"contain alphanumeric characters and underscores, then try training again."
) from e
else:
raise

eval_sets = [lgb_train]
eval_names = [LightGBMTrainer.TRAIN_KEY]
if validation_set is not None:
X_val = validation_set.to_df(self.model.input_features.values())
y_val = validation_set.to_df(self.model.output_features.values())
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train, free_raw_data=False).construct()
try:
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train, free_raw_data=False).construct()
except lgb.basic.LightGBMError as e:
if re.search(r"special JSON characters", str(e)):
raise ValueError(
"Some column names in the validation set contain invalid characters. Please ensure column "
"names only contain alphanumeric characters and underscores, then try training again."
) from e
else:
raise
eval_sets.append(lgb_val)
eval_names.append(LightGBMTrainer.VALID_KEY)
else:
Expand All @@ -823,7 +842,16 @@ def _construct_lgb_datasets(
if test_set is not None:
X_test = test_set.to_df(self.model.input_features.values())
y_test = test_set.to_df(self.model.output_features.values())
lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train, free_raw_data=False).construct()
try:
lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train, free_raw_data=False).construct()
except lgb.basic.LightGBMError as e:
if re.search(r"special JSON characters", str(e)):
raise ValueError(
"Some column names in the test set contain invalid characters. Please ensure column "
"names only contain alphanumeric characters and underscores, then try training again."
)
else:
raise
eval_sets.append(lgb_test)
eval_names.append(LightGBMTrainer.TEST_KEY)

Expand Down
22 changes: 21 additions & 1 deletion tests/integration_tests/test_gbm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re

import numpy as np
import pytest
Expand All @@ -10,7 +11,7 @@
from jsonschema.exceptions import ValidationError

from ludwig.api import LudwigModel
from ludwig.constants import INPUT_FEATURES, MODEL_TYPE, OUTPUT_FEATURES, TRAINER
from ludwig.constants import COLUMN, INPUT_FEATURES, MODEL_TYPE, NAME, OUTPUT_FEATURES, TRAINER
from tests.integration_tests import synthetic_test_data
from tests.integration_tests.utils import binary_feature, category_feature, generate_data, number_feature, text_feature

Expand Down Expand Up @@ -315,3 +316,22 @@ def test_save_load(tmpdir, local_backend):
preds, _ = model.predict(dataset=os.path.join(tmpdir, "training.csv"), split="test")

assert init_preds.equals(preds)


def test_lgbm_dataset_setup(tmpdir, local_backend):
"""Test that LGBM dataset column name errors are handled."""
input_features = [number_feature()]
output_features = [binary_feature()]

# Overwrite the automatically generated feature/column name with an invalid string.
input_features[0][NAME] = ",Unnamed: 0"
input_features[0][COLUMN] = input_features[0][NAME]

# Test that the custom error is raised (lightgbm.basic.LightGBMError -> ValueError)
with pytest.raises(ValueError):
try:
_train_and_predict_gbm(input_features, output_features, tmpdir, local_backend)
except ValueError as e:
# Check that the intended ValueError was raised
assert re.search("Some column names in the training set contain invalid characters", str(e))
raise

0 comments on commit 303f389

Please sign in to comment.