Skip to content
This repository has been archived by the owner on May 18, 2023. It is now read-only.

Commit

Permalink
add arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Jun 13, 2020
1 parent ad2132b commit 01a6eb1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
26 changes: 26 additions & 0 deletions splink_data_normalisation/arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import expr, regexp_replace, col

def fix_zero_length_arrays(df:DataFrame):
"""For every field of type array, turn zero length arrays into true nulls
Args:
df (DataFrame): Input Spark dataframe
Returns:
DataFrame: Spark Dataframe with clean arrays
"""

array_cols = [item[0] for item in df.dtypes if item[1].startswith('array')]

stmt = """
case
when size({c}) > 0 then {c}
else null
end
"""

for c in array_cols:
df = df.withColumn(c, expr(stmt.format(c=c)))

return df
31 changes: 31 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import pandas as pd

from splink_data_normalisation.arrays import fix_zero_length_arrays
from pyspark.sql import Row


def test_fix_1(spark):

names_list = [
{"id": 1, "my_arr1": ["a", "b", "c"], "other_arr": [ ],"my_str": "a"},
{"id": 2, "my_arr1": [ ], "other_arr": [1],"my_str": "a"},

]

df = spark.createDataFrame(Row(**x) for x in names_list)
df = df.select(list(names_list[0].keys()))

df = fix_zero_length_arrays(df)

df_result = df.toPandas()

df_expected = [
{"id": 1, "my_arr1": ["a", "b", "c"], "other_arr": None,"my_str": "a"},
{"id": 2, "my_arr1": None, "other_arr": [1] ,"my_str": "a"},
]

df_expected = pd.DataFrame(df_expected)

pd.testing.assert_frame_equal(df_result,df_expected)

0 comments on commit 01a6eb1

Please sign in to comment.