Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torchscript] Parallelized Text/Sequence Preprocessing #2206

Merged
merged 10 commits into from
Jun 29, 2022
62 changes: 36 additions & 26 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,38 +92,48 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor:
if not torch.jit.isinstance(v, List[str]):
raise ValueError(f"Unsupported input: {v}")

v = [self.computed_fill_value if s == "nan" else s for s in v]
futures: List[torch.jit.Future[torch.Tensor]] = []
for sequence in v:
futures.append(
torch.jit.fork(
self._process_sequence,
sequence,
)
)

sequence_matrix = []
for future in futures:
sequence_matrix.append(torch.jit.wait(future))

return torch.stack(sequence_matrix)

def _process_sequence(self, sequence: str) -> torch.Tensor:
sequence = self.computed_fill_value if sequence == "nan" else sequence

if self.lowercase:
sequences = [sequence.lower() for sequence in v]
sequence_str: str = sequence.lower()
else:
sequences = v
sequence_str: str = sequence

unit_sequences = self.tokenizer(sequences)
# refines type of unit_sequences from Any to List[List[str]]
assert torch.jit.isinstance(unit_sequences, List[List[str]]), "unit_sequences is not a list of lists."
unit_sequence = self.tokenizer(sequence_str)
assert torch.jit.isinstance(unit_sequence, List[str])

sequence_matrix = torch.full(
[len(unit_sequences), self.max_sequence_length], self.unit_to_id[self.padding_symbol]
)
sequence_matrix[:, 0] = self.unit_to_id[self.start_symbol]
for sample_idx, unit_sequence in enumerate(unit_sequences):
# Add <EOS> if sequence length is less than max_sequence_length. Else, truncate to max_sequence_length.
if len(unit_sequence) + 1 < self.max_sequence_length:
sequence_length = len(unit_sequence)
sequence_matrix[sample_idx][len(unit_sequence) + 1] = self.unit_to_id[self.stop_symbol]
else:
sequence_length = self.max_sequence_length - 1

for i in range(sequence_length):
curr_unit = unit_sequence[i]
if curr_unit in self.unit_to_id:
curr_id = self.unit_to_id[curr_unit]
else:
curr_id = self.unit_to_id[self.unknown_symbol]
sequence_matrix[sample_idx][i + 1] = curr_id
sequence_vector = torch.full([self.max_sequence_length], self.unit_to_id[self.padding_symbol])
sequence_vector[0] = self.unit_to_id[self.start_symbol]
if len(unit_sequence) + 1 < self.max_sequence_length:
sequence_length = len(unit_sequence)
sequence_vector[len(unit_sequence) + 1] = self.unit_to_id[self.stop_symbol]
else:
sequence_length = self.max_sequence_length - 1

return sequence_matrix
for i in range(sequence_length):
curr_unit = unit_sequence[i]
if curr_unit in self.unit_to_id:
curr_id = self.unit_to_id[curr_unit]
else:
curr_id = self.unit_to_id[self.unknown_symbol]
sequence_vector[i + 1] = curr_id
return sequence_vector


class _SequencePostprocessing(torch.nn.Module):
Expand Down
74 changes: 34 additions & 40 deletions ludwig/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def __init__(

def preprocessor_forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, torch.Tensor]:
"""Forward pass through the preprocessor."""
with torch.no_grad():
return self.preprocessor(inputs)
return self.preprocessor(inputs)

def predictor_forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward pass through the predictor.
Expand All @@ -67,24 +66,22 @@ def predictor_forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str
for k, v in preproc_inputs.items():
preproc_inputs[k] = v.to(self.predictor.device)

with torch.no_grad():
with torch.no_grad(): # Ensure model params do not compute gradients
predictions_flattened = self.predictor(preproc_inputs)
return predictions_flattened

def postprocessor_forward(self, predictions_flattened: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:
"""Forward pass through the postprocessor."""
with torch.no_grad():
postproc_outputs_flattened: Dict[str, Any] = self.postprocessor(predictions_flattened)
# Turn flat inputs into nested predictions per feature name
postproc_outputs: Dict[str, Dict[str, Any]] = _unflatten_dict_by_feature_name(postproc_outputs_flattened)
return postproc_outputs
postproc_outputs_flattened: Dict[str, Any] = self.postprocessor(predictions_flattened)
# Turn flat inputs into nested predictions per feature name
postproc_outputs: Dict[str, Dict[str, Any]] = _unflatten_dict_by_feature_name(postproc_outputs_flattened)
return postproc_outputs

def forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, Dict[str, Any]]:
with torch.no_grad():
preproc_inputs: Dict[str, torch.Tensor] = self.preprocessor_forward(inputs)
predictions_flattened: Dict[str, torch.Tensor] = self.predictor_forward(preproc_inputs)
postproc_outputs: Dict[str, Dict[str, Any]] = self.postprocessor_forward(predictions_flattened)
return postproc_outputs
preproc_inputs: Dict[str, torch.Tensor] = self.preprocessor_forward(inputs)
predictions_flattened: Dict[str, torch.Tensor] = self.predictor_forward(preproc_inputs)
postproc_outputs: Dict[str, Dict[str, Any]] = self.postprocessor_forward(predictions_flattened)
return postproc_outputs

