Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename and unify evaluation i/o types #198

Merged
merged 2 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: temporian.core.mixins.EventSetOperationsMixin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: temporian.implementation.numpy.data.event_set.EventSetCollection
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: temporian.core.data.node.EventSetNodeCollection
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: temporian.implementation.numpy.data.event_set.NodeToEventSetMapping
17 changes: 8 additions & 9 deletions temporian/core/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

from functools import wraps
from copy import copy
from typing import Any, Dict, Optional, Tuple, Callable, TypeVar
from temporian.core.data.node import EventSetNode
from typing import Any, Dict, Optional, Tuple, Callable
from temporian.core.data.node import EventSetNode, EventSetNodeCollection
from temporian.implementation.numpy.data.event_set import EventSet

T = TypeVar("T", bound=Callable)


# TODO: unify the fn's output type with run's EvaluationQuery, and add it to the
# public API so it shows in the docs.
# TODO: make compile change the fn's annotations to EventSetOrNode
def compile(fn: Optional[Callable] = None, *, verbose: int = 0) -> Any:
def compile(
fn: Optional[Callable[..., EventSetNodeCollection]] = None,
*,
verbose: int = 0
) -> Any:
"""Compiles a Temporian function.

A Temporian function is a function that takes
Expand Down Expand Up @@ -74,7 +73,7 @@ def compile(fn: Optional[Callable] = None, *, verbose: int = 0) -> Any:
The compiled function.
"""

def _compile(fn: T) -> T:
def _compile(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
is_eager = None
Expand Down
31 changes: 22 additions & 9 deletions temporian/core/data/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,32 @@

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union

from temporian.core.data.dtype import DType, IndexDType
from temporian.core.data.schema import Schema, FeatureSchema, IndexSchema
from temporian.core.mixins import EventSetOperationsMixin
from temporian.utils import string

if TYPE_CHECKING:
from temporian.core.evaluation import EvaluationInput, EvaluationResult
from temporian.core.operators.base import Operator
from temporian.implementation.numpy.data.event_set import EventSetCollection
from temporian.implementation.numpy.data.event_set import (
NodeToEventSetMapping,
)


class EventSetNode(EventSetOperationsMixin):
"""A EventSetNode is a reference to the input/output of ops in a compute graph.
"""An EventSetNode is a reference to the input/output of ops in a compute
graph.

Use [`tp.input_node()`][temporian.input_node] to create an EventSetNode manually, or
use [`event_set.node()`][temporian.EventSet.node] to create an EventSetNode
compatible with a given [`EventSet`][temporian.EventSet].
Use [`tp.input_node()`][temporian.input_node] to create an EventSetNode
manually, or use [`event_set.node()`][temporian.EventSet.node] to create an
EventSetNode compatible with a given [`EventSet`][temporian.EventSet].

A EventSetNode does not contain any data. Use
[`node.run()`][temporian.EventSetNode.run] to get the
[`EventSet`][temporian.EventSet] resulting from an [`EventSetNodes`][temporian.EventSetNode].
[`EventSet`][temporian.EventSet] resulting from an EventSetNode.
"""

def __init__(
Expand Down Expand Up @@ -143,10 +147,10 @@ def check_same_sampling(self, other: EventSetNode):

def run(
self,
input: EvaluationInput,
input: NodeToEventSetMapping,
verbose: int = 0,
check_execution: bool = True,
) -> EvaluationResult:
) -> EventSetCollection:
"""Evaluates the EventSetNode on the specified input.

See [`tp.run()`][temporian.run] for details.
Expand Down Expand Up @@ -174,6 +178,15 @@ def __repr__(self) -> str:
)


EventSetNodeCollection = Union[
EventSetNode, List[EventSetNode], Set[EventSetNode], Dict[str, EventSetNode]
]
"""A collection of [`EventSetNodes`][temporian.EventSetNode].

This can be a single EventSetNode, a list or set of EventSetNodes, or a
dictionary mapping names to EventSetNodes."""


