# Data Quality Mini-Workshop
---


- Simple mini workshop working with Scala directly on a Jupyter noteeok using Spark
- Simple tests and code to check for data quality

**To use Scala Spark directly, we use `spylon-kernel` which can be abstracted onto your python environment:** 
```
pip install spylon-kernel
python -m spylon_kernel install --user
```
Then launch Jupyter and select the spylon-kernel

In [1]:
spark

Intitializing Scala interpreter ...

Spark Web UI available at http://192.168.35.161:4041
SparkContext available as 'sc' (version = 3.4.1, master = local[*], app id = local-1697816975153)
SparkSession available as 'spark'


res0: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@c4445ff


In [16]:
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._


In [85]:
// Define schema 
val schema = StructType(Array(
    StructField("id", IntegerType, true),
    StructField("name", StringType, true),
    StructField("age", IntegerType, true),
    StructField("salary", IntegerType, true),
    StructField("bonus", IntegerType, true),
    StructField("compTotal", IntegerType, true),
    StructField("registered", TimestampType, true)
))

schema: org.apache.spark.sql.types.StructType = StructType(StructField(id,IntegerType,true),StructField(name,StringType,true),StructField(age,IntegerType,true),StructField(salary,IntegerType,true),StructField(bonus,IntegerType,true),StructField(compTotal,IntegerType,true),StructField(registered,TimestampType,true))


In [88]:
// RDD of rows
val timestamp = new java.sql.Timestamp(System.currentTimeMillis())
val rdd = spark.sparkContext.parallelize(Array(
    Row(1, "John", 21, 1000, 500, 1500, timestamp),
    Row(2, "Smith", 23, 2000, 200, 2200, timestamp),
    Row(3, "Mary", 22, 3000, 1000, 4000, timestamp),
    Row(4, "Jane", 24, 4000, 0, 5000, timestamp),
    Row(5, "Joe", 25, 5000, 100, 5100, timestamp),
    Row(5, "Joe", 25, 5000, 300, 5500, timestamp),
    Row(7, "Adam", null, null, null, null, null),
    Row(8, null, null, null, null, null, null)
)) // Duplicate rows for cardinality checks

timestamp: java.sql.Timestamp = 2023-10-21 08:35:35.789
rdd: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = ParallelCollectionRDD[249] at parallelize at <console>:40


In [89]:
// Create DataFrame
val df = spark.createDataFrame(rdd, schema)
df.show()

+---+-----+----+------+-----+---------+--------------------+
| id| name| age|salary|bonus|compTotal|          registered|
+---+-----+----+------+-----+---------+--------------------+
|  1| John|  21|  1000|  500|     1500|2023-10-21 08:35:...|
|  2|Smith|  23|  2000|  200|     2200|2023-10-21 08:35:...|
|  3| Mary|  22|  3000| 1000|     4000|2023-10-21 08:35:...|
|  4| Jane|  24|  4000|    0|     5000|2023-10-21 08:35:...|
|  5|  Joe|  25|  5000|  100|     5100|2023-10-21 08:35:...|
|  5|  Joe|  25|  5000|  300|     5500|2023-10-21 08:35:...|
|  7| Adam|null|  null| null|     null|                null|
|  8| null|null|  null| null|     null|                null|
+---+-----+----+------+-----+---------+--------------------+



df: org.apache.spark.sql.DataFrame = [id: int, name: string ... 5 more fields]


### Null Value Check

In [42]:
// Identify columns with nulls
val nullCheck = df.select(df.columns.map(c => 
    sum(when(col(c).isNull || isnan(col(c)), 1).otherwise(0)).alias(c)
): _*)

// Python
// null_check = df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns])

nullCheck.show()

+---+----+---+------+
| id|name|age|salary|
+---+----+---+------+
|  0|   1|  2|     2|
+---+----+---+------+



nullCheck: org.apache.spark.sql.DataFrame = [id: bigint, name: bigint ... 2 more fields]


### Duplicate Check

In [45]:
// Identify if duplicates exist
val duplicates = df.groupBy(df.columns.map(col): _*)
    .count()
    .filter(col("count") > 1)
    .agg(sum("count").alias("duplicates"))

