In [1]:
import time
from functools import lru_cache
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional
from urllib.parse import urlparse

import awswrangler as wr
import boto3
import datahub.emitter.mce_builder as builder
import pandas as pd
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from ds_dqv_tool import dqv_check, dqv_check_yaml
from ds_dqv_tool.recipes import condition_description_map, metric_description_map
from jinja2 import DebugUndefined, Template

graph = DataHubGraph(config=DatahubClientConfig(server="http://localhost:8091"))

# [methods for methods in graph.__dir__() if "assertion" in methods]


@lru_cache
def get_athena_table_dataset_urn(catalog: str, database: str, table: str, region: str) -> str:
    """e.g. urn:li:dataset:(urn:li:dataPlatform:hive,/iceberg/yellow_rides_hourly_actuals,PROD)"""
    session = boto3.Session(profile_name="sandbox")
    athena_client = session.client("athena", region_name=region)
    table_metadata = athena_client.get_table_metadata(CatalogName=catalog, DatabaseName=database, TableName=table)

    # Dataset has also its' physical location which we can add in symlink facet.
    s3_location = table_metadata["TableMetadata"]["Parameters"]["location"]
    parsed_path = urlparse(s3_location)

    return builder.make_dataset_urn(
        platform="hive",
        name=parsed_path.path,
    )

In [2]:
from datahub.emitter.serialization_helper import pre_json_transform


def make_assertion_urn(dataset_urn: str, assertion_name: str) -> str:
    return builder.make_assertion_urn(
        builder.datahub_guid(
            pre_json_transform(
                # these key-val pairs are essentially hashed; we want to choose pairs
                # that make the assertions unique (example: https://github.com/datahub-project/datahub/blob/d2d9d36987f20a9f7d6c973073d1404edf33e667/metadata-ingestion-modules/gx-plugin/src/datahub_gx_plugin/action.py#L277-L289)
                {
                    "platform": "pattern-ds-dqv",
                    # bad name since assertions and datasets have a many-to-many relationship
                    "dataset_urn": dataset_urn,
                    "assertion_name": assertion_name,
                }
            )
        )
    )

In [3]:
def substitute_map_into_string(string: str, values: dict[str, Any]) -> str:
    """
    Format a string using a dictionary with Jinja2 templating.

    :param string: The template string containing placeholders
    :param values: A dictionary of values to substitute into the template
    """
    template = Template(string, undefined=DebugUndefined)
    return template.render(values)


