Skip to content

Commit

Permalink
Merge pull request #318 from google/gbm_variadic_inputs
Browse files Browse the repository at this point in the history
Add support for operators with variadic inputs
  • Loading branch information
ianspektor committed Nov 27, 2023
2 parents 17b2f2f + 8d1e02d commit fd8dd8b
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ build_package
tmp_*
.cache/
.env
my_venv

# benchmark outputs
profile.*
Expand Down
44 changes: 37 additions & 7 deletions temporian/core/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,46 @@ def check(self) -> None:
with OperatorExceptionDecorator(self):
# Check that expected inputs are present
for expected_input in definition.inputs:
if (
not expected_input.is_optional
and expected_input.key not in self._inputs
):
raise ValueError(f'Missing input "{expected_input.key}".')
if expected_input.HasField("key"):
if (
not expected_input.is_optional
and expected_input.key not in self._inputs
):
raise ValueError(
f'Missing input "{expected_input.key}".'
)
elif expected_input.HasField("key_prefix"):
# Nothing to check
pass
else:
raise ValueError("Invalid operator definition")

# Check that no unexpected inputs are present
for available_input in self._inputs:
if available_input not in [v.key for v in definition.inputs]:
raise ValueError(f'Unexpected input "{available_input}".')
num_multi_input_matches = sum(
available_input.startswith(v.key_prefix)
for v in definition.inputs
if v.HasField("key_prefix")
)
if num_multi_input_matches > 1:
raise ValueError(
f'Input "{available_input}" matches multiple prefix'
" inputs."
)

if available_input in [
v.key for v in definition.inputs if v.HasField("key")
]:
if num_multi_input_matches != 0:
raise ValueError(
f'Input "{available_input}" matches both a prefix'
" and non-prefix input."
)
else:
if num_multi_input_matches != 1:
raise ValueError(
f'Unexpected input "{available_input}".'
)

# Check that expected outputs are present
for expected_output in definition.outputs:
Expand Down
16 changes: 5 additions & 11 deletions temporian/core/operators/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from temporian.proto import core_pb2 as pb
from temporian.utils.typecheck import typecheck

MAX_NUM_ARGUMENTS = 30
_INPUT_KEY_PREFIX = "input_"


class How(str, Enum):
Expand Down Expand Up @@ -59,11 +59,6 @@ def __init__(self, how: How, **inputs: EventSetNode):
if len(inputs) < 2:
raise ValueError("At least two arguments should be provided")

if len(inputs) >= MAX_NUM_ARGUMENTS:
raise ValueError(
f"Too many (>{MAX_NUM_ARGUMENTS}) arguments provided"
)

# Attributes
self._how = how
self.add_attribute("how", how)
Expand Down Expand Up @@ -110,10 +105,7 @@ def build_op_definition(cls) -> pb.OperatorDef:
type=pb.OperatorDef.Attribute.Type.STRING,
),
],
inputs=[
pb.OperatorDef.Input(key=f"input_{idx}", is_optional=idx >= 2)
for idx in range(MAX_NUM_ARGUMENTS)
],
inputs=[pb.OperatorDef.Input(key_prefix=_INPUT_KEY_PREFIX)],
outputs=[pb.OperatorDef.Output(key="output")],
)

