In [0]:
from abc import ABC, abstractmethod

from pyspark.sql import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *

from typing import Union

import mlflow
import numpy as np
import pandas as pd

In [0]:
class SeriesData:
    """This class is designed to enable consistent and efficient retrieval of standard price time series data"""
    
    # Columns to join with after calculating moving averages
    _remaining_cols = ['swire_period_date','unit_sales','is_promotion','promotional_price','dollar_sales',
                        'any_price_decrease_dollar_sales','standard_physical_volume',
                        'any_price_decrease_standard_physical_volume','state']
    
    # Column lists for ordering data and partitioning windows
    _df_order_cols = ['period_number','pod_id','product_key','key_brand','key_package']
    _window_order_cols = ['pod_id','product_key','period_number']
    _window_partition_cols = ['pod_id','product_key','key_brand','key_package']
    
    # Variables for assisting with consistency
    DATABASE = 'stf_db'
    TABLE = f'{DATABASE}.product_sales'  # Table containing sales data
    STD_PRICE = 'standard_price'
    _TRANS_NUM = 'transaction_number'
    SERIES_TABLE = f'{DATABASE}.product_sales_series'
        
    #----- PUBLIC METHODS -----#
    @classmethod
    def calc_dollar_promotion(cls, df):
        '''
        Function to calculate the dollars spent on promotion
        '''
        dollar_promo = lambda d: f'dollar_promo_{d}wk'
        week_number = lambda w: int(w.rsplit('_', maxsplit=1)[1].rstrip('wk'))
        
        avg_sp_cols_list = [c for c in df.columns if c.find('wk') >=0]
        wk_windows_list = [week_number(w) for w in avg_sp_cols_list]
        
        for w,c in zip(wk_windows_list, avg_sp_cols_list):
            col_name = dollar_promo(w)
            df = df.withColumn(col_name, \
                           col('avg_unit_sales') * (col(c) - col('avg_promotional_price'))) \
            .withColumn(col_name, round(when(col(col_name)<0, 0).otherwise(col(col_name)),2))

        return df
    
    
    @staticmethod
    def estimate_physical_volume_range(df):
        """
        df: Spark dataframe containing swire_period date and standard_physical_volume columns
        """
        volume_by_period = [x.asDict().get('sum(standard_physical_volume)') for x in \
                            df.groupBy('swire_period_date').sum('standard_physical_volume').collect()]
        avg_spv = np.mean(volume_by_period)
        med_spv = np.median(volume_by_period)
        stdDev_spv = np.std(volume_by_period)
        
        print(f"Average Physical Volume: {np.round(avg_spv,2)}\nMedian Physical Volume: {np.round(med_spv,2)}\nStandard Deviation: {np.round(stdDev_spv,2)}\n")
        print(f"Physical Volume Range Expectation: {np.round(avg_spv - stdDev_spv,2)} to {np.round(avg_spv + stdDev_spv,2)}\n")
        return None
    
    @classmethod
    def group_by_state_and_keyCol(cls, keyCol: str='key_brand'):
        """
        Retrieve a dataframe with metric values grouped by state and one additional parameter
        - Parameters -
        keyCol: Accepts the following values: 'key_brand','key_package'
        """
        group_cols = ['state','period_number', 'swire_period_date']
        group_cols.insert(0, keyCol)
        
        # Compute grouped averages
        df = spark.sql(f"SELECT * FROM {cls.SERIES_TABLE}")
        value_cols = [vc[0] for vc in df.dtypes if vc[1].find('double') >= 0]
        df = df.groupBy(group_cols).mean(*value_cols).orderBy(group_cols)
        
        # Round values to 2 decimal places
        avg_value_cols = [f'avg({vc})' for vc in value_cols]
        for vc, avc in zip(value_cols, avg_value_cols):
            df = df.withColumn(avc, (round(avc,2))).withColumnRenamed(f'avg({vc})', f'avg_{vc}')
            
        return df
    
    @classmethod
    def std_price_moving_average(cls, numWeeksWindow: Union[int, list]=0):
        """
        Retrieve sales data with standard price moving averages computed
        - Parameters -
        numWeeksWindow: The number of weeks to use for the moving average window. Accepts int or list types.
        """
        # Retrieve standard_price data
        select_cols = [cls._TRANS_NUM,'period_number','pod_id','product_key','key_brand','key_package', cls.STD_PRICE]
        query = f"SELECT {', '.join(select_cols)} FROM {cls.TABLE} ORDER BY {', '.join(cls._df_order_cols)}"
        df_std_price = spark.sql(query)
        
        # Impute missing standard_price values
        df = cls._fill_standard_price(df_std_price)
        
        def compute_moving_average(cls, df, numWeeks: int):
            if numWeeks > 0:
                lag_price_str = 'lag_price'
                
                # Generate lag_price columns
                lag_columns = [f"{lag_price_str}{x+1}" for x in range(numWeeks)]
                prev_columns = [cls.STD_PRICE]
                lag_position = [x+1 for x in range(numWeeks)]
                _ = [prev_columns.append(f"{lag_price_str}{x+1}") for x in range(numWeeks-1)]
                
                # Create moving average column
                Window_spec = Window.partitionBy(cls._window_partition_cols).orderBy(cls._window_order_cols)
                lag_window = lambda x: lag(cls.STD_PRICE, x).over(Window_spec)
                mvgAvg_col = f"{cls.STD_PRICE}_{numWeeks}wk"
                df = df.withColumn(mvgAvg_col, lit(0))
                
                # Compute moving average
                for lag_c, position, prev_c in zip(lag_columns, lag_position, prev_columns):
                    
                    df= df.withColumn(lag_c, when(lag_window(position).isNotNull(), lag_window(position)).otherwise(col(prev_c)))
                    df = df.withColumn(mvgAvg_col, col(mvgAvg_col)+col(lag_c))
                    
                for lag_c in lag_columns:
                    df = df.drop(lag_c)
                    
                df = df.withColumn(mvgAvg_col, round((col(mvgAvg_col)/numWeeks).astype(DoubleType()),2))
                
            return df
        
        if type (numWeeksWindow) is int:
            df = compute_moving_average(cls, df, numWeeksWindow)
        elif type (numWeeksWindow) is list:
            for w in numWeeksWindow:
                df = compute_moving_average(cls, df, w)
        
        # Retrieve and join remaining table columns
        df = cls._get_remaining_columns(df)
        
        # Order results
        df = df.orderBy(cls._window_order_cols)
        
        return df
    
    @classmethod
    def time_series_transform_by_keyCols(cls, df, keyCol1: str='state', keyCol2: str='key_brand'):
        """
        Function to extract final time series
        - Parameters:
        df: Spark Dataframe
        keyCol1 (str): name of the first key column for the target timeseries
        keyCol2 (str): name of the second key column for the target timeseries
        """
        promo_cols = [c for c in df.columns if c.find('dollar_promo') >= 0]
        order_cols = [keyCol1, keyCol2, 'period_number','swire_period_date']
        select_cols = order_cols + ['avg_standard_physical_volume'] + promo_cols
        
        timeSeries_df = df.select(select_cols).orderBy(order_cols)

        return timeSeries_df
    
    @classmethod
    def write_moving_avg_table(cls, df, tableName: str='product_sales_series'):
        series_table = f"{cls.DATABASE}.product_sales_series"
        _ = spark.sql(f'DROP TABLE IF EXISTS {series_table}')
        result = mvg_avg_data.write.saveAsTable(series_table, format='delta', mode='overwrite')
        _ = spark.sql(f"OPTIMIZE {series_table}")
        
        return result
    
    @classmethod
    def write_to_delta_table(cls, df, tableName):
        """
        function to write the input dataframe to a delta table
        - Parameters:
        df: Spark Dataframe
        tableName (str): name of the output delta table 
        """
        table = f"{cls.DATABASE}.{tableName}"
        _ = spark.sql(f"DROP TABLE IF EXISTS {table}")
        result = df.write.saveAsTable(table, format='delta', mode='overwrite')
        _ = spark.sql(f"OPTIMIZE {table}")
        
        return result
    
    #----- PRIVATE METHODS -----#
    @classmethod
    def _fill_standard_price(cls, df):
        """
        Fill in missing standard price gaps with the values closest to them
        """
        Window_spec = Window.partitionBy(cls._window_partition_cols).orderBy(cls._window_order_cols)
        read_last = last(df[cls.STD_PRICE], ignorenulls=True).over(Window_spec)
        df = df.withColumn(cls.STD_PRICE, read_last)
        
        window_order_cols_r = [col(c).desc() for c in cls._window_order_cols]
        Window_spec_r = Window.partitionBy(cls._window_partition_cols).orderBy(window_order_cols_r)

        df_reverse = df.orderBy(cls._df_order_cols, ascending=False)
        read_last_r = last(df_reverse[cls.STD_PRICE], ignorenulls=True).over(Window_spec_r)
        df_reverse = df_reverse.withColumn(cls.STD_PRICE, read_last_r)
        df = df_reverse.orderBy(cls._df_order_cols, ascending=True)
        
        return df
    
    @classmethod
    def _get_remaining_columns(cls, df):
        """
        Append additional sales data columns to the standard price data
        """
        select_cols = cls._remaining_cols.copy()
        select_cols.insert(0,cls._TRANS_NUM)
        
        query = f"SELECT {', '.join(select_cols)} FROM {cls.TABLE}"
        df_remaining = spark.sql(query)
        
        full_df = df.join(df_remaining, cls._TRANS_NUM, how='left').drop(cls._TRANS_NUM)
        return full_df

