In [1]:
import pathlib
import pprint

import tensorflow as tf
from absl import logging
from tfx import v1 as tfx
from tfx.components import ExampleValidator
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.proto import example_gen_pb2
from tfx.v1 import proto

from mymllib import Utils, rawdata_duckdb_sql
from mymllib.tfx.example_gen.duckdb.component import DuckDBExampleGen

pp = pprint.PrettyPrinter()
logging.set_verbosity(logging.INFO)  # Set default logging level.

2023-06-21 16:16:41.037904: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))

TensorFlow version: 2.10.1
TFX version: 1.11.0


In [3]:
###################################################################################################################
COMPETITION = 'store-sales-time-series-forecasting'
PIPELINE_NAME = COMPETITION
CURRENT_DIR: pathlib.Path = pathlib.Path(globals()['_dh'][0]) if '_dh' in globals() else pathlib.Path(__file__).parent
DATA_DIR = CURRENT_DIR.joinpath("data")
TFX_DIR = CURRENT_DIR.joinpath("out/tfx").joinpath(PIPELINE_NAME)
PIPELINE_ROOT = TFX_DIR.joinpath('pipelines')
METADATA_PATH = TFX_DIR.joinpath('metadata.db')
SERVING_MODEL_DIR = TFX_DIR.joinpath('serving_model')
###########################Set up variables#####################################################################
print("Data dir: %s" % DATA_DIR)
print("tfx dir: %s" % TFX_DIR)
print("Pipelines root: %s" % PIPELINE_ROOT)

Data dir: /Users/ismailsimsek/development/StoreSalesTimeSeriesForecasting/data
tfx dir: /Users/ismailsimsek/development/StoreSalesTimeSeriesForecasting/out/tfx/store-sales-time-series-forecasting
Pipelines root: /Users/ismailsimsek/development/StoreSalesTimeSeriesForecasting/out/tfx/store-sales-time-series-forecasting/pipelines


In [4]:
Utils.download(datadir=DATA_DIR, competition=COMPETITION)

In [5]:
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.orchestration import metadata
# Declare the InteractiveContext and use a local sqlite file as the metadata store.
metadata_connection_config = metadata.sqlite_metadata_connection_config(METADATA_PATH.as_posix())
context = InteractiveContext(pipeline_root=PIPELINE_ROOT.as_posix(), 
                             metadata_connection_config=metadata_connection_config
                            )

In [6]:
components = []
output = proto.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
        proto.SplitConfig.Split(name='train', hash_buckets=8),
        proto.SplitConfig.Split(name='eval', hash_buckets=2)
    ]))
example_gen = DuckDBExampleGen(query=rawdata_duckdb_sql(data="train"), output_config=output)
components.append(example_gen)

In [7]:
context.run(example_gen, enable_cache=True)

INFO:absl:Running driver for DuckDBExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for DuckDBExampleGen
INFO:absl:Generating examples.


FloatProgress(value=0.0, layout=Layout(width='100%'), style=ProgressStyle(bar_color='black'))


KeyboardInterrupt



In [None]:
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen, enable_cache=True)

In [None]:
context.show(statistics_gen.outputs['statistics'])

In [None]:
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
context.run(schema_gen, enable_cache=True)

In [None]:
from tfx.components import ExampleValidator

example_validator = ExampleValidator(
  statistics=statistics_gen.outputs['statistics'],
  schema=schema_gen.outputs['schema']
)
context.run(example_validator)


In [None]:
import tensorflow_data_validation as tfdv

# stats = tfdv.generate_statistics_from_tfrecord(data_location=path)
