In [1]:
import os, tempfile, pathlib
os.environ["SEMI_PERSISTENT_DIRECTORY"] = str(pathlib.Path(tempfile.gettempdir())/"beam_scratch")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [2]:
import logging
logging.getLogger().setLevel(logging.INFO)

In [3]:
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.orchestration import pipeline as tfx_pipeline
from tfx.orchestration.metadata import sqlite_metadata_connection_config
from tfx.components import CsvExampleGen, StatisticsGen
import os
import logging
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import tensorflow as tf
import pyarrow as pa
import pyarrow.dataset as ds

2025-09-28 11:55:54.194952: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-28 11:55:54.222170: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-28 11:55:54.222225: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO:absl:tensorflow_io is not available: No module named 'tensorflow_io'
INFO:absl:tensorflow_ranking is not available: No module named 'tensorflow_ranking'
INFO:absl:tensorflow_text is not available: No module named 'tensorflow_text'
INFO:absl:tensorflow_decision_forests is not available: No module named 'tensorflow_decision_forests'
INFO:absl:struct2tensor is

In [4]:
PIPELINE_NAME = "tfx_spark_demo_parquet"
PARQUET_PATH = "./tfx/data/parquet_demo/data.parquet"
TFRECORD_DIR = "./tfx/data/parquet_tfr/"
BASE_DIR = os.path.abspath("./tfx/")
PIPELINE_ROOT = os.path.join(BASE_DIR, "pipelines", PIPELINE_NAME)
METADATA_PATH = os.path.join(BASE_DIR, "metadata", PIPELINE_NAME, "metadata.db")
SERVING_MODEL_DIR = os.path.join(BASE_DIR, "serving_model", PIPELINE_NAME)

BEAM_ARGS = [
    "--runner=DirectRunner",
    "--direct_running_mode=multi_threading",
    "--direct_num_workers=2",
]
os.makedirs(TFRECORD_DIR, exist_ok=True)
os.makedirs(BASE_DIR, exist_ok=True)
os.makedirs(PIPELINE_ROOT, exist_ok=True)
os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
os.makedirs(SERVING_MODEL_DIR, exist_ok=True)

In [5]:
LABEL_KEY = "is_fraud"          # e.g., "is_fraud" | set to None if you don't have one yet
BINARY_CLASSIFICATION = True

In [6]:
import os
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import pyarrow as pa
import tensorflow as tf

In [7]:
def _infer_feature_types(parquet_path):
    d = ds.dataset(parquet_path, format="parquet")
    schema: pa.Schema = d.schema
    types = {}
    for f in schema:
        t = f.type
        if pa.types.is_integer(t):
            types[f.name] = "int"
        elif pa.types.is_floating(t) or pa.types.is_decimal(t):
            types[f.name] = "float"
        else:
            types[f.name] = "bytes"
    return types

In [8]:
def _to_tfexample(row, feature_types):
    feats = {}
    for k, v in row.items():
        t = feature_types.get(k, "bytes")
        if v is None:
            if t == "int":
                feats[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=[]))
            elif t == "float":
                feats[k] = tf.train.Feature(float_list=tf.train.FloatList(value=[]))
            else:
                feats[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[]))
            continue

        if not isinstance(v, (list, tuple)):
            v = [v]

        if t == "int":
            v = [int(x) for x in v if x is not None]
            feats[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
        elif t == "float":
            v = [float(x) for x in v if x is not None]
            feats[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
        else:
            v = [x if isinstance(x, (bytes, bytearray)) else str(x).encode("utf-8") for x in v if x is not None]
            feats[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))

    return tf.train.Example(features=tf.train.Features(feature=feats)).SerializeToString()


In [9]:
class ReadParquetAsDicts(beam.PTransform):
    def __init__(self, file_pattern, columns=None):
        super().__init__()
        self.file_pattern = file_pattern
        self.columns = columns

    def expand(self, p):
        from apache_beam.io.parquetio import ReadFromParquet
        return p | "ReadParquet" >> ReadFromParquet(self.file_pattern, columns=self.columns)


In [10]:
def parquet_to_tfrecords(parquet_path, output_dir, num_shards=32, beam_runner_args=None):
    os.makedirs(output_dir, exist_ok=True)
    feature_types = _infer_feature_types(parquet_path)
    options = PipelineOptions(beam_runner_args or [])
    prefix = os.path.join(output_dir, "data")
    with beam.Pipeline(options=options) as p:
        _ = (
            p
            | "ReadParquetDicts" >> ReadParquetAsDicts(parquet_path)
            | "RowToTFExample" >> beam.Map(_to_tfexample, feature_types=feature_types)
            | "WriteTFRecords" >> beam.io.tfrecordio.WriteToTFRecord(
                file_path_prefix=prefix, file_name_suffix=".tfrecord", num_shards=num_shards
            )
        )
    return prefix

In [11]:
# EDIT these paths
parquet_path = "./tfx/data/parquet_demo/data.parquet"   # file or glob
output_dir   = "./tfx/data/parquet_tfr/"

In [12]:
BEAM_ARGS = [
    "--runner=DirectRunner",
    "--direct_running_mode=multi_threading",
    "--direct_num_workers=2",
]


In [13]:
prefix = parquet_to_tfrecords(
    parquet_path=parquet_path,
    output_dir=output_dir,
    num_shards=4,
    beam_runner_args=BEAM_ARGS
)

INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting control server on port 39613
INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting data server on port 33073
INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting state server on port 44551
INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting logging server on port 33393
INFO:apache_beam.runners.worker.statecache:Creating state cache with size 104857600
INFO:apache_beam.runners.worker.sdk_worker:Creating insecure control channel for localhost:39613.
INFO:apache_beam.runners.worker.sdk_worker:Control channel established.
INFO:apache_beam.runners.worker.sdk_worker:Initializing SDKHarness with unbounded number of workers.
INFO:apache_beam.runners.worker.statecache:Creating state cache with size 104857600
INFO:apache_beam.runners.worker.sdk_worker:Creating insecure control channel for localhost:39613.
INFO:apache_beam.runners.worker.sdk_worker:Contr

In [14]:
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions

opts = PipelineOptions([
    "--runner=DirectRunner",
    "--direct_running_mode=multi_threading",
    "--direct_num_workers=2",
])

with beam.Pipeline(options=opts) as p:
    (
        p
        | "Create" >> beam.Create([1, 2, 3])
        | "Double" >> beam.Map(lambda x: x * 2)
        | "Print" >> beam.Map(print)
    )

INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting control server on port 39049
INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting data server on port 35409
INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting state server on port 46327
INFO:apache_beam.runners.portability.fn_api_runner.worker_handlers:starting logging server on port 43229
INFO:apache_beam.runners.worker.statecache:Creating state cache with size 104857600
INFO:apache_beam.runners.worker.sdk_worker:Creating insecure control channel for localhost:39049.
INFO:apache_beam.runners.worker.sdk_worker:Control channel established.
INFO:apache_beam.runners.worker.sdk_worker:Initializing SDKHarness with unbounded number of workers.
INFO:apache_beam.runners.worker.statecache:Creating state cache with size 104857600
INFO:apache_beam.runners.worker.sdk_worker:Creating insecure control channel for localhost:39049.
INFO:apache_beam.runners.worker.sdk_worker:Contr

24

6
