From 01a6eb1ec164516d6387a7d89710e722f654ed43 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sat, 13 Jun 2020 09:39:40 +0100 Subject: [PATCH] add arrays --- splink_data_normalisation/arrays.py | 26 ++++++++++++++++++++++++ tests/test_array.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 splink_data_normalisation/arrays.py create mode 100644 tests/test_array.py diff --git a/splink_data_normalisation/arrays.py b/splink_data_normalisation/arrays.py new file mode 100644 index 0000000..0ec54ae --- /dev/null +++ b/splink_data_normalisation/arrays.py @@ -0,0 +1,26 @@ +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import expr, regexp_replace, col + +def fix_zero_length_arrays(df:DataFrame): + """For every field of type array, turn zero length arrays into true nulls + + Args: + df (DataFrame): Input Spark dataframe + + Returns: + DataFrame: Spark Dataframe with clean arrays + """ + + array_cols = [item[0] for item in df.dtypes if item[1].startswith('array')] + + stmt = """ + case + when size({c}) > 0 then {c} + else null + end + """ + + for c in array_cols: + df = df.withColumn(c, expr(stmt.format(c=c))) + + return df \ No newline at end of file diff --git a/tests/test_array.py b/tests/test_array.py new file mode 100644 index 0000000..5d2f078 --- /dev/null +++ b/tests/test_array.py @@ -0,0 +1,31 @@ +import pytest +import pandas as pd + +from splink_data_normalisation.arrays import fix_zero_length_arrays +from pyspark.sql import Row + + +def test_fix_1(spark): + + names_list = [ + {"id": 1, "my_arr1": ["a", "b", "c"], "other_arr": [ ],"my_str": "a"}, + {"id": 2, "my_arr1": [ ], "other_arr": [1],"my_str": "a"}, + + ] + + df = spark.createDataFrame(Row(**x) for x in names_list) + df = df.select(list(names_list[0].keys())) + + df = fix_zero_length_arrays(df) + + df_result = df.toPandas() + + df_expected = [ + {"id": 1, "my_arr1": ["a", "b", "c"], "other_arr": None,"my_str": "a"}, + {"id": 2, "my_arr1": None, "other_arr": [1] ,"my_str": "a"}, + ] + + df_expected = pd.DataFrame(df_expected) + + pd.testing.assert_frame_equal(df_result,df_expected) +