In [None]:
import pyspark
import pandas as pd
import numpy as np
from hashlib import sha256
from pyspark.sql import SparkSession
from pyspark.sql.types import ShortType, ByteType
import pyspark.sql.functions as F

In [None]:
spark = SparkSession.builder.appName("MNO_Event_Cleansing").getOrCreate()

In [None]:
data = {
    'user_id':['00', '1', '1','1','2', '2', '2', '3'],
    'timestamp': ['2023-01-01T00:12:00', None, '2023-01-01T00:00:00', '2023-01-01T00:03:00', 
                  '2023-05-01T00:00:00', '2023-05-01T01:00:00', '2023-05-01T01:00:00', '2023-01-01T00:03:00'],
    'mcc': [254, 254, 254, 254, 254, 254, 254, 254], 
    'cell_id': ['214030412038931', '214030412038931', '214030412038931', '214030412038931', '214030412038935', None, None, None], 
    'latitude': [-3.62958, -3.62954, -3.62958, -3.62954, None, -3.62959, -3.62950, None],
    'longitude': [40.51873, 40.51870, 40.51873, 40.51870, None, None, 40.51874, None],
    'loc_error': [100, 100, 100, 100, None, None, 100, None]
    
}

In [None]:
mno_data = pd.DataFrame(data)
mno_data['user_id'] = mno_data['user_id'].apply(lambda x: sha256(x.encode('utf-8')).digest())

In [None]:
mno_data.loc[0, 'user_id'] = None
mno_data = mno_data.replace({float('nan'): None})

In [None]:
display(mno_data.head())

In [None]:
timestamp_format = "yyyy-MM-dd'T'HH:mm:ss"
input_timezone = 'America/Los_Angeles'
data_period_start = '2023-01-01' # '2023-01-01 00:00:00'
data_period_end = '2023-05-01'
do_bounding_box_filtering = True
bounding_box = {'min_lon': -180,
                'max_lon': 180,
                'min_lat': -90,
                'max_lat': 90
                }
clean_mno_event_data_write_path = '../sample_data/output/clean_mno_data'

In [None]:
mandatory_columns_casting_dict = {
    "user_id": "binary",
    "timestamp": "timestamp",
    "mcc": "integer",
    "cell_id": "string",
    "latitude": "float",
    "longitude": "float"
}
# if you know for sure that this column is not present and thus will be always null
# make dtype '' - empty string
optional_columns_casting_dict = {
    "loc_error": "float"
}

In [None]:
spark_df = spark.createDataFrame(mno_data)

In [None]:
spark_df.printSchema()
spark_df.show()

In [None]:
# better naming ? convert_to_schema?
def check_existance_of_columns(df: pyspark.sql.dataframe.DataFrame, 
                               mandatory_columns: list[str], 
                               optional_columns: list[str]
                              ) -> pyspark.sql.dataframe.DataFrame:
    df_columns = df.columns
    mandatory_common_columns = set(mandatory_columns).intersection(set(df_columns))
    optional_common_columns = set(optional_columns).intersection(set(df_columns))

    if len(mandatory_common_columns) != len(mandatory_columns):
        raise KeyError("Not all mandatory columns in df are present")

    missing_optional_columns = set(optional_columns) - set(optional_common_columns)
    for missing_optional_column in missing_optional_columns:
        df = df.withColumn(missing_optional_column, F.lit(None).cast('string'))
        
    df = df.select(mandatory_columns + optional_columns)
    
    return df    

In [None]:
def filter_nulls(df: pyspark.sql.dataframe.DataFrame, 
                 filter_columns: list[str]= None
                ) -> pyspark.sql.dataframe.DataFrame:
    
    df = df.na.drop(how='any', subset = filter_columns)
    
    return df

In [None]:
def filter_null_locations(df: pyspark.sql.dataframe.DataFrame) -> pyspark.sql.dataframe.DataFrame:
    
    df = df.filter((F.col('cell_id').isNotNull()) | (F.col('longitude').isNotNull()&F.col('latitude').isNotNull()))

    return df

