In [1]:
import findspark

In [2]:
findspark.init('/home/gerardo-rodriguez/spark-4.0.0-bin-hadoop3')

In [3]:
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.appName('logreg').getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/08/19 10:50:52 WARN Utils: Your hostname, Lanz-Lenovo, resolves to a loopback address: 127.0.1.1; using 192.168.1.145 instead (on interface wlp2s0)
25/08/19 10:50:52 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/19 10:50:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/08/19 10:50:54 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [12]:
my_data = spark.read.csv('titanic.csv', header=True, inferSchema=True)

In [13]:
my_data.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [14]:
my_data.columns

['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked']

In [34]:
my_cols = my_data.select(['Survived',
                    'Pclass',
                    'Sex',
                    'Age',
                    'SibSp',
                    'Parch',                
                    'Fare',
                    'Embarked'])

In [35]:
my_cols.printSchema()

root
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Embarked: string (nullable = true)



In [36]:
my_final_data = my_cols.na.drop()

In [37]:
from pyspark.ml.feature import (VectorAssembler, VectorIndexer,
                                OneHotEncoder, StringIndexer)

In [38]:
gender_indexer = StringIndexer(inputCol='Sex', outputCol='SexIndex')
gender_encoder = OneHotEncoder(inputCol='SexIndex', outputCol='SexVec')

In [39]:
embark_indexer = StringIndexer(inputCol='Embarked', outputCol='EmbarkIndex')
embark_encoder = OneHotEncoder(inputCol='EmbarkIndex', outputCol='EmbarkVec')

In [68]:
assembler = VectorAssembler(inputCols=['Pclass', 'SexVec', 'EmbarkVec', 'Age', 'Sibsp', 'Parch', 'Fare'],
                           outputCol='features')

In [51]:
from pyspark.ml.classification import LogisticRegression

In [52]:
from pyspark.ml import Pipeline

In [53]:
log_reg_titanic = LogisticRegression(featuresCol='features', labelCol='Survived')

In [54]:
pipeline = Pipeline(stages=[gender_indexer, embark_indexer, gender_encoder, 
                            embark_encoder, assembler, log_reg_titanic])

In [55]:
train_data, test_data = my_final_data.randomSplit([0.7,0.3])

In [56]:
fit_model = pipeline.fit(train_data)

25/08/19 11:10:17 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [71]:
results = fit_model.transform(test_data)

In [59]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [60]:
my_eval = BinaryClassificationEvaluator(rawPredictionCol='prediction', labelCol='Survived')

In [61]:
results.select('Survived', 'prediction').show()

+--------+----------+
|Survived|prediction|
+--------+----------+
|       0|       1.0|
|       0|       1.0|
|       0|       0.0|
|       0|       1.0|
|       0|       1.0|
|       0|       1.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       0.0|
|       0|       1.0|
|       0|       0.0|
+--------+----------+
only showing top 20 rows


In [63]:
AUC = my_eval.evaluate(results)

In [64]:
AUC

0.8280430457587268