In [0]:
class SeriesModel(ABC):
    """
    This class is designed to assist with modeling the time series data
    """
    DATABASE = 'stf_db'
    
    #----- PUBLIC METHODS -----#
    @classmethod
    def append_time_features(cls, df):
        sp_date = df.swire_period_date
        df = df.withColumn('weekofyear',weekofyear(sp_date)).\
            withColumn('month',month(sp_date)).\
            withColumn('quarter',quarter(sp_date)).\
            withColumn('year',year(sp_date))
        
        return df
    
    @abstractmethod
    def get_filtered_series(self, state: str, keyBrand: str):
        pass
    
    @classmethod
    def get_model_name(cls, state: str, keyVal: str):
        model_name = f"{state}_{keyVal}".replace(" ", "_").replace(".", "_")
        return model_name
    
    @abstractmethod
    def get_swire_period_date(self, periodNumber: int, date_format :str='%m-%d-%Y'):
        pass
    
    @abstractmethod
    def predict(self, model_params_dict: dict, state: str, keyCol: str):
        pass
    
    @abstractmethod
    def quick_stats(self):
        pass
    
    #----- PRIVATE METHODS -----#
    @abstractmethod
    def _get_period_data(self):
        pass
    
    def _predict(self, model_params_dict: dict, state: str, keyCol: str, keyType: str, evalMetric: str='r2'):
        """
        This function runs a forecasting model for predicting physical volume.
        - Parameters:
        model_params_dict: A dict containing independent variable parameters for the prediction
        state: The value of the state dimension to filter on
        keyCol: The value of the product dimension (brand name or package type) to filter on
        keyType: The type of model to search for (e.g.: 'keyBrand' or 'keyPackage')
        evalMetric: The metric used to evaluate the best model. Assumes higher metric values are better.
        """
        run_name_col = 'run_name'
        mlflow_metrics = ['mae','rmse','mse','r2']
        metric_cols = [f"metrics.{m}" for m in mlflow_metrics]
        run_cols = ['run_id',run_name_col]
        _ = [run_cols.append(m) for m in metric_cols]
        eval_metric = [mc for mc in metric_cols if evalMetric in mc].pop()
        
        experiments = mlflow.MlflowClient().search_experiments()
        experiment_ids = [dict(x).get('experiment_id') for x in experiments if dict(x).get('name').find(keyType) >=0]
        
        # Retrieve model run ID
        runs_df = mlflow.search_runs(experiment_ids, order_by=["tags.mlflow.runName ASC"]).\
                rename({'tags.mlflow.runName':run_name_col}, axis=1)[run_cols]
        
        model_name = kbm.get_model_name(state, keyCol)
        run_id = runs_df[runs_df[run_name_col] == model_name].sort_values(eval_metric)['run_id'].values[0]
        
        # Load model as a PyFuncModel
        model_uri = f'runs:/{run_id}/SparkML-linear-regression'
        model = mlflow.pyfunc.load_model(model_uri)
        
        # Predict on a Pandas DataFrame
        results = model.predict(pd.DataFrame(model_params_dict))
        
        return results
    
    @abstractmethod
    def _retrieve_data(self):
        pass

