In [1]:
import ipytest
ipytest.autoconfig()


In [6]:
from pyspark.sql import SparkSession
from src.utilities import remove_duplicates, fill_nulls,flatten_json,mask_dataframe


spark = SparkSession.builder.master("local[*]").appName("Test").getOrCreate()

In [4]:
%%ipytest -qq
def test_remove_duplicates():
    df = spark.createDataFrame([
        ("Alice", "NY"), ("Alice", "NY"), ("Bob", "LA")
    ], ["name", "city"])
    result = remove_duplicates(df)
    assert result.count() == 2

[Stage 0:>                                                        (0 + 12) / 12]

[32m.[0m[32m                                                                                            [100%][0m


                                                                                

In [5]:
%%ipytest -qq
def test_fill_nulls():
    df = spark.createDataFrame([
        (None, "NY"), ("Bob", None)
    ], ["name", "city"])
    result = fill_nulls(df, {"name": "NA", "city": "Unknown City"})
    rows = result.collect()
    assert rows[0]["name"] == "NA"
    assert rows[1]["city"] == "Unknown City"


[32m.[0m[32m                                                                                            [100%][0m


In [20]:
%%ipytest -qq
def test_flatten_json_exploding_arrays():
    data = [{
        "id": 1,
        "name": "John",
        "address": {"city": "NY", "zipcode": 12345},
        "phones": [
            {"type": "home", "number": "1234"},
            {"type": "work", "number": "5678"}
        ]
    }]
    df = spark.read.json(spark.sparkContext.parallelize(data))

    result=flatten_json(df,explode_arrays=True)
    # Check flattened columns
    expected_cols = {"id", "name", "address_city", "address_zipcode", "phones_type", "phones_number"}
    assert set(result.columns) == expected_cols
    # Check number of output rows count 
    assert result.count() == 2

[32m.[0m[32m                                                                                            [100%][0m


In [22]:
%%ipytest -qq
def test_flatten_json_no_exploding_arrays():
    data = [{
        "id": 1,
        "name": "John",
        "address": {"city": "NY", "zipcode": 12345},
        "phones": [
            {"type": "home", "number": "1234"},
            {"type": "work", "number": "5678"}
        ]
    }]
    df = spark.read.json(spark.sparkContext.parallelize(data))

    result=flatten_json(df,explode_arrays=False)
    # Check flattened columns
    expected_cols = {"id", "name", "address_city", "address_zipcode", "phones"}
    assert set(result.columns) == expected_cols
    # Check number of output rows count 
    assert result.count() == 1

[32m.[0m[32m                                                                                            [100%][0m


In [7]:
%%ipytest -qq

df=spark.createDataFrame(
        [
            (1, "Alice", "alice@example.com", "9876543210"),
            (2, "Bob", "bob@example.com", "9123456789"),
            (3, None, None, None),   # Null edge case
        ],
        ["id", "name", "email", "phone"]
    )
def test_full_mask():
    masked = mask_dataframe(df, ["name"], default_mask="MASKED")
    rows = masked.select("name").collect()
    assert all(r.name == "MASKED" or r.name is None for r in rows)

def test_partial_mask():
    masked = mask_dataframe(df, {"name": "partial"})
    rows = [r.name for r in masked.select("name").collect()]

    for original, masked_val in zip(["Alice", "Bob", None], rows):
        if original is None:
            assert masked_val is None
        else:
            assert masked_val.startswith(original[:2])  # first 2 preserved
            assert set(masked_val[2:]) <= {"*"}        # rest are only '*'
            assert len(masked_val) == len(original)    # length unchanged


def test_hash_mask():
    masked = mask_dataframe(df, {"email": "hash"})
    rows = masked.select("email").collect()
    # Ensure SHA2 hash length is 64 hex chars
    assert all(len(r.email) == 64 or r.email is None for r in rows)

def test_custom_expr_mask():
    masked = mask_dataframe(df, {"phone": "expr:concat('XXX', substr(phone, -4, 4))"})
    rows = [r.phone for r in masked.select("phone").collect()]
    assert rows[0] == "XXX3210"
    assert rows[1] == "XXX6789"
    assert rows[2] is None  # null case handled

def test_skip_nonexistent_column():
    # Should not raise error if column doesn't exist
    masked = mask_dataframe(df, {"nonexistent": "full"})
    assert "nonexistent" not in masked.columns

[32m.[0m

                                                                                

[31mF[0m[31mF[0m[32m.[0m[32m.[0m[31m                                                                                        [100%][0m
[31m[1m________________________________________ test_partial_mask _________________________________________[0m

    [0m[94mdef[39;49;00m[90m [39;49;00m[92mtest_partial_mask[39;49;00m():[90m[39;49;00m
        masked = mask_dataframe(df, {[33m"[39;49;00m[33mname[39;49;00m[33m"[39;49;00m: [33m"[39;49;00m[33mpartial[39;49;00m[33m"[39;49;00m})[90m[39;49;00m
        rows = masked.select([33m"[39;49;00m[33mname[39;49;00m[33m"[39;49;00m).collect()[90m[39;49;00m
        [94massert[39;49;00m rows[[94m0[39;49;00m].name.startswith([33m"[39;49;00m[33mAl[39;49;00m[33m"[39;49;00m) [95mand[39;49;00m rows[[94m0[39;49;00m].name.endswith([33m"[39;49;00m[33m**[39;49;00m[33m"[39;49;00m)[90m[39;49;00m
>       [94massert[39;49;00m rows[[94m1[39;49;00m].name.startswith([33m"[39;49;00m[33mBo[39;49;00m[33m