# Random Forest による製造品質の予測
Spark Machie Learning Library というSpark対応の機械学習ライブラリを使用して、製造工程データから品質を予測するモデルを作成します。

In [2]:
factory = spark.table("factory_csv")
display(factory)

ID,Quality,ProcessA-Pressure,ProcessA-Humidity,ProcessA-Vibration,ProcessB-Light,ProcessB-Skill,ProcessB-Temp,ProcessB-Rotation,ProcessC-Density,ProcessC-PH,ProcessC-skewness,ProcessC-Time
1,0,7.0,0.27,0.36,20.7,0.045,45.0,170.0,1.001,3.0,0.45,8.8
2,0,6.3,0.3,0.34,1.6,0.049,14.0,132.0,0.994,3.3,0.49,9.5
3,0,8.1,0.28,0.4,6.9,0.05,30.0,97.0,0.9951,3.26,0.44,10.1
4,0,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.9956,3.19,0.4,9.9
5,0,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.9956,3.19,0.4,9.9
6,0,8.1,0.28,0.4,6.9,0.05,30.0,97.0,0.9951,3.26,0.44,10.1
7,0,6.2,0.32,0.16,7.0,0.045,30.0,136.0,0.9949,3.18,0.47,9.6
8,0,7.0,0.27,0.36,20.7,0.045,45.0,170.0,1.001,3.0,0.45,8.8
9,0,6.3,0.3,0.34,1.6,0.049,14.0,132.0,0.994,3.3,0.49,9.5
10,0,8.1,0.22,0.43,1.5,0.044,28.0,129.0,0.9938,3.22,0.45,11.0


In [3]:
# 説明変数の指定
nonFeatureCols = ['ID','Quality']
featureCols = factory.columns
for i in range(len(nonFeatureCols)):
  featureCols.remove(nonFeatureCols[i])
print(featureCols)

In [4]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=featureCols,outputCol="features")
dataset = assembler.transform(factory)
train, test = dataset.select("Quality","features").randomSplit([0.85, 0.15], seed=1)

In [5]:
from pyspark.ml.classification import RandomForestClassifier
# ランダムフォレストのモデル学習
model = RandomForestClassifier(labelCol="Quality", featuresCol="features", numTrees=10).fit(train)

In [6]:
predictions = model.transform(test)
display(predictions)

Quality,features,rawPrediction,probability,prediction
0,"List(1, 11, List(), List(4.6, 0.445, 0.0, 1.4, 0.053, 11.0, 178.0, 0.99426, 3.79, 0.55, 10.2))","List(1, 2, List(), List(9.38650240145145, 0.6134975985485513))","List(1, 2, List(), List(0.9386502401451449, 0.06134975985485512))",0.0
0,"List(1, 11, List(), List(4.8, 0.29, 0.23, 1.1, 0.044, 38.0, 180.0, 0.98924, 3.28, 0.34, 11.9))","List(1, 2, List(), List(6.472928137876985, 3.527071862123014))","List(1, 2, List(), List(0.6472928137876985, 0.3527071862123014))",0.0
0,"List(1, 11, List(), List(5.0, 0.2, 0.4, 1.9, 0.015, 20.0, 98.0, 0.9897, 3.37, 0.55, 12.05))","List(1, 2, List(), List(5.779829980267732, 4.220170019732268))","List(1, 2, List(), List(0.5779829980267732, 0.4220170019732268))",0.0
0,"List(1, 11, List(), List(5.0, 0.35, 0.25, 7.8, 0.031, 24.0, 116.0, 0.99241, 3.39, 0.4, 11.3))","List(1, 2, List(), List(7.043476990922786, 2.9565230090772148))","List(1, 2, List(), List(0.7043476990922786, 0.2956523009077215))",0.0
0,"List(1, 11, List(), List(5.1, 0.11, 0.32, 1.6, 0.028, 12.0, 90.0, 0.99008, 3.57, 0.52, 12.2))","List(1, 2, List(), List(5.969762107163133, 4.030237892836867))","List(1, 2, List(), List(0.5969762107163132, 0.4030237892836867))",0.0
0,"List(1, 11, List(), List(5.1, 0.165, 0.22, 5.7, 0.047, 42.0, 146.0, 0.9934, 3.18, 0.55, 9.9))","List(1, 2, List(), List(8.731361984971123, 1.268638015028877))","List(1, 2, List(), List(0.8731361984971123, 0.1268638015028877))",0.0
0,"List(1, 11, List(), List(5.1, 0.29, 0.28, 8.3, 0.026, 27.0, 107.0, 0.99308, 3.36, 0.37, 11.0))","List(1, 2, List(), List(6.644176880700185, 3.3558231192998154))","List(1, 2, List(), List(0.6644176880700184, 0.33558231192998156))",0.0
0,"List(1, 11, List(), List(5.1, 0.35, 0.26, 6.8, 0.034, 36.0, 120.0, 0.99188, 3.38, 0.4, 11.5))","List(1, 2, List(), List(6.136981869583936, 3.8630181304160645))","List(1, 2, List(), List(0.6136981869583936, 0.38630181304160643))",0.0
0,"List(1, 11, List(), List(5.3, 0.36, 0.27, 6.3, 0.028, 40.0, 132.0, 0.99186, 3.37, 0.4, 11.6))","List(1, 2, List(), List(5.466253944367265, 4.533746055632735))","List(1, 2, List(), List(0.5466253944367265, 0.45337460556327347))",0.0
0,"List(1, 11, List(), List(5.4, 0.15, 0.32, 2.5, 0.037, 10.0, 51.0, 0.98878, 3.04, 0.58, 12.6))","List(1, 2, List(), List(5.075680311991408, 4.924319688008592))","List(1, 2, List(), List(0.5075680311991408, 0.4924319688008592))",0.0


In [7]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
areaUnderROC = BinaryClassificationEvaluator(labelCol="Quality").evaluate(predictions)

In [8]:
print("areaUnderROC: ", areaUnderROC)