In [9]:
import ipytest
ipytest.autoconfig()


In [10]:
from pyspark.sql import SparkSession
from src.utilities import remove_duplicates, fill_nulls,flatten_json


spark = SparkSession.builder.master("local[*]").appName("Test").getOrCreate()

In [4]:
%%ipytest -qq
def test_remove_duplicates():
    df = spark.createDataFrame([
        ("Alice", "NY"), ("Alice", "NY"), ("Bob", "LA")
    ], ["name", "city"])
    result = remove_duplicates(df)
    assert result.count() == 2

[Stage 0:>                                                        (0 + 12) / 12]

[32m.[0m[32m                                                                                            [100%][0m


                                                                                

In [5]:
%%ipytest -qq
def test_fill_nulls():
    df = spark.createDataFrame([
        (None, "NY"), ("Bob", None)
    ], ["name", "city"])
    result = fill_nulls(df, {"name": "NA", "city": "Unknown City"})
    rows = result.collect()
    assert rows[0]["name"] == "NA"
    assert rows[1]["city"] == "Unknown City"


[32m.[0m[32m                                                                                            [100%][0m


In [20]:
%%ipytest -qq
def test_flatten_json_exploding_arrays():
    data = [{
        "id": 1,
        "name": "John",
        "address": {"city": "NY", "zipcode": 12345},
        "phones": [
            {"type": "home", "number": "1234"},
            {"type": "work", "number": "5678"}
        ]
    }]
    df = spark.read.json(spark.sparkContext.parallelize(data))

    result=flatten_json(df,explode_arrays=True)
    # Check flattened columns
    expected_cols = {"id", "name", "address_city", "address_zipcode", "phones_type", "phones_number"}
    assert set(result.columns) == expected_cols
    # Check number of output rows count 
    assert result.count() == 2

[32m.[0m[32m                                                                                            [100%][0m


In [22]:
%%ipytest -qq
def test_flatten_json_exploding_arrays():
    data = [{
        "id": 1,
        "name": "John",
        "address": {"city": "NY", "zipcode": 12345},
        "phones": [
            {"type": "home", "number": "1234"},
            {"type": "work", "number": "5678"}
        ]
    }]
    df = spark.read.json(spark.sparkContext.parallelize(data))

    result=flatten_json(df,explode_arrays=False)
    # Check flattened columns
    expected_cols = {"id", "name", "address_city", "address_zipcode", "phones"}
    assert set(result.columns) == expected_cols
    # Check number of output rows count 
    assert result.count() == 1

[32m.[0m[32m                                                                                            [100%][0m