class KeyBrandModel(SeriesModel):
    def __init__(self):
        self._key_col = 'key_brand'
        self._period_dict = {}
        
        self.series_data = self._retrieve_data()
        self.keys_list = self._get_unique_col_values(self._key_col)
        self.period_list = self._get_period_data()
        self.states_list = self._get_unique_col_values('state')
        
    #----- PUBLIC METHODS -----#
    def get_filtered_series(self, state: str, keyBrand: str):
        df_filter = (col('state') == state) & (col(self._key_col) == keyBrand)
        df = self.series_data.where(df_filter)
        return df
    
    def get_swire_period_date(self, periodNumber: int, date_format :str='%m-%d-%Y'):
        return self._period_dict.get(periodNumber).strftime(date_format)
    
    def predict(self, model_params_dict: dict, state: str, keyCol: str, evalMetric: str='r2'):
        keyType = 'keyBrand_2'
        results = self._predict(model_params_dict, state, keyCol, keyType, evalMetric)
        return results
    
    def quick_stats(self):
        print(f"{'-'*3} Key Brand Quick Stats {'-'*3}\n{'-'*29}")
        
        print(f"• {len(self.keys_list)} unique {self._key_col.replace('_',' ')} values")
        print(f"• {len(self.states_list)} unique states")
        print(f"• {len(self.period_list)} unique periods")
        print(f"   🠆 Min: Week ending on {self.get_swire_period_date(self.period_list[0])}")
        print(f"   🠆 Max: Week ending on {self.get_swire_period_date(self.period_list[-1])}")
        print("\n")
        return None
    
    #----- PRIVATE METHODS -----#
    def _get_period_data(self):
        period_number = 'period_number'
        period_date = 'swire_period_date'
        column_data = self.series_data.select(period_number, period_date).distinct().orderBy(period_number).collect()
        
        __ = [self._period_dict.update({d.asDict().get(period_number):d.asDict().get(period_date)}) for d in column_data]
        numbers = [k for k in self._period_dict.keys()]
        return numbers
    
    def _get_unique_col_values(self, colName:str):
        column_data = self.series_data.select(colName).distinct().orderBy(colName).collect()
        values = [v.asDict().get(colName) for v in column_data]
        return values
    
    def _retrieve_data(self):
        table = f"{self.DATABASE}.state_{self._key_col.replace('_','')}_timeseries"
        data = spark.read.table(table)
        return data

