In [1]:
import os
import findspark
findspark.init(os.getenv('SPARK_HOME'))
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import corr

In [2]:
spark = SparkSession.builder.appName("lr_exercice").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/08/16 20:10:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/08/16 20:10:12 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/08/16 20:10:12 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [28]:
data = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)

In [29]:
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [30]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [31]:
data.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [32]:
indexer = StringIndexer(inputCol="Cruise_line",outputCol="Cruise_line_enc")
indexer_model = indexer.fit(data)
data = indexer_model.transform(data)
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_enc|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|            1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|            1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|            1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|

In [39]:
encoder = OneHotEncoder(inputCol="Cruise_line_enc", outputCol="Cruise_line_vec")
encoder_model = encoder.fit(data)
data = encoder_model.transform(data)
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+---------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_enc|Cruise_line_vec|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+---------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|(19,[16],[1.0])|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|(19,[16],[1.0])|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|            1.0| (19,[1],[1.0])|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|            1.0| (19,[1],[1.0])|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0| 

In [41]:
data.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_line_enc',
 'Cruise_line_vec']

In [65]:
assembler = VectorAssembler(inputCols=['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'Cruise_line_vec'],
                           outputCol='features')
output = assembler.transform(data)
output.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- Cruise_line_enc: double (nullable = false)
 |-- Cruise_line_vec: vector (nullable = true)
 |-- features: vector (nullable = true)



In [66]:
output.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+---------------+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_enc|Cruise_line_vec|            features|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+---------------+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|(19,[16],[1.0])|(25,[0,1,2,3,4,5,...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|(19,[16],[1.0])|(25,[0,1,2,3,4,5,...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|            1.0| (19,[1],[1.0])|(25,[0,1,2,3,4,5,...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|    

In [67]:
final_data = output.select('features', 'crew')
final_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|(25,[0,1,2,3,4,5,...|3.55|
|(25,[0,1,2,3,4,5,...|3.55|
|(25,[0,1,2,3,4,5,...| 6.7|
|(25,[0,1,2,3,4,5,...|19.1|
|(25,[0,1,2,3,4,5,...|10.0|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...|11.5|
|(25,[0,1,2,3,4,5,...|11.6|
|(25,[0,1,2,3,4,5,...| 6.6|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...| 9.3|
|(25,[0,1,2,3,4,5,...|11.6|
|(25,[0,1,2,3,4,5,...|10.3|
|(25,[0,1,2,3,4,5,...| 9.2|
|(25,[0,1,2,3,4,5,...| 9.3|
|(25,[0,1,2,3,4,5,...| 9.2|
+--------------------+----+
only showing top 20 rows



In [68]:
train_data, test_data = final_data.randomSplit([0.8, 0.2])
train_data.describe().show()
test_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               123|
|   mean| 7.763577235772362|
| stddev|3.5487513689777996|
|    min|              0.59|
|    max|              21.0|
+-------+------------------+

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                35|
|   mean| 7.901714285714286|
| stddev|3.3874339461450456|
|    min|              0.88|
|    max|              13.6|
+-------+------------------+



In [69]:
lr = LinearRegression(featuresCol="features", labelCol="crew", predictionCol="prediction")
lr_model = lr.fit(train_data)

22/08/16 21:04:58 WARN Instrumentation: [5d4a7cdd] regParam is zero, which might cause numerical instability and overfitting.


In [70]:
predictions = lr_model.evaluate(test_data)

In [71]:
predictions.r2, predictions.rootMeanSquaredError, predictions.meanSquaredError, predictions.meanAbsoluteError

(0.9629073136609574,
 0.6430139801102667,
 0.4134669786172464,
 0.47834286218897426)

In [85]:
for col in ['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density']:
    data.select(corr('crew', col)).show()

+-------------------+
|    corr(crew, Age)|
+-------------------+
|-0.5306565039638852|
+-------------------+

+-------------------+
|corr(crew, Tonnage)|
+-------------------+
|  0.927568811544939|
+-------------------+

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+

+------------------+
|corr(crew, length)|
+------------------+
|0.8958566271016579|
+------------------+

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

+-----------------------------+
|corr(crew, passenger_density)|
+-----------------------------+
|         -0.15550928421699717|
+-----------------------------+

