In [1]:
import numpy as np
import pandas as pd
import pyspark
import os
import sys
import pyspark.sql.functions as F
import matplotlib.pyplot as plt
plt.style.use(style='seaborn')

from pyspark.ml.feature import VectorAssembler
from pyspark.sql import types as T
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Regression').getOrCreate()
spark

In [2]:
df = spark.read.csv('train.csv', header = True, inferSchema = True)
df.show()
df.printSchema()

+------------------+----------------+--------------------+-----------------+---------------------+-------------------+-----------------------+-----------------+---------------------+----------------+-------------------+--------+----------------+----------------+----------------+----------------+-----------------+---------+----------------+----------------+----------------+------------------+----------------------+-------------------+-----------------------+---------------------+-------------------------+-------------------+-----------------------+-----------------+---------------------+------------+----------------+----------------+-----------------+-----------------+-------------------+-------------+-----------------+----------------+----------------+---------------------+-------------------------+----------------------+--------------------------+------------------------+----------------------------+----------------------+--------------------------+--------------------+---------------

In [3]:
dfPySpark = df.select('mean_atomic_mass','mean_fie','mean_atomic_radius','mean_Density','mean_ElectronAffinity','mean_FusionHeat','mean_ThermalConductivity','mean_Valence','critical_temp')
dfPySpark.show()

+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|mean_atomic_mass|mean_fie|mean_atomic_radius|mean_Density|mean_ElectronAffinity|mean_FusionHeat|mean_ThermalConductivity|mean_Valence|critical_temp|
+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|      88.9444675| 775.425|            160.25|  4654.35725|              81.8375|         6.9055|              107.756645|        2.25|         29.0|
|       92.729214|  766.44|             161.2|   5821.4858|                90.89|         7.7844|              172.205316|         2.0|         26.0|
|      88.9444675| 775.425|            160.25|  4654.35725|              81.8375|         6.9055|              107.756645|        2.25|         19.0|
|      88.9444675| 775.425|            160.25|  4654.35725|              81.8375|         6.9055|   

In [4]:
from pyspark.sql.functions import col, count, isnan, when
dfPySpark.select([count(when(col(c).isNull(), c)).alias(c) for c in dfPySpark.columns]).show()

+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|mean_atomic_mass|mean_fie|mean_atomic_radius|mean_Density|mean_ElectronAffinity|mean_FusionHeat|mean_ThermalConductivity|mean_Valence|critical_temp|
+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+
|               0|       0|                 0|           0|                    0|              0|                       0|           0|            0|
+----------------+--------+------------------+------------+---------------------+---------------+------------------------+------------+-------------+



In [5]:
features = dfPySpark.drop('critical_temp')

In [6]:
myAssembler = VectorAssembler(
    inputCols = features.columns,
    outputCol = 'features'
)

output = myAssembler.transform(dfPySpark).select('features','critical_temp')
output.show(truncate = False)