class KeyPackageModel(SeriesModel):
    def __init__(self):
        self._key_col = 'key_package'
        self._period_dict = {}
        
        self.series_data = self._retrieve_data()
        self.keys_list = self._get_unique_col_values(self._key_col)
        self.period_list = self._get_period_data()
        self.states_list = self._get_unique_col_values('state')
        
    #----- PUBLIC METHODS -----#
    def get_filtered_series(self, state: str, keyPackage: str):
        df_filter = (col('state') == state) & (col(self._key_col) == keyPackage)
        df = self.series_data.where(df_filter)
        return df
    
    def get_swire_period_date(self, periodNumber: int, date_format :str='%m-%d-%Y'):
        return self._period_dict.get(periodNumber).strftime(date_format)
    
    def predict(self, model_params_dict: dict, state: str, keyCol: str, evalMetric: str='r2'):
        keyType = 'keyPackage'
        results = self._predict(model_params_dict, state, keyCol, keyType, evalMetric)
        return results
    
    def quick_stats(self):
        print(f"{'-'*3} Key Package Quick Stats {'-'*3}\n{'-'*31}")
        
        print(f"• {len(self.keys_list)} unique {self._key_col.replace('_',' ')} values")
        print(f"• {len(self.states_list)} unique states")
        print(f"• {len(self.period_list)} unique periods")
        print(f"   🠆 Min: Week ending on {self.get_swire_period_date(self.period_list[0])}")
        print(f"   🠆 Max: Week ending on {self.get_swire_period_date(self.period_list[-1])}")
        print("\n")
        return None
    
    #----- PRIVATE METHODS -----#
    def _get_period_data(self):
        period_number = 'period_number'
        period_date = 'swire_period_date'
        column_data = self.series_data.select(period_number, period_date).distinct().orderBy(period_number).collect()
        
        __ = [self._period_dict.update({d.asDict().get(period_number):d.asDict().get(period_date)}) for d in column_data]
        numbers = [k for k in self._period_dict.keys()]
        return numbers
    
    def _get_unique_col_values(self, colName:str):
        column_data = self.series_data.select(colName).distinct().orderBy(colName).collect()
        values = [v.asDict().get(colName) for v in column_data]
        return values
    
    def _retrieve_data(self):
        table = f"{self.DATABASE}.state_{self._key_col.replace('_','')}_timeseries"
        data = spark.read.table(table)
        return data