In [1]:
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pandas as pd
import numpy as np
import os
import pyspark
from pyspark.sql.functions import col
from pyspark.sql.functions import when


In [2]:
spark = pyspark.sql.SparkSession.builder.master(
    'local[2]').appName('H1B-3').getOrCreate()
training_data = spark.read.csv(
    '../DATA/training_downsampling.csv', header=True, inferSchema=True)
test_data = spark.read.csv(
    '../DATA/test_downsampling.csv', header=True, inferSchema=True)
training_data = training_data.drop(col('_c0'))
test_data = test_data.drop(col('_c0'))
cols = training_data.columns
cols.remove('EMPLOYER_NAME')
training_data = training_data.drop(col('EMPLOYER_NAME'))
training_data = training_data.withColumn(
    "CASE_DURATION", col("CASE_DURATION").cast("int"))
training_data = training_data.withColumn(
    "FULL_TIME_POSITION", col("FULL_TIME_POSITION").cast("int"))
test_data = test_data.drop(col('EMPLOYER_NAME'))
test_data = test_data.withColumn(
    "CASE_DURATION", col("CASE_DURATION").cast("int"))
test_data = test_data.withColumn(
    "FULL_TIME_POSITION", col("FULL_TIME_POSITION").cast("int"))

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/11/28 17:46:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
21/11/28 17:47:01 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
from pyspark.ml.feature import VectorAssembler
try:
  cols.remove('CASE_STATUS')
except:
  pass
assembler = VectorAssembler(
    inputCols=cols, outputCol="features", handleInvalid='skip')


In [4]:
len(cols)

420

In [5]:
training_data = assembler.transform(training_data)
test_data = assembler.transform(test_data)


In [6]:
training_data = training_data.withColumn("CASE_STATUS",
                                         when(training_data.CASE_STATUS ==
                                              'CERTIFIED', 1)
                                         .otherwise(when(training_data.CASE_STATUS == 'DENIED', 0)
                                                    .otherwise(when(training_data.CASE_STATUS == 'WITHDRAWN', 3)
                                                    .otherwise(2))))
test_data = test_data.withColumn("CASE_STATUS",
                                 when(test_data.CASE_STATUS ==
                                      'CERTIFIED', 1)
                                 .otherwise(when(test_data.CASE_STATUS == 'DENIED', 0)
                                            .otherwise(when(test_data.CASE_STATUS == 'WITHDRAWN', 3)
                                                       .otherwise(2))))


In [7]:
layers = [420, 128, 32, 4]
trainer = MultilayerPerceptronClassifier(
    maxIter=100, layers=layers, blockSize=128, seed=0)
trainer.setFeaturesCol("features")
trainer.setLabelCol('CASE_STATUS')
trainer.setMaxIter(1000)
model = trainer.fit(training_data)


21/11/28 17:47:52 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
21/11/28 17:48:03 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
21/11/28 17:48:03 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
21/11/28 17:48:09 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
21/11/28 17:48:09 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
21/11/28 18:10:58 WARN BlockManager: Asked to remove block broadcast_693_piece0, which does not exist
21/11/28 18:10:58 WARN BlockManager: Asked to remove block broadcast_693, which does not exist
21/11/28 18:39:26 WARN BlockManager: Asked to remove block broadcast_1359_piece0, which does not exist
21/11/28 19:01:45 WARN BlockManager: Asked to remove block broa

In [8]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
predict_train = model.transform(training_data)
predict_test = model.transform(test_data)
evaluator = MulticlassClassificationEvaluator(
    labelCol='CASE_STATUS', probabilityCol='probability', metricName='f1')


In [9]:
evaluator.evaluate(predict_test)




0.7164745177639129

In [10]:
evaluator.evaluate(predict_train)



0.7502021850661849

In [11]:
test_prediction = np.array(predict_test.select('prediction').collect())
train_prediction = np.array(predict_train.select('prediction').collect())
test_label = np.array(predict_test.select('CASE_STATUS').collect())
train_label = np.array(predict_train.select('CASE_STATUS').collect())



In [12]:
from sklearn.metrics import f1_score
f1_score(train_label, train_prediction, average=None), f1_score(
    test_label, test_prediction, average=None)


(array([0.7964374 , 0.68176559, 0.81987425, 0.70273246]),
 array([0.15746584, 0.7619214 , 0.50339339, 0.19688664]))

In [14]:
filename = '../saved_models/nn_pyspark_model.sav'
model.save(filename)


Py4JJavaError: An error occurred while calling o395.save.
: java.io.IOException: Path ../saved_models/nn_pyspark_model.sav already exists. To overwrite it, please use write.overwrite().save(path) for Scala and use write().overwrite().save(path) for Java and Python.
	at org.apache.spark.ml.util.FileSystemOverwrite.handleOverwrite(ReadWrite.scala:683)
	at org.apache.spark.ml.util.MLWriter.save(ReadWrite.scala:167)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:748)


In [None]:
_