Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torchscript] Adds GPU-enabled input types for Vector and Timeseries #2197

Merged
merged 7 commits into from
Jun 28, 2022

Conversation

geoffreyangus
Copy link
Contributor

@geoffreyangus geoffreyangus commented Jun 27, 2022

This PR enables torchscript users to pass in Union[torch.Tensor, List[torch.Tensor]] objects for Vector input features and List[torch.Tensor] objects for Timeseries input features in order to better utilize GPU resources.

Prior to this change, users could only pass Vector and Timeseries features as List[str] objects, which required stripping and parsing each sample into torch.Tensor objects on CPU, which can be slow. With this change, users now have the option to pass in torch.Tensor or List[torch.Tensor] objects, which can be operated upon on GPU.

@github-actions
Copy link

github-actions bot commented Jun 27, 2022

Unit Test Results

       6 files  +    1         6 suites  +1   2h 17m 36s ⏱️ + 34m 26s
2 886 tests +    6  2 840 ✔️ +    6    46 💤 ±  0  0 ±0 
8 658 runs  +179  8 516 ✔️ +156  142 💤 +23  0 ±0 

Results for commit 74a02d3. ± Comparison against base commit c26e81a.

♻️ This comment has been updated with latest results.

@geoffreyangus geoffreyangus marked this pull request as ready for review June 27, 2022 18:10
@geoffreyangus geoffreyangus requested review from justinxzhao and tgaddair and removed request for justinxzhao June 27, 2022 18:10
@@ -247,7 +247,9 @@ def generate_text(feature):

def generate_timeseries(feature):
series = []
for _ in range(feature.get("max_len", 10)):
max_len = feature.get("max_len", 10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's make 10 a default function parameter.

Comment on lines 102 to 107
if torch.jit.isinstance(v, List[torch.Tensor]):
return self.forward_list_of_tensors(v)
elif torch.jit.isinstance(v, List[str]):
return self.forward_list_of_strs(v)
else:
raise ValueError(f"Unsupported input: {v}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (personal preference):

Suggested change
if torch.jit.isinstance(v, List[torch.Tensor]):
return self.forward_list_of_tensors(v)
elif torch.jit.isinstance(v, List[str]):
return self.forward_list_of_strs(v)
else:
raise ValueError(f"Unsupported input: {v}")
if torch.jit.isinstance(v, List[torch.Tensor]):
return self.forward_list_of_tensors(v)
if torch.jit.isinstance(v, List[str]):
return self.forward_list_of_strs(v)
raise ValueError(f"Unsupported input: {v}")

Comment on lines 58 to 63
if v.isnan().any():
if self.computed_fill_value == "":
v = torch.nan_to_num(v, nan=self.padding_value)
else:
raise ValueError(f"Fill value must be empty string. Got {self.computed_fill_value}.")
return v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if v.isnan().any():
if self.computed_fill_value == "":
v = torch.nan_to_num(v, nan=self.padding_value)
else:
raise ValueError(f"Fill value must be empty string. Got {self.computed_fill_value}.")
return v
if not v.isnan().any():
# No nans to replace.
return v
if v.isnan().any() and self.computed_fill_value != "":
# Nans present, but fill value is non-empty. (Question: why does the fill value have to be an empty string?)
raise ValueError(f"Fill value must be empty string. Got {self.computed_fill_value}.")
return torch.nan_to_num(v, nan=self.padding_value)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question– I've actually updated the function to support all computed_fill_value values. Thanks for the inspiration!

Comment on lines 506 to 508
feature_name = feature_name_expected[: feature_name_expected.rfind("_")] # remove proc suffix
if feature_name not in preproc_inputs.keys():
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems potentially brittle like if the feature naming changes, then this test won't actually check anything. Are output features the only feature that wouldn't have an entry in preproc_inputs? Perhaps we do a hard continue, only for the output feature.

Alternatively, if we could make this feature_name_expected[: feature_name_expected.rfind("_")] # remove proc suffix into a tested function that's guarantees that it's in sync with preproc module feature naming, that would feel a bit more robust.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. There should actually be no mismatching values in between the two dictionaries, so I've changed the conditional to an assert. Thanks!

skip_save_progress=True,
skip_save_log=True,
skip_save_processed_input=True,
ludwig_model, script_module = initialize_ludwig_model_and_scripted_module(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice simplification.

@geoffreyangus geoffreyangus changed the title [Torchscript] Adds alternative input types for Vector and Timeseries [Torchscript] Adds GPU-enabled input types for Vector and Timeseries Jun 27, 2022
Copy link
Contributor

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@geoffreyangus geoffreyangus merged commit d0e8439 into master Jun 28, 2022
@geoffreyangus geoffreyangus deleted the ts-add-timeseries-vector-alt-dtype branch June 28, 2022 23:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants