In [1]:
from pyspark import SparkConf, SparkContext
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
from pyspark.sql import SparkSession
from pyspark.mllib.classification import LabeledPoint
from pyspark.mllib.evaluation import BinaryClassificationMetrics


In [2]:
conf = SparkConf().setMaster("local[*]").setAppName("SUSY")
conf.set("spark.executor.memory", "4G")
conf.set("spark.driver.memory", "20G")
conf.set("spark.executor.cores", "4")
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.default.parallelism", "4")
# create a Spark Session instead of a Spark Context
spark = SparkSession.builder \
    .config(conf = conf) \
  .appName("spark session example") \
  .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/11/30 11:48:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/11/30 11:48:43 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
df = spark.read.option("delimiter", ",").option("header", "false").csv('/work/li.baol/data/SUSY.csv')
df.show()

                                                                                

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                 _c0|                 _c1|                 _c2|                 _c3|                 _c4|                 _c5|                 _c6|                 _c7|                 _c8|                 _c9|                _c10|                _c11|                _c12|                _c13|                _c14|                _c15|                _c16|                _c17|                _c18|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------

In [4]:
data = df.rdd.map(lambda row: LabeledPoint(float(row['_c0']), list(map(lambda x: float(x), row.asDict().values()))[1:]))
data.take(2)

                                                                                

[LabeledPoint(0.0, [0.9728614687919617,0.6538545489311218,1.1762245893478394,1.1571564674377441,-1.7398731708526611,-0.8743090629577637,0.5677649974822998,-0.17500004172325134,0.8100607395172119,-0.2525521218776703,1.9218870401382446,0.8896374106407166,0.41077184677124023,1.1456208229064941,1.9326320886611938,0.994464099407196,1.3678154945373535,0.04071449860930443]),
 LabeledPoint(1.0, [1.6679730415344238,0.06419061869382858,-1.225171446800232,0.5061022043228149,-0.33893898129463196,1.6725428104400635,3.475464344024658,-1.219136357307434,0.012954562902450562,3.7751736640930176,1.0459771156311035,0.568051278591156,0.48192843794822693,0.0,0.4484102725982666,0.20535576343536377,1.3218934535980225,0.3775840103626251])]

In [5]:
(trainingData, testData) = data.randomSplit([0.7, 0.3])

In [6]:
testData.collect()

                                                                                

[LabeledPoint(0.0, [0.43781763315200806,-1.1198827028274536,-1.336822509765625,0.5023199319839478,-1.7175148725509644,1.0170669555664062,0.21561898291110992,-0.46120038628578186,0.3236706852912903,0.17362567782402039,0.411898136138916,0.3705247640609741,0.7982602119445801,0.6713690757751465,0.3859100937843323,0.5155220031738281,0.47910958528518677,0.029057899489998817]),
 LabeledPoint(0.0, [0.8401151299476624,0.9582490921020508,-0.5856583714485168,1.317929744720459,-1.0637295246124268,1.2462345361709595,0.2618653178215027,-0.027341315522789955,0.3305169641971588,0.18318995833396912,1.5572975873947144,0.5310215950012207,0.3025917410850525,0.42670151591300964,1.5643757581710815,0.3385521471500397,1.137455701828003,0.01104000024497509]),
 LabeledPoint(0.0, [0.6983363032341003,1.6894314289093018,-1.1346700191497803,0.9665942788124084,1.5033674240112305,0.8809488415718079,0.24257320165634155,-0.22865428030490875,0.364132285118103,0.10934973508119583,0.6685540080070496,0.48994091153144836,0.

In [None]:
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                    impurity='gini', maxDepth=8)

In [None]:
predictions = model.predict(testData.map(lambda x: x.features))
predictions

In [None]:
# predsAndLabels = testData.map(lambda lp: (model.predict(lp.features), lp.label))
predsAndLabels = predictions.zip(testData.map(lambda lp: lp.label))
metrics = BinaryClassificationMetrics(predsAndLabels)
print(f'AUC = {metrics.areaUnderROC}')


In [None]:
predsAndLabels.collect()