In [7]:
import boto3
import json
import pyspark
import copy
from pyspark.sql import functions as f 
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr
spark = SparkSession.builder.appName('de-copilot').getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/11/21 17:58:39 WARN Utils: Your hostname, LAPTOP-E66VD905, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/11/21 17:58:39 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/21 17:58:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
s3_client = boto3.client('s3',region_name = 'us-east-1')

In [3]:
contracts = s3_client.get_object(Bucket='de-copilot-s3', Key='contracts/employees_test.json')
contracts = contracts['Body'].read().decode('utf-8')
contracts = json.loads(contracts)

In [4]:
all_rules = contracts.get('data_quality',{}).get('rules',[])
final_rules = []

for cur_rule in all_rules:
    cur_col = cur_rule.get('column','')
    cur_rtype = cur_rule.get('rule_type','')
    cur_spark_exp = cur_rule.get('spark_exp')
    if cur_spark_exp:
        if cur_col.upper() =='__TABLE__' or (cur_rtype in ('pk','fk')) or (cur_spark_exp.lower() == 'true'):
            continue 
        final_rules.append(cur_rule)

In [5]:
final_rules

[{'column': 'emp_id',
  'rule_type': 'not_null',
  'condition': 'must not be null',
  'severity': 'ERROR',
  'action': 'FAIL_JOB',
  'description': 'Employee ID is a required field and serves as the primary key.',
  'spark_exp': 'emp_id IS NOT NULL'},
 {'column': 'emp_id',
  'rule_type': 'min',
  'condition': '0',
  'severity': 'ERROR',
  'action': 'DROP_ROW',
  'description': 'Employee ID must be a positive integer.',
  'spark_exp': 'emp_id > 0'},
 {'column': 'name',
  'rule_type': 'not_empty',
  'condition': 'if present, must not be empty or whitespace',
  'action': 'WARN',
  'description': 'Employee name, if provided, must contain non-whitespace characters.',
  'spark_exp': 'name IS NULL OR length(trim(name)) > 0'},
 {'column': 'salary',
  'rule_type': 'min',
  'condition': '0',
  'action': 'WARN',
  'description': 'Salary, if provided, must be a non-negative value.',
  'spark_exp': 'salary IS NULL OR salary >= 0'},
 {'column': 'department',
  'rule_type': 'not_empty',
  'condition'

In [40]:
errors = []
severity_checks = []
for rule in final_rules:
    spark_exp = rule['spark_exp']
    error_msg = rule['description']
    severity = rule['severity']

    cond = f.when(expr(f'NOT ({spark_exp})'), f.lit(error_msg)).otherwise(f.lit(None))
    errors.append(cond)

    sev_cond = f.when(expr(f'NOT ({spark_exp})'), f.lit(severity)).otherwise(f.lit(None))
    severity_checks.append(sev_cond)

In [41]:
errors

[Column<'CASE WHEN NOT (emp_id IS NOT NULL) THEN 'Employee ID is a required field and serves as the primary key.' ELSE NULL END'>,
 Column<'CASE WHEN NOT (emp_id > 0) THEN 'Employee ID must be a positive integer.' ELSE NULL END'>,
 Column<'CASE WHEN NOT (name IS NULL OR length(trim(name)) > 0) THEN 'Employee name, if provided, must contain non-whitespace characters.' ELSE NULL END'>,
 Column<'CASE WHEN NOT (salary IS NULL OR salary >= 0) THEN 'Salary, if provided, must be a non-negative value.' ELSE NULL END'>,
 Column<'CASE WHEN NOT (department IS NULL OR length(trim(department)) > 0) THEN 'Department name, if provided, must contain non-whitespace characters.' ELSE NULL END'>,
 Column<'CASE WHEN NOT (joining_date IS NULL OR joining_date <= current_date()) THEN 'The joining date cannot be in the future.' ELSE NULL END'>]

In [42]:
severity_checks

[Column<'CASE WHEN NOT (emp_id IS NOT NULL) THEN 'ERROR' ELSE NULL END'>,
 Column<'CASE WHEN NOT (emp_id > 0) THEN 'ERROR' ELSE NULL END'>,
 Column<'CASE WHEN NOT (joining_date IS NULL OR joining_date <= current_date()) THEN 'ERROR' ELSE NULL END'>]

# Test

In [48]:
data = [
    (101, "Alice", 50000.0, "IT", "2023-01-01"),
    (102, "Bob",  -100.0,   "HR", "2023-01-01"), 
    (103, None,    60000.0, "Sales", "2023-01-01")
]
columns = ["emp_id", "name", "salary", "department", "joining_date"]

In [49]:
df = spark.createDataFrame(data = data,schema = columns)

In [50]:
df = df.withColumn('all_errors', f.array(*(errors))).withColumn('severity_check',f.array(*(severity_checks)))\
    .withColumn('Reason',expr("filter(all_errors, x -> x is Not Null)"))\
    .withColumn('dq_severity',expr("filter(severity_check, x -> x is Not Null)")).drop('all_errors','severity_check')

In [51]:
df.show(truncate=False)

+------+-----+-------+----------+------------+----------------------------------------------------+-----------+
|emp_id|name |salary |department|joining_date|Reason                                              |dq_severity|
+------+-----+-------+----------+------------+----------------------------------------------------+-----------+
|101   |Alice|50000.0|IT        |2023-01-01  |[]                                                  |[]         |
|103   |NULL |60000.0|Sales     |2023-01-01  |[]                                                  |[]         |
+------+-----+-------+----------+------------+----------------------------------------------------+-----------+



In [52]:
df_valid = df.filter(f.size('Reason')==0).drop('Reason')
df_invalid = df.filter(f.size('Reason')>0)

In [53]:
df_invalid.show(truncate=False)

+------+----+------+----------+------------+----------------------------------------------------+-----------+
|emp_id|name|salary|department|joining_date|Reason                                              |dq_severity|
+------+----+------+----------+------------+----------------------------------------------------+-----------+
+------+----+------+----------+------------+----------------------------------------------------+-----------+