// Python
// duplicates = df.groupBy(df.columns).count().where(col("count") > 1).select(sum("count").alias("duplicates")

duplicates.show()

+----------+
|duplicates|
+----------+
|         2|
+----------+



duplicates: org.apache.spark.sql.DataFrame = [duplicates: bigint]


### Unique Key Check

In [61]:
// Ensure designated unique key columns are indeed unique
val uniqueKeyCheck = df.groupBy("id")
    .count()
    .where(col("count") > 1)

// Python
// unique_key_check = df.groupBy("id").count().where(col("count") > 1)

uniqueKeyCheck.show()

+---+-----+
| id|count|
+---+-----+
|  5|    2|
+---+-----+



uniqueKeyCheck: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, count: bigint]


### Data Type Check

In [64]:
// Check that each column conforms to expected data type
val dataTypeCheck = Map(
    "id" -> IntegerType,
    "name" -> StringType,
    "age" -> IntegerType,
    "salary" -> IntegerType
)
dataTypeCheck.foreach { case (column, dataType) =>
    val mismatches = df.filter(col(column).cast(dataType) =!= col(column))
        .count()
    println(s"Mismatches in $column: $mismatches")
}

// Python
// data_type_check = {
//     "id": IntegerType,
//     "name": StringType,
//     "age": IntegerType,
//     "salary": IntegerType
// }
// for column, data_type in data_type_check.items():
//     mismatches = df.filter(col(column).cast(data_type) != col(column)).count()

Mismatches in id: 0
Mismatches in name: 0
Mismatches in age: 0
Mismatches in salary: 0


dataTypeCheck: scala.collection.immutable.Map[String,org.apache.spark.sql.types.AtomicType with Product with Serializable] = Map(id -> IntegerType, name -> StringType, age -> IntegerType, salary -> IntegerType)


### Value Range Validation

In [71]:
// Check that numeric columns fall within expected ranges
val lowerBound = 20
val upperBound = 23
val rangeCheck = df.filter((col("age") < lowerBound) || (col("age") > upperBound))

// Python
// lower_bound = 20
// upper_bound = 23
// range_check = df.filter((col("age") < lower_bound) | (col("age") > upper_bound))

rangeCheck.show()

+---+----+---+------+
| id|name|age|salary|
+---+----+---+------+
|  4|Jane| 24|  4000|
|  5| Joe| 25|  5000|
|  5| Joe| 25|  5000|
+---+----+---+------+



lowerBound: Int = 20
upperBound: Int = 23
rangeCheck: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, name: string ... 2 more fields]


### Categorical Value Validation

In [73]:
// Ensure that categorical columns only contain expected values
val validCategories = Seq("Mary", "John", "Relu")
val categoryCheck = df.filter(!col("name").isin(validCategories: _*))

// Python
// valid_categories = ["Tanh", "Sigmoid", "Relu"]
// category_check = df.filter(~col("name").isin(valid_categories))

categoryCheck.show()

+---+-----+----+------+
| id| name| age|salary|
+---+-----+----+------+
|  2|Smith|  23|  2000|
|  4| Jane|  24|  4000|
|  5|  Joe|  25|  5000|
|  5|  Joe|  25|  5000|
|  7| Adam|null|  null|
+---+-----+----+------+



validCategories: Seq[String] = List(Mary, John, Relu)
categoryCheck: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, name: string ... 2 more fields]


### Date Range Check

In [80]:
// Check that timestamp columns fall within expected ranges
val lowerDate = "2019-01-01"
val upperDate = "2022-12-31"
val dateCheck = df.filter((col("registered") < lowerDate) || (col("registered") > upperDate))

// Python
// lower_date = "2019-01-01"
// upper_date = "2022-12-31"
// date_check = df.filter((col("registered") < lower_date) | (col("registered") > upper_date))

dateCheck.show()

+---+-----+---+------+--------------------+
| id| name|age|salary|          registered|
+---+-----+---+------+--------------------+
|  1| John| 21|  1000|2023-10-21 08:29:...|
|  2|Smith| 23|  2000|2023-10-21 08:29:...|
|  3| Mary| 22|  3000|2023-10-21 08:29:...|
|  4| Jane| 24|  4000|2023-10-21 08:29:...|
|  5|  Joe| 25|  5000|2023-10-21 08:29:...|
|  5|  Joe| 25|  5000|2023-10-21 08:29:...|
+---+-----+---+------+--------------------+



