# Training a simple Sklearn model

### Installation - RUN ONCE

In [24]:
# !pip install spark mlflow pyspark mlflow --user

In [1]:
# install java on machine
# !wget https://download.oracle.com/java/19/latest/jdk-19_linux-x64_bin.deb
# !sudo dpkg -i jdk-19_linux-x64_bin.deb
# !echo Y | sudo apt-get install -f

In [2]:
# # # Set Javahome
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/jdk-19"
os.environ["PATH"] = os.environ["JAVA_HOME"]+":"+os.environ["PATH"]

## Import libraries

In [3]:
# General
import random
import string

# Preprocessing
from pyspark import *
from pyspark.sql import SparkSession, SQLContext


# Training
import mlflow
from pyspark.ml.feature import FeatureHasher, VectorAssembler
from pyspark.ml.regression import GBTRegressor
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator



## Helpers

In [4]:
def generate_uuid(length: int = 8) -> str:
    return "".join(random.choices(string.ascii_lowercase + string.digits, k=length))

## Define constants

In [5]:
UUID = generate_uuid(4)

In [6]:
# Data ingestion
FILE_PATH = "rawfiles/"
SALES_TABLE_NAME = f"sales_data_set"
FEATURES_TABLE_NAME = f"features_data_set"
STORES_TABLE_NAME = f"stores_data_set"

FINAL_DATASET_SCHEMA_NAME = "databricks_vertex"
FINAL_DATASET_TABLE_NAME = f"kaggle_retail_dataset_{UUID}"
MODEL_NAME = f"Weekly_Sales_GBTR_model_{UUID}"

## Ingest data

In [7]:
### create spark context
warehouse_location='hive/'
spark = SparkSession.builder.master("local[1]").appName("MLFLOW-TRAINING").config("spark.jars.packages", "org.mlflow:mlflow-spark:1.11.0")\
.config("spark.sql.warehouse.dir", warehouse_location)\
.config("spark.sql.catalogImplementation", "hive").getOrCreate()
sqlContext = SQLContext(spark)

:: loading settings :: url = jar:file:/home/jupyter/.local/lib/python3.7/site-packages/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/jupyter/.ivy2/cache
The jars for the packages stored in: /home/jupyter/.ivy2/jars
org.mlflow#mlflow-spark added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-cc0f1c1f-aec2-4e82-94c5-0e7ee8fc8329;1.0
	confs: [default]
	found org.mlflow#mlflow-spark;1.11.0 in central
	found org.slf4j#slf4j-api;1.7.25 in central
:: resolution report :: resolve 191ms :: artifacts dl 5ms
	:: modules in use:
	org.mlflow#mlflow-spark;1.11.0 from central in [default]
	org.slf4j#slf4j-api;1.7.25 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   2   |   0   |   0   |   0   ||   2   |   0   |
	---------------------------------------------------------------------
:: re

23/02/14 15:27:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [8]:
sales_df = spark.read.csv(FILE_PATH+SALES_TABLE_NAME+".csv", header="true", inferSchema="true")
features_df = spark.read.csv(FILE_PATH+FEATURES_TABLE_NAME+".csv", header="true", inferSchema="true")
stores_df = spark.read.csv(FILE_PATH+STORES_TABLE_NAME+".csv", header="true", inferSchema="true")

                                                                                

In [9]:
query_schema = 'CREATE SCHEMA IF NOT EXISTS {}'.format(FINAL_DATASET_SCHEMA_NAME)
query_table = 'DROP TABLE IF EXISTS {0}.{1}'.format(FINAL_DATASET_TABLE_NAME,FINAL_DATASET_TABLE_NAME)
sqlContext.sql(query_schema)
sqlContext.sql(query_table)

23/02/14 15:27:48 WARN HiveConf: HiveConf of name hive.stats.jdbc.timeout does not exist
23/02/14 15:27:48 WARN HiveConf: HiveConf of name hive.stats.retries.wait does not exist
23/02/14 15:27:50 WARN ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0
23/02/14 15:27:50 WARN ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore jupyter@10.128.0.44
23/02/14 15:27:51 WARN ObjectStore: Failed to get database global_temp, returning NoSuchObjectException
23/02/14 15:27:51 WARN ObjectStore: Failed to get database kaggle_retail_dataset_goq6, returning NoSuchObjectException


DataFrame[]

In [10]:
sales_df.createOrReplaceTempView(SALES_TABLE_NAME+UUID)
features_df.createOrReplaceTempView(FEATURES_TABLE_NAME+UUID)
stores_df.createOrReplaceTempView(STORES_TABLE_NAME+UUID)

