In [42]:
import re
import boto3
import time
import json, copy
from google import genai

In [2]:
database = 'copilot_demo'
table = 'employees_test'
region = 'us-east-1'
USER_DEFINED_PII = ["name", "email", "phone_number", "salary"]
s3_bucket = "s3://de-copilot-s3/athena-results/"

In [25]:
# metadeta stats of the tables in the glue
def glue_metadata(database, table, region):
    try:
        client = boto3.client('glue', region_name=region)
        resp = client.get_table(DatabaseName=database, Name=table)
        t = resp['Table']

        schema = []
        for col in t["StorageDescriptor"]["Columns"]:
            is_pk = False
            if (col['Comment']).lower() in ['pk','primary_key']:
                is_pk = True
            is_fk = False
            if (col['Comment']).lower() in ['fk','foreign_key']:
                is_fk = True

            schema.append({
                "name": col["Name"],
                "type": col["Type"],
                "nullable": False if is_pk else True,
                "partition_key": False,
                "primary_key" : is_pk,
                "foregin_key" : is_fk, 
                "comments" : col.get("Comment", "")
            })
  
        for col in t.get("PartitionKeys", []):
            schema.append({
                "name": col["Name"],
                "type": col["Type"],
                "nullable": True,
                "partition_key": True,
                "primary_key" : False,
                "foreign_key" : False,
                "comments" : col.get("Comment", "")  
            })
 
        return {
            "table_name": t["Name"],
            "database": t["DatabaseName"],
            "schema": schema
        }

    except Exception as e:
        print("error in glue_metadata:", str(e))
        return {}

def build_stats_sql(database, table, schema, pii_list):

    pii_columns = set(c.lower() for c in pii_list)
    selects = []

    row_cnts = f""" 
                SELECT
                'ROW_COUNT' as col_name,
                CAST(COUNT(*) AS VARCHAR) AS min_val,
                NULL as max_val,
                NULL as null_pct,
                NULL as distinct_count
                FROM "{database}"."{table}"
                """.strip()
    selects.append(row_cnts)

    for col in schema:
        col_name = col["name"]
        col_lower = col_name.lower()

        if col_lower in pii_columns:
            continue

        s = f"""
            SELECT
            '{col_name}' AS col_name,
            CAST(MIN("{col_name}") AS VARCHAR) AS min_val,
            CAST(MAX("{col_name}") AS VARCHAR) AS max_val,
            AVG(CASE WHEN "{col_name}" IS NULL THEN 1.0 ELSE 0 END) AS null_pct,
            APPROX_DISTINCT("{col_name}") AS distinct_count
            FROM "{database}"."{table}"
            """.strip()

        selects.append(s)

    if not selects:
        return None

    sql = "\nUNION ALL\n".join(selects)
    return sql


def athena_setup(database,region,query,s3_bucket):
    try:
        athena_client = boto3.client('athena',region_name=region)

        response = athena_client.start_query_execution(QueryString=query, QueryExecutionContext = {'Database' : database},
        ResultConfiguration = {'OutputLocation':s3_bucket})

        id = response['QueryExecutionId'] # generated by athena as a ticket number
        
        while True:
            stats = athena_client.get_query_execution(QueryExecutionId=id)
            status = stats['QueryExecution']['Status']['State']

            if status =='SUCCEEDED':
                break 
            elif status in ['FAILED', 'CANCELLED']:
                reason = stats['QueryExecution']['Status'].get('StateChangeReason', 'Unknown')
                raise Exception(f"Query Failed: {reason}")

            time.sleep(5)

        return athena_client.get_query_results(QueryExecutionId=id)

    except Exception as e:
        print("error in run_athena_query:", str(e))
        return None 

# get the column value stats, runs on athena
def get_athena_data(database, table, region, schema, pii_list, s3_bucket):
    try:
        sql = build_stats_sql(database, table, schema, pii_list)

        if not sql:
            print("no non-PII columns found for stats")
            return {}

        query_out = athena_setup(database, region, sql, s3_bucket)

        if not query_out:
            return {}

        rows = query_out["ResultSet"]["Rows"]
        col_stats = parse_stats_rows(rows)


        return col_stats

    except Exception as e:
        print("error in get_athena_data:", str(e))
        return {}

