In [5]:
import re
import math
from datetime import datetime
import pyspark.sql.functions as F
from pyspark.sql.types import StringType

In [7]:
def convert_time_to_period(date: str) -> str:
    """
    Convert time to a period format like 2018-Q1
    """
    try:
        time = datetime.strptime(date, input_time_format)
        month = time.month
        year = time.year
        period = math.ceil(int(month)/period_length)
        return f'{year}-{period_prefix}{period}'
    except:
        return None
    
spark.udf.register("convert_time", convert_time_to_period, StringType())

In [3]:
'''
Convert entity attribute data from wide format to long format.
Input: 
- Attributes in wide format: [entity, attribute_1, attribute_2, ...]
- attribute mappings: e.g.
    "contact_attribute_mapping": {
        "name": 
        {
            "mapping_cols": ["official_name", "commercial_name"],
            "output_table": "company_names"
        },
        "address": 
        {
            "mapping_cols": ["address"],
            "output_table": "company_addresses"
        }
    }
Output: 
- Attribute data in long format: [entity, attribute, value]
'''

def convert_to_long_format(
    attribute_wide_format_data: DataFrame,
    attribute_mappings: dict,
    entity_col: str = 'company', 
    time_col: str = None
) -> DataFrame:
    # get a list of contact attribute columns
    attribute_columns = []
    for attribute in attribute_mappings:
        attribute_columns.extend(attribute_mappings[attribute]['mapping_cols'])

    # load attribute data
    if time_col is None:
        attribute_wide_format_data = attribute_wide_format_data.select([entity_col] + attribute_columns)
    else:
        attribute_wide_format_data = attribute_wide_format_data.select([entity_col, time_col] + attribute_columns)

    # convert contact data to long format
    if time_col is None:
        melted_attribute_data = melt(
            data=attribute_wide_format_data,
            id_vars=[entity_col],
            value_vars=attribute_columns,
            var_name='attribute',
            value_name='value'
        )
    else:
        melted_attribute_data = melt(
            data=attribute_wide_format_data,
            id_vars=[entity_col, time_col],
            value_vars=attribute_columns,
            var_name='attribute',
            value_name='value'
        )

    melted_attribute_data = melted_attribute_data.filter(F.trim('value') != "")
    return melted_attribute_data


In [None]:
"""
Convert dataframe from wide format to long format. Equivalent to pd.melt
Reference: https://stackoverflow.com/questions/41670103/how-to-melt-spark-dataframe
"""
def melt(
        data: DataFrame, 
        id_vars: Iterable[str], 
        value_vars: Iterable[str], 
        var_name: str="attribute", 
        value_name: str="value") -> DataFrame:
    var_value_pairs = F.create_map(
        list(chain.from_iterable([
            [F.lit(column), F.col(column)] for column in value_vars]
        ))
    )
    
    data = data.select(*id_vars, F.explode(var_value_pairs)) \
        .withColumnRenamed('key', var_name) \
        .withColumnRenamed('value', value_name)
    return data

In [3]:
staging_table = 'INSERT TABLE NAME HERE/DB CONNECTION'
buyer_join_token = '=='
## do we need to set a min time?
input_min_time = ''

### unsure about the partition number
partitions = 

In [None]:
## get from HA these inputs 

attribute_dict = properties['attribute_mapping']
entity_col = property_dict['entity_col']

In [None]:
property_df = spark.sql(f"SELECT * FROM {staging_table}")

activity_data = spark.sql(f"""
    SELECT DISTINCT
        parties_id,
        awards_id,
        tender.lots.id,
        awards_date,
        convert_time(awards_date) AS {output_time_col},
        concat(buyer_id, '{buyer_join_token}', buyer_name) AS buyer,
        item_description
    FROM {staging_table}
    WHERE parties_id != '' 
        AND awards_date IS NOT NULL
        AND TRIM({input_time_col}) >= '{input_min_time}'
    """).repartition(partitions)
    
activity_data = activity_data.filter(F.col(output_time_col).isNotNull())

In [None]:
activity_data = activity_data.withColumn('item_description', F.trim(F.regexp_replace(F.col('item_description'), "[\"\'*()-:;]", ""))) 

In [None]:
melted_activity_data = convert_to_long_format(
                        attribute_wide_format_data=activity_data,
                        attribute_mappings=attribute_dict,
                        entity_col=entity_col,
                        time_col=output_time_col
                    )

In [None]:
melted_activity_data.write.mode("overwrite").saveAsTable(f'{output_db}.{output_table}')