In [None]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline

from DataManipulation import DataManipulation
from Estimators.XGBoost import XGBoost
from Logging import Logging
from Transformers.FilterDepartment import FilterDepartment
from Transformers.ImputePrice import ImputePrice
from Transformers.LagFeature import LagFeature
from Transformers.LogTransformation import LogTransformation
from Transformers.MonthlyAggregate import MonthlyAggregate
from Transformers.NegativeSales import NegativeSales

import pandas as pd
import findspark

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def initialize_session(name):
    return SparkSession.builder.master("local[*]").appName(name).\
        config("spark.driver.bindAddress", "localhost").\
        config("spark.ui.port", "4050").getOrCreate()

In [None]:
findspark.init()
spark = initialize_session("Assignment")
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
log = Logging.getLogger()
log.info("Initializing session")

data = DataManipulation()
df = data.get_data()

In [None]:
pandas_df = pd.DataFrame({"Letters":["X", "Y", "Z"]})
spark_df = spark.createDataFrame(pandas_df)

# Add the spark data frame to the catalog
spark_df.createOrReplaceTempView('spark_df')

spark_df.show()

In [None]:
# df = data.filter_store(df, "WI_1")
filterDepartment = FilterDepartment(inputCol="FOODS_1", filterCol="dept_id")

In [None]:
imputePrice = ImputePrice()
negativeSales = NegativeSales(column="sales")
aggregate = MonthlyAggregate(columns=["store_id", "dept_id", "year", "month"],
                             expressions={"sales": "sum",
                                          "sell_price": "avg",
                                          "event_name_1": "count",
                                          "event_name_2": "count",
                                          "snap_WI": "sum"}
                             )
logTransformation = LogTransformation(inputCols=["sales"])
lagFeatures = LagFeature(partitionBy=["store_id", "dept_id"],
                         orderBy=["year", "month"],
                         lags=[i for i in range(1, 12)],
                         target="sales"
                         )

storeIndexer = StringIndexer(inputCol="store_id", outputCol="store_id_index")
yearIndexer = StringIndexer(inputCol="year", outputCol="year_index")

In [None]:
log.info("Initiating pipeline")
transformed = Pipeline(stages=[filterDepartment, imputePrice, negativeSales, aggregate,
                               logTransformation, lagFeatures, storeIndexer,
                               yearIndexer]).fit(df).transform(df)

In [None]:
train, test = data.train_test_split(transformed)

In [None]:
inputColumns = ["store_id_index", "month", "year_index", "event_name_1", "event_name_2", "sell_price"]
inputColumns.extend(["lag_{}".format(i) for i in range(1, 12)])

#xgbModel =
xgbModel = XGBoost(inputCols=inputColumns, labelCol="sales").fit(train)

Training XGBoost
score:                                                                                                                 
4.416647634808213                                                                                                      
score:                                                                                                                 
4.61977159898341                                                                                                       
  2%|▉                                            | 2/100 [18:45<15:19:37, 563.03s/trial, best loss: 4.416647634808213]

In [16]:
pred = xgbModel.transform(test)
print(pred.show(10))

    prediction    actual
0     3.884693  3.766264
1     3.907605  3.991448
2     3.933316  3.898725
3     3.886194  4.073572
4     3.973794  3.770557
5     4.027556  3.922154
6     3.903761  3.970393
7     3.940236  4.006380
8     4.057255  4.146748
9     3.911772  4.128529
10    3.979000  3.827434
11    4.011597  3.823279
12    4.051314  3.981139
13    4.065004  3.773933
14    3.913321  4.146686
15    3.688709  3.850830
16    3.704654  3.932879
17    3.710896  4.117006
18    3.803989  3.881670
19    3.919174  3.850830
20    3.754980  4.121429
21    3.990922  3.908860
22    3.899374  3.791269
23    3.868634  3.890533
24    3.958189  3.940865
25    3.775538  4.104282
26    3.868951  3.920019
27    3.978200  4.025019
28    3.821599  3.822756
29    3.950344  3.860518
30    3.830292  3.796921
31    3.803753  3.931814
32    3.943211  3.942157
33    3.937515  4.023705
34    3.876850  3.913920
35    3.848201  3.876564
36    3.688139  3.970440
37    3.883985  3.875061
38    3.834059  3.926651