def parse_stats_rows(rows):
    data_rows = rows[1:]

    stats = {'ROW_COUNT' : 0}

    for r in data_rows:
        vals = [c.get("VarCharValue") for c in r["Data"]]


        col_name = vals[0]
        min_val = vals[1]
        if col_name == 'ROW_COUNT':
            stats['ROW_COUNT'] = min_val if min_val else 0
            continue

        max_val = vals[2]
        null_pct = vals[3]
        distinct = vals[4]

        try:
            null_pct = float(null_pct) if null_pct is not None else None
        except:
            pass

        try:
            distinct = int(distinct) if distinct is not None else None
        except:
            pass

        stats[col_name] = {
            "min": min_val,
            "max": max_val,
            "null_pct": null_pct,
            "distinct_count": distinct
        }

    return stats

    

# remove the pii information in the column stats from athena
def filter_pii(col_stats, pii_list):
    pii_columns = set(c.lower() for c in pii_list)

    cleaned_stats = {}
    for col_name, vals in col_stats.items():
        if col_name.lower() not in pii_columns:
            cleaned_stats[col_name] = vals

    return cleaned_stats



ddl_obj = glue_metadata(database, table, region)
schema = ddl_obj.get("schema", [])

col_stats = get_athena_data(database, table, region, schema, USER_DEFINED_PII, s3_bucket)


col_stats = filter_pii(col_stats, USER_DEFINED_PII)

ddl_obj["column_stats"] = col_stats

print(json.dumps(ddl_obj, indent=2))



