Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus committed Jun 17, 2022
2 parents 2ca3e78 + fe77793 commit cc52796
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 25 deletions.
40 changes: 39 additions & 1 deletion ludwig/features/set_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,40 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor:
return set_matrix


class _SetPostprocessing(torch.nn.Module):
"""Torchscript-enabled version of postprocessing done by SetFeatureMixin.add_feature_data."""

def __init__(self, metadata: Dict[str, Any]):
super().__init__()
self.idx2str = {i: v for i, v in enumerate(metadata["idx2str"])}
self.predictions_key = PREDICTIONS
self.probabilities_key = PROBABILITIES
self.unk = UNKNOWN_SYMBOL

def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
predictions = preds[self.predictions_key]
probabilities = preds[self.probabilities_key]

inv_preds: List[List[str]] = []
filtered_probs: List[torch.Tensor] = []
for sample_idx, sample in enumerate(predictions):
sample_preds: List[str] = []
pos_sample_idxs: List[int] = []
pos_class_idxs: List[int] = []
for class_idx, is_positive in enumerate(sample):
if is_positive == 1:
sample_preds.append(self.idx2str.get(class_idx, self.unk))
pos_sample_idxs.append(sample_idx)
pos_class_idxs.append(class_idx)
inv_preds.append(sample_preds)
filtered_probs.append(probabilities[pos_sample_idxs, pos_class_idxs])

return {
self.predictions_key: inv_preds,
self.probabilities_key: filtered_probs,
}


class _SetPredict(PredictModule):
def __init__(self, threshold):
super().__init__()
Expand Down Expand Up @@ -339,7 +373,7 @@ def idx2str(pred_set):
threshold = self.threshold

def get_prob(prob_set):
return [prob for prob in prob_set if prob >= threshold]
return np.array([prob for prob in prob_set if prob >= threshold])

result[probabilities_col] = backend.df_engine.map_objects(
result[probabilities_col],
Expand All @@ -348,6 +382,10 @@ def get_prob(prob_set):

return result

@staticmethod
def create_postproc_module(metadata: Dict[str, Any]) -> torch.nn.Module:
return _SetPostprocessing(metadata)

@staticmethod
def populate_defaults(output_feature):
set_default_value(output_feature, LOSS, {TYPE: SIGMOID_CROSS_ENTROPY, "weight": 1})
Expand Down
19 changes: 4 additions & 15 deletions tests/integration_tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ def test_torchscript_e2e_tabular(csv_filename, tmpdir):
binary_feature(),
number_feature(),
category_feature(vocab_size=3),
set_feature(vocab_size=3),
vector_feature()
# TODO: future support
# sequence_feature(vocab_size=3),
# text_feature(vocab_size=3),
# set_feature(vocab_size=3),
]
backend = LocalTestBackend()
config = {"input_features": input_features, "output_features": output_features, TRAINER: {"epochs": 2}}
Expand Down Expand Up @@ -421,17 +421,6 @@ def validate_torchscript_outputs(tmpdir, config, backend, training_data_csv_path

assert output_name in feature_outputs
output_values = feature_outputs[output_name]
if isinstance(output_values, list):
# Strings should match exactly
assert np.all(
output_values == output_values_expected
), f"feature: {feature_name}, output: {output_name}"
else:
output_values = np.array(output_values)
# Shapes and values must both match
assert (
output_values.shape == output_values_expected.shape
), f"feature: {feature_name}, output: {output_name}"
assert np.allclose(
output_values, output_values_expected, atol=tolerance
), f"feature: {feature_name}, output: {output_name}"
assert utils.is_all_close(
output_values, output_values_expected
), f"feature: {feature_name}, output: {output_name}"
26 changes: 17 additions & 9 deletions tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import unittest
import uuid
from distutils.util import strtobool
from typing import List
from typing import List, Union

import cloudpickle
import numpy as np
Expand Down Expand Up @@ -468,14 +468,22 @@ def get_weights(model: torch.nn.Module) -> List[torch.Tensor]:
return [param.data for param in model.parameters()]


def is_all_close(list_tensors_1, list_tensors_2):
"""Returns whether all of list_tensors_1 is close to list_tensors_2."""
assert len(list_tensors_1) == len(list_tensors_2)
for i in range(len(list_tensors_1)):
assert list_tensors_1[i].size() == list_tensors_2[i].size()

is_close_values = [torch.isclose(list_tensors_1[i], list_tensors_2[i]) for i in range(len(list_tensors_1))]
return torch.all(torch.Tensor([torch.all(is_close_value) for is_close_value in is_close_values]))
def is_all_close(
val1: Union[np.ndarray, torch.Tensor, str, list],
val2: Union[np.ndarray, torch.Tensor, str, list],
tolerance=1e-8,
):
"""Checks if two values are close to each other."""
if isinstance(val1, list):
return all(is_all_close(v1, v2, tolerance) for v1, v2 in zip(val1, val2))

if isinstance(val1, str):
return val1 == val2
if isinstance(val1, torch.Tensor):
val1 = val1.detach().numpy()
if isinstance(val2, torch.Tensor):
val2 = val2.detach().numpy()
return val1.shape == val2.shape and np.allclose(val1, val2, atol=tolerance)


def run_api_experiment(input_features, output_features, data_csv):
Expand Down

0 comments on commit cc52796

Please sign in to comment.