@torch.jit.unused
def predict(
Expand Down Expand Up @@ -172,12 +169,11 @@ def __init__(self, config: Dict[str, Any], training_set_metadata: Dict[str, Any]
self.preproc_modules[module_dict_key] = feature.create_preproc_module(training_set_metadata[feature_name])

def forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, torch.Tensor]:
with torch.no_grad():
preproc_inputs = {}
for module_dict_key, preproc in self.preproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
preproc_inputs[feature_name] = preproc(inputs[feature_name])
return preproc_inputs
preproc_inputs = {}
for module_dict_key, preproc in self.preproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
preproc_inputs[feature_name] = preproc(inputs[feature_name])
return preproc_inputs


class _InferencePredictor(nn.Module):
Expand All @@ -200,17 +196,16 @@ def __init__(self, model: "ECD", device: TorchDevice):
self.predict_modules[module_dict_key] = feature.prediction_module.to(device=self.device)

def forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
with torch.no_grad():
model_outputs = self.model(preproc_inputs)
predictions_flattened: Dict[str, torch.Tensor] = {}
for module_dict_key, predict in self.predict_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_predictions = predict(model_outputs, feature_name)
# Flatten out the predictions to support Triton input/output
for predict_key, tensor_values in feature_predictions.items():
predict_concat_key = output_feature_utils.get_feature_concat_name(feature_name, predict_key)
predictions_flattened[predict_concat_key] = tensor_values
return predictions_flattened
model_outputs = self.model(preproc_inputs)
predictions_flattened: Dict[str, torch.Tensor] = {}
for module_dict_key, predict in self.predict_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_predictions = predict(model_outputs, feature_name)
# Flatten out the predictions to support Triton input/output
for predict_key, tensor_values in feature_predictions.items():
predict_concat_key = output_feature_utils.get_feature_concat_name(feature_name, predict_key)
predictions_flattened[predict_concat_key] = tensor_values
return predictions_flattened


class _InferencePostprocessor(nn.Module):
Expand All @@ -231,16 +226,15 @@ def __init__(self, model: "ECD", training_set_metadata: Dict[str, Any]):
self.postproc_modules[module_dict_key] = feature.create_postproc_module(training_set_metadata[feature_name])

def forward(self, predictions_flattened: Dict[str, torch.Tensor]) -> Dict[str, Any]:
with torch.no_grad():
postproc_outputs_flattened: Dict[str, Any] = {}
for module_dict_key, postproc in self.postproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_postproc_outputs = postproc(predictions_flattened, feature_name)
# Flatten out the predictions to support Triton input/output
for postproc_key, tensor_values in feature_postproc_outputs.items():
postproc_concat_key = output_feature_utils.get_feature_concat_name(feature_name, postproc_key)
postproc_outputs_flattened[postproc_concat_key] = tensor_values
return postproc_outputs_flattened
postproc_outputs_flattened: Dict[str, Any] = {}
for module_dict_key, postproc in self.postproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_postproc_outputs = postproc(predictions_flattened, feature_name)
# Flatten out the predictions to support Triton input/output
for postproc_key, tensor_values in feature_postproc_outputs.items():
postproc_concat_key = output_feature_utils.get_feature_concat_name(feature_name, postproc_key)
postproc_outputs_flattened[postproc_concat_key] = tensor_values
return postproc_outputs_flattened


def save_ludwig_model_for_inference(
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,10 @@ def validate_torchscript_outputs(tmpdir, config, backend, training_data_csv_path

assert output_name in feature_outputs
output_values = feature_outputs[output_name]
assert utils.has_no_grad(output_values), f'"{feature_name}.{output_name}" tensors have gradients'
assert utils.is_all_close(
output_values, output_values_expected
), f"feature: {feature_name}, output: {output_name}"
), f'"{feature_name}.{output_name}" tensors are not close to ludwig model'


def initialize_torchscript_module(tmpdir, config, backend, training_data_csv_path, device=None):
Expand Down
11 changes: 11 additions & 0 deletions tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,17 @@ def get_weights(model: torch.nn.Module) -> List[torch.Tensor]:
return [param.data for param in model.parameters()]


def has_no_grad(
val: Union[np.ndarray, torch.Tensor, str, list],
):
"""Checks if two values are close to each other."""
if isinstance(val, list):
return all(has_no_grad(v) for v in val)
if isinstance(val, torch.Tensor):
return not val.requires_grad
return True


def is_all_close(
val1: Union[np.ndarray, torch.Tensor, str, list],
val2: Union[np.ndarray, torch.Tensor, str, list],
Expand Down