In [1]:
import numpy as np
import pandas as pd
import pyspark
import os
import sys
import pyspark.sql.functions as F
import matplotlib.pyplot as plt
plt.style.use(style='seaborn')

from pyspark.ml.feature import VectorAssembler
from pyspark.sql import types as T
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Regression').getOrCreate()
spark

In [2]:
df = spark.read.csv('train.csv', header = True, inferSchema = True)
df.show()
df.printSchema()

+------------------+----------------+--------------------+-----------------+---------------------+-------------------+-----------------------+-----------------+---------------------+---------------+-------------------+--------+------------+-----------+-------------+-----------+---------------+---------+-------------+-----------+-----------+------------------+----------------------+-------------------+-----------------------+---------------------+-------------------------+-------------------+-----------------------+-----------------+---------------------+------------+----------------+-------------+-----------------+---------------+-------------------+-------------+-----------------+-----------+---------------+---------------------+-------------------------+----------------------+--------------------------+------------------------+----------------------------+----------------------+--------------------------+--------------------+------------------------+---------------+------------------

In [3]:
dfPySpark = df.select('mean_atomic_mass','mean_fie','mean_atomic_radius','mean_Density','mean_ElectronAffinity','mean_FusionHeat','mean_ThermalConductivity','mean_Valence','critical_temp')
dfPySpark.show()

+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|mean_atomic_mass|mean_fie|mean_atomic_radius|mean_Density|mean_ElectronAffinity|mean_FusionHeat|mean_ThermalConductivity|mean_Valence|critical_temp|
+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|      88.9444675| 775.425|            160.25|  4654.35725|              81.8375|         6.9055|              107.756645|        2.25|         29.0|
|       92.729214|  766.44|             161.2|   5821.4858|                90.89|         7.7844|              172.205316|         2.0|         26.0|
|      88.9444675| 775.425|            160.25|  4654.35725|              81.8375|         6.9055|              107.756645|        2.25|         19.0|
|      88.9444675| 775.425|            160.25|  4654.35725|              81.8375|         6.9055|   

In [4]:
from pyspark.sql.functions import col, count, isnan, when
dfPySpark.select([count(when(col(c).isNull(), c)).alias(c) for c in dfPySpark.columns]).show()

+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|mean_atomic_mass|mean_fie|mean_atomic_radius|mean_Density|mean_ElectronAffinity|mean_FusionHeat|mean_ThermalConductivity|mean_Valence|critical_temp|
+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|               0|       0|                 0|           0|                    0|              0|                       0|           0|            0|
+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+



In [5]:
features = dfPySpark.drop('critical_temp')

In [6]:
myAssembler = VectorAssembler(
    inputCols = features.columns,
    outputCol = 'features'
)

output = myAssembler.transform(dfPySpark).select('features','critical_temp')
output.show(truncate = False)

