In [None]:
import re
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, IntegerType

In [None]:
def concat_address(column_values: list, join_token: str=' ') -> str:
    """
    Concat multiple address fields into a single attribute
    """
    compound_values = []
    ret = ''
    for value in column_values:
        if value and (str(value)).strip().upper() != '':
            value = re.sub(r'[_;\s]+', ' ', re.sub(r'[.,():/]', '', str(value).upper())).strip()
            compound_values.append(value)
    if len(compound_values) > 0:
        ret = join_token.join(compound_values)
    return ret
spark.udf.register("concat_address", concat_address, StringType())

def is_valid_phone_number(phone_number: str, min_length: int=7) -> int:
    """
    Flag phone numbers that are too short, or contain single digits, e.g. 1111111, 11222222
    """
    if len(phone_number) < min_length:
        return 0
    else:
        # check if the phone number contains only single digits
        if len(phone_number) <= 2:
            digits = set(phone_number)
        else:
            # do not take into account first 2 digits which are likely area code
            digits = set(phone_number[2:])
        return 0 if len(digits) == 1 else 1
spark.udf.register("is_valid_phone_number", is_valid_phone_number, IntegerType())

In [None]:
'''
Convert entity attribute data from wide format to long format.
Input: 
- Attributes in wide format: [entity, attribute_1, attribute_2, ...]
- attribute mappings: e.g.
    "contact_attribute_mapping": {
        "name": 
        {
            "mapping_cols": ["official_name", "commercial_name"],
            "output_table": "company_names"
        },
        "address": 
        {
            "mapping_cols": ["address"],
            "output_table": "company_addresses"
        }
    }
Output: 
- Attribute data in long format: [entity, attribute, value]
'''

def convert_to_long_format(
    attribute_wide_format_data: DataFrame,
    attribute_mappings: dict,
    entity_col: str = 'company', 
    time_col: str = None
) -> DataFrame:
    # get a list of contact attribute columns
    attribute_columns = []
    for attribute in attribute_mappings:
        attribute_columns.extend(attribute_mappings[attribute]['mapping_cols'])

    # load attribute data
    if time_col is None:
        attribute_wide_format_data = attribute_wide_format_data.select([entity_col] + attribute_columns)
    else:
        attribute_wide_format_data = attribute_wide_format_data.select([entity_col, time_col] + attribute_columns)

    # convert contact data to long format
    if time_col is None:
        melted_attribute_data = melt(
            data=attribute_wide_format_data,
            id_vars=[entity_col],
            value_vars=attribute_columns,
            var_name='attribute',
            value_name='value'
        )
    else:
        melted_attribute_data = melt(
            data=attribute_wide_format_data,
            id_vars=[entity_col, time_col],
            value_vars=attribute_columns,
            var_name='attribute',
            value_name='value'
        )

    melted_attribute_data = melted_attribute_data.filter(F.trim('value') != "")
    return melted_attribute_data


In [None]:
"""
Convert dataframe from wide format to long format. Equivalent to pd.melt
Reference: https://stackoverflow.com/questions/41670103/how-to-melt-spark-dataframe
"""
def melt(
        data: DataFrame, 
        id_vars: Iterable[str], 
        value_vars: Iterable[str], 
        var_name: str="attribute", 
        value_name: str="value") -> DataFrame:
    var_value_pairs = F.create_map(
        list(chain.from_iterable([
            [F.lit(column), F.col(column)] for column in value_vars]
        ))
    )
    
    data = data.select(*id_vars, F.explode(var_value_pairs)) \
        .withColumnRenamed('key', var_name) \
        .withColumnRenamed('value', value_name)
    return data

In [None]:
contact_data = spark.sql(f"""
    SELECT DISTINCT
        {entity_col},
        establishment,
        establishment_type,
        official_name,
        commercial_name,
        email,
        phone_number,
        is_valid_phone_number(phone_number) AS is_valid_phone,
        TRIM(UPPER(concat_address(array(street_number, 
                                        street_name, 
                                        address_complement,
                                        district, 
                                        city_name,
                                        state, 
                                        zip_code), ' '))) AS address
    FROM {output_db}.{input_table}
    WHERE {entity_col} != '' 
    """).repartition(partitions)
if display_data:
    display(contact_data.limit(100))

In [None]:
# set invalid phone numbers to empty string so we can filter them out later
contact_data = contact_data.withColumn('phone_number', 
                                        F.when(F.col('is_valid_phone') == 0, '').otherwise(F.col('phone_number')))

# minor cleanup on strings
for attribute in attribute_columns:
    contact_data = contact_data.withColumn(attribute, 
                                           F.regexp_replace(F.col(attribute), "[\"\'*()]", "")) 
    contact_data = contact_data.withColumn(attribute, F.expr(f"rtrim('.', rtrim({attribute}))"))

In [None]:
hq_contact_data = contact_data.filter(F.col('establishment_type') == 'headquarter').distinct()
hq_contact_data = hq_contact_data.select([entity_col] + attribute_columns)
for attribute in attribute_columns:
    hq_contact_data = hq_contact_data.withColumnRenamed(attribute, f'headquarter_{attribute}')
contact_data = contact_data.join(hq_contact_data, on='company', how='left').cache()
contact_data = contact_data.drop(*['is_valid_phone', 'establishment'])

print(f'Total records: {contact_data.count()}')
print(f'Number of companies: {contact_data.select(entity_col).distinct().count()}')

if display_data:
    display(contact_data.limit(100))

In [None]:
melted_contact_data = convert_to_long_format(
                        attribute_wide_format_data=contact_data,
                        attribute_mappings=attribute_dict,
                        entity_col=entity_col
                    )