lowerDate: String = 2019-01-01
upperDate: String = 2022-12-31
dateCheck: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, name: string ... 3 more fields]


### Consistency Check

In [90]:
// Check if data in different columns are consistent with each other
val consistencyCheck = df.filter(col("salary") + col("bonus") =!= col("compTotal"))

// Python
// consistency_check = df.filter(col("salary") + col("bonus") != col("compTotal"))

consistencyCheck.show()

+---+----+---+------+-----+---------+--------------------+
| id|name|age|salary|bonus|compTotal|          registered|
+---+----+---+------+-----+---------+--------------------+
|  4|Jane| 24|  4000|    0|     5000|2023-10-21 08:35:...|
|  5| Joe| 25|  5000|  300|     5500|2023-10-21 08:35:...|
+---+----+---+------+-----+---------+--------------------+



consistencyCheck: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, name: string ... 5 more fields]


### Cardinality Check

In [91]:
// Check cardinality of columns to understand how unique data is
val cardinalityCheck = df.select("salary").distinct()

// Python
// cardinality_check = df.select("salary").distinct()

cardinalityCheck.show()

+------+
|salary|
+------+
|  1000|
|  2000|
|  3000|
|  4000|
|  5000|
|  null|
+------+



cardinalityCheck: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [salary: int]


### Anomaly Detection

In [101]:
// Identify statistical anomalies or outliers in numeric columns
val numericColumns = Seq("age", "salary", "bonus", "compTotal")

// Create a sequence of column expressions for avg and stddev
val statExprs = numericColumns.flatMap(c =>
    Seq(avg(col(c)).alias(s"${c}_mean"), stddev(col(c)).alias(s"${c}_stddev"))
)

// Collect the statistics
val dfStats = df.select(statExprs: _*).collect()(0)

// Iterate through each numeric column and check for anomalies
numericColumns.foreach { column =>
    val mean = dfStats.getAs[Double](s"${column}_mean")
    val stddev = dfStats.getAs[Double](s"${column}_stddev")

  // Anomaly Check
  val anomalyCheck = df.filter((col(column) < (mean - 3 * stddev)) || (col(column) > (mean + 3 * stddev)))
    .count()

    println(s"Number of anomalies in $column: $anomalyCheck")
}

// Python
// numeric_columns = ["age", "salary", "bonus", "compTotal"]

// stat_exprs = []
// for col in numeric_columns:
//     stat_exprs.append(F.avg(F.col(col)).alias(f"{col}_mean"))
//     stat_exprs.append(F.stddev(F.col(col)).alias(f"{col}_stddev"))

// df_stats = df.select(*stat_exprs).collect()[0]

// for col in numeric_columns:
//     mean = df_stats[f"{col}_mean"]
//     stddev = df_stats[f"{col}_stddev"]
    
//     anomaly_check = df.filter((F.col(col) < (mean - 3 * stddev)) | (F.col(col) > (mean + 3 * stddev))).count()

//     print(f"Number of anomalies in {col}: {anomaly_check}")


Number of anomalies in age: 0
Number of anomalies in salary: 0
Number of anomalies in bonus: 0
Number of anomalies in compTotal: 0


numericColumns: Seq[String] = List(age, salary, bonus, compTotal)
statExprs: Seq[org.apache.spark.sql.Column] = List(avg(age) AS age_mean, stddev_samp(age) AS age_stddev, avg(salary) AS salary_mean, stddev_samp(salary) AS salary_stddev, avg(bonus) AS bonus_mean, stddev_samp(bonus) AS bonus_stddev, avg(compTotal) AS compTotal_mean, stddev_samp(compTotal) AS compTotal_stddev)
dfStats: org.apache.spark.sql.Row = [23.333333333333332,1.632993161855452,3333.3333333333335,1632.993161855452,350.0,361.93922141707714,3883.3333333333335,1665.4328766620006]
