# Example Training and Exporting with MLeap

In [3]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, Tokenizer, HashingTF
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import pyspark, os

In [4]:
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages ml.combust.mleap:mleap-spark_2.11:0.15.0 pyspark-shell'
MAX_MEMORY = "10g"

spark = pyspark.sql.SparkSession.builder\
    .appName('mleap-example')\
    .config("spark.executor.memory", MAX_MEMORY) \
    .config("spark.driver.memory", MAX_MEMORY) \
    .getOrCreate()

df = spark.read.format("csv").option("header","true").load("./data/20news_small_1.csv").select("text", "topic")
df.cache()
display(df)

DataFrame[text: string, topic: string]

In [5]:
df.printSchema()

root
 |-- text: string (nullable = true)
 |-- topic: string (nullable = true)



In [6]:
labelIndexer = StringIndexer(inputCol="topic", outputCol="label", handleInvalid="keep")

In [7]:
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol="words", outputCol="features")

In [8]:
dt = DecisionTreeClassifier()
pipeline = Pipeline(stages=[labelIndexer, tokenizer, hashingTF, dt])

In [9]:
paramGrid = ParamGridBuilder().addGrid(hashingTF.numFeatures, [1000, 2000]).build()
cv = CrossValidator(estimator=pipeline, evaluator=MulticlassClassificationEvaluator(), estimatorParamMaps=paramGrid)

In [10]:
cvModel = cv.fit(df)

In [11]:
model = cvModel.bestModel

In [12]:
sparkTransformed = model.transform(df)
display(sparkTransformed)

DataFrame[text: string, topic: string, label: double, words: array<string>, features: vector, rawPrediction: vector, probability: vector, prediction: double]

# Export Trained Model with MLeap

In [13]:
import mleap.pyspark
from mleap.pyspark.spark_support import SimpleSparkSerializer

model.serializeToBundle("jar:file:/home/jovyan/mleap_python_model_export/20news_pipeline-json.zip", sparkTransformed)