+---------------------------------------------------------------------+-------------+
|features                                                             |critical_temp|
+---------------------------------------------------------------------+-------------+
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|29.0         |
|[92.729214,766.44,161.2,5821.4858,90.89,7.7844,172.205316,2.0]       |26.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|19.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|22.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|23.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|23.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]|11.0         |
|[76.5177175,787.05,151.75,4434.35725,79.6075,6.9055,112.006645,2.25] |33.0         |
|[76.5177175,787.05,151.75,4434.35725,79.6075,6.9055,1

In [7]:
(test, train) = output.randomSplit([0.33, 0.66], seed = 9)

# Print the dataset
print(f'Train data set size: {train.count()} records')
print(f'Test data set size: {test.count()} records')

Train data set size: 14163 records
Test data set size: 7100 records


In [8]:
train.printSchema()
test.printSchema()

root
 |-- features: vector (nullable = true)
 |-- critical_temp: double (nullable = true)

root
 |-- features: vector (nullable = true)
 |-- critical_temp: double (nullable = true)



#### Linear Regression

In [9]:
from pyspark.ml.regression import LinearRegression
linReg = LinearRegression(featuresCol = 'features', labelCol = 'critical_temp')
linearModel = linReg.fit(train)

In [10]:
print('Coefficients: ' + str(linearModel.coefficients))
print('\nIntercept: ' + str(linearModel.intercept))

Coefficients: [0.40385740677321386,0.1744324236746792,0.6008573697772123,-0.006536867290909302,-0.22923110798083426,-0.19294058886407114,0.3340680166566386,-2.9748642348922294]

Intercept: -190.203427496262


In [11]:
trainSummary = linearModel.summary
print('RMSE: %f' % trainSummary.rootMeanSquaredError)
print('\nr2: %f' % trainSummary.r2)

RMSE: 24.257118

r2: 0.497754


In [12]:
from pyspark.sql.functions import abs
predictions = linearModel.transform(test)
x = ((predictions['critical_temp']-predictions['prediction'])/predictions['critical_temp']) * 100
predictions = predictions.withColumn('Accuracy', abs(x))
predictions.select('prediction', 'critical_temp', 'Accuracy', 'features').show(truncate = False)

+------------------+-------------+------------------+------------------------------------------------------------------------------------------------------------------------------+
|prediction        |critical_temp|Accuracy          |features                                                                                                                      |
+------------------+-------------+------------------+------------------------------------------------------------------------------------------------------------------------------+
|33.02229700302459 |8.5          |288.4976118002893 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]                                                                          |
|33.02229700302459 |12.8         |157.9866953361296 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]                                                                          |
|33.02229700302459 |19.5         |69.34511283602353 |[25.5545,752.1,155.0,1558.0,102.65,53.665,

In [13]:
from pyspark.ml.evaluation import RegressionEvaluator
myPredEvaluator = RegressionEvaluator (
    predictionCol = 'prediction',
    labelCol = 'critical_temp',
    metricName = 'r2'
)

print('R Squared (R2) on test data = %g' % myPredEvaluator.evaluate(predictions))

R Squared (R2) on test data = 0.500639


In [14]:
def adjR2(x):
    r2 = trainSummary.r2
    n = dfPySpark.count()
    p = len(dfPySpark.columns)
    adjustedR2 = 1 - (1 - r2) * (n - 1) / (n-p-1)
    return adjustedR2

In [15]:
adjR2(train)

0.49754172990467405

In [16]:
adjR2(test)

0.49754172990467405

In [17]:
linReg = LinearRegression(
    featuresCol = 'features', 
    labelCol = 'critical_temp',
    maxIter = 50,
    regParam = 0.12,
    elasticNetParam = 0.2)
linearModel = linReg.fit(train)

In [18]:
linearModel.summary.rootMeanSquaredError

24.260664944000634

In [19]:
colName = features.columns
featuresRDD = features.rdd
featuresRDD.collect()

[Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_FusionHeat=6.9055, mean_ThermalConductivity=107.756645, mean_Valence=2.25),
 Row(mean_atomic_mass=92.729214, mean_fie=766.44, mean_atomic_radius=161.2, mean_Density=5821.4858, mean_ElectronAffinity=90.89, mean_FusionHeat=7.7844, mean_ThermalConductivity=172.205316, mean_Valence=2.0),
 Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_FusionHeat=6.9055, mean_ThermalConductivity=107.756645, mean_Valence=2.25),
 Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_FusionHeat=6.9055, mean_ThermalConductivity=107.756645, mean_Valence=2.25),
 Row(mean_atomic_mass=88.9444675, mean_fie=775.425, mean_atomic_radius=160.25, mean_Density=4654.35725, mean_ElectronAffinity=81.8375, mean_Fus

In [20]:
featuresRDD = features.rdd.map(lambda row: row[0:])
featuresRDD.collect()

[(88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (92.729214, 766.44, 161.2, 5821.4858, 90.89, 7.7844, 172.205316, 2.0),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (88.9444675, 775.425, 160.25, 4654.35725, 81.8375, 6.9055, 107.756645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (76.5177175, 787.05, 151.75, 4434.35725, 79.6075, 6.9055, 112.006645, 2.25),
 (111.273574, 821.54, 162.4, 6430.2858, 65.77, 5.9824, 87.865316, 2.2),
 (92.729214, 766.44, 161.2, 5821.4858, 90.89, 7.7844, 172.205316, 2.0)

In [21]:
from pyspark.mllib.feature import StandardScaler
scaler1 = StandardScaler().fit(featuresRDD)
scaledFeatures = scaler1.transform(featuresRDD)

In [22]:
for data in scaledFeatures.collect():
    print(data)

[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]
[3.1246684263445097,8.760446223279647,8.001076718941723,2.0449332921119043,3.281003614791262,0.6888734974988701,4.470834927533144,1.9145874621629795]
[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]
[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]
[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]
[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]
[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0

[3.3815067227455975,8.392512968325088,8.524720387582141,1.2424039821538957,2.7123600957902787,0.47459387583968454,0.8179829298381674,2.3932343277037242]
[3.3815067227455975,8.392512968325088,8.524720387582141,1.2424039821538957,2.7123600957902787,0.47459387583968454,0.8179829298381674,2.3932343277037242]
[3.3815067227455975,8.392512968325088,8.524720387582141,1.2424039821538957,2.7123600957902787,0.47459387583968454,0.8179829298381674,2.3932343277037242]
[3.03749669993977,8.929153736269063,7.755386087683897,1.9575088698814942,3.2976089802088433,0.5690613310809419,2.7521694569324833,2.8718811932444694]
[3.03749669993977,8.929153736269063,7.755386087683897,1.9575088698814942,3.2976089802088433,0.5690613310809419,2.7521694569324833,2.8718811932444694]
[3.03749669993977,8.929153736269063,7.755386087683897,1.9575088698814942,3.2976089802088433,0.5690613310809419,2.7521694569324833,2.8718811932444694]
[3.03749669993977,8.929153736269063,7.755386087683897,1.9575088698814942,3.2976089802088433

[3.639983391169746,8.169382096760975,8.28895665051655,1.999319353130792,1.7819826149208937,0.5920508349091418,2.722424394180475,2.324856204055048]
[3.429022558558892,8.539878948264061,7.906062160244598,1.9437177871510962,2.2339889127007133,0.719735183913719,2.577777654857985,2.735124945947116]
[2.453691081635069,9.902993851899021,6.419391577645535,1.768946140881358,3.753053241843246,0.5788104314579088,3.574364413278946,2.233685372523473]
[3.6621129762627733,8.355528597736528,8.196778347302928,1.8954427813095187,2.1532827112146413,0.6767523337537622,2.581486545609845,2.598368698649754]
[3.908858379569414,8.141194781207556,8.574354858543318,2.05659304964352,2.1165072416077697,0.8122873739199076,2.307479918526759,2.632557760474097]
[3.908858379569414,8.141194781207556,8.574354858543318,2.05659304964352,2.1165072416077697,0.8122873739199076,2.307479918526759,2.632557760474097]
[3.908858379569414,8.141194781207556,8.574354858543318,2.05659304964352,2.1165072416077697,0.8122873739199076,2.30

[2.5783945060528475,8.99601951885633,7.532030968358601,1.557671886727404,2.873720929304603,0.6110960301344287,2.907942868510549,2.153910894933352]
[2.503333642895321,9.267483166112333,7.4352437499843065,1.7477559692492994,2.3098063295855833,0.618963155915176,2.9494479411207015,2.1060462083792775]
[2.503333642895321,9.267483166112333,7.4352437499843065,1.7477559692492994,2.3098063295855833,0.618963155915176,2.9494479411207015,2.1060462083792775]
[2.503333642895321,9.267483166112333,7.4352437499843065,1.7477559692492994,2.3098063295855833,0.618963155915176,2.9494479411207015,2.1060462083792775]
[2.503333642895321,9.267483166112333,7.4352437499843065,1.7477559692492994,2.3098063295855833,0.618963155915176,2.9494479411207015,2.1060462083792775]
[2.5783945060528475,8.99601951885633,7.532030968358601,1.557671886727404,2.873720929304603,0.6110960301344287,2.907942868510549,2.153910894933352]
[2.503333642895321,9.267483166112333,7.4352437499843065,1.7477559692492994,2.3098063295855833,0.618963

[5.3780150867995165,6.5545612007981235,7.742977469943603,3.3651994716592872,0.9223197530852321,0.33008300519896017,1.6615830568327006,2.8718811932444694]
[5.3780150867995165,6.5545612007981235,7.742977469943603,3.3651994716592872,0.9223197530852321,0.33008300519896017,1.6615830568327006,2.8718811932444694]
[5.3780150867995165,6.5545612007981235,7.742977469943603,3.3651994716592872,0.9223197530852321,0.33008300519896017,1.6615830568327006,2.8718811932444694]
[4.084343206972418,8.155339465985632,6.924008699084184,2.3798774969197987,4.008751803527007,0.9185687919477765,1.1033949986779652,4.307821789866704]
[4.084343206972418,8.155339465985632,6.924008699084184,2.3798774969197987,4.008751803527007,0.9185687919477765,1.1033949986779652,4.307821789866704]
[3.934561379261668,7.23636358222215,7.46998787965713,2.567808782654425,2.5124639849210233,0.4539747497776583,1.9341865270943155,3.350528058785214]
[3.9859488454616443,7.951313103732172,7.1721810538900685,2.4601434759672047,2.438461812951367

[3.4109731382821424,7.350092562862033,8.43786006340008,2.653168226455386,3.0575531540633722,0.8583928017238373,5.361201581811761,2.3932343277037242]
[2.8195552109953534,9.052598257968635,7.991149824749488,2.831263856114181,0.05414793070950481,0.8628175068873623,1.3461418983871176,2.8718811932444694]
[2.635964389887168,8.050183021574885,8.512311769841846,3.078911625166352,3.7452318740740833,1.765457360246449,5.893427404703485,3.350528058785214]
[3.0469721274983064,7.492777665875886,8.255867003209081,2.9085440109957097,2.8626206035091544,1.2731351657182441,4.6385860336579645,3.1909791036049624]
[3.789758351172283,8.262781920057481,8.983839243973026,4.002175293713086,2.887889637840257,2.238900812743617,4.422234073132856,3.829174924325959]
[2.4370227847570836,8.714840330540639,6.204308870147118,2.0428306396069504,1.9836191949915265,1.4729843489374508,1.0255082928889323,3.829174924325959]
[2.4908439108112423,7.84958567381308,9.033473714934203,2.3486141615359077,1.7381485757751045,1.22121862

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [23]:
df2 = scaledFeatures.map(lambda x: (x, )).toDF(['ScaledData'])

In [24]:
df2.show(truncate = False)

+------------------------------------------------------------------------------------------------------------------------------------------------------+
|ScaledData                                                                                                                                            |
+------------------------------------------------------------------------------------------------------------------------------------------------------+
|[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]|
|[3.1246684263445097,8.760446223279647,8.001076718941723,2.0449332921119043,3.281003614791262,0.6888734974988701,4.470834927533144,1.9145874621629795] |
|[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542208529594003,0.6110960301344287,2.7976033686427524,2.153910894933352]|
|[2.9971349621843597,8.863145207304708,7.953923971528606,1.6349520416089325,2.9542

#### Random Forest Regressor

In [25]:
# Random Forest
from pyspark.ml.regression import RandomForestRegressor
randomForestReg = RandomForestRegressor(featuresCol = 'features', labelCol = 'critical_temp')
randomForestModel = randomForestReg.fit(train)

In [26]:
predictions = randomForestModel.transform(test)
predictions.show()

+--------------------+-------------+------------------+
|            features|critical_temp|        prediction|
+--------------------+-------------+------------------+
|[25.5545,752.1,15...|          8.5|27.358817322435044|
|[25.5545,752.1,15...|         12.8|27.358817322435044|
|[25.5545,752.1,15...|         19.5|27.358817322435044|
|[25.5545,752.1,15...|         20.0|27.358817322435044|
|[38.3441833333333...|        115.0| 41.04820118665663|
|[48.73925,744.25,...|        12.25|24.750197148464732|
|[48.73925,744.25,...|         28.4|24.750197148464732|
|[48.73925,744.25,...|         28.6|24.750197148464732|
|[48.73925,744.25,...|         29.0|24.750197148464732|
|[48.73925,744.25,...|         29.8|24.750197148464732|
|[51.3599916666667...|         40.0| 50.69599200864067|
|[51.3599916666667...|         48.9| 50.69599200864067|
|[51.3599916666667...|         60.3| 50.69599200864067|
|[53.79222,847.14,...|        117.0| 59.66989035024519|
|[55.7218,869.0333...|         91.0| 46.26441231

In [27]:
from pyspark.sql.functions import abs
predictions = randomForestModel.transform(test)
x = ((predictions['critical_temp']-predictions['prediction'])/predictions['critical_temp']) * 100
predictions = predictions.withColumn('Accuracy', abs(x))
predictions.select('prediction', 'critical_temp', 'Accuracy', 'features').show(truncate = False)

+------------------+-------------+------------------+------------------------------------------------------------------------------------------------------------------------------+
|prediction        |critical_temp|Accuracy          |features                                                                                                                      |
+------------------+-------------+------------------+------------------------------------------------------------------------------------------------------------------------------+
|27.358817322435044|8.5          |221.8684390874711 |[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]                                                                          |
|27.358817322435044|12.8         |113.74076033152376|[25.5545,752.1,155.0,1558.0,102.65,53.665,120.0,2.5]                                                                          |
|27.358817322435044|19.5         |40.30162729453869 |[25.5545,752.1,155.0,1558.0,102.65,53.665,

In [28]:
from pyspark.ml.evaluation import RegressionEvaluator
myPredEvaluator = RegressionEvaluator (
    predictionCol = 'prediction',
    labelCol = 'critical_temp',
    metricName = 'rmse'
)

print('Root Mean Square Error (rmse) on test data = %g' % myPredEvaluator.evaluate(predictions))

Root Mean Square Error (rmse) on test data = 20.3782


In [29]:
myPredEvaluator = RegressionEvaluator (
    predictionCol = 'prediction',
    labelCol = 'critical_temp',
    metricName = 'r2'
)

print('R Squared (R2) on test data = %g' % myPredEvaluator.evaluate(predictions))

R Squared (R2) on test data = 0.647117


#### Logistic Regression

In [30]:
# Import logistic Regression model
from pyspark.ml.classification import LogisticRegression

In [31]:
output.show(200, truncate=False)

+--------------------------------------------------------------------------------------------------------------------+-------------+
|features                                                                                                            |critical_temp|
+--------------------------------------------------------------------------------------------------------------------+-------------+
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                                               |29.0         |
|[92.729214,766.44,161.2,5821.4858,90.89,7.7844,172.205316,2.0]                                                      |26.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                                               |19.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25]                                               |22.0         |
|[88.9444675,775.425,160.25,4654.35725,81.8375,6.9055,107.756645,2.25

In [32]:
#output.withColumn("critical_temp",col("critical_temp").cast("Integer")).show()
#output.withColumn("critical_temp",col("critical_temp")*0).show()

df_stats = output.select(
    F.mean(col('critical_temp')).alias('mean')
).collect()
mean = df_stats[0]['mean']
print(mean)

output.withColumn("critical_temp", F.when(F.col("critical_temp")>mean, 1).otherwise(0)).show()

34.42121913535251
+--------------------+-------------+
|            features|critical_temp|
+--------------------+-------------+
|[88.9444675,775.4...|            0|
|[92.729214,766.44...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[88.9444675,775.4...|            0|
|[76.5177175,787.0...|            0|
|[76.5177175,787.0...|            1|
|[76.5177175,787.0...|            0|
|[76.5177175,787.0...|            0|
|[111.273574,821.5...|            0|
|[92.729214,766.44...|            0|
|[92.729214,766.44...|            0|
|[92.729214,766.44...|            0|
|[92.729214,766.44...|            0|
|[69.17125,753.08,...|            1|
|[88.9444675,775.4...|            0|
|[76.5177175,787.0...|            1|
|[76.5177175,787.0...|            0|
+--------------------+-------------+
only showing top 20 rows



In [33]:
data = output.withColumn("critical_temp", F.when(F.col("critical_temp")>mean, 1).otherwise(0)).select (
    F.col('features').alias('features'),
    F.col('critical_temp').alias('label')
)

In [34]:
model = LogisticRegression().fit(data)

In [35]:
model.summary.areaUnderROC

0.910305579853294

In [36]:
model.summary.pr.show()

+--------------------+-------------------+
|              recall|          precision|
+--------------------+-------------------+
|                 0.0|                0.0|
|                 0.0|                0.0|
|                 0.0|                0.0|
|9.981285090455397E-4| 0.2962962962962963|
|9.981285090455397E-4|0.20512820512820512|
|9.981285090455397E-4| 0.1568627450980392|
|9.981285090455397E-4|0.11428571428571428|
| 0.01634435433562071| 0.6616161616161617|
|0.020586400499064253| 0.7112068965517241|
| 0.02096069868995633|  0.702928870292887|
|0.023580786026200874| 0.7269230769230769|
|0.030193387398627574| 0.7096774193548387|
| 0.03505926388022458| 0.7356020942408377|
|0.053774173424828445| 0.7966728280961183|
| 0.05527136618839676| 0.8010849909584087|
|0.056269494697442295|  0.799645390070922|
|0.056269494697442295| 0.7802768166089965|
|0.059263880224578916| 0.7838283828382838|
| 0.07573300062383032| 0.8136729222520107|
| 0.07747972551466001| 0.8171052631578948|
+----------