Skip to content

Commit

Permalink
Fix timezone issue in materialize & materialize_incremental (#1439)
Browse files Browse the repository at this point in the history
Signed-off-by: Tsotne Tabidze <tsotne@tecton.ai>
  • Loading branch information
Tsotne Tabidze committed Apr 7, 2021
1 parent 9e5377c commit 6d7678f
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 173 deletions.
5 changes: 2 additions & 3 deletions sdk/python/feast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import click
import pkg_resources
import yaml
from pytz import utc

from feast.client import Client
from feast.config import Config
Expand Down Expand Up @@ -425,8 +424,8 @@ def materialize_command(start_ts: str, end_ts: str, repo_path: str, views: List[
store = FeatureStore(repo_path=repo_path)
store.materialize(
feature_views=None if not views else views,
start_date=datetime.fromisoformat(start_ts).replace(tzinfo=utc),
end_date=datetime.fromisoformat(end_ts).replace(tzinfo=utc),
start_date=datetime.fromisoformat(start_ts),
end_date=datetime.fromisoformat(end_ts),
)


Expand Down
5 changes: 5 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
import pyarrow

from feast import utils
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.infra.provider import Provider, RetrievalJob, get_provider
Expand Down Expand Up @@ -359,6 +360,10 @@ def _materialize_single_feature_view(
event_timestamp_column,
created_timestamp_column,
) = _run_reverse_field_mapping(feature_view)

start_date = utils.make_tzaware(start_date)
end_date = utils.make_tzaware(end_date)

provider = self._get_provider()
table = provider.pull_latest_from_table_or_query(
feature_view.input,
Expand Down
6 changes: 5 additions & 1 deletion sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.protobuf.duration_pb2 import Duration
from google.protobuf.timestamp_pb2 import Timestamp

from feast import utils
from feast.data_source import BigQuerySource, DataSource, FileSource
from feast.feature import Feature
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
Expand Down Expand Up @@ -166,7 +167,10 @@ def from_proto(cls, feature_view_proto: FeatureViewProto):

for interval in feature_view_proto.meta.materialization_intervals:
feature_view.materialization_intervals.append(
(interval.start_time.ToDatetime(), interval.end_time.ToDatetime())
(
utils.make_tzaware(interval.start_time.ToDatetime()),
utils.make_tzaware(interval.end_time.ToDatetime()),
)
)

return feature_view
Expand Down
21 changes: 6 additions & 15 deletions sdk/python/feast/infra/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import pyarrow
from google.cloud import bigquery
from jinja2 import BaseLoader, Environment
from pytz import utc

from feast import FeatureTable
from feast import FeatureTable, utils
from feast.data_source import BigQuerySource, DataSource
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
Expand Down Expand Up @@ -252,14 +251,14 @@ def _write_minibatch(

entity = client.get(key)
if entity is not None:
if entity["event_ts"] > _make_tzaware(timestamp):
if entity["event_ts"] > utils.make_tzaware(timestamp):
# Do not overwrite feature values computed from fresher data
continue
elif (
entity["event_ts"] == _make_tzaware(timestamp)
entity["event_ts"] == utils.make_tzaware(timestamp)
and created_ts is not None
and entity["created_ts"] is not None
and entity["created_ts"] > _make_tzaware(created_ts)
and entity["created_ts"] > utils.make_tzaware(created_ts)
):
# Do not overwrite feature values computed from the same data, but
# computed later than this one
Expand All @@ -273,9 +272,9 @@ def _write_minibatch(
values={
k: v.SerializeToString() for k, v in features.items()
},
event_ts=_make_tzaware(timestamp),
event_ts=utils.make_tzaware(timestamp),
created_ts=(
_make_tzaware(created_ts)
utils.make_tzaware(created_ts)
if created_ts is not None
else None
),
Expand Down Expand Up @@ -316,14 +315,6 @@ def compute_datastore_entity_id(entity_key: EntityKeyProto) -> str:
return mmh3.hash_bytes(serialize_entity_key(entity_key)).hex()


def _make_tzaware(t: datetime):
""" We assume tz-naive datetimes are UTC """
if t.tzinfo is None:
return t.replace(tzinfo=utc)
else:
return t


class BigQueryRetrievalJob(RetrievalJob):
def __init__(self, query):
self.query = query
Expand Down
7 changes: 7 additions & 0 deletions sdk/python/feast/infra/local_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def pull_latest_from_table_or_query(
) -> pyarrow.Table:
assert isinstance(data_source, FileSource)
source_df = pd.read_parquet(data_source.path)
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
source_df[event_timestamp_column] = source_df[event_timestamp_column].apply(
lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc)
)
source_df[created_timestamp_column] = source_df[created_timestamp_column].apply(
lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc)
)

ts_columns = (
[event_timestamp_column, created_timestamp_column]
Expand Down
11 changes: 11 additions & 0 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from datetime import datetime

from pytz import utc


def make_tzaware(t: datetime):
""" We assume tz-naive datetimes are UTC """
if t.tzinfo is None:
return t.replace(tzinfo=utc)
else:
return t
176 changes: 176 additions & 0 deletions sdk/python/tests/test_materialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import contextlib
import tempfile
import time
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Iterator, Tuple, Union

import pandas as pd
import pytest
from google.cloud import bigquery
from pytz import timezone, utc

from feast.data_format import ParquetFormat
from feast.data_source import BigQuerySource, FileSource
from feast.feature import Feature
from feast.feature_store import FeatureStore
from feast.feature_view import FeatureView
from feast.repo_config import LocalOnlineStoreConfig, OnlineStoreConfig, RepoConfig
from feast.value_type import ValueType


def create_dataset() -> pd.DataFrame:
now = datetime.utcnow()
ts = pd.Timestamp(now).round("ms")
data = {
"id": [1, 2, 1, 3, 3],
"value": [0.1, 0.2, 0.3, 4, 5],
"ts_1": [
ts - timedelta(hours=4),
ts,
ts - timedelta(hours=3),
# Use different time zones to test tz-naive -> tz-aware conversion
(ts - timedelta(hours=4))
.replace(tzinfo=utc)
.astimezone(tz=timezone("Europe/Berlin")),
(ts - timedelta(hours=1))
.replace(tzinfo=utc)
.astimezone(tz=timezone("US/Pacific")),
],
"created_ts": [ts, ts, ts, ts, ts],
}
return pd.DataFrame.from_dict(data)


def get_feature_view(data_source: Union[FileSource, BigQuerySource]) -> FeatureView:
return FeatureView(
name="test_bq_correctness",
entities=["driver_id"],
features=[Feature("value", ValueType.FLOAT)],
ttl=timedelta(days=5),
input=data_source,
)


# bq_source_type must be one of "query" and "table"
@contextlib.contextmanager
def prep_bq_fs_and_fv(
bq_source_type: str,
) -> Iterator[Tuple[FeatureStore, FeatureView]]:
client = bigquery.Client()
gcp_project = client.project
bigquery_dataset = "test_ingestion"
dataset = bigquery.Dataset(f"{gcp_project}.{bigquery_dataset}")
client.create_dataset(dataset, exists_ok=True)
dataset.default_table_expiration_ms = (
1000 * 60 * 60 * 24 * 14
) # 2 weeks in milliseconds
client.update_dataset(dataset, ["default_table_expiration_ms"])

df = create_dataset()

job_config = bigquery.LoadJobConfig()
table_ref = f"{gcp_project}.{bigquery_dataset}.{bq_source_type}_correctness_{int(time.time())}"
query = f"SELECT * FROM `{table_ref}`"
job = client.load_table_from_dataframe(df, table_ref, job_config=job_config)
job.result()

bigquery_source = BigQuerySource(
table_ref=table_ref if bq_source_type == "table" else None,
query=query if bq_source_type == "query" else None,
event_timestamp_column="ts",
created_timestamp_column="created_ts",
date_partition_column="",
field_mapping={"ts_1": "ts", "id": "driver_id"},
)

fv = get_feature_view(bigquery_source)
with tempfile.TemporaryDirectory() as repo_dir_name:
config = RepoConfig(
registry=str(Path(repo_dir_name) / "registry.db"),
project=f"test_bq_correctness_{uuid.uuid4()}",
provider="gcp",
)
fs = FeatureStore(config=config)
fs.apply([fv])

yield fs, fv


@contextlib.contextmanager
def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]:
with tempfile.NamedTemporaryFile(suffix=".parquet") as f:
df = create_dataset()
f.close()
df.to_parquet(f.name)
file_source = FileSource(
file_format=ParquetFormat(),
file_url=f"file://{f.name}",
event_timestamp_column="ts",
created_timestamp_column="created_ts",
date_partition_column="",
field_mapping={"ts_1": "ts", "id": "driver_id"},
)
fv = get_feature_view(file_source)
with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name:
config = RepoConfig(
registry=str(Path(repo_dir_name) / "registry.db"),
project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}",
provider="local",
online_store=OnlineStoreConfig(
local=LocalOnlineStoreConfig(
path=str(Path(data_dir_name) / "online_store.db")
)
),
)
fs = FeatureStore(config=config)
fs.apply([fv])

yield fs, fv


def run_materialization_test(fs: FeatureStore, fv: FeatureView) -> None:
now = datetime.utcnow()
# Run materialize()
# use both tz-naive & tz-aware timestamps to test that they're both correctly handled
start_date = (now - timedelta(hours=5)).replace(tzinfo=utc)
end_date = now - timedelta(hours=2)
fs.materialize([fv.name], start_date, end_date)

# check result of materialize()
response_dict = fs.get_online_features(
[f"{fv.name}:value"], [{"driver_id": 1}]
).to_dict()
assert abs(response_dict[f"{fv.name}__value"][0] - 0.3) < 1e-6

# check prior value for materialize_incremental()
response_dict = fs.get_online_features(
[f"{fv.name}:value"], [{"driver_id": 3}]
).to_dict()
assert abs(response_dict[f"{fv.name}__value"][0] - 4) < 1e-6

# run materialize_incremental()
fs.materialize_incremental(
[fv.name], now - timedelta(seconds=0),
)

# check result of materialize_incremental()
response_dict = fs.get_online_features(
[f"{fv.name}:value"], [{"driver_id": 3}]
).to_dict()
assert abs(response_dict[f"{fv.name}__value"][0] - 5) < 1e-6


@pytest.mark.integration
@pytest.mark.parametrize(
"bq_source_type", ["query", "table"],
)
def test_bq_materialization(bq_source_type: str):
with prep_bq_fs_and_fv(bq_source_type) as (fs, fv):
run_materialization_test(fs, fv)


def test_local_materialization():
with prep_local_fs_and_fv() as (fs, fv):
run_materialization_test(fs, fv)

0 comments on commit 6d7678f

Please sign in to comment.