Expand Down Expand Up @@ -249,5 +241,7 @@ def combine(
return inputs[0]

# NOTE: input name must match op. definition name
inputs_dict = {f"input_{idx}": input for idx, input in enumerate(inputs)}
inputs_dict = {
f"{_INPUT_KEY_PREFIX}{idx}": input for idx, input in enumerate(inputs)
}
return Combine(how=how, **inputs_dict).outputs["output"] # type: ignore
21 changes: 5 additions & 16 deletions temporian/core/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from temporian.proto import core_pb2 as pb
from temporian.utils.typecheck import typecheck

# Maximum number of arguments taken by the glue operator
MAX_NUM_ARGUMENTS = 100
_INPUT_KEY_PREFIX = "input_"


class GlueOperator(Operator):
Expand All @@ -43,11 +42,6 @@ def __init__(
if len(inputs) < 2:
raise ValueError("At least two arguments should be provided.")

if len(inputs) >= MAX_NUM_ARGUMENTS:
raise ValueError(
f"Too many (>{MAX_NUM_ARGUMENTS}) arguments provided."
)

# inputs
output_features = []
output_feature_schemas = []
Expand Down Expand Up @@ -95,11 +89,7 @@ def __init__(
def build_op_definition(cls) -> pb.OperatorDef:
return pb.OperatorDef(
key="GLUE",
# TODO: Add support to array of nodes arguments.
inputs=[
pb.OperatorDef.Input(key=f"input_{idx}", is_optional=idx >= 2)
for idx in range(MAX_NUM_ARGUMENTS)
],
inputs=[pb.OperatorDef.Input(key_prefix=_INPUT_KEY_PREFIX)],
outputs=[pb.OperatorDef.Output(key="output")],
)

Expand Down Expand Up @@ -203,10 +193,9 @@ def glue(
"""
if len(inputs) == 1 and isinstance(inputs[0], EventSetNode):
return inputs[0]

# Note: The node should be called "input_{idx}" with idx in [0, MAX_NUM_ARGUMENTS).
inputs_dict = {f"input_{idx}": input for idx, input in enumerate(inputs)}

inputs_dict = {
f"{_INPUT_KEY_PREFIX}{idx}": input for idx, input in enumerate(inputs)
}
return GlueOperator(**inputs_dict).outputs["output"] # type: ignore


Expand Down
48 changes: 41 additions & 7 deletions temporian/core/operators/test/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

from absl.testing import absltest
from absl.testing.parameterized import TestCase
from temporian.core.operators.glue import MAX_NUM_ARGUMENTS, glue

from temporian.core.operators.glue import glue
import tempfile
import os
from temporian.implementation.numpy.data.io import event_set
from temporian.test.utils import assertOperatorResult, f32
from temporian.core import serialization
from temporian.core import evaluation


class GlueTest(TestCase):
Expand Down Expand Up @@ -96,11 +99,6 @@ def test_no_evsets(self):
):
glue()

def test_too_many_evsets(self):
evsets = [event_set([], features={"a": []})] * (MAX_NUM_ARGUMENTS + 1)
with self.assertRaisesRegex(ValueError, "Too many"):
glue(*evsets)

def test_order_unchanged(self):
"""Tests that input evsets' order is kept.
Expand Down Expand Up @@ -164,6 +162,42 @@ def test_order_unchanged(self):

assertOperatorResult(self, result, expected)

def test_serialization(self):
# Generate some input event sets.
num_inputs = 1000
input_0 = event_set([0], {"f0": [0]})
inputs = [
event_set([0], {f"f{i}": [i]}, same_sampling_as=input_0)
for i in range(1, num_inputs)
]
inputs = [input_0, *inputs]

# Glue the inputs together.
input_nodes = [e.node() for e in inputs]
output_node = glue(*input_nodes)

# Save and restore the graph.
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, "my_fn.tem")
serialization.save_graph(
inputs={f"i_{i}": v for i, v in enumerate(input_nodes)},
outputs={"output": output_node},
path=path,
)
loaded_inputs, loaded_output = serialization.load_graph(
path=path, squeeze=True
)

# Execute the loaded graph.
output = evaluation.run(
loaded_output,
{loaded_inputs[f"i_{i}"]: v for i, v in enumerate(inputs)},
)
expeted_output = event_set(
[0], {f"f{i}": [i] for i in range(0, num_inputs)}
)
self.assertEqual(output, expeted_output)


if __name__ == "__main__":
absltest.main()
86 changes: 86 additions & 0 deletions temporian/core/test/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,92 @@ def build_fake_node():
):
t.check()

def test_check_operator_with_key_prefix(self):
class ToyOperator(base.Operator):
@classmethod
def build_op_definition(cls) -> pb.OperatorDef:
return pb.OperatorDef(
key="TOY",
inputs=[pb.OperatorDef.Input(key_prefix="input_")],
)

def _get_pandas_implementation(self):
raise NotImplementedError()

def build_fake_node():
return input_node(features=[])

t = ToyOperator()
t.check()

t.add_input("input", build_fake_node())
with self.assertRaisesRegex(ValueError, 'Unexpected input "input"'):
t.check()

t = ToyOperator()
t.add_input("input_1", build_fake_node())
t.add_input("input_2", build_fake_node())
t.add_input("input_3", build_fake_node())

with self.assertRaisesRegex(
ValueError, 'Already existing input "input_3"'
):
t.add_input("input_3", build_fake_node())

t.check()

def test_check_operator_error_overlapping_prefixes(self):
class ToyOperator(base.Operator):
@classmethod
def build_op_definition(cls) -> pb.OperatorDef:
return pb.OperatorDef(
key="TOY",
inputs=[
pb.OperatorDef.Input(key_prefix="input_"),
pb.OperatorDef.Input(key_prefix="input"),
],
)

def _get_pandas_implementation(self):
raise NotImplementedError()

def build_fake_node():
return input_node(features=[])

t = ToyOperator()
t.add_input("input_1", build_fake_node())
with self.assertRaisesRegex(
ValueError, 'Input "input_1" matches multiple prefix inputs'
):
t.check()

def test_check_operator_error_overlapping_prefix_and_non_prefix(
self,
):
class ToyOperator(base.Operator):
@classmethod
def build_op_definition(cls) -> pb.OperatorDef:
return pb.OperatorDef(
key="TOY",
inputs=[
pb.OperatorDef.Input(key_prefix="input_"),
pb.OperatorDef.Input(key="input_1"),
],
)

def _get_pandas_implementation(self):
raise NotImplementedError()

def build_fake_node():
return input_node(features=[])

t = ToyOperator()
t.add_input("input_1", build_fake_node())
with self.assertRaisesRegex(
ValueError, "matches both a prefix and non-prefix input"
):
t.check()


if __name__ == "__main__":
absltest.main()
56 changes: 56 additions & 0 deletions temporian/core/test/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,62 @@ def test_serialize(self):
& serialization._all_identifiers(restored.named_outputs.values())
)

def test_serialize_with_key_prefix(self):
i1 = utils.create_simple_input_node()
i2 = utils.create_simple_input_node()
i3 = utils.create_simple_input_node()
o2 = utils.OpMultiIO1(input_1=i1, input_2=i2, input_3=i3)

original = graph.infer_graph_named_nodes(
{"io_input_1": i1, "io_input_2": i2, "io_input_3": i3},
{"io_output": o2.outputs["output"]},
)
logging.info("original:\n%s", original)

proto = serialization._serialize(original)
logging.info("proto:\n%s", proto)

restored = serialization._unserialize(proto)
logging.info("restored:\n%s", restored)

self.assertEqual(len(original.samplings), len(restored.samplings))
self.assertEqual(len(original.features), len(restored.features))
self.assertEqual(len(original.operators), len(restored.operators))
self.assertEqual(len(original.nodes), len(restored.nodes))
self.assertEqual(
original.named_inputs.keys(), restored.named_inputs.keys()
)
self.assertEqual(
original.named_outputs.keys(), restored.named_outputs.keys()
)
# TODO: Deep equality tests.

# Ensures that "original" and "restored" don't link to the same objects.
self.assertFalse(
serialization._all_identifiers(original.samplings)
& serialization._all_identifiers(restored.samplings)
)
self.assertFalse(
serialization._all_identifiers(original.features)
& serialization._all_identifiers(restored.features)
)
self.assertFalse(
serialization._all_identifiers(original.operators)
& serialization._all_identifiers(restored.operators)
)
self.assertFalse(
serialization._all_identifiers(original.nodes)
& serialization._all_identifiers(restored.nodes)
)
self.assertFalse(
serialization._all_identifiers(original.named_inputs.values())
& serialization._all_identifiers(restored.named_inputs.values())
)
self.assertFalse(
serialization._all_identifiers(original.named_outputs.values())
& serialization._all_identifiers(restored.named_outputs.values())
)

def test_serialize_autonode(self):
input_data = event_set(
timestamps=[1, 2, 3, 4],
Expand Down

0 comments on commit fd8dd8b

Please sign in to comment.