def input_node(
features: List[Tuple[str, DType]],
indexes: Optional[List[Tuple[str, IndexDType]]] = None,
Expand Down
45 changes: 19 additions & 26 deletions temporian/core/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,28 @@

import time
import sys
from typing import Dict, List, Set, Union, Optional
from typing import Dict, List, Set, Optional
from collections import defaultdict

from temporian.core.data.node import EventSetNode
from temporian.core.data.node import EventSetNode, EventSetNodeCollection
from temporian.core.operators.base import Operator
from temporian.implementation.numpy import evaluation as np_eval
from temporian.implementation.numpy.data.event_set import EventSet
from temporian.implementation.numpy.data.event_set import (
EventSet,
EventSetCollection,
NodeToEventSetMapping,
)
from temporian.core.graph import infer_graph
from temporian.core.schedule import Schedule, ScheduleStep
from temporian.core.operators.leak import LeakOperator

EvaluationQuery = Union[
EventSetNode, List[EventSetNode], Set[EventSetNode], Dict[str, EventSetNode]
]
EvaluationInput = Union[
# A dict of EventSetNodes to corresponding EventSet.
Dict[EventSetNode, EventSet],
# A single EventSet. Equivalent to {event_set.node() : event_set}.
EventSet,
# A list of EventSets. Feed each EventSet individually like EventSet.
List[EventSet],
]
EvaluationResult = Union[EventSet, List[EventSet], Dict[str, EventSet]]


def run(
query: EvaluationQuery,
input: EvaluationInput,
query: EventSetNodeCollection,
input: NodeToEventSetMapping,
verbose: int = 0,
check_execution: bool = True,
) -> EvaluationResult:
) -> EventSetCollection:
"""Evaluates [`EventSetNodes`][temporian.EventSetNode] on [`EventSets`][temporian.EventSet].

Performs all computation defined by the graph between the `query` EventSetNodes and
Expand Down Expand Up @@ -291,8 +282,8 @@ def build_schedule(


def has_leak(
output: EvaluationQuery,
input: Optional[EvaluationQuery] = None,
output: EventSetNodeCollection,
input: Optional[EventSetNodeCollection] = None,
) -> bool:
"""Tests if a node depends on a leak operator.

Expand Down Expand Up @@ -346,7 +337,9 @@ def has_leak(
return False


def _normalize_input(input: EvaluationInput) -> Dict[EventSetNode, EventSet]:
def _normalize_input(
input: NodeToEventSetMapping,
) -> Dict[EventSetNode, EventSet]:
"""Normalizes an input into a dictionary of node to evsets."""

if isinstance(input, dict):
Expand Down Expand Up @@ -374,7 +367,7 @@ def _normalize_input(input: EvaluationInput) -> Dict[EventSetNode, EventSet]:
)


def _normalize_query(query: EvaluationQuery) -> Set[EventSetNode]:
def _normalize_query(query: EventSetNodeCollection) -> Set[EventSetNode]:
"""Normalizes a query into a list of query EventSetNodes."""

if isinstance(query, EventSetNode):
Expand All @@ -390,14 +383,14 @@ def _normalize_query(query: EvaluationQuery) -> Set[EventSetNode]:
return set(query.values())

raise TypeError(
f"Evaluate query argument must be one of {EvaluationQuery}."
f"Evaluate query argument must be one of {EventSetNodeCollection}."
f" Received {type(query)} instead."
)


def _denormalize_outputs(
outputs: Dict[EventSetNode, EventSet], query: EvaluationQuery
) -> EvaluationResult:
outputs: Dict[EventSetNode, EventSet], query: EventSetNodeCollection
) -> EventSetCollection:
"""Converts outputs into the same format as the query."""

if isinstance(query, EventSetNode):
Expand Down
3 changes: 1 addition & 2 deletions temporian/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
EventSetNode,
Sampling,
Feature,
create_node_with_new_reference,
input_node,
)
from temporian.core.data.schema import Schema
Expand All @@ -57,7 +56,7 @@
INV_DTYPE_MAPPING = {v: k for k, v in DTYPE_MAPPING.items()}


# TODO: allow saved fn to return a single Node too
# TODO: allow saved fn to return a list or node too (EventSetNodeCollection)
def save(
fn: Callable[..., Dict[str, EventSetNode]],
path: str,
Expand Down
22 changes: 21 additions & 1 deletion temporian/implementation/numpy/data/event_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import datetime
import sys

Expand Down Expand Up @@ -575,3 +575,23 @@ def creator(self) -> Optional[Operator]:
created EventSets have a `None` creator.
"""
return self.node()._creator


EventSetCollection = Union[EventSet, List[EventSet], Dict[str, EventSet]]
"""A collection of [`EventSets`][temporian.EventSet].

This can be a single EventSet, a list of EventSets, or a dictionary mapping
names to EventSets."""

NodeToEventSetMapping = Union[
Dict[EventSetNode, EventSet], EventSet, List[EventSet]
]
"""A mapping of [`EventSetNodes`][temporian.EventSetNode] to
[`EventSets`][temporian.EventSet].

If a dictionary, the mapping is defined by it.

If a single EventSet or a list of EventSets, each EventSet is mapped to their
own node using [`EventSet.node()`][temporian.EventSet.node], i.e., `[event_set]`
is equivalent to `{event_set.node() : event_set}`.
"""