In [1]:
import time
from functools import lru_cache
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional
from urllib.parse import urlparse

import awswrangler as wr
import boto3
import datahub.emitter.mce_builder as builder
import pandas as pd
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from ds_dqv_tool import dqv_check, dqv_check_yaml
from ds_dqv_tool.recipes import condition_description_map, metric_description_map
from jinja2 import DebugUndefined, Template

graph = DataHubGraph(config=DatahubClientConfig(server="http://localhost:8091"))

# [methods for methods in graph.__dir__() if "assertion" in methods]


@lru_cache
def get_athena_table_dataset_urn(catalog: str, database: str, table: str, region: str) -> str:
    """e.g. urn:li:dataset:(urn:li:dataPlatform:hive,/iceberg/yellow_rides_hourly_actuals,PROD)"""
    session = boto3.Session(profile_name="sandbox")
    athena_client = session.client("athena", region_name=region)
    table_metadata = athena_client.get_table_metadata(CatalogName=catalog, DatabaseName=database, TableName=table)

    # Dataset has also its' physical location which we can add in symlink facet.
    s3_location = table_metadata["TableMetadata"]["Parameters"]["location"]
    parsed_path = urlparse(s3_location)

    return builder.make_dataset_urn(
        platform="hive",
        name=parsed_path.path,
    )

In [2]:
from datahub.emitter.serialization_helper import pre_json_transform


def make_assertion_urn(dataset_urn: str, assertion_name: str) -> str:
    return builder.make_assertion_urn(
        builder.datahub_guid(
            pre_json_transform(
                # these key-val pairs are essentially hashed; we want to choose pairs
                # that make the assertions unique (example: https://github.com/datahub-project/datahub/blob/d2d9d36987f20a9f7d6c973073d1404edf33e667/metadata-ingestion-modules/gx-plugin/src/datahub_gx_plugin/action.py#L277-L289)
                {
                    "platform": "pattern-ds-dqv",
                    # bad name since assertions and datasets have a many-to-many relationship
                    "dataset_urn": dataset_urn,
                    "assertion_name": assertion_name,
                }
            )
        )
    )

In [3]:
def substitute_map_into_string(string: str, values: dict[str, Any]) -> str:
    """
    Format a string using a dictionary with Jinja2 templating.

    :param string: The template string containing placeholders
    :param values: A dictionary of values to substitute into the template
    """
    template = Template(string, undefined=DebugUndefined)
    return template.render(values)