In [None]:
def convert_time_column_to_timestamp(df: pyspark.sql.dataframe.DataFrame, 
                                     timestampt_format: str, 
                                     input_timezone: str
                                    ) -> pyspark.sql.dataframe.DataFrame:

    df = df.withColumn('timestamp',  F.to_utc_timestamp(F.to_timestamp('timestamp', timestampt_format), input_timezone))\
           .filter(F.col('timestamp').isNotNull())

    return df

In [None]:
def data_period_filtering(df: pyspark.sql.dataframe.DataFrame, 
                          data_period_start: str, 
                          data_period_end: str
                         ) -> pyspark.sql.dataframe.DataFrame:
    
    data_period_start = F.to_date(F.lit(data_period_start))
    data_period_end = F.to_date(F.lit(data_period_end))
    # inclusive on both sides
    df = df.filter(F.col('timestamp').between(data_period_start, data_period_end))

    return df

In [None]:
def bounding_box_filtering(df: pyspark.sql.dataframe.DataFrame,
                           bounding_box: dict
                          ) -> pyspark.sql.dataframe.DataFrame:
    # coordinates of bounding box should be of the same crs of mno data
    lat_condition = (F.col('latitude').between(bounding_box['min_lat'], bounding_box['max_lat']))
    lon_condition = (F.col('longitude').between(bounding_box['min_lon'], bounding_box['max_lon']))

    df = df.filter(lat_condition & lon_condition)

    return df


In [None]:
def cast_columns(df: pyspark.sql.dataframe.DataFrame, 
                 mandatory_columns_casting_dict: dict, 
                 optional_columns_casting_dict: dict
                ) -> pyspark.sql.dataframe.DataFrame:
    # casting timestamp was done in convert_time_column_to_timestamp function
    # not del or pop to not change global mandatory_columns_casting_dict
    mandatory_columns_casting_dict = {col:dtype for col, dtype in mandatory_columns_casting_dict.items() if col!='timestamp'}
    # perform casting only for those optional columns that are not entirely nulls
    optional_columns_casting_dict = {col:dtype for col, dtype in optional_columns_casting_dict.items() if dtype.strip()}
    # for python 3.9 and greater 
    columns_casting_dict = mandatory_columns_casting_dict | optional_columns_casting_dict
    
    for col, dtype in columns_casting_dict.items():
        df = df.withColumn(col, F.col(col).cast(dtype))
        
    # nulls in location columns are treated differently
    # optional columns can have null values
    filter_columns = list(set(mandatory_columns_casting_dict.keys())\
                          - set(['cell_id', 'latitude', 'longitude'] + list(optional_columns_casting_dict.keys())))

    df = filter_nulls(df, filter_columns)
    df = filter_null_locations(df)

    return df

In [None]:
df = check_existance_of_columns(spark_df, 
                                list(mandatory_columns_casting_dict.keys()), 
                                list(optional_columns_casting_dict.keys()))
df.show()

In [None]:
df = filter_nulls(df, ['user_id', 'timestamp'])
df.show()

In [None]:
df = filter_null_locations(df)
df.show()

In [None]:
df = convert_time_column_to_timestamp(df, timestamp_format, input_timezone)
df.show()
df.printSchema()

In [None]:
df = data_period_filtering(df, data_period_start, data_period_end)
df.show()

In [None]:
if do_bounding_box_filtering:
    df = bounding_box_filtering(df, bounding_box)
df.show()

In [None]:
df = cast_columns(df, mandatory_columns_casting_dict, optional_columns_casting_dict)
df.show()
df.printSchema()

In [None]:
df = df.withColumn("year", F.year("timestamp").cast(ShortType())) \
       .withColumn("month", F.month("timestamp").cast(ByteType())) \
       .withColumn("day", F.dayofmonth("timestamp").cast(ByteType()))
df.show()
df.printSchema()

In [None]:
df = df.sort(['user_id', 'timestamp'])
df.show()

In [None]:
df.write.parquet(clean_mno_event_data_write_path, mode='append', partitionBy = ['year', 'month', 'day'])