In [14]:
query = f"CREATE TABLE {FINAL_DATASET_SCHEMA_NAME}.{FINAL_DATASET_TABLE_NAME} AS SELECT sa.Store, sa.Dept, CAST(LEFT(sa.Date, 2) AS int) AS day, CAST(SUBSTRING(sa.Date, 4, 2) AS int) AS month, CAST(RIGHT(sa.Date, 4) AS int) AS year, sa.Weekly_Sales, sa.IsHoliday, f.Temperature, f.Fuel_Price, f.MarkDown1, f.MarkDown2, f.MarkDown3, f.MarkDown4, f.MarkDown5, f.CPI, f.Unemployment, st.Type, CAST(st.Size AS decimal) AS Size FROM {SALES_TABLE_NAME+UUID} AS sa INNER JOIN {FEATURES_TABLE_NAME+UUID} AS f on (f.Store=sa.Store AND f.Date=sa.Date) INNER JOIN {STORES_TABLE_NAME+UUID} as st ON st.Store=sa.Store limit 10000"
sqlContext.sql(query)

23/02/14 15:28:15 WARN ResolveSessionCatalog: A Hive serde table will be created as there is no table provider specified. You can set spark.sql.legacy.createHiveTableByDefault to false so that native data source table will be created instead.


AnalysisException: `databricks_vertex`.`kaggle_retail_dataset_goq6` already exists.

In [15]:
preprocessed_df = sqlContext.table(FINAL_DATASET_SCHEMA_NAME+"."+FINAL_DATASET_TABLE_NAME)
preprocessed_df = preprocessed_df.withColumn("Size_double", preprocessed_df['Size'].cast("double"))

## Training-Test splitting

In [18]:
(train, test) = preprocessed_df.randomSplit([0.8, 0.2])

## Model Training

In [16]:
#set tracking server
mlflow.set_tracking_uri("http://34.123.222.224:5000")

In [19]:
# enable MLFlow autologging
mlflow.spark.autolog()

# start MLFlow run
with mlflow.start_run(run_name='Weekly_Sales_GBTR'+UUID) as run:
    
    # add hasher
    hasher = FeatureHasher(inputCols=['IsHoliday', 'day', 'month', 'year', 'Temperature', 'Fuel_Price', 'Size_double', 'Store', 'Dept', 'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5', 'CPI', 'Unemployment', 'Type'],
                       outputCol="features")
    
    # add assembler
    feat_cols=["features"]
    assembler = VectorAssembler(inputCols=feat_cols, outputCol="features_dense")
    
    # model definition with parameters
    gbtr = GBTRegressor(featuresCol='features_dense', labelCol='Weekly_Sales', maxIter=1)
    
    # define pipeline
    pipeline = Pipeline(stages=[hasher, assembler, gbtr])
    
    # training pipeline
    gbtr_model = pipeline.fit(train)
    
    # predictions on test data
    gbtr_predictions = gbtr_model.transform(test)

    # log model
    mlflow.spark.log_model(gbtr_model, MODEL_NAME)
  
    # log the rmse
    rmse=RegressionEvaluator(labelCol="Weekly_Sales", predictionCol="prediction", metricName="rmse")
    rmse=rmse.evaluate(gbtr_predictions) 
    mlflow.log_metric("RMSE", rmse)
    
    # log mae
    mae=RegressionEvaluator(labelCol="Weekly_Sales", predictionCol="prediction", metricName="mae")
    mae=mae.evaluate(gbtr_predictions) 
    mlflow.log_metric("MAE", mae)
    
    # log r^2
    r2=RegressionEvaluator(labelCol="Weekly_Sales", predictionCol="prediction", metricName="r2")
    r2=r2.evaluate(gbtr_predictions)
    mlflow.log_metric("R-Square", r2)
    
mlflow.end_run()

23/02/14 15:28:37 WARN DAGScheduler: Broadcasting large task binary with size 14.7 MiB


[Stage 10:>                                                         (0 + 0) / 1]

23/02/14 15:28:39 WARN DAGScheduler: Broadcasting large task binary with size 14.7 MiB


[Stage 11:>                                                         (0 + 1) / 1]

23/02/14 15:28:41 WARN DAGScheduler: Broadcasting large task binary with size 17.3 MiB


[Stage 11:>                                                         (0 + 1) / 1]

23/02/14 15:31:08 WARN DAGScheduler: Broadcasting large task binary with size 1033.7 KiB


[Stage 13:>                                                         (0 + 1) / 1]

23/02/14 15:31:09 WARN DAGScheduler: Broadcasting large task binary with size 18.0 MiB
23/02/14 15:31:11 WARN MemoryStore: Not enough space to cache rdd_64_0 in memory! (computed 272.8 MiB so far)
23/02/14 15:31:11 WARN BlockManager: Persisting block rdd_64_0 to disk instead.
23/02/14 15:31:44 ERROR Executor: Exception in task 0.0 in stage 13.0 (TID 13)
java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.reflect.Array.newArray(Native Method)
	at java.base/java.lang.reflect.Array.newInstance(Array.java:78)
	at java.base/java.io.ObjectInputStream.readArray(ObjectInputStream.java:2146)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1748)
	at java.base/java.io.ObjectInputStream$FieldValues.<init>(ObjectInputStream.java:2625)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2476)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2280)
	at java.base/java.io.ObjectInputStream.readObject0(O

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/pyspark/sql/utils.py", line 190, in deco
    return f(*a, **kw)
  File "/home/jupyter/.local/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)
py4j.protocol.Py4JJavaError: <unprintable Py4JJavaError object>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(com

Py4JError: py4j does not exist in the JVM