{
  "table_name": "employees_test",
  "database": "copilot_demo",
  "schema": [
    {
      "name": "emp_id",
      "type": "int",
      "nullable": false,
      "partition_key": false,
      "primary_key": true,
      "foregin_key": false,
      "comments": "primary_key"
    },
    {
      "name": "name",
      "type": "string",
      "nullable": true,
      "partition_key": false,
      "primary_key": false,
      "foregin_key": false,
      "comments": ""
    },
    {
      "name": "salary",
      "type": "double",
      "nullable": true,
      "partition_key": false,
      "primary_key": false,
      "foregin_key": false,
      "comments": ""
    },
    {
      "name": "department",
      "type": "string",
      "nullable": true,
      "partition_key": false,
      "primary_key": false,
      "foregin_key": false,
      "comments": ""
    },
    {
      "name": "joining_date",
      "type": "date",
      "nullable": true,
      "partition_key": false,
      "primary_key": false,
  

In [26]:
payload = f"""
Pay Attention, Do not hallucinate, only work on what is there in the below and think deep for all the edge cases for the below requirements.
Do NOT assume any properties (file format, update frequency, row counts, S3 location, etc.) that are not explicitly present in the JSON. If unknown, either omit or mark as "unknown".

You are a Senior Data Engineer Copilot specializing in **AWS Glue and Athena**.

You will receive a schema object extracted directly from the **AWS Glue Data Catalog**.
Your goal is to generate a Data Quality Contract and Post-Load Tests suitable for an AWS Data Lake environment and also confulence style documentation for the table.

The JSON object will look like this (shape, not exact values):

{json.dumps(ddl_obj, indent=2)}

Where:
- table_name: name of the table
- database: database / schema name
- schema: list of columns, each like:
  - name: column name
  - type: data type (string, int, double, date, etc.)
  - nullable: true / false
- column_stats: ONLY for NON-PII columns, something like:
  - min, max, null_pct, distinct_count, top_values
- constraints: optional list of primary keys, unique keys, check constraints
- job_summary: includes inputs, filters, and grain if available
- rule_type MUST be exactly one of the allowed values. 
  Do NOT invent new rule_type names. If unsure, choose the closest one.

IMPORTANT PRIVACY RULES:
- PII columns (name, email, phone_number, salary, etc.) appear in "schema"
  but their stats and values are NOT provided.
- For PII columns:
  - You MAY define structural rules: not_null, not_empty, length, regex.
  - You MUST NOT include any concrete example values (no fake SSNs, emails, phones).
  - Only describe patterns, like "must be 9 digits", "must match email format".
- For non-PII columns:
  - You MAY use column_stats to propose ranges and allowed_values.
  - Still avoid writing specific sample values in descriptions; talk about rules.

----------------------------------------
THINKING / COVERAGE REQUIREMENTS
----------------------------------------

You must think column-by-column and constraint-by-constraint. 
Do not skip any column.
For coverage:
- Every column in "schema" (except purely technical partition columns) MUST appear in at least one rule in data_quality.rules.
- Do NOT skip columns just because nullable = true. If a column is nullable, you can still enforce rules like "if present, must not be empty" or "if present, must match pattern".

For every column in "schema":

1. COMPLETENESS
   IMPORTANT:
    - If nullable = true, you MUST NOT create a not_null rule.
    - Nullable columns must allow NULL in all spark_exp expressions.
    - Never generate contradictory rules for any column. If nullable = true:
        * Do NOT generate not_null
        * Do NOT generate conditions that fail for NULL

   - If nullable = false → always generate a not_null rule.
   - For ALL string-like columns (string, varchar, char), regardless of nullable:
       * Always generate a not_empty-style rule:
         - If nullable = true: use logic "value IS NULL OR trimmed length > 0".
         - If nullable = false: use logic "value IS NOT NULL AND trimmed length > 0".

2. VALIDITY
   Use the combination of:
   - column name
   - data type 
   - column_stats (only non-PII)
   - constraints (CHECK, PK, UNIQUE)
   to infer validity rules such as:
     * numeric columns >= 0 unless obviously not applicable
     * string columns with stable lengths → infer min_length / max_length or regex
     * year/date columns must not be in the future
     * codes (country_code, dep_code) must be in allowed_values if low-cardinality

3. RANGE RULES (NON-PII ONLY)
   Use column_stats[min_val, max_val, distinct_count, null_pct].
   Create soft WARNING rules with a 20–25% buffer around min/max or p95 if present.

4. ALLOWED VALUES (NON-PII ONLY)
   If distinct_count is small (< 50) AND stable → generate allowed_values.

5. PII COLUMNS
   - PII columns appear in schema but have NO column_stats.
   - For these columns you MUST generate:
       * not_null (if nullable = false)
       * not_empty for strings
       * regex or fixed-length patterns inferred ONLY from schema + column name
   - NEVER include example email or SSN values. Only describe patterns.

6. CROSS-COLUMN LOGIC (IF OBVIOUS)
   If year/date columns exist → ensure year <= current year.
   If ID + email exist → email should not be null if ID exists.
   If joining_date and resign_date exist → resign_date >= joining_date.

7. TABLE-LEVEL RULES
   - If constraints include primary key → include uniqueness rule.
   - Add table-level rule: row_count > 0.

8. Data Quality (pre-load PySpark)
   For each rule, you MUST output "spark_exp" using **Spark SQL syntax only**, not PySpark API.
    spark_exp MUST be a SQL expression that can be passed directly into:

    df.filter(expr(spark_exp))
    Examples of valid spark_exp:
      "salary >= 0"
      "salary IS NULL OR salary >= 0"
      "name IS NULL OR length(trim(name)) > 0"
      "joining_date <= current_date()"

    Examples of INVALID spark_exp (do NOT generate these):
      col('salary') >= 0
      F.col("name").isNull()
      dataframe.count() > 0
      salary.notNull()

9. TEST COVERAGE (Post-load SQL)
   You must generate SQL tests for:
       * uniqueness of PK/grain
       * null checks on required columns
       * each CHECK constraint
       * future-date violations
       * allowed_values validation (for low-cardinality columns)
       * numeric range violations

After generating rules and tests, REVIEW:
- Did you include ALL columns?
- Did you cover ALL non-nullable columns?
- Did you enforce ALL constraints?
- Did you create BOTH rules AND tests?


----------------------------------------
OUTPUT FORMAT (MUST BE VALID JSON)
----------------------------------------
Important: In the final JSON, the set of column names used in data_quality.rules (excluding "__TABLE__") MUST match the set of column names in "schema" (case-insensitive). Do not omit any columns.
           Coverage requirement does NOT override the schema.
           If a column is nullable, you may generate "if present, must..." rules
           (e.g., not_empty, min/max with NULL allowed), 
           but you MUST NOT force mandatory constraints such as not_null.

Return ONLY valid JSON in this exact structure (no extra comments):

{{
  "data_quality": {{
    "rules": [
      {{
        "column": "col_name_or__TABLE__for_table_level",
        "rule_type": "not_null | not_empty | min | max | allowed_values | regex | pk | fk | check_constraint | custom_sql",
        "condition": "value / list / SQL expression / description string",
        "severity": "ERROR | WARNING",
        "action": "FAIL_JOB | DROP_ROW | WARN",
        "description": "Short reasoning for the rule (no concrete example values).",
        "spark_exp": "A PySpark boolean expression that returns TRUE for valid rows and can be passed directly to pyspark.sql.functions.expr(). 
                      It MUST NOT reference any DataFrame variable and MUST NOT call actions like count(), groupBy(), collect(), etc. 
                      Examples: \"col('salary') >= 0\", \"(col('name').isNull()) OR (length(trim(col('name'))) > 0)\"."

      }}
    ]
  }},
  "tests": [
    {{
      "name": "test_name",
      "sql": "SELECT ...",
      "description": "What this test validates."
    }}
  ],
  "docs_markdown": "# Table Documentation\\n..."
}}

Do NOT include anything outside this JSON object.
"""


In [27]:
client = genai.Client()
response = client.models.generate_content(model='gemini-2.5-pro', contents=payload)
print(response.text)

```json
{
  "data_quality": {
    "rules": [
      {
        "column": "emp_id",
        "rule_type": "not_null",
        "condition": "must not be null",
        "severity": "ERROR",
        "action": "FAIL_JOB",
        "description": "Employee ID is a required field and serves as the primary key.",
        "spark_exp": "emp_id IS NOT NULL"
      },
      {
        "column": "emp_id",
        "rule_type": "pk",
        "condition": "primary_key",
        "severity": "ERROR",
        "action": "FAIL_JOB",
        "description": "Employee ID must be unique across all records.",
        "spark_exp": "true"
      },
      {
        "column": "emp_id",
        "rule_type": "min",
        "condition": "0",
        "severity": "ERROR",
        "action": "DROP_ROW",
        "description": "Employee ID must be a positive integer.",
        "spark_exp": "emp_id > 0"
      },
      {
        "column": "name",
        "rule_type": "not_empty",
        "condition": "if present, must not be empty 

# Push to S3

In [43]:
contract_json_path = "s3://de-copilot-s3/contracts/"
s3_client = boto3.client('s3',region_name=region)

In [45]:
clean_json_str = re.sub(r"```json\n|\n```", "", response.text).strip()

In [49]:
cleaned_json = json.loads(clean_json_str)

In [53]:
s3_client.put_object(Bucket='de-copilot-s3',Key='contracts/employees_test.json',Body=json.dumps(cleaned_json,indent=2),ContentType='application/json')

{'ResponseMetadata': {'RequestId': '0FVC1RT15N5N1XCK',
  'HostId': 'K0z10dwJhJ4cRdAzTmfpP5r2rfu+dcKHd3leppuDN9vc42FXWify2x3rkkaescimpB5dCXRHL2WTm9pHmmOozw==',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amz-id-2': 'K0z10dwJhJ4cRdAzTmfpP5r2rfu+dcKHd3leppuDN9vc42FXWify2x3rkkaescimpB5dCXRHL2WTm9pHmmOozw==',
   'x-amz-request-id': '0FVC1RT15N5N1XCK',
   'date': 'Sat, 22 Nov 2025 00:17:15 GMT',
   'x-amz-server-side-encryption': 'AES256',
   'etag': '"7486daece0e956f31450cf2f6c9ddc63"',
   'x-amz-checksum-crc32': 'W5rfaQ==',
   'x-amz-checksum-type': 'FULL_OBJECT',
   'content-length': '0',
   'server': 'AmazonS3'},
  'RetryAttempts': 0},
 'ETag': '"7486daece0e956f31450cf2f6c9ddc63"',
 'ChecksumCRC32': 'W5rfaQ==',
 'ChecksumType': 'FULL_OBJECT',
 'ServerSideEncryption': 'AES256'}