+---------------------------------------------------------------------+-------------+
|features                                                             |critical_temp|
+---------------------------------------------------------------------+-------------+
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|29.0         |
|[92.729214,766.44,161.2,5821.4858,90.89,7.7844,172.205316,2.0]       |26.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|19.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|22.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|23.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|23.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|11.0         |
|[76.5177175,787.05,151.75,4434.35725,79.6075,6.9055,112.006645,2.25] |33.0         |
|[76.5177175,787.05,151.75,4434.35725,79.6075,6.9055,1

In [7]:
(test, train) = output.randomSplit([0.33, 0.66], seed = 9)

# Print the dataset
print(f'Train data set size: {train.count()} records')
print(f'Test data set size: {test.count()} records')

Train data set size: 14121 records
Test data set size: 7142 records


In [8]:
train.printSchema()
test.printSchema()

root
 |-- features: vector (nullable = true)
 |-- critical_temp: double (nullable = true)

root
 |-- features: vector (nullable = true)
 |-- critical_temp: double (nullable = true)



#### Linear Regression

In [9]:
from pyspark.ml.regression import LinearRegression
linReg = LinearRegression(featuresCol = 'features', labelCol = 'critical_temp')
linearModel = linReg.fit(train)

In [10]:
print('Coefficients: ' + str(linearModel.coefficients))
print('\nIntercept: ' + str(linearModel.intercept))

Coefficients: [0.3946753361366382,0.1723896244790861,0.6017383814211767,-0.006415579069635275,-0.22904545855413672,-0.19465597109011842,0.3320414751183857,-3.102585487798238]

Intercept: -188.245982001181


In [11]:
trainSummary = linearModel.summary
print('RMSE: %f' % trainSummary.rootMeanSquaredError)
print('\nr2: %f' % trainSummary.r2)

RMSE: 24.241180

r2: 0.496494


In [12]:
from pyspark.sql.functions import abs
predictions = linearModel.transform(test)
x = ((predictions['critical_temp']-predictions['prediction'])/predictions['critical_temp']) * 100
predictions = predictions.withColumn('Accuracy', abs(x))
predictions.select('prediction', 'critical_temp', 'Accuracy', 'features').show(truncate = False)

+------------------+-------------+------------------+-----------------------------------------------------------------+
|prediction        |critical_temp|Accuracy          |features                                                         |
+------------------+-------------+------------------+-----------------------------------------------------------------+
|32.89874666221135 |8.5          |287.04407837895707|[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|32.89874666221135 |12.8         |157.02145829852617|[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|32.89874666221135 |19.3         |70.45982726534379 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|32.89874666221135 |19.5         |68.7115213446736  |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|51.53837900297978 |7.2          |615.8108194858303 |[35.61323386,879.6,131.2,2296.5286,134.214,5.5324,148.007096,2.4]|
|51.53837900297978 |19.7         |161.61

In [13]:
from pyspark.ml.evaluation import RegressionEvaluator
myPredEvaluator = RegressionEvaluator (
    predictionCol = 'prediction',
    labelCol = 'critical_temp',
    metricName = 'r2'
)

print('R Squared (R2) on test data = %g' % myPredEvaluator.evaluate(predictions))

R Squared (R2) on test data = 0.502938


In [14]:
def adjR2(x):
    r2 = trainSummary.r2
    n = dfPySpark.count()
    p = len(dfPySpark.columns)
    adjustedR2 = 1 - (1 - r2) * (n - 1) / (n-p-1)
    return adjustedR2

In [15]:
adjR2(train)

0.49628050290463377

In [16]:
adjR2(test)

0.49628050290463377

In [17]:
linReg = LinearRegression(
    featuresCol = 'features', 
    labelCol = 'critical_temp',
    maxIter = 50,
    regParam = 0.12,
    elasticNetParam = 0.2)
linearModel = linReg.fit(train)

In [18]:
linearModel.summary.rootMeanSquaredError

24.244527107448214

In [19]:
colName = features.columns
featuresRDD = features.rdd
featuresRDD.collect()

[Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_FusionHeat=6.9055, mean_ThermalConductivity=107.756645, mean_Valence=2.25),
 Row(mean_atomic_mass=92.729214, mean_fie=766.44, mean_atomic_radius=161.2, mean_Density=5821.4858, mean_ElectronAffinity=90.89, mean_FusionHeat=7.7844, mean_ThermalConductivity=172.205316, mean_Valence=2.0),
 Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_FusionHeat=6.9055, mean_ThermalConductivity=107.756645, mean_Valence=2.25),
 Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_FusionHeat=6.9055, mean_ThermalConductivity=107.756645, mean_Valence=2.25),
 Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_Fus

In [20]:
featuresRDD = features.rdd.map(lambda row: row[0:])
featuresRDD.collect()

[(88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (92.729214, 766.44, 161.2, 5821.4858, 90.89, 7.7844, 172.205316, 2.0),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (111.273574, 821.54, 162.4, 6430.2858, 65.77, 5.9824, 87.865316, 2.2),
 (92.729214, 766.44, 161.2, 5821.4858, 90.89, 7.7844, 172.205316, 2.0)

In [21]:
from pyspark.mllib.feature import StandardScaler
scaler1 = StandardScaler().fit(featuresRDD)
scaledFeatures = scaler1.transform(featuresRDD)

In [22]:
for data in scaledFeatures.collect():
    print(data)

[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]
[3.1246684262242628,8.760446222959084,8.001076718806639,2.0449332921173045,3.281003614797889,0.6888734974979709,4.470834927539427,1.9145874621658754]
[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]
[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]
[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]
[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]
[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.79760336

[2.57592942392758,9.075458354247576,8.164870472975757,1.4879441106223177,2.788167198789217,0.7239260118033697,2.8235656039097314,2.15391089493661]
[3.010368077032484,8.464179370890951,8.983839243821349,1.6568464052384946,2.602349549904058,0.7012626719558246,2.326354294811708,2.488963700815638]
[3.010368077032484,8.464179370890951,8.983839243821349,1.6568464052384946,2.602349549904058,0.7012626719558246,2.326354294811708,2.488963700815638]
[3.010368077032484,8.464179370890951,8.983839243821349,1.6568464052384946,2.602349549904058,0.7012626719558246,2.326354294811708,2.488963700815638]
[3.010368077032484,8.464179370890951,8.983839243821349,1.6568464052384946,2.602349549904058,0.7012626719558246,2.326354294811708,2.488963700815638]
[3.235990054604913,8.977159938823737,8.338591121336943,1.89138164644809,2.9722701632019057,0.911975981252934,2.8105844862782075,2.15391089493661]
[3.5384165815743507,8.385540638551879,9.1228157625103,1.9795964338991123,2.7496319214342084,0.851702647515476,2.315

[2.5306305300565657,8.58328047423007,7.8819539885018255,1.60513895615399,2.450013371507677,0.6780772168989841,2.40943344765346,2.2975049545990505]
[2.5306305300565657,8.58328047423007,7.8819539885018255,1.60513895615399,2.450013371507677,0.6780772168989841,2.40943344765346,2.2975049545990505]
[2.828848002335647,9.750974209078851,7.435243749858775,2.270310325626939,2.427487832332477,0.7807303766926287,2.7900398166097355,2.6804224470322255]
[2.805073633220017,8.914066072179597,7.435243749858775,1.6103378029369426,2.6388091565818645,0.721439327501472,3.338881470070558,2.488963700815638]
[2.805073633220017,8.914066072179597,7.435243749858775,1.6103378029369426,2.6388091565818645,0.721439327501472,3.338881470070558,2.488963700815638]
[3.0829879455725298,8.168484021238074,8.396498005778488,1.6686324603221707,1.8587781326464736,0.6203731619303117,3.024715420847241,2.074136417665463]
[3.0829879455725298,8.168484021238074,8.396498005778488,1.6686324603221707,1.8587781326464736,0.620373161930311

[3.5012207059896703,8.790850151450645,7.469987879531011,2.005854954351549,2.4192092154909806,0.6715522517140127,2.933847597426575,2.7123322377492256]
[3.5012207059896703,8.790850151450645,7.469987879531011,2.005854954351549,2.4192092154909806,0.6715522517140127,2.933847597426575,2.7123322377492256]
[3.5012207059896703,8.790850151450645,7.469987879531011,2.005854954351549,2.4192092154909806,0.6715522517140127,2.933847597426575,2.7123322377492256]
[3.5012207059896703,8.790850151450645,7.469987879531011,2.005854954351549,2.4192092154909806,0.6715522517140127,2.933847597426575,2.7123322377492256]
[3.5012207059896703,8.790850151450645,7.469987879531011,2.005854954351549,2.4192092154909806,0.6715522517140127,2.933847597426575,2.7123322377492256]
[3.3815067226154665,8.39251296801799,8.524720387438213,1.2424039821571764,2.7123600957957574,0.4745938758390651,0.817982929839317,2.3932343277073445]
[3.3815067226154665,8.39251296801799,8.524720387438213,1.2424039821571764,2.7123600957957574,0.47459

[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.5361562737103264,2.9091834739266798,2.074136417665463]
[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.5361562737103264,2.9091834739266798,2.074136417665463]
[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.5361562737103264,2.9091834739266798,2.074136417665463]
[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.5361562737103264,2.9091834739266798,2.074136417665463]
[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.5361562737103264,2.9091834739266798,2.074136417665463]
[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.5361562737103264,2.9091834739266798,2.074136417665463]
[3.717440630566319,9.039072700857352,7.999422238095776,2.3070133759856946,1.8285154555758873,0.53615

[2.40602768467442,8.682950495450521,7.72312368142874,1.6753936424101104,2.6751243687777793,0.8924099350198523,2.4873201534426026,2.6804224470322255]
[2.40602768467442,8.682950495450521,7.72312368142874,1.6753936424101104,2.6751243687777793,0.8924099350198523,2.4873201534426026,2.6804224470322255]
[3.6399833895855283,8.169382096135474,8.288956650376605,1.9993193530357092,1.781982614769785,0.5920508348830853,2.722424394184301,2.3248562044688317]
[2.578394505953623,8.996019518527147,7.532030968231435,1.5576718867315174,2.873720929310408,0.6110960301336311,2.9079428685146356,2.15391089493661]
[2.57592942392758,9.075458354247576,8.164870472975757,1.4879441106223177,2.788167198789217,0.7239260118033697,2.8235656039097314,2.15391089493661]
[3.0936299248868835,8.860287694904112,8.48749453421796,1.7409487995021717,2.9722701632019057,0.6597677869323418,2.797603368646684,2.15391089493661]
[3.1516724891662626,9.01487909823056,8.425451445517536,1.7890732595876144,2.9722701632019057,0.69295307565873

[3.5640235456730487,7.435817924881311,7.941515353654232,2.014377491678616,3.0683827402114705,1.6858126673007998,0.9476215871012312,3.3505280587902817]
[3.5640235456730487,7.435817924881311,7.941515353654232,2.014377491678616,3.0683827402114705,1.6858126673007998,0.9476215871012312,3.3505280587902817]
[3.5640235456730487,7.435817924881311,7.941515353654232,2.014377491678616,3.0683827402114705,1.6858126673007998,0.9476215871012312,3.3505280587902817]
[3.5640235456730487,7.435817924881311,7.941515353654232,2.014377491678616,3.0683827402114705,1.6858126673007998,0.9476215871012312,3.3505280587902817]
[3.5640235456730487,7.435817924881311,7.941515353654232,2.014377491678616,3.0683827402114705,1.6858126673007998,0.9476215871012312,3.3505280587902817]
[2.7216124693901675,7.860444219415691,8.363408356817112,1.7203616296967517,2.73627543185917,1.9114726306402772,0.9995460576273261,3.3505280587902817]
[0.9784634086777201,8.343554261492155,6.5848398157261165,1.3334339451411679,1.1491394182725945,

[2.423247574052321,7.6729914272120885,9.083108185742027,2.597140114173136,1.4890680945143902,2.4203137244449806,2.0899599386753183,4.786468655414689]
[2.423247574052321,7.6729914272120885,9.083108185742027,2.597140114173136,1.4890680945143902,2.4203137244449806,2.0899599386753183,4.786468655414689]
[2.423247574052321,7.6729914272120885,9.083108185742027,2.597140114173136,1.4890680945143902,2.4203137244449806,2.0899599386753183,4.786468655414689]
[2.423247574052321,7.6729914272120885,9.083108185742027,2.597140114173136,1.4890680945143902,2.4203137244449806,2.0899599386753183,4.786468655414689]
[2.089961946971759,8.405657523568852,7.469987879531011,2.2147789842242,3.763281184318186,3.8140958509535157,3.751542995510354,4.786468655414689]
[3.878317699636116,6.91975124390615,8.537129005178299,2.161736696100829,1.64068230050131,0.32610077055136205,2.27169558551665,2.3932343277073445]
[2.502869159410746,9.823745517002308,7.130818996289845,2.3675160679052665,3.331902669664926,1.115674657669164

[2.6437587732652417,8.397656489755283,8.288956650376605,3.5555896714222603,2.1659172283845676,1.747758539590068,2.5832424086732195,4.30782178987322]
[2.6437587732652417,8.397656489755283,8.288956650376605,3.5555896714222603,2.1659172283845676,1.747758539590068,2.5832424086732195,4.30782178987322]
[2.6437587732652417,8.397656489755283,8.288956650376605,3.5555896714222603,2.1659172283845676,1.747758539590068,2.5832424086732195,4.30782178987322]
[1.9289254217519318,7.973030195220903,7.842246411733553,2.646845304699341,3.104481360684547,1.5884691537033782,5.594861699186722,3.3505280587902817]
[1.9289254217519318,7.973030195220903,7.842246411733553,2.646845304699341,3.104481360684547,1.5884691537033782,5.594861699186722,3.3505280587902817]
[1.9289254217519318,7.973030195220903,7.842246411733553,2.646845304699341,3.104481360684547,1.5884691537033782,5.594861699186722,3.3505280587902817]
[1.9289254217519318,7.973030195220903,7.842246411733553,2.646845304699341,3.104481360684547,1.588469153703

[1.84017473238925,10.925411672427641,6.162946812563234,1.212954476291315,4.046053710153995,0.750459493737775,1.2204599291381164,2.3932343277073445]
[2.5878382586145965,7.504683965919526,9.505001188904908,1.577832434954647,1.4624453619154967,1.1466623431259901,1.6031680274931788,2.6325577604780785]
[2.754930880967606,8.47509506702232,8.375816974557198,2.284243305259939,1.6005225852250127,1.453515646216044,1.6031680274931788,3.3505280587902817]
[3.043283686692503,8.324504180602657,8.537129005178299,2.124618803411754,2.5082524790839105,1.56044602070611,1.375998468941514,2.871881193248813]
[2.947394113043062,9.786026357595485,6.3201226339620105,2.2168866248118837,4.878126912058399,1.0825778630460408,0.7141345513887655,4.786468655414689]
[2.947394113043062,9.786026357595485,6.3201226339620105,2.2168866248118837,4.878126912058399,1.0825778630460408,0.7141345513887655,4.786468655414689]
[2.2712417642036833,9.712683547256773,6.427663989363894,2.229356831622345,3.8571375975481836,0.849543391395

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [23]:
df2 = scaledFeatures.map(lambda x: (x, )).toDF(['ScaledData'])

In [24]:
df2.show(truncate = False)

+-----------------------------------------------------------------------------------------------------------------------------------------------------+
|ScaledData                                                                                                                                           |
+-----------------------------------------------------------------------------------------------------------------------------------------------------+
|[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]    |
|[3.1246684262242628,8.760446222959084,8.001076718806639,2.0449332921173045,3.281003614797889,0.6888734974979709,4.470834927539427,1.9145874621658754]|
|[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653675,0.6110960301336311,2.797603368646684,2.15391089493661]    |
|[2.997134962069021,8.863145206980386,7.953923971394317,1.63495204161325,2.9542208529653

#### Random Forest Regressor

In [25]:
# Random Forest
from pyspark.ml.regression import RandomForestRegressor
randomForestReg = RandomForestRegressor(featuresCol = 'features', labelCol = 'critical_temp')
randomForestModel = randomForestReg.fit(train)

In [26]:
predictions = randomForestModel.transform(test)
predictions.show()

+--------------------+-------------+------------------+
|            features|critical_temp|        prediction|
+--------------------+-------------+------------------+
|[25.5545,752.1,15...|          8.5|29.732315955132368|
|[25.5545,752.1,15...|         12.8|29.732315955132368|
|[25.5545,752.1,15...|         19.3|29.732315955132368|
|[25.5545,752.1,15...|         19.5|29.732315955132368|
|[35.61323386,879....|          7.2| 32.42439970111751|
|[35.61323386,879....|         19.7| 32.42439970111751|
|[35.61323386,879....|         26.0| 32.42439970111751|
|[35.61323386,879....|         26.0| 32.42439970111751|
|[35.61323386,879....|         27.5| 32.42439970111751|
|[35.61323386,879....|         28.5| 32.42439970111751|
|[37.99742,898.6,1...|         12.5|31.867813709169354|
|[37.99742,898.6,1...|         25.0|31.867813709169354|
|[37.99742,898.6,1...|         32.0|31.867813709169354|
|[38.83494,864.18,...|         18.0|33.778266400722245|
|[43.61088,799.34,...|        110.0|49.341867192

In [27]:
from pyspark.sql.functions import abs
predictions = randomForestModel.transform(test)
x = ((predictions['critical_temp']-predictions['prediction'])/predictions['critical_temp']) * 100
predictions = predictions.withColumn('Accuracy', abs(x))
predictions.select('prediction', 'critical_temp', 'Accuracy', 'features').show(truncate = False)

+------------------+-------------+------------------+-----------------------------------------------------------------+
|prediction        |critical_temp|Accuracy          |features                                                         |
+------------------+-------------+------------------+-----------------------------------------------------------------+
|29.732315955132368|8.5          |249.791952413322  |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|29.732315955132368|12.8         |132.2837183994716 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|29.732315955132368|19.3         |54.05345054472729 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|29.732315955132368|19.5         |52.47341515452496 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]             |
|32.42439970111751 |7.2          |350.3388847377432 |[35.61323386,879.6,131.2,2296.5286,134.214,5.5324,148.007096,2.4]|
|32.42439970111751 |19.7         |64.590

In [28]:
from pyspark.ml.evaluation import RegressionEvaluator
myPredEvaluator = RegressionEvaluator (
    predictionCol = 'prediction',
    labelCol = 'critical_temp',
    metricName = 'rmse'
)

print('Root Mean Square Error (rmse) on test data = %g' % myPredEvaluator.evaluate(predictions))

Root Mean Square Error (rmse) on test data = 20.411


In [29]:
myPredEvaluator = RegressionEvaluator (
    predictionCol = 'prediction',
    labelCol = 'critical_temp',
    metricName = 'r2'
)

print('R Squared (R2) on test data = %g' % myPredEvaluator.evaluate(predictions))

R Squared (R2) on test data = 0.648556


#### Logistic Regression

In [30]:
# Import logistic Regression model
from pyspark.ml.classification import LogisticRegression

In [31]:
output.show(200, truncate=False)

+--------------------------------------------------------------------------------------+-------------+
|features                                                                              |critical_temp|
+--------------------------------------------------------------------------------------+-------------+
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                 |29.0         |
|[92.729214,766.44,161.2,5821.4858,90.89,7.7844,172.205316,2.0]                        |26.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                 |19.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                 |22.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                 |23.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                 |23.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]   

In [32]:
#output.withColumn("critical_temp",col("critical_temp").cast("Integer")).show()
#output.withColumn("critical_temp",col("critical_temp")*0).show()

df_stats = output.select(
    F.mean(col('critical_temp')).alias('mean')
).collect()
mean = df_stats[0]['mean']
print(mean)

output.withColumn("critical_temp", F.when(F.col("critical_temp")>mean, 1).otherwise(0)).show()

34.42121913535249
+--------------------+-------------+
|            features|critical_temp|
+--------------------+-------------+
|[88.9444675,775.4...|            0|
|[92.729214,766.44...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[76.5177175,787.0...|            0|
|[76.5177175,787.0...|            1|
|[76.5177175,787.0...|            0|
|[76.5177175,787.0...|            0|
|[111.273574,821.5...|            0|
|[92.729214,766.44...|            0|
|[92.729214,766.44...|            0|
|[92.729214,766.44...|            0|
|[92.729214,766.44...|            0|
|[69.17125,753.08,...|            1|
|[88.9444675,775.4...|            0|
|[76.5177175,787.0...|            1|
|[76.5177175,787.0...|            0|
+--------------------+-------------+
only showing top 20 rows



In [33]:
data = output.withColumn("critical_temp", F.when(F.col("critical_temp")>mean, 1).otherwise(0)).select (
    F.col('features').alias('features'),
    F.col('critical_temp').alias('label')
)

In [34]:
model = LogisticRegression().fit(data)

In [35]:
model.summary.areaUnderROC

0.910304468561362

In [36]:
model.summary.pr.show()

+--------------------+-------------------+
|              recall|          precision|
+--------------------+-------------------+
|                 0.0|                0.0|
|                 0.0|                0.0|
|                 0.0|                0.0|
|9.981285090455397E-4| 0.2962962962962963|
|9.981285090455397E-4|0.20512820512820512|
|9.981285090455397E-4| 0.1568627450980392|
|9.981285090455397E-4|0.11428571428571428|
| 0.01634435433562071| 0.6616161616161617|
|0.020586400499064253| 0.7112068965517241|
| 0.02096069868995633|  0.702928870292887|
|0.023580786026200874| 0.7269230769230769|
|0.030193387398627574| 0.7096774193548387|
| 0.03505926388022458| 0.7356020942408377|
|0.053774173424828445| 0.7966728280961183|
| 0.05527136618839676| 0.8010849909584087|
|0.056269494697442295|  0.799645390070922|
|0.056269494697442295| 0.7802768166089965|
|0.059263880224578916| 0.7838283828382838|
| 0.07573300062383032| 0.8136729222520107|
| 0.07747972551466001| 0.8171052631578948|
+----------

