In [1]:
import os, sys, subprocess, shutil, json
import pyspark, tfx, apache_beam

## ENV Checks

In [2]:
SPARK_HOME = os.path.dirname(pyspark.__file__)
JAVA_BIN = shutil.which("java")
JAVA_HOME = os.path.dirname(os.path.dirname(os.path.realpath(JAVA_BIN))) if JAVA_BIN else ""

os.environ["SPARK_HOME"] = SPARK_HOME
os.environ["JAVA_HOME"] = JAVA_HOME
os.environ["PATH"] = f'{os.path.join(SPARK_HOME,"bin")}:{os.environ["PATH"]}'
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
os.environ["SPARK_LOCAL_IP"] = os.environ.get("SPARK_LOCAL_IP","127.0.0.1")  # match the job server bind

print("TFX:", tfx.__version__)
print("Beam:", apache_beam.__version__)
print("SPARK_HOME:", os.environ["SPARK_HOME"])
print("JAVA_HOME:", os.environ["JAVA_HOME"])

TFX: 1.16.0
Beam: 2.59.0
SPARK_HOME: /home/jpg/miniconda3/envs/tfx-spark-env/lib/python3.10/site-packages/pyspark
JAVA_HOME: /home/jpg/miniconda3/envs/tfx-spark-env


In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[2]").appName("sanity").getOrCreate()
print("Spark version:", spark.version)
spark.stop()

import subprocess, shlex
print(subprocess.check_output(shlex.split("java -version"), stderr=subprocess.STDOUT).decode())


Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/28 10:57:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Spark version: 3.2.3
openjdk version "11.0.1" 2018-10-16 LTS
OpenJDK Runtime Environment Zulu11.2+3 (build 11.0.1+13-LTS)
OpenJDK 64-Bit Server VM Zulu11.2+3 (build 11.0.1+13-LTS, mixed mode)



In [4]:
import os, tempfile, pathlib
os.environ["SEMI_PERSISTENT_DIRECTORY"] = str(pathlib.Path(tempfile.gettempdir())/"beam_scratch")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

## POC

In [5]:
import os
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.orchestration import pipeline as tfx_pipeline
from tfx.orchestration.metadata import sqlite_metadata_connection_config
from tfx.components import CsvExampleGen, StatisticsGen

2025-09-28 10:57:06.498789: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-28 10:57:06.535911: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-28 10:57:06.535949: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [6]:
HOME = '/home/jpg/Desktop/tfx-spark'
PIPELINE_NAME = "tfx_spark_demo_parquet"
PIPELINE_ROOT = os.path.join(HOME, "tfx", "pipelines", PIPELINE_NAME)
METADATA_PATH = os.path.join(HOME, "tfx", "metadata", PIPELINE_NAME, "metadata.db")
DATA_ROOT = os.path.join(HOME, "tfx", "data", "parquet_demo")
os.makedirs(PIPELINE_ROOT, exist_ok=True)
os.makedirs(os.path.dirname(METADATA_PATH), exist_ok=True)
os.makedirs(DATA_ROOT, exist_ok=True)

In [None]:
example_gen = CsvExampleGen(input_base=DATA_ROOT)
stats_gen = StatisticsGen(examples=example_gen.outputs["examples"])

In [None]:
beam_pipeline_args = [
    "--runner=DirectRunner",
    "--direct_running_mode=multi_processing",
    "--direct_num_workers=4",
]

In [None]:
metadata_config = sqlite_metadata_connection_config(METADATA_PATH)


In [None]:
def build_pipeline():
    return tfx_pipeline.Pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_root=PIPELINE_ROOT,
        components=[example_gen, stats_gen],
        metadata_connection_config=metadata_config,
        enable_cache=True,
        beam_pipeline_args=beam_pipeline_args,
    )


In [None]:
BeamDagRunner().run(build_pipeline())