# Pre requirements

In [1]:
!pip install pyspark


Collecting pyspark
  Downloading pyspark-3.5.0.tar.gz (316.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.9/316.9 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.0-py2.py3-none-any.whl size=317425345 sha256=7c0296f223dd8deadfd89548cbb98ac29b05dd33c81cc9ec6f36fe0247d70e70
  Stored in directory: /root/.cache/pip/wheels/41/4e/10/c2cf2467f71c678cfc8a6b9ac9241e5e44a01940da8fbb17fc
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.0


In [None]:
# Config.yml
database:
  name: 'my_database'
  connection:
    host: 'localhost'
    port: 5432
    username: 'user'
    password: 'password'
mandatory_columns:
  - age
  - sex
schema:
  - {'field_name': 'id', 'field_type': 'StringType'}
  - {'field_name': 'name', 'field_type': 'StringType'}
  - {'field_name': 'age', 'field_type': 'IntegerType'}

# Validators.py

In [10]:
# your_package/validators.py
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

def validate_negative_values(df: DataFrame, column_name: str) -> None:
    validation_result = df.filter(col(column_name) < 0)

    if validation_result.count() > 0:
        failed_rows = validation_result.select(column_name).collect()
        error_message = {
            'validator': 'NegativeValuesValidator',
            'column': column_name,
            'failed_values': [row[column_name] for row in failed_rows]
        }
        print(f'Validation Error: {error_message}')
    else:
        print(f'Validation Passed: No negative values in {column_name}.')


# Transformers.py

In [11]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

def transform_square_column(df: DataFrame, column_name: str) -> DataFrame:
    return df.withColumn(column_name, col(column_name) ** 2)

## Check date format

In [12]:
# Check if date have correct format
def check_date_format(df, column_to_check, date_format):

  # Convert date column to string
  formatted_date_column = date_format(df[column_to_check], date_format).cast(StringType())

  # Conditional check if format of column is the same as the one specified in date_format
  is_same_format_condition = when(formatted_date_column != df[column_to_check], df[column_to_check])

  # Filter df with correct/incorrect formats
  correct_format_df = df.filter(is_same_format_condition.isNull())
  incorrect_format_df = df.filter(is_same_format_condition.isNotNull())

    # Change return accordingly
  return correct_format_df, incorrect_format_df

In [37]:
correct_format_df, incorrect_format_df = check_date_format(df, 'date', 'yyyy-MM-dd')
correct_format_df.show()
incorrect_format_df.show()

NameError: ignored

## Check no duplicates

In [14]:
# Check if a column doesn't have duplicates
def check_no_duplicates(df, column_to_check):

    # Create a Window partition for all the unique values of the column to check
    window_partition = Window().partitionBy(column_to_check)

    # Add column with frequency per each unique value
    df_with_frequency = df.withColumn('Frequency', count(column_to_check).over(window_partition))

    # Filter rows where the count is greater than 1 (indicating duplicates)
    df_wth_duplicates = df_with_frequency.filter(col('Frequency') > 1)
    df_no_duplicates = df_with_frequency.filter(col('Frequency') == 1)
    # Change return accordingly
    return df_wth_duplicates, df_no_duplicates


In [38]:
df_wth_duplicates, df_no_duplicates = check_no_duplicates(df, 'province')
df_wth_duplicates.show()
df_no_duplicates.show()

NameError: ignored

## Check if null values

In [16]:
def check_null_values(df, column_to_check):
    # Filter df with/without nulls
  df_with_null = df.filter(df[column_to_check].isNull())
  df_without_null = df.filter(df[column_to_check].isNotNull())
  return df_with_null, df_without_null


In [17]:
df_wth_duplicates = df_wth_duplicates.withColumn('continent_name', f.when(f.col('continent_name') == 'North America', None).otherwise(f.col('continent_name')))
df_wth_duplicates.show(1)
df_with_null, df_without_null = check_null_values(df_wth_duplicates, 'continent_name')
df_with_null.show(1)
df_without_null.show(1)

NameError: ignored

## Check column type

In [18]:
# Create spark session
def check_column_types(df, schema_json):
  # Create empty list for mismatched columns
  mismatched_columns = []

  # Check each field's type
  for field in schema_json:
    field_name = field['field_name']
    expected_type = field['field_type']

    # Get the actual type of the column
    actual_type = type(df.schema[field_name].dataType).__name__

    # Check if actual type matches expected type
    if actual_type != expected_type:
      # String with summary of the column
      query = f'column {field_name} expected {expected_type}, but instead is {actual_type}'
      mismatched_columns.append(query)
      # To delete
      print(query)
  if mismatched_columns:
    print('mismatched columns saved in JSON (FIX THIS)')
    # Save output in json file or log.
  else:
    print('All column types match schema.')




In [39]:
# get schema_json from config file
spark = SparkSession.builder.appName('DataValidation').getOrCreate()
data_validator = DataValidator(spark)
config = data_validator.load_config('/content/config.yml')
schema_json = config.get('schema', [])
schema_json

[{'field_name': 'id', 'field_type': 'StringType'},
 {'field_name': 'name', 'field_type': 'StringType'},
 {'field_name': 'age', 'field_type': 'IntegerType'}]

In [40]:
check_column_types(source_df, schema_json)

NameError: ignored

## Compare Schemas

In [41]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType

def compare_schemas(spark, source_df, target_df):
    # Get source/target schema
    source_schema = source_df.schema
    target_schema = target_df.schema

    if source_schema == target_schema:
        print('Schema for both dataframes match.')
    else:
        print('Schema do not match, please check details below')

        # Get dictionary with field_name: field_type
        source_columns = {field.name: field.dataType for field in source_schema.fields}
        target_columns = {field.name: field.dataType for field in target_schema.fields}

        # Get mismatched fields
        mismatched_columns = set(source_columns.keys()) ^ set(target_columns.keys())

        for column in mismatched_columns:
            source_data_type = source_columns.get(column, 'not present in Source')
            target_data_type = target_columns.get(column, 'not present in Target')
            print(f'Column: {column}, Source DataType: {source_data_type}, Target DataType: {target_data_type}')




In [42]:
# Create a Spark session
spark = SparkSession.builder.appName('SchemaComparisonTest').getOrCreate()

# Define schemas for source and target DataFrames
source_schema = StructType([
    StructField('id', IntegerType(), True),
    StructField('name', StringType(), True),
    StructField('age', IntegerType(), True)
])

target_schema = StructType([
    StructField('id', IntegerType(), True),
    StructField('name', StringType(), True),
    StructField('score', FloatType(), True)
])

# Sample data for source and target DataFrames
source_data = [(1, 'Alice', 25), (2, 'Bob', 30), (3, 'Charlie', 22)]
target_data = [(1, 'Alice', 95.1), (2, 'Bob', 80.1), (3, 'Charlie', 75.1)]

# Create source and target DataFrames
source_df = spark.createDataFrame(source_data, schema=source_schema)
target_df = spark.createDataFrame(target_data, schema=target_schema)

# Print the content of source and target DataFrames
print('Source DataFrame:')
source_df.show()

print('Target DataFrame:')
target_df.show()

# Compare Schemas
compare_schemas(spark, source_df, target_df)
# Stop the Spark session
#spark.stop()

Source DataFrame:
+---+-------+---+
| id|   name|age|
+---+-------+---+
|  1|  Alice| 25|
|  2|    Bob| 30|
|  3|Charlie| 22|
+---+-------+---+

Target DataFrame:
+---+-------+-----+
| id|   name|score|
+---+-------+-----+
|  1|  Alice| 95.1|
|  2|    Bob| 80.1|
|  3|Charlie| 75.1|
+---+-------+-----+

Schema do not match, please check details below
Column: age, Source DataType: IntegerType(), Target DataType: not present in Target
Column: score, Source DataType: not present in Source, Target DataType: FloatType()


## Check all required columns are in dataframe




In [23]:
# get columns from config file
spark = SparkSession.builder.appName('DataValidation').getOrCreate()
data_validator = DataValidator(spark)
config = data_validator.load_config('/content/config.yml')
mandatory_columns = config.get('mandatory_columns', [])

def check_mandatory_columns(df, mandatory_columns):
  if mandatory_columns == None:
    print('n')
  else:
    # Get all field_name from df as list
    fields = [field.name for field in df.schema.fields]

    # Check all mandatory columns are present in df
    missing_columns = set(mandatory_columns) - set(fields)

  if missing_columns:
      print(f'Missing columns: {', '.join(missing_columns)}')
  else:
      print('All mandatory columns are present in the dataframe.')






SyntaxError: ignored

In [24]:
check_mandatory_columns(target_df, mandatory_columns)

NameError: ignored

## Check column names are the same in Source & Target

In [25]:
# get columns from config file
spark = SparkSession.builder.appName('DataValidation').getOrCreate()
data_validator = DataValidator(spark)
config = data_validator.load_config('/content/config.yml')
mandatory_columns = config.get('mandatory_columns', [])

def check_sourceFields_equal_targetFields(source_df, target_df):
    # Get Source/Target fields
    source_field_names = [field.name for field in source_df.schema.fields]
    target_field_names = [field.name for field in target_df.schema.fields]

    # Check if different fields
    wrong_fields = set(source_field_names) - set(target_field_names)

    if wrong_fields:
        print(f'''Columns names are not the same.
Fields from Source not present in target:
{wrong_fields}
              ''')
    else:
        print('Columns names are the same in source and target dataframes.')






NameError: ignored

In [26]:
check_sourceFields_equal_targetFields(source_df, target_df)

NameError: ignored

## Check condition by sql query

does the constrains need to be in a table? can they be pre-defined in a config file?

In [27]:
def execute_sql_query_in_df(df,temp_table_name, query):
  # Create temp delta table
  df.createOrReplaceTempView(temp_table_name)

  # Execute sql query in delta table
  output = spark.sql(query)

  # Return whatever the output is
  return output


In [28]:
execute_sql_query_in_df(target_df,'temp_table_name', 'select count(*) from temp_table_name').show()

NameError: ignored

## Compare output of 2 sql queries in dfs

In [29]:
def compare_source_and_target_output_by_sql(source_df, target_df, source_query, target_query):
  # Run query for each dataframe
  source_output = execute_sql_query_in_df(source_df, 'source_df', source_query)
  print(source_output)
  target_output = execute_sql_query_in_df(target_df,'target_df', target_query)
  print(target_output)
  # Compare if source & Target output are equal.
  if source_output.exceptAll(target_output).count() == 0 and target_output.exceptAll(source_output).count() == 0:
    print('both output are the same.')
  else:
    print('Source & Target outputs are different.')

In [30]:
source_query = f'select * from source_df'
target_query = f'select * from source_df'
compare_source_and_target_output_by_sql(source_df, target_df, source_query, target_query)

NameError: ignored

# main.py

In [31]:
from pyspark.sql import SparkSession, DataFrame
import yaml
from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, StringType, IntegerType


In [32]:
#validators import validate_negative_values
#from your_package.transformations import transform_square_column


In [33]:


class DataValidator:
    # Define init function
    def __init__(self, spark: SparkSession):
        self.spark = spark
        self.data_df = None

    # Get config file from local
    def load_config(self, file_path):
        with open(file_path, 'r') as config_file:
          config = yaml.safe_load(config_file)
        return config

    # Apply transformations to df
    def apply_transformations(self, transformations_config):
        transformed_df = self.data_df

        for transformation in transformations_config:
            transformation_name = transformation['name']
            transformation_params = transformation.get('params', {})

            if transformation_name == 'square_column':
                column_name = transformation_params.get('column')
                transformed_df = transform_square_column(transformed_df, column_name)

        self.data_df = transformed_df

In [34]:
spark = SparkSession.builder.appName('DataValidation').getOrCreate()


In [35]:
#spark_session.stop()

In [36]:
data_validator = DataValidator(spark)


In [43]:
config = data_validator.load_config('/content/config.yml')
config

{'app_settings': {'title': 'My App',
  'version': '1.0',
  'options2': ['feature1', 'feature2'],
  'options': {'feature1': True, 'feature2': False}},
 'database': {'name': 'my_database',
  'connection': {'host': 'localhost',
   'port': 5432,
   'username': 'user',
   'password': 'password'}}}

In [44]:
# Get dictionary from config
config['app_settings'].get('options', {})

{'feature1': True, 'feature2': False}

In [31]:
# Get list from config
config['app_settings'].get('options2', [])

'feature1'

In [15]:
# Define schemas for source and target DataFrames
source_schema = StructType([
    StructField('id', IntegerType(), True),
    StructField('name', StringType(), True),
    StructField('age', IntegerType(), True)
])

source_data = [(1, 'Alice', 25), (2, 'Bob', 30), (3, 'Charlie', 22)]
source_df = spark.createDataFrame(source_data, schema=source_schema)
source_df.show()

+---+-------+---+
| id|   name|age|
+---+-------+---+
|  1|  Alice| 25|
|  2|    Bob| 30|
|  3|Charlie| 22|
+---+-------+---+



In [None]:
data_validator.