def query_pandas_from_athena(
    sql_query: str,
    glue_database: str,
    datalake_s3_bucket: str,
    # job_name: str,
    s3_output_location: Optional[str] = None,
    ctx: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
    # Validate job_name format
    # if not _is_valid_snake_case_identifier(job_name):
    #     raise ValueError(f"job_name must be a valid lower snake case identifier. Got: {job_name}")

    # Apply Jinja2 templating if context is provided
    if ctx is not None:
        sql_query = substitute_map_into_string(sql_query, ctx)

    session = boto3.Session(profile_name="sandbox", region_name="us-east-1")
    # region_name = session.region_name or "us-east-1"
    if s3_output_location is None:
        s3_output_location = f"s3://{datalake_s3_bucket}/athena-results"  # /{job_name}"

    # Emit OpenLineage START event for SQL query
    # query_run_id = str(generate_new_uuid())
    # emit_openlineage_start_event(job_name, sql_query, query_run_id)

    # Execute query using AWS Data Wrangler
    df = wr.athena.read_sql_query(
        sql=sql_query,
        database=glue_database,
        s3_output=s3_output_location,
        boto3_session=session,
    )

    return df

In [4]:
THIS_DIR = Path("metaflow")
SQL_DIR = THIS_DIR / "sql"
sql_path = SQL_DIR / "prepare_training_data.sql"
glue_database = "nyc_taxi"
datalake_s3_bucket = "airflow-metaflow-6721"
as_of_datetime = "2025-06-01 00:00:00.000"
lookback_days = 30


querys = {
    "raw_weather": """SELECT unique_row_id, filename, region_type, region_code, region_name, year, month, meteorological_element, day_01, day_02, day_03, day_04, day_05, day_06, day_07, day_08, day_09, day_10, day_11, day_12, day_13, day_14, day_15, day_16, day_17, day_18, day_19, day_20, day_21, day_22, day_23, day_24, day_25, day_26, day_27, day_28, day_29, day_30, day_31 FROM nyc_taxi.raw_weather LIMIT 2000""",
    "yellow_rides_hourly_forecast": """SELECT * from nyc_taxi.yellow_rides_hourly_forecast LIMIT 2000""",
    "yellow_rides_hourly_actuals": """SELECT year , month , day , hour , pulocationid , total_rides from nyc_taxi.yellow_rides_hourly_actuals LIMIT 2000""",
    "raw_yellow": """SELECT unique_row_id , filename , vendorid , passenger_count , trip_distance , ratecodeid , store_and_fwd_flag , pulocationid , dolocationid , payment_type , fare_amount , extra , mta_tax , tip_amount , tolls_amount , improvement_surcharge , total_amount , congestion_surcharge , cbd_congestion_fee , airport_fee from nyc_taxi.raw_yellow LIMIT 2000""",
}


In [5]:
datasets = [
    "yellow_rides_hourly_forecast",
    "yellow_rides_hourly_actuals",
    "raw_weather",
    "raw_yellow",
]

In [6]:
entity_urn = get_athena_table_dataset_urn(
    catalog="AwsDataCatalog",
    database="nyc_taxi",
    table="raw_weather",
    region="us-east-1",
)

dataset_name_to_urn = {
    table: get_athena_table_dataset_urn(
        catalog="AwsDataCatalog",
        database="nyc_taxi",
        table=table,
        region="us-east-1",
    )
    for table in datasets
}


assertion_urn = make_assertion_urn(
    dataset_urn=entity_urn,
    assertion_name="test_assertion",
)

In [7]:
dfs = {}
for table in datasets:
    print(table)
    dfs[table] = query_pandas_from_athena(
        sql_query=querys[table],
        glue_database=glue_database,
        datalake_s3_bucket=datalake_s3_bucket,
    )


yellow_rides_hourly_forecast
DEBUG: Attempting to get workgroup: 'primary'
DEBUG: Current caller identity:
  Account: 847068433460
  UserId: AROA4KOJWRQ2KDSUZWP3A:eric.riddoch@pattern.com
  Arn: arn:aws:sts::847068433460:assumed-role/AWSReservedSSO_AWSAdministratorAccess_32e455a75ec338cc/eric.riddoch@pattern.com
yellow_rides_hourly_actuals
DEBUG: Attempting to get workgroup: 'primary'
DEBUG: Current caller identity:
  Account: 847068433460
  UserId: AROA4KOJWRQ2KDSUZWP3A:eric.riddoch@pattern.com
  Arn: arn:aws:sts::847068433460:assumed-role/AWSReservedSSO_AWSAdministratorAccess_32e455a75ec338cc/eric.riddoch@pattern.com
raw_weather
DEBUG: Attempting to get workgroup: 'primary'
DEBUG: Current caller identity:
  Account: 847068433460
  UserId: AROA4KOJWRQ2KDSUZWP3A:eric.riddoch@pattern.com
  Arn: arn:aws:sts::847068433460:assumed-role/AWSReservedSSO_AWSAdministratorAccess_32e455a75ec338cc/eric.riddoch@pattern.com
raw_yellow
DEBUG: Attempting to get workgroup: 'primary'
DEBUG: Current call

In [8]:
def datahub_upsert_assertion(
    entity_urn,
    assertion_urn,
    status,
    checks_description,
    properties,
):
    graph.upsert_custom_assertion(
        urn=assertion_urn,
        entity_urn=entity_urn,
        type="DQV",  # This categorizes your assertion in DataHub
        description=checks_description,
        # platform_urn="urn:li:dataPlatform:great-expectations", # OR you can provide 'platformName="My Custom Platform"'
        platform_name="metaflow",
        # external_url="https://my-monitoring-tool.com/result-for-this-assertion",  # Optional: link to monitoring tool
    )


def datahub_update_assertion_result(
    entity_urn,
    assertion_urn,
    status,
    checks_description,
    properties,
):
    graph.report_assertion_result(
        urn=assertion_urn,
        timestamp_millis=int(time.time() * 1000),
        type=status,
        properties=properties,
    )

In [None]:
def log_dqv_report_datahub(dqv_results, dataset_name_to_urn):
    for status in ["passed", "failed"]:
        for result in dqv_results[status]:
            for column, metrics in result["checks"].items():
                for metric, conditions in metrics.items():
                    metric_desc = metric_description_map.get(metric, metric)
                    for condition_tuple in conditions:
                        condition = condition_tuple["condition"]
                        value = condition_tuple["value"]
                        actual = condition_tuple["calculated_value"]
                        cond_desc = condition_description_map.get(condition, condition)
                        properties = [
                            {"key": "column", "value": column},
                            {"key": "metric", "value": metric},
                            {"key": "condition", "value": condition},
                            {"key": "expected", "value": value},
                            {"key": "actual", "value": float(actual)},
                        ]
                        datahub_upsert_assertion(
                            entity_urn=dataset_name_to_urn[result["dataset_name"]],
                            assertion_urn=make_assertion_urn(
                                dataset_urn=dataset_name_to_urn[result["dataset_name"]],
                                assertion_name=f"{column}_{metric}_{condition}_{value}",
                            ),
                            status="SUCCESS" if status == "passed" else "FAILURE",
                            checks_description=f"Column: {column} - {metric_desc} value {cond_desc} {value}",
                            properties=properties,
                        )
    sleep(5)
    for status in ["passed", "failed"]:
        for result in dqv_results[status]:
            for column, metrics in result["checks"].items():
                for metric, conditions in metrics.items():
                    metric_desc = metric_description_map.get(metric, metric)
                    for condition_tuple in conditions:
                        condition = condition_tuple["condition"]
                        value = condition_tuple["value"]
                        actual = condition_tuple["calculated_value"]
                        cond_desc = condition_description_map.get(condition, condition)
                        properties = [
                            {"key": "column", "value": column},
                            {"key": "metric", "value": metric},
                            {"key": "condition", "value": condition},
                            {"key": "expected", "value": value},
                            {"key": "actual", "value": float(actual)},
                        ]
                        datahub_update_assertion_result(
                            entity_urn=dataset_name_to_urn[result["dataset_name"]],
                            assertion_urn=make_assertion_urn(
                                dataset_urn=dataset_name_to_urn[result["dataset_name"]],
                                assertion_name=f"{column}_{metric}_{condition}_{value}",
                            ),
                            status="SUCCESS" if status == "passed" else "FAILURE",
                            checks_description=f"Column: {column} - {metric_desc} value {cond_desc} {value}",
                            properties=properties,
                        )

In [10]:
dqv_results = dqv_check_yaml("assertion_checks.yaml", datasets=dfs)
dqv_results

{'passed': [{'dataset_name': 'yellow_rides_hourly_forecast',
   'dataset_owner': {'data_team@company.com': 'U123456789'},
   'dataset_type': 'pandas',
   'checks': {'year': {'missing_percent': [{'condition': 'eq',
       'value': 0,
       'criticality': 'fail',
       'calculated_value': np.float64(0.0)}],
     'min': [{'condition': 'gte',
       'value': 2020,
       'criticality': 'fail',
       'calculated_value': np.int64(2025)}],
     'max': [{'condition': 'lte',
       'value': 2030,
       'criticality': 'fail',
       'calculated_value': np.int64(2025)}]},
    'month': {'missing_percent': [{'condition': 'eq',
       'value': 0,
       'criticality': 'fail',
       'calculated_value': np.float64(0.0)}],
     'min': [{'condition': 'gte',
       'value': 1,
       'criticality': 'fail',
       'calculated_value': np.int64(6)}],
     'max': [{'condition': 'lte',
       'value': 12,
       'criticality': 'fail',
       'calculated_value': np.int64(6)}]},
    'day': {'missing_percen

In [11]:
for name, df in dfs.items():
    print(name)
    display(df.describe().round(5))

yellow_rides_hourly_forecast


Unnamed: 0,year,month,day,hour,pulocationid,forecasted_total_rides
count,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0
mean,2025.0,6.0,1.0,11.7185,131.081,17.913
std,0.0,0.0,0.0,6.98925,76.83548,44.0481
min,2025.0,6.0,1.0,0.0,1.0,0.0
25%,2025.0,6.0,1.0,6.0,63.0,1.0
50%,2025.0,6.0,1.0,12.0,132.0,3.0
75%,2025.0,6.0,1.0,18.0,197.0,12.25
max,2025.0,6.0,1.0,23.0,265.0,501.0


yellow_rides_hourly_actuals


Unnamed: 0,year,month,day,hour,pulocationid,total_rides
count,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0
mean,2025.0,4.5355,15.854,11.885,139.929,35.567
std,0.0,0.49886,8.63649,6.58525,76.13635,76.73315
min,2025.0,4.0,1.0,0.0,1.0,1.0
25%,2025.0,4.0,8.0,7.0,74.0,2.0
50%,2025.0,5.0,16.0,12.0,141.5,4.0
75%,2025.0,5.0,23.0,17.0,210.0,23.0
max,2025.0,5.0,31.0,23.0,265.0,634.0


raw_weather


Unnamed: 0,region_code,year,month,day_01,day_02,day_03,day_04,day_05,day_06,day_07,...,day_22,day_23,day_24,day_25,day_26,day_27,day_28,day_29,day_30,day_31
count,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,...,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,0.0,0.0,0.0
mean,24914.0635,2025.0,2.0,6.25494,4.645,6.46402,7.25362,4.39847,4.85298,5.8801,...,-3.48685,1.17165,5.12,8.39902,9.73764,9.75692,8.89356,,,
std,13658.00576,0.0,0.0,6.49199,7.14557,8.53979,10.56498,11.41114,10.82636,10.50875,...,5.8607,4.8583,4.25444,3.99603,4.93625,5.71505,4.83217,,,
min,1001.0,2025.0,2.0,-15.86,-20.24,-23.02,-25.57,-24.48,-25.21,-21.89,...,-15.78,-9.96,-9.43,-7.33,-4.28,-8.3,-5.12,,,
25%,13080.0,2025.0,2.0,1.975,0.21,1.1775,0.8125,-4.145,-2.86,-1.2,...,-7.44,-2.72,1.845,5.5,5.6175,5.13,5.87,,,
50%,23615.0,2025.0,2.0,6.89,5.37,8.26,10.25,5.2,3.96,4.535,...,-4.17,0.49,4.955,8.045,9.91,9.76,8.84,,,
75%,38067.5,2025.0,2.0,10.96,10.555,12.7425,14.84,13.995,14.44,14.65,...,-0.0275,4.5825,8.0025,11.18,13.8525,14.925,12.4425,,,
max,48045.0,2025.0,2.0,22.58,23.94,24.47,25.73,25.83,25.72,26.27,...,20.01,20.76,21.78,20.44,22.73,23.51,22.14,,,


raw_yellow


Unnamed: 0,vendorid,passenger_count,trip_distance,ratecodeid,pulocationid,dolocationid,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,cbd_congestion_fee,airport_fee
count,2000.0,1686.0,2000.0,1686.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,1686.0,2000.0,1686.0
mean,1.8015,1.19454,2.9805,2.70403,165.0975,166.702,1.006,17.33175,1.52988,0.47775,3.13806,0.46744,0.952,26.30719,2.24792,0.5325,0.11077
std,0.5561,0.59283,4.10651,12.52552,65.40532,68.06811,0.66572,18.77381,2.02159,0.13835,3.67421,2.03134,0.2859,23.17514,0.88437,0.36126,0.48287
min,1.0,0.0,0.0,1.0,4.0,1.0,0.0,-122.7,-6.0,-0.5,0.0,-6.94,-1.0,-135.39,-2.5,-0.75,-1.75
25%,2.0,1.0,0.91,1.0,132.0,125.0,1.0,8.6,0.0,0.5,0.0,0.0,1.0,15.9575,2.5,0.0,0.0
50%,2.0,1.0,1.65,1.0,162.0,163.0,1.0,13.5,1.0,0.5,2.625,0.0,1.0,21.06,2.5,0.75,0.0
75%,2.0,1.0,3.02,1.0,234.0,234.0,1.0,20.5,2.5,0.5,4.29,0.0,1.0,29.4225,2.5,0.75,0.0
max,7.0,6.0,71.32,99.0,264.0,265.0,4.0,275.0,12.5,0.5,34.92,31.06,1.0,334.63,2.5,0.75,6.75


In [12]:
log_dqv_report_datahub(dqv_results, dataset_name_to_urn)

GraphError: Error executing graphql query: [{'message': 'An unknown error occurred.', 'locations': [{'line': 10, 'column': 17}], 'path': ['reportAssertionResult'], 'extensions': {'code': 500, 'type': 'SERVER_ERROR', 'classification': 'DataFetchingException'}}]