# 4. Train a Machine Learning model

setup Spark context

In [None]:
import os
# https://search.maven.org/artifact/com.memsql/memsql-spark-connector_2.11
args = '--packages "com.memsql:memsql-spark-connector_2.11:3.0.0-spark-2.4.4" pyspark-shell'
os.environ['PYSPARK_SUBMIT_ARGS'] = args

In [None]:
!pip install findspark
import findspark
findspark.init()

import mlflow.spark
import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
sc = SparkContext(appName="MLTraining")

spark = SparkSession(sc)

query MemSQL data

In [None]:
data = spark.read.format("memsql") \
    .option("ddlEndpoint", "memsql") \
    .option("user", "root") \
    .option("password", "") \
    .option("database", "tpch") \
    .load("lineitem") \
    .select('l_partkey','l_suppkey','l_quantity','l_discount','l_tax','l_extendedprice') \
    .limit(1000000)

assemble features vector (columns used as predictors in model)

In [None]:
from pyspark.ml.feature import VectorAssembler

feature_columns = ['l_partkey','l_suppkey','l_quantity','l_discount','l_tax']
assembler = VectorAssembler(inputCols=feature_columns,outputCol="features")
data_2 = assembler.transform(data)

separate data into 2 parts: training data and validation data

In [None]:
train, test = data_2.randomSplit([0.7, 0.3])

setup the linear regression algorithm to predict the `l_extendedprice` column

In [None]:
from pyspark.ml.regression import LinearRegression
algo = LinearRegression(featuresCol="features", labelCol="l_extendedprice")

train the model, and capture the time it takes

In [None]:
%%time
model = algo.fit(train)

validate the model against the test data

In [None]:
%%time
# evaluation
evaluation_summary = model.evaluate(test)

how well did our model do?

In [None]:
r_squared = evaluation_summary.r2
r_squared

save the trained model

In [None]:
mlflow.spark.save_model(model, "spark-model")

take our model out for a quick spin

In [None]:
%%time
# predicting values
predictions = model.transform(test)

In [None]:
predictions.select( \
                   predictions['l_partkey'], \
                   predictions['l_suppkey'], \
                   predictions['l_quantity'], \
                   predictions['l_discount'], \
                   predictions['l_tax'], \
                   predictions['prediction'] \
                ).show() 