-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torchscript] Adds GPU-enabled input types for Vector and Timeseries #2197
Conversation
ludwig/data/dataset_synthesizer.py
Outdated
@@ -247,7 +247,9 @@ def generate_text(feature): | |||
|
|||
def generate_timeseries(feature): | |||
series = [] | |||
for _ in range(feature.get("max_len", 10)): | |||
max_len = feature.get("max_len", 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Let's make 10
a default function parameter.
if torch.jit.isinstance(v, List[torch.Tensor]): | ||
return self.forward_list_of_tensors(v) | ||
elif torch.jit.isinstance(v, List[str]): | ||
return self.forward_list_of_strs(v) | ||
else: | ||
raise ValueError(f"Unsupported input: {v}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (personal preference):
if torch.jit.isinstance(v, List[torch.Tensor]): | |
return self.forward_list_of_tensors(v) | |
elif torch.jit.isinstance(v, List[str]): | |
return self.forward_list_of_strs(v) | |
else: | |
raise ValueError(f"Unsupported input: {v}") | |
if torch.jit.isinstance(v, List[torch.Tensor]): | |
return self.forward_list_of_tensors(v) | |
if torch.jit.isinstance(v, List[str]): | |
return self.forward_list_of_strs(v) | |
raise ValueError(f"Unsupported input: {v}") |
if v.isnan().any(): | ||
if self.computed_fill_value == "": | ||
v = torch.nan_to_num(v, nan=self.padding_value) | ||
else: | ||
raise ValueError(f"Fill value must be empty string. Got {self.computed_fill_value}.") | ||
return v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if v.isnan().any(): | |
if self.computed_fill_value == "": | |
v = torch.nan_to_num(v, nan=self.padding_value) | |
else: | |
raise ValueError(f"Fill value must be empty string. Got {self.computed_fill_value}.") | |
return v | |
if not v.isnan().any(): | |
# No nans to replace. | |
return v | |
if v.isnan().any() and self.computed_fill_value != "": | |
# Nans present, but fill value is non-empty. (Question: why does the fill value have to be an empty string?) | |
raise ValueError(f"Fill value must be empty string. Got {self.computed_fill_value}.") | |
return torch.nan_to_num(v, nan=self.padding_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question– I've actually updated the function to support all computed_fill_value
values. Thanks for the inspiration!
feature_name = feature_name_expected[: feature_name_expected.rfind("_")] # remove proc suffix | ||
if feature_name not in preproc_inputs.keys(): | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems potentially brittle like if the feature naming changes, then this test won't actually check anything. Are output features the only feature that wouldn't have an entry in preproc_inputs
? Perhaps we do a hard continue, only for the output feature.
Alternatively, if we could make this feature_name_expected[: feature_name_expected.rfind("_")] # remove proc suffix
into a tested function that's guarantees that it's in sync with preproc module feature naming, that would feel a bit more robust.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. There should actually be no mismatching values in between the two dictionaries, so I've changed the conditional to an assert. Thanks!
skip_save_progress=True, | ||
skip_save_log=True, | ||
skip_save_processed_input=True, | ||
ludwig_model, script_module = initialize_ludwig_model_and_scripted_module( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice simplification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR enables torchscript users to pass in
Union[torch.Tensor, List[torch.Tensor]]
objects for Vector input features andList[torch.Tensor]
objects for Timeseries input features in order to better utilize GPU resources.Prior to this change, users could only pass Vector and Timeseries features as
List[str]
objects, which required stripping and parsing each sample intotorch.Tensor
objects on CPU, which can be slow. With this change, users now have the option to pass intorch.Tensor
orList[torch.Tensor]
objects, which can be operated upon on GPU.