def query_pandas_from_athena(
    sql_query: str,
    glue_database: str,
    datalake_s3_bucket: str,
    # job_name: str,
    s3_output_location: Optional[str] = None,
    ctx: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
    # Validate job_name format
    # if not _is_valid_snake_case_identifier(job_name):
    #     raise ValueError(f"job_name must be a valid lower snake case identifier. Got: {job_name}")

    # Apply Jinja2 templating if context is provided
    if ctx is not None:
        sql_query = substitute_map_into_string(sql_query, ctx)

    session = boto3.Session(profile_name="sandbox", region_name="us-east-1")
    # region_name = session.region_name or "us-east-1"
    if s3_output_location is None:
        s3_output_location = f"s3://{datalake_s3_bucket}/athena-results"  # /{job_name}"

    # Emit OpenLineage START event for SQL query
    # query_run_id = str(generate_new_uuid())
    # emit_openlineage_start_event(job_name, sql_query, query_run_id)

    # Execute query using AWS Data Wrangler
    df = wr.athena.read_sql_query(
        sql=sql_query,
        database=glue_database,
        s3_output=s3_output_location,
        boto3_session=session,
    )

    return df

In [4]:
THIS_DIR = Path("metaflow")
SQL_DIR = THIS_DIR / "sql"
sql_path = SQL_DIR / "prepare_training_data.sql"
glue_database = "nyc_taxi"
datalake_s3_bucket = "airflow-metaflow-6721"
as_of_datetime = "2025-06-01 00:00:00.000"
lookback_days = 30


querys = {
    "raw_weather": """SELECT unique_row_id, filename, region_type, region_code, region_name, year, month, meteorological_element, day_01, day_02, day_03, day_04, day_05, day_06, day_07, day_08, day_09, day_10, day_11, day_12, day_13, day_14, day_15, day_16, day_17, day_18, day_19, day_20, day_21, day_22, day_23, day_24, day_25, day_26, day_27, day_28, day_29, day_30, day_31 FROM nyc_taxi.raw_weather""",
    "yellow_rides_hourly_forecast": """SELECT * from nyc_taxi.yellow_rides_hourly_forecast""",
    "yellow_rides_hourly_actuals": """SELECT year , month , day , hour , pulocationid , total_rides from nyc_taxi.yellow_rides_hourly_actuals""",
    "raw_yellow": """SELECT unique_row_id , filename , vendorid , passenger_count , trip_distance , ratecodeid , store_and_fwd_flag , pulocationid , dolocationid , payment_type , fare_amount , extra , mta_tax , tip_amount , tolls_amount , improvement_surcharge , total_amount , congestion_surcharge , cbd_congestion_fee , airport_fee from nyc_taxi.raw_yellow""",
}


In [5]:
datasets = [
    "yellow_rides_hourly_forecast",
    "yellow_rides_hourly_actuals",
    "raw_weather",
    "raw_yellow",
]

In [6]:
entity_urn = get_athena_table_dataset_urn(
    catalog="AwsDataCatalog",
    database="nyc_taxi",
    table="raw_weather",
    region="us-east-1",
)

dataset_name_to_urn = {
    table: get_athena_table_dataset_urn(
        catalog="AwsDataCatalog",
        database="nyc_taxi",
        table=table,
        region="us-east-1",
    )
    for table in datasets
}


assertion_urn = make_assertion_urn(
    dataset_urn=entity_urn,
    assertion_name="test_assertion",
)

In [7]:
dfs = {}
for table in datasets:
    print(table)
    dfs[table] = query_pandas_from_athena(
        sql_query=querys[table],
        glue_database=glue_database,
        datalake_s3_bucket=datalake_s3_bucket,
    )


yellow_rides_hourly_forecast
yellow_rides_hourly_actuals
raw_weather
raw_yellow


In [None]:
def datahub_upsert_assertion(
    entity_urn,
    assertion_urn,
    status,
    checks_description,
    properties,
):
    graph.upsert_custom_assertion(
        urn=assertion_urn,
        entity_urn=entity_urn,
        type="DQV",  # This categorizes your assertion in DataHub
        description=checks_description,
        # platform_urn="urn:li:dataPlatform:great-expectations", # OR you can provide 'platformName="My Custom Platform"'
        platform_name="metaflow",
        # external_url="https://my-monitoring-tool.com/result-for-this-assertion",  # Optional: link to monitoring tool
    )


def datahub_update_assertion_result(
    entity_urn,
    assertion_urn,
    status,
    checks_description,
    properties,
):
    graph.report_assertion_result(
        urn=assertion_urn,
        timestamp_millis=int(time.time() * 1000),
        type=status,
        properties=properties,
    )

In [18]:
def log_dqv_report_datahub(dqv_results, dataset_name_to_urn):
    for status in ["passed", "failed"]:
        for result in dqv_results[status]:
            for column, metrics in result["checks"].items():
                for metric, conditions in metrics.items():
                    metric_desc = metric_description_map.get(metric, metric)
                    for condition_tuple in conditions:
                        condition = condition_tuple["condition"]
                        value = condition_tuple["value"]
                        actual = condition_tuple["calculated_value"]
                        cond_desc = condition_description_map.get(condition, condition)
                        properties = [
                            {"key": "column", "value": column},
                            {"key": "metric", "value": metric},
                            {"key": "condition", "value": condition},
                            {"key": "expected", "value": value},
                            {"key": "actual", "value": float(actual)},
                        ]
                        datahub_upsert_assertion(
                            entity_urn=dataset_name_to_urn[result["dataset_name"]],
                            assertion_urn=make_assertion_urn(
                                dataset_urn=dataset_name_to_urn[result["dataset_name"]],
                                assertion_name=f"{column}_{metric}_{condition}_{value}",
                            ),
                            status="SUCCESS" if status == "passed" else "FAILURE",
                            checks_description=f"Column: {column} - {metric_desc} value {cond_desc} {value}",
                            properties=properties,
                        )
    sleep(8)
    for status in ["passed", "failed"]:
        for result in dqv_results[status]:
            for column, metrics in result["checks"].items():
                for metric, conditions in metrics.items():
                    metric_desc = metric_description_map.get(metric, metric)
                    for condition_tuple in conditions:
                        condition = condition_tuple["condition"]
                        value = condition_tuple["value"]
                        actual = condition_tuple["calculated_value"]
                        cond_desc = condition_description_map.get(condition, condition)
                        properties = [
                            {"key": "column", "value": column},
                            {"key": "metric", "value": metric},
                            {"key": "condition", "value": condition},
                            {"key": "expected", "value": value},
                            {"key": "actual", "value": float(actual)},
                        ]
                        datahub_update_assertion_result(
                            entity_urn=dataset_name_to_urn[result["dataset_name"]],
                            assertion_urn=make_assertion_urn(
                                dataset_urn=dataset_name_to_urn[result["dataset_name"]],
                                assertion_name=f"{column}_{metric}_{condition}_{value}",
                            ),
                            status="SUCCESS" if status == "passed" else "FAILURE",
                            checks_description=f"Column: {column} - {metric_desc} value {cond_desc} {value}",
                            properties=properties,
                        )

In [26]:
dqv_results = dqv_check_yaml("assertion_checks.yaml", datasets=dfs)
dqv_results

{'passed': [{'dataset_name': 'yellow_rides_hourly_forecast',
   'dataset_owner': {'data_team@company.com': 'U123456789'},
   'dataset_type': 'pandas',
   'checks': {'year': {'missing_percent': [{'condition': 'eq',
       'value': 0,
       'criticality': 'fail',
       'calculated_value': np.float64(0.0)}],
     'min': [{'condition': 'gte',
       'value': 2020,
       'criticality': 'fail',
       'calculated_value': np.int64(2025)}],
     'max': [{'condition': 'lte',
       'value': 2030,
       'criticality': 'fail',
       'calculated_value': np.int64(2025)}]},
    'month': {'missing_percent': [{'condition': 'eq',
       'value': 0,
       'criticality': 'fail',
       'calculated_value': np.float64(0.0)}],
     'min': [{'condition': 'gte',
       'value': 1,
       'criticality': 'fail',
       'calculated_value': np.int64(6)}],
     'max': [{'condition': 'lte',
       'value': 12,
       'criticality': 'fail',
       'calculated_value': np.int64(6)}]},
    'day': {'missing_percen

In [20]:
for name, df in dfs.items():
    print(name)
    display(df.describe().round(5))

yellow_rides_hourly_forecast


Unnamed: 0,year,month,day,hour,pulocationid,forecasted_total_rides
count,6240.0,6240.0,6240.0,6240.0,6240.0,6240.0
mean,2025.0,6.0,1.0,11.5,133.16923,18.25465
std,0.0,0.0,0.0,6.92274,77.05168,44.1934
min,2025.0,6.0,1.0,0.0,1.0,0.0
25%,2025.0,6.0,1.0,5.75,65.75,0.0
50%,2025.0,6.0,1.0,11.5,134.5,3.0
75%,2025.0,6.0,1.0,17.25,200.25,12.0
max,2025.0,6.0,1.0,23.0,265.0,501.0


yellow_rides_hourly_actuals


Unnamed: 0,year,month,day,hour,pulocationid,total_rides
count,240879.0,240879.0,240879.0,240879.0,240879.0,240879.0
mean,2025.0,4.53853,16.09038,11.91919,138.06191,35.0111
std,0.0,0.49854,8.67531,6.69065,76.27138,76.06195
min,2025.0,4.0,1.0,0.0,1.0,1.0
25%,2025.0,4.0,9.0,7.0,72.0,2.0
50%,2025.0,5.0,16.0,12.0,140.0,4.0
75%,2025.0,5.0,24.0,18.0,209.0,22.0
max,2025.0,6.0,31.0,23.0,265.0,966.0


raw_weather


Unnamed: 0,region_code,year,month,day_01,day_02,day_03,day_04,day_05,day_06,day_07,...,day_22,day_23,day_24,day_25,day_26,day_27,day_28,day_29,day_30,day_31
count,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,...,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,37284.0,24856.0,24856.0,12428.0
mean,24879.27036,2025.0,3.0,7.01274,5.36881,6.87188,7.73093,8.2236,7.1648,5.84627,...,5.10806,6.9507,8.6073,9.60103,9.87041,9.38722,9.76666,12.84167,12.87304,11.12627
std,13708.05752,0.0,0.81651,8.65796,9.16588,10.49438,10.81167,12.06425,11.72068,9.69948,...,10.09983,9.78653,9.32848,9.1186,9.26281,9.12876,9.16151,10.17275,10.17342,9.48659
min,1001.0,2025.0,2.0,-24.73,-28.61,-27.24,-28.46,-27.23,-30.66,-28.6,...,-24.91,-18.42,-17.51,-17.35,-15.8,-15.88,-16.78,-9.03,-10.28,-11.56
25%,13045.0,2025.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.68,0.33,0.19,0.3,2.15,2.69,2.2
50%,23211.0,2025.0,3.0,5.37,2.74,4.07,5.67,6.43,4.64,3.53,...,2.015,4.31,6.815,8.29,8.775,7.98,8.73,13.13,13.72,10.685
75%,39007.0,2025.0,4.0,12.97,11.16,13.65,14.8,15.4,13.0,11.15,...,12.03,14.0,15.67,16.64,17.01,16.24,16.91,21.1825,20.8,19.0525
max,48045.0,2025.0,4.0,41.87,65.7,85.99,106.16,136.83,147.66,137.77,...,38.69,64.14,67.19,92.55,36.53,73.7,102.12,93.26,92.38,61.14


raw_yellow


Unnamed: 0,vendorid,passenger_count,trip_distance,ratecodeid,pulocationid,dolocationid,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,cbd_congestion_fee,airport_fee
count,8562398.0,6620492.0,8562398.0,6620492.0,8562398.0,8562398.0,8562398.0,8562398.0,8562398.0,8562398.0,8562398.0,8562398.0,8562398.0,8562398.0,6620492.0,8562398.0,6620492.0
mean,1.85981,1.30163,7.34798,2.45004,161.90209,161.65919,0.95476,18.17578,1.21855,0.47693,2.91449,0.50062,0.95475,26.74798,2.20511,0.52996,0.1443
std,0.68048,0.73944,657.5107,11.43111,66.23322,70.32285,0.74956,19.29194,1.8605,0.14042,3.98424,2.12195,0.27985,23.682,0.93749,0.36081,0.52587
min,1.0,0.0,0.0,1.0,1.0,1.0,0.0,-1777.5,-17.39,-0.5,-90.44,-148.17,-1.0,-1793.37,-2.5,-0.75,-1.75
25%,2.0,1.0,1.05,1.0,116.0,112.0,1.0,8.6,0.0,0.5,0.0,0.0,1.0,15.73,2.5,0.0,0.0
50%,2.0,1.0,1.84,1.0,161.0,162.0,1.0,13.5,0.0,0.5,2.16,0.0,1.0,21.23,2.5,0.75,0.0
75%,2.0,1.0,3.61,1.0,233.0,234.0,1.0,22.45,2.5,0.5,4.0,0.0,1.0,30.35,2.5,0.75,0.0
max,7.0,9.0,386088.4,99.0,265.0,265.0,4.0,1777.5,133.6,22.14,525.0,148.17,1.0,1793.37,2.5,1.25,6.75


In [27]:
log_dqv_report_datahub(dqv_results, dataset_name_to_urn)