From 9c22511a2a95ddb1dfd0c75391ad5e2202adcace Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Thu, 23 Nov 2023 16:47:10 +0100 Subject: [PATCH 1/3] variadic inputs for glue implementation --- .gitignore | 1 + temporian/core/operators/base.py | 42 ++++++++++++--- temporian/core/operators/glue.py | 21 ++------ temporian/core/operators/test/test_glue.py | 48 +++++++++++++++--- temporian/core/test/operator_test.py | 59 ++++++++++++++++++++++ temporian/core/test/serialization_test.py | 56 ++++++++++++++++++++ temporian/core/test/utils.py | 30 +++++++++++ temporian/proto/core.proto | 10 +++- 8 files changed, 235 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 7dee94767..2291dad0e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ build_package tmp_* .cache/ .env +my_venv # benchmark outputs profile.* diff --git a/temporian/core/operators/base.py b/temporian/core/operators/base.py index 8dcf296d7..f8ce7641e 100644 --- a/temporian/core/operators/base.py +++ b/temporian/core/operators/base.py @@ -159,16 +159,44 @@ def check(self) -> None: with OperatorExceptionDecorator(self): # Check that expected inputs are present for expected_input in definition.inputs: - if ( - not expected_input.is_optional - and expected_input.key not in self._inputs - ): - raise ValueError(f'Missing input "{expected_input.key}".') + if expected_input.HasField("key"): + if ( + not expected_input.is_optional + and expected_input.key not in self._inputs + ): + raise ValueError( + f'Missing input "{expected_input.key}".' + ) + elif expected_input.HasField("key_prefix"): + # Nothing to check + pass + else: + raise ValueError("Invalid operator definition") # Check that no unexpected inputs are present for available_input in self._inputs: - if available_input not in [v.key for v in definition.inputs]: - raise ValueError(f'Unexpected input "{available_input}".') + num_multi_input_matches = sum( + available_input.startswith(v.key_prefix) + for v in definition.inputs + if v.HasField("key_prefix") + ) + if num_multi_input_matches > 1: + raise ValueError( + f'Input "{available_input}" matches multiple prefix' + " inputs." + ) + + if available_input in [v.key for v in definition.inputs]: + if num_multi_input_matches != 0: + raise ValueError( + f'Input "{available_input}" matches both a prefix' + " and non-prefix input." + ) + else: + if num_multi_input_matches != 1: + raise ValueError( + f'Unexpected input "{available_input}".' + ) # Check that expected outputs are present for expected_output in definition.outputs: diff --git a/temporian/core/operators/glue.py b/temporian/core/operators/glue.py index 838c082bc..c8e957f8f 100644 --- a/temporian/core/operators/glue.py +++ b/temporian/core/operators/glue.py @@ -26,8 +26,7 @@ from temporian.proto import core_pb2 as pb from temporian.utils.typecheck import typecheck -# Maximum number of arguments taken by the glue operator -MAX_NUM_ARGUMENTS = 100 +_INPUT_KEY_PREFIX = "input_" class GlueOperator(Operator): @@ -43,11 +42,6 @@ def __init__( if len(inputs) < 2: raise ValueError("At least two arguments should be provided.") - if len(inputs) >= MAX_NUM_ARGUMENTS: - raise ValueError( - f"Too many (>{MAX_NUM_ARGUMENTS}) arguments provided." - ) - # inputs output_features = [] output_feature_schemas = [] @@ -95,11 +89,7 @@ def __init__( def build_op_definition(cls) -> pb.OperatorDef: return pb.OperatorDef( key="GLUE", - # TODO: Add support to array of nodes arguments. - inputs=[ - pb.OperatorDef.Input(key=f"input_{idx}", is_optional=idx >= 2) - for idx in range(MAX_NUM_ARGUMENTS) - ], + inputs=[pb.OperatorDef.Input(key_prefix=_INPUT_KEY_PREFIX)], outputs=[pb.OperatorDef.Output(key="output")], ) @@ -203,10 +193,9 @@ def glue( """ if len(inputs) == 1 and isinstance(inputs[0], EventSetNode): return inputs[0] - - # Note: The node should be called "input_{idx}" with idx in [0, MAX_NUM_ARGUMENTS). - inputs_dict = {f"input_{idx}": input for idx, input in enumerate(inputs)} - + inputs_dict = { + f"{_INPUT_KEY_PREFIX}{idx}": input for idx, input in enumerate(inputs) + } return GlueOperator(**inputs_dict).outputs["output"] # type: ignore diff --git a/temporian/core/operators/test/test_glue.py b/temporian/core/operators/test/test_glue.py index c4f0f2de9..6490feff5 100644 --- a/temporian/core/operators/test/test_glue.py +++ b/temporian/core/operators/test/test_glue.py @@ -14,10 +14,13 @@ from absl.testing import absltest from absl.testing.parameterized import TestCase -from temporian.core.operators.glue import MAX_NUM_ARGUMENTS, glue - +from temporian.core.operators.glue import glue +import tempfile +import os from temporian.implementation.numpy.data.io import event_set from temporian.test.utils import assertOperatorResult, f32 +from temporian.core import serialization +from temporian.core import evaluation class GlueTest(TestCase): @@ -96,11 +99,6 @@ def test_no_evsets(self): ): glue() - def test_too_many_evsets(self): - evsets = [event_set([], features={"a": []})] * (MAX_NUM_ARGUMENTS + 1) - with self.assertRaisesRegex(ValueError, "Too many"): - glue(*evsets) - def test_order_unchanged(self): """Tests that input evsets' order is kept. @@ -164,6 +162,42 @@ def test_order_unchanged(self): assertOperatorResult(self, result, expected) + def test_serialization(self): + # Generate some input event sets. + num_inputs = 1000 + input_0 = event_set([0], {"f0": [0]}) + inputs = [ + event_set([0], {f"f{i}": [i]}, same_sampling_as=input_0) + for i in range(1, num_inputs) + ] + inputs = [input_0, *inputs] + + # Glue the inputs together. + input_nodes = [e.node() for e in inputs] + output_node = glue(*input_nodes) + + # Save and restore the graph. + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "my_fn.tem") + serialization.save_graph( + inputs={f"i_{i}": v for i, v in enumerate(input_nodes)}, + outputs={"output": output_node}, + path=path, + ) + loaded_inputs, loaded_output = serialization.load_graph( + path=path, squeeze=True + ) + + # Execute the loaded graph. + output = evaluation.run( + loaded_output, + {loaded_inputs[f"i_{i}"]: v for i, v in enumerate(inputs)}, + ) + expeted_output = event_set( + [0], {f"f{i}": [i] for i in range(0, num_inputs)} + ) + self.assertEqual(output, expeted_output) + if __name__ == "__main__": absltest.main() diff --git a/temporian/core/test/operator_test.py b/temporian/core/test/operator_test.py index bbacc4ae6..1f205bb05 100644 --- a/temporian/core/test/operator_test.py +++ b/temporian/core/test/operator_test.py @@ -68,6 +68,65 @@ def build_fake_node(): ): t.check() + def test_check_operator_with_key_prefix(self): + class ToyOperator(base.Operator): + @classmethod + def build_op_definition(cls) -> pb.OperatorDef: + return pb.OperatorDef( + key="TOY", + inputs=[pb.OperatorDef.Input(key_prefix="input_")], + ) + + def _get_pandas_implementation(self): + raise NotImplementedError() + + def build_fake_node(): + return input_node(features=[]) + + t = ToyOperator() + t.check() + + t.add_input("input", build_fake_node()) + with self.assertRaisesRegex(ValueError, 'Unexpected input "input"'): + t.check() + + t = ToyOperator() + t.add_input("input_1", build_fake_node()) + t.add_input("input_2", build_fake_node()) + t.add_input("input_3", build_fake_node()) + + with self.assertRaisesRegex( + ValueError, 'Already existing input "input_3"' + ): + t.add_input("input_3", build_fake_node()) + + t.check() + + def test_check_operator_with_key_prefix_and_invalid_def(self): + class ToyOperator(base.Operator): + @classmethod + def build_op_definition(cls) -> pb.OperatorDef: + return pb.OperatorDef( + key="TOY", + inputs=[ + pb.OperatorDef.Input(key_prefix="input_"), + pb.OperatorDef.Input(key_prefix="input"), + ], + ) + + def _get_pandas_implementation(self): + raise NotImplementedError() + + def build_fake_node(): + return input_node(features=[]) + + t = ToyOperator() + t.add_input("input_1", build_fake_node()) + with self.assertRaisesRegex( + ValueError, 'Input "input_1" matches multiple prefix inputs' + ): + t.check() + if __name__ == "__main__": absltest.main() diff --git a/temporian/core/test/serialization_test.py b/temporian/core/test/serialization_test.py index a36672df7..f59b7dbc8 100644 --- a/temporian/core/test/serialization_test.py +++ b/temporian/core/test/serialization_test.py @@ -104,6 +104,62 @@ def test_serialize(self): & serialization._all_identifiers(restored.named_outputs.values()) ) + def test_serialize_with_key_prefix(self): + i1 = utils.create_simple_input_node() + i2 = utils.create_simple_input_node() + i3 = utils.create_simple_input_node() + o2 = utils.OpMultiIO1(input_1=i1, input_2=i2, input_3=i3) + + original = graph.infer_graph_named_nodes( + {"io_input_1": i1, "io_input_2": i2, "io_input_3": i3}, + {"io_output": o2.outputs["output"]}, + ) + logging.info("original:\n%s", original) + + proto = serialization._serialize(original) + logging.info("proto:\n%s", proto) + + restored = serialization._unserialize(proto) + logging.info("restored:\n%s", restored) + + self.assertEqual(len(original.samplings), len(restored.samplings)) + self.assertEqual(len(original.features), len(restored.features)) + self.assertEqual(len(original.operators), len(restored.operators)) + self.assertEqual(len(original.nodes), len(restored.nodes)) + self.assertEqual( + original.named_inputs.keys(), restored.named_inputs.keys() + ) + self.assertEqual( + original.named_outputs.keys(), restored.named_outputs.keys() + ) + # TODO: Deep equality tests. + + # Ensures that "original" and "restored" don't link to the same objects. + self.assertFalse( + serialization._all_identifiers(original.samplings) + & serialization._all_identifiers(restored.samplings) + ) + self.assertFalse( + serialization._all_identifiers(original.features) + & serialization._all_identifiers(restored.features) + ) + self.assertFalse( + serialization._all_identifiers(original.operators) + & serialization._all_identifiers(restored.operators) + ) + self.assertFalse( + serialization._all_identifiers(original.nodes) + & serialization._all_identifiers(restored.nodes) + ) + self.assertFalse( + serialization._all_identifiers(original.named_inputs.values()) + & serialization._all_identifiers(restored.named_inputs.values()) + ) + self.assertFalse( + serialization._all_identifiers(original.named_outputs.values()) + & serialization._all_identifiers(restored.named_outputs.values()) + ) + def test_serialize_autonode(self): input_data = event_set( timestamps=[1, 2, 3, 4], diff --git a/temporian/core/test/utils.py b/temporian/core/test/utils.py index 9acb1a6c6..1a0bd6884 100644 --- a/temporian/core/test/utils.py +++ b/temporian/core/test/utils.py @@ -1,4 +1,5 @@ """Utilities for unit testing.""" + from typing import Dict, List, Optional from temporian.core.data.dtype import DType @@ -30,6 +31,10 @@ def create_input_node(name: Optional[str] = None): ) +def create_simple_input_node(name: Optional[str] = None): + return input_node(features=[("f1", DType.FLOAT64)]) + + def create_input_event_set(name: Optional[str] = None) -> EventSet: return event_set( timestamps=[0, 2, 4, 6], @@ -128,6 +133,30 @@ def __init__(self, input_1: EventSetNode, input_2: EventSetNode): self.check() +class OpMultiIO1(base.Operator): + @classmethod + def build_op_definition(cls) -> pb.OperatorDef: + return pb.OperatorDef( + key="OpMultiIO1", + inputs=[pb.OperatorDef.Input(key_prefix="input_")], + outputs=[pb.OperatorDef.Output(key="output")], + ) + + def __init__(self, **inputs: EventSetNode): + super().__init__() + for k, v in inputs.items(): + self.add_input(k, v) + self.add_output( + "output", + create_node_new_features_existing_sampling( + features=[("o", DType.BOOLEAN)], + sampling_node=next(iter(inputs.values())), + creator=self, + ), + ) + self.check() + + class OpI1O2(base.Operator): @classmethod def build_op_definition(cls) -> pb.OperatorDef: @@ -245,6 +274,7 @@ def __init__( OpI2O1, OpI1O2, OpWithAttributes, + OpMultiIO1, ] # Utilities to register and unregister test operators. diff --git a/temporian/proto/core.proto b/temporian/proto/core.proto index d06b99fdf..7f76a2a29 100644 --- a/temporian/proto/core.proto +++ b/temporian/proto/core.proto @@ -190,8 +190,14 @@ message OperatorDef { optional bool is_serializable = 5 [default = true]; message Input { - // String identifier of the input. Should be unique. - optional string key = 1; + oneof type { + // String identifier of the input. Should be unique. + string key = 1; + + // String prefix to identify multiple inputs. An input should only match + // one "key_prefix". If set, "is_optional" is ignored. + string key_prefix = 2; + } // If true, the input is optional. optional bool is_optional = 3; From 99351daf533d7de8bb539a16035da7528ec3fd87 Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Thu, 23 Nov 2023 16:49:23 +0100 Subject: [PATCH 2/3] variadic inputs for combine --- temporian/core/operators/combine.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/temporian/core/operators/combine.py b/temporian/core/operators/combine.py index 9ae9616d1..db86322c7 100644 --- a/temporian/core/operators/combine.py +++ b/temporian/core/operators/combine.py @@ -28,7 +28,7 @@ from temporian.proto import core_pb2 as pb from temporian.utils.typecheck import typecheck -MAX_NUM_ARGUMENTS = 30 +_INPUT_KEY_PREFIX = "input_" class How(str, Enum): @@ -59,11 +59,6 @@ def __init__(self, how: How, **inputs: EventSetNode): if len(inputs) < 2: raise ValueError("At least two arguments should be provided") - if len(inputs) >= MAX_NUM_ARGUMENTS: - raise ValueError( - f"Too many (>{MAX_NUM_ARGUMENTS}) arguments provided" - ) - # Attributes self._how = how self.add_attribute("how", how) @@ -110,10 +105,7 @@ def build_op_definition(cls) -> pb.OperatorDef: type=pb.OperatorDef.Attribute.Type.STRING, ), ], - inputs=[ - pb.OperatorDef.Input(key=f"input_{idx}", is_optional=idx >= 2) - for idx in range(MAX_NUM_ARGUMENTS) - ], + inputs=[pb.OperatorDef.Input(key_prefix=_INPUT_KEY_PREFIX)], outputs=[pb.OperatorDef.Output(key="output")], ) @@ -249,5 +241,7 @@ def combine( return inputs[0] # NOTE: input name must match op. definition name - inputs_dict = {f"input_{idx}": input for idx, input in enumerate(inputs)} + inputs_dict = { + f"{_INPUT_KEY_PREFIX}{idx}": input for idx, input in enumerate(inputs) + } return Combine(how=how, **inputs_dict).outputs["output"] # type: ignore From 8d1e02dcc62f8ae655df05bed1076080999f310b Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Thu, 23 Nov 2023 16:59:36 +0100 Subject: [PATCH 3/3] Extra test --- temporian/core/operators/base.py | 4 +++- temporian/core/test/operator_test.py | 29 +++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/temporian/core/operators/base.py b/temporian/core/operators/base.py index f8ce7641e..0ee03fb17 100644 --- a/temporian/core/operators/base.py +++ b/temporian/core/operators/base.py @@ -186,7 +186,9 @@ def check(self) -> None: " inputs." ) - if available_input in [v.key for v in definition.inputs]: + if available_input in [ + v.key for v in definition.inputs if v.HasField("key") + ]: if num_multi_input_matches != 0: raise ValueError( f'Input "{available_input}" matches both a prefix' diff --git a/temporian/core/test/operator_test.py b/temporian/core/test/operator_test.py index 1f205bb05..9ca55c818 100644 --- a/temporian/core/test/operator_test.py +++ b/temporian/core/test/operator_test.py @@ -102,7 +102,7 @@ def build_fake_node(): t.check() - def test_check_operator_with_key_prefix_and_invalid_def(self): + def test_check_operator_error_overlapping_prefixes(self): class ToyOperator(base.Operator): @classmethod def build_op_definition(cls) -> pb.OperatorDef: @@ -127,6 +127,33 @@ def build_fake_node(): ): t.check() + def test_check_operator_error_overlapping_prefix_and_non_prefix( + self, + ): + class ToyOperator(base.Operator): + @classmethod + def build_op_definition(cls) -> pb.OperatorDef: + return pb.OperatorDef( + key="TOY", + inputs=[ + pb.OperatorDef.Input(key_prefix="input_"), + pb.OperatorDef.Input(key="input_1"), + ], + ) + + def _get_pandas_implementation(self): + raise NotImplementedError() + + def build_fake_node(): + return input_node(features=[]) + + t = ToyOperator() + t.add_input("input_1", build_fake_node()) + with self.assertRaisesRegex( + ValueError, "matches both a prefix and non-prefix input" + ): + t.check() + if __name__ == "__main__": absltest.main()