diff --git a/ludwig/features/sequence_feature.py b/ludwig/features/sequence_feature.py index 2f0d905b53d..45a834789e2 100644 --- a/ludwig/features/sequence_feature.py +++ b/ludwig/features/sequence_feature.py @@ -92,38 +92,48 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor: if not torch.jit.isinstance(v, List[str]): raise ValueError(f"Unsupported input: {v}") - v = [self.computed_fill_value if s == "nan" else s for s in v] + futures: List[torch.jit.Future[torch.Tensor]] = [] + for sequence in v: + futures.append( + torch.jit.fork( + self._process_sequence, + sequence, + ) + ) + + sequence_matrix = [] + for future in futures: + sequence_matrix.append(torch.jit.wait(future)) + + return torch.stack(sequence_matrix) + + def _process_sequence(self, sequence: str) -> torch.Tensor: + sequence = self.computed_fill_value if sequence == "nan" else sequence if self.lowercase: - sequences = [sequence.lower() for sequence in v] + sequence_str: str = sequence.lower() else: - sequences = v + sequence_str: str = sequence - unit_sequences = self.tokenizer(sequences) - # refines type of unit_sequences from Any to List[List[str]] - assert torch.jit.isinstance(unit_sequences, List[List[str]]), "unit_sequences is not a list of lists." + unit_sequence = self.tokenizer(sequence_str) + assert torch.jit.isinstance(unit_sequence, List[str]) - sequence_matrix = torch.full( - [len(unit_sequences), self.max_sequence_length], self.unit_to_id[self.padding_symbol] - ) - sequence_matrix[:, 0] = self.unit_to_id[self.start_symbol] - for sample_idx, unit_sequence in enumerate(unit_sequences): - # Add if sequence length is less than max_sequence_length. Else, truncate to max_sequence_length. - if len(unit_sequence) + 1 < self.max_sequence_length: - sequence_length = len(unit_sequence) - sequence_matrix[sample_idx][len(unit_sequence) + 1] = self.unit_to_id[self.stop_symbol] - else: - sequence_length = self.max_sequence_length - 1 - - for i in range(sequence_length): - curr_unit = unit_sequence[i] - if curr_unit in self.unit_to_id: - curr_id = self.unit_to_id[curr_unit] - else: - curr_id = self.unit_to_id[self.unknown_symbol] - sequence_matrix[sample_idx][i + 1] = curr_id + sequence_vector = torch.full([self.max_sequence_length], self.unit_to_id[self.padding_symbol]) + sequence_vector[0] = self.unit_to_id[self.start_symbol] + if len(unit_sequence) + 1 < self.max_sequence_length: + sequence_length = len(unit_sequence) + sequence_vector[len(unit_sequence) + 1] = self.unit_to_id[self.stop_symbol] + else: + sequence_length = self.max_sequence_length - 1 - return sequence_matrix + for i in range(sequence_length): + curr_unit = unit_sequence[i] + if curr_unit in self.unit_to_id: + curr_id = self.unit_to_id[curr_unit] + else: + curr_id = self.unit_to_id[self.unknown_symbol] + sequence_vector[i + 1] = curr_id + return sequence_vector class _SequencePostprocessing(torch.nn.Module): diff --git a/ludwig/models/inference.py b/ludwig/models/inference.py index 5e89dffe941..bf24808d5b4 100644 --- a/ludwig/models/inference.py +++ b/ludwig/models/inference.py @@ -56,8 +56,7 @@ def __init__( def preprocessor_forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, torch.Tensor]: """Forward pass through the preprocessor.""" - with torch.no_grad(): - return self.preprocessor(inputs) + return self.preprocessor(inputs) def predictor_forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward pass through the predictor. @@ -67,24 +66,22 @@ def predictor_forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str for k, v in preproc_inputs.items(): preproc_inputs[k] = v.to(self.predictor.device) - with torch.no_grad(): + with torch.no_grad(): # Ensure model params do not compute gradients predictions_flattened = self.predictor(preproc_inputs) return predictions_flattened def postprocessor_forward(self, predictions_flattened: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]: """Forward pass through the postprocessor.""" - with torch.no_grad(): - postproc_outputs_flattened: Dict[str, Any] = self.postprocessor(predictions_flattened) - # Turn flat inputs into nested predictions per feature name - postproc_outputs: Dict[str, Dict[str, Any]] = _unflatten_dict_by_feature_name(postproc_outputs_flattened) - return postproc_outputs + postproc_outputs_flattened: Dict[str, Any] = self.postprocessor(predictions_flattened) + # Turn flat inputs into nested predictions per feature name + postproc_outputs: Dict[str, Dict[str, Any]] = _unflatten_dict_by_feature_name(postproc_outputs_flattened) + return postproc_outputs def forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, Dict[str, Any]]: - with torch.no_grad(): - preproc_inputs: Dict[str, torch.Tensor] = self.preprocessor_forward(inputs) - predictions_flattened: Dict[str, torch.Tensor] = self.predictor_forward(preproc_inputs) - postproc_outputs: Dict[str, Dict[str, Any]] = self.postprocessor_forward(predictions_flattened) - return postproc_outputs + preproc_inputs: Dict[str, torch.Tensor] = self.preprocessor_forward(inputs) + predictions_flattened: Dict[str, torch.Tensor] = self.predictor_forward(preproc_inputs) + postproc_outputs: Dict[str, Dict[str, Any]] = self.postprocessor_forward(predictions_flattened) + return postproc_outputs @torch.jit.unused def predict( @@ -172,12 +169,11 @@ def __init__(self, config: Dict[str, Any], training_set_metadata: Dict[str, Any] self.preproc_modules[module_dict_key] = feature.create_preproc_module(training_set_metadata[feature_name]) def forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, torch.Tensor]: - with torch.no_grad(): - preproc_inputs = {} - for module_dict_key, preproc in self.preproc_modules.items(): - feature_name = get_name_from_module_dict_key(module_dict_key) - preproc_inputs[feature_name] = preproc(inputs[feature_name]) - return preproc_inputs + preproc_inputs = {} + for module_dict_key, preproc in self.preproc_modules.items(): + feature_name = get_name_from_module_dict_key(module_dict_key) + preproc_inputs[feature_name] = preproc(inputs[feature_name]) + return preproc_inputs class _InferencePredictor(nn.Module): @@ -200,17 +196,16 @@ def __init__(self, model: "ECD", device: TorchDevice): self.predict_modules[module_dict_key] = feature.prediction_module.to(device=self.device) def forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - with torch.no_grad(): - model_outputs = self.model(preproc_inputs) - predictions_flattened: Dict[str, torch.Tensor] = {} - for module_dict_key, predict in self.predict_modules.items(): - feature_name = get_name_from_module_dict_key(module_dict_key) - feature_predictions = predict(model_outputs, feature_name) - # Flatten out the predictions to support Triton input/output - for predict_key, tensor_values in feature_predictions.items(): - predict_concat_key = output_feature_utils.get_feature_concat_name(feature_name, predict_key) - predictions_flattened[predict_concat_key] = tensor_values - return predictions_flattened + model_outputs = self.model(preproc_inputs) + predictions_flattened: Dict[str, torch.Tensor] = {} + for module_dict_key, predict in self.predict_modules.items(): + feature_name = get_name_from_module_dict_key(module_dict_key) + feature_predictions = predict(model_outputs, feature_name) + # Flatten out the predictions to support Triton input/output + for predict_key, tensor_values in feature_predictions.items(): + predict_concat_key = output_feature_utils.get_feature_concat_name(feature_name, predict_key) + predictions_flattened[predict_concat_key] = tensor_values + return predictions_flattened class _InferencePostprocessor(nn.Module): @@ -231,16 +226,15 @@ def __init__(self, model: "ECD", training_set_metadata: Dict[str, Any]): self.postproc_modules[module_dict_key] = feature.create_postproc_module(training_set_metadata[feature_name]) def forward(self, predictions_flattened: Dict[str, torch.Tensor]) -> Dict[str, Any]: - with torch.no_grad(): - postproc_outputs_flattened: Dict[str, Any] = {} - for module_dict_key, postproc in self.postproc_modules.items(): - feature_name = get_name_from_module_dict_key(module_dict_key) - feature_postproc_outputs = postproc(predictions_flattened, feature_name) - # Flatten out the predictions to support Triton input/output - for postproc_key, tensor_values in feature_postproc_outputs.items(): - postproc_concat_key = output_feature_utils.get_feature_concat_name(feature_name, postproc_key) - postproc_outputs_flattened[postproc_concat_key] = tensor_values - return postproc_outputs_flattened + postproc_outputs_flattened: Dict[str, Any] = {} + for module_dict_key, postproc in self.postproc_modules.items(): + feature_name = get_name_from_module_dict_key(module_dict_key) + feature_postproc_outputs = postproc(predictions_flattened, feature_name) + # Flatten out the predictions to support Triton input/output + for postproc_key, tensor_values in feature_postproc_outputs.items(): + postproc_concat_key = output_feature_utils.get_feature_concat_name(feature_name, postproc_key) + postproc_outputs_flattened[postproc_concat_key] = tensor_values + return postproc_outputs_flattened def save_ludwig_model_for_inference( diff --git a/tests/integration_tests/test_torchscript.py b/tests/integration_tests/test_torchscript.py index 7165af1ea73..a15f6040dc8 100644 --- a/tests/integration_tests/test_torchscript.py +++ b/tests/integration_tests/test_torchscript.py @@ -622,9 +622,10 @@ def validate_torchscript_outputs(tmpdir, config, backend, training_data_csv_path assert output_name in feature_outputs output_values = feature_outputs[output_name] + assert utils.has_no_grad(output_values), f'"{feature_name}.{output_name}" tensors have gradients' assert utils.is_all_close( output_values, output_values_expected - ), f"feature: {feature_name}, output: {output_name}" + ), f'"{feature_name}.{output_name}" tensors are not close to ludwig model' def initialize_torchscript_module(tmpdir, config, backend, training_data_csv_path, device=None): diff --git a/tests/integration_tests/utils.py b/tests/integration_tests/utils.py index 370ff0b7795..96e6328e842 100644 --- a/tests/integration_tests/utils.py +++ b/tests/integration_tests/utils.py @@ -501,6 +501,17 @@ def get_weights(model: torch.nn.Module) -> List[torch.Tensor]: return [param.data for param in model.parameters()] +def has_no_grad( + val: Union[np.ndarray, torch.Tensor, str, list], +): + """Checks if two values are close to each other.""" + if isinstance(val, list): + return all(has_no_grad(v) for v in val) + if isinstance(val, torch.Tensor): + return not val.requires_grad + return True + + def is_all_close( val1: Union[np.ndarray, torch.Tensor, str, list], val2: Union[np.ndarray, torch.Tensor, str, list],