Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample_size as a global preprocessing parameter #3650

Merged
merged 12 commits into from
Oct 12, 2023
Merged

Conversation

Infernaught
Copy link
Contributor

@Infernaught Infernaught commented Sep 21, 2023

Adds sample_size as a global preprocessing parameter, allowing users to specify exactly how many samples they want to train on instead of having to calculate the sample_ratio. Adds two integration tests to verify sample_size works as intended.

@github-actions
Copy link

github-actions bot commented Sep 21, 2023

Unit Test Results

  6 files  ±  0    6 suites  ±0   21m 26s ⏱️ - 21m 0s
12 tests  - 19    9 ✔️  - 17    3 💤  - 2  0 ±0 
60 runs   - 22  42 ✔️  - 24  18 💤 +2  0 ±0 

Results for commit 2de5849. ± Comparison against base commit ee92f7d.

This pull request removes 19 tests.
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-experiment-1919-0]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-experiment-1919-1]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-experiment-31-0]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-experiment-31-1]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-train-1919-0]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-train-1919-1]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-train-31-0]
tests.integration_tests.test_cli ‑ test_reproducible_cli_runs[horovod-train-31-1]
tests.integration_tests.test_cli ‑ test_train_cli_horovod
tests.integration_tests.test_horovod ‑ test_horovod_gpu_memory_limit
…

♻️ This comment has been updated with latest results.

if sample_cap < len(dataset_df):
dataset_df = dataset_df.sample(n=sample_cap, random_state=random_seed)
else:
logger.info("sample_cap is larger than dataset size, ignoring sample_cap")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps logger.warning?

Comment on lines 1214 to 1217
sample_cap = global_preprocessing_parameters["sample_cap"]
if sample_cap:
if sample_ratio < 1.0:
raise ValueError("sample_cap cannot be used when sample_ratio < 1.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we can push this up into a schema validation check, i.e., if preprocessing sample_ratio is specified and it is < 1 and sample_cap is also specified, then raise a ConfigValidationError?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If we can implement this as an auxiliary validation, that would allow the config to fail as early as possible.

Copy link
Contributor

@arnavgarg1 arnavgarg1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments, but generally LGTM

Copy link
Collaborator

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I like the change overall.

Comment on lines 1214 to 1217
sample_cap = global_preprocessing_parameters["sample_cap"]
if sample_cap:
if sample_ratio < 1.0:
raise ValueError("sample_cap cannot be used when sample_ratio < 1.0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If we can implement this as an auxiliary validation, that would allow the config to fail as early as possible.

@@ -1211,6 +1211,15 @@ def build_dataset(
logger.debug(f"sample {sample_ratio} of data")
dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)

sample_cap = global_preprocessing_parameters["sample_cap"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you refactor this section out into a separate function?

dataset_df = get_sampled_dataset_df(dataset_df, sample_ratio, sample_cap)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Done!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnavgarg1 @justinxzhao Is this what you guys were looking for? I've testing this locally and it seems to have the right functionality.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Last request from me is to add a simple test to https://github.com/ludwig-ai/ludwig/blob/master/tests/ludwig/config_validation/test_checks.py since we're adding code to checks.py

dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)

if sample_cap:
if sample_cap < len(dataset_df):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(dataset_df) is a very expensive op for Dask dataframes (this is why we have an explicit check above to skip calling this when df_engine.partitioned), so calling it twice is quick succession is definitely not ideal. Let's do this instead:

df_len = len(dataset_df)
if sample_cap < df_len:
  # Cannot use 'n' parameter when using dask DataFrames -- only 'frac' is supported
  sample_ratio = sample_cap / df_len
  dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)

if sample_cap < len(dataset_df):
# Cannot use 'n' parameter when using dask DataFrames -- only 'frac' is supported
sample_ratio = sample_cap / len(dataset_df)
dataset_df = dataset_df.sample(frac=sample_ratio, random_state=random_seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that for Dask this will not be exact, but that's probably okay.

Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of the name sample_cap personally. Can we call it something like sample_size instead?

def check_sample_ratio_and_cap_compatible(config: "ModelConfig") -> None:
sample_ratio = config.preprocessing.sample_ratio
sample_cap = config.preprocessing.sample_cap
if sample_cap and sample_ratio < 1.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edge case, but this would allow something like:

sample_cap: 0
sample_ratio: 0.5

So would be more correct to ay:

if sample_cap is not None and sample_ratio < 1.0:

- 1000
expected_impact: 2
suggested_values: Depends on data size
ui_display_name: Sample Cap
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Sample Size.

count = len(train_set) + len(val_set) + len(test_set)
assert sample_size == count

# Check that sample cap is disabled when doing preprocessing for prediction
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sample size

@Infernaught Infernaught changed the title Add sample_cap as a global preprocessing parameter Add sample_size as a global preprocessing parameter Oct 11, 2023
@Infernaught Infernaught merged commit df6f5ef into master Oct 12, 2023
17 checks passed
@Infernaught Infernaught deleted the sample_cap branch October 12, 2023 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants