## Determine how many crew members should be staffed for new cruise ships based on current ship information

In [1]:
import findspark
findspark.init('/home/nick/spark-3.0.1-bin-hadoop2.7')

In [2]:
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.sql.functions import isnan, count, countDistinct, col, when, mean, corr
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator


In [3]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [4]:
spark = SparkSession.builder.appName('Linear_Regression').getOrCreate()

In [5]:
data = spark.read.csv('Linear_Regression/cruise_ship_info.csv', inferSchema=True, header=True)

In [6]:
data.show(2)

+---------+-----------+---+------------------+----------+------+------+-----------------+----+
|Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|    Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 2 rows



In [7]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [8]:
data.describe().show()

+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|
|   mean| Infinity|       null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|
| stddev|      NaN|       null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|
|    min|Adventure|    Azamara|   

In [9]:
data.select([count(when(isnan(c), c)).alias(c) for c in data.columns]).show()

+---------+-----------+---+-------+----------+------+------+-----------------+----+
|Ship_name|Cruise_line|Age|Tonnage|passengers|length|cabins|passenger_density|crew|
+---------+-----------+---+-------+----------+------+------+-----------------+----+
|        0|          0|  0|      0|         0|     0|     0|                0|   0|
+---------+-----------+---+-------+----------+------+------+-----------------+----+



In [10]:
data.select(countDistinct('Cruise_line')).show()

+---------------------------+
|count(DISTINCT Cruise_line)|
+---------------------------+
|                         20|
+---------------------------+



In [11]:
data.groupby('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [12]:
string_indexer = StringIndexer(inputCol='Cruise_line', outputCol='ix_cruise_line')

In [13]:
model = string_indexer.fit(data)

In [14]:
data_ix = model.transform(data)

In [15]:
data_ix.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|ix_cruise_line|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|          16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|          16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|           1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|           1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|           1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|        

In [16]:
data_ix.groupby('Cruise_line', 'ix_cruise_line').count().show()

+-----------------+--------------+-----+
|      Cruise_line|ix_cruise_line|count|
+-----------------+--------------+-----+
|            Costa|           5.0|   11|
|        Norwegian|           4.0|   13|
|              MSC|           7.0|    8|
|           Orient|          19.0|    1|
|Regent_Seven_Seas|          10.0|    5|
|           Disney|          18.0|    2|
|         Windstar|          15.0|    3|
|              P&O|           8.0|    6|
|  Royal_Caribbean|           0.0|   23|
|         Seabourn|          14.0|    3|
|             Star|           9.0|    6|
|         Princess|           2.0|   17|
|          Oceania|          13.0|    3|
|          Azamara|          16.0|    2|
| Holland_American|           3.0|   14|
|           Cunard|          12.0|    3|
|        Celebrity|           6.0|   10|
|        Silversea|          11.0|    4|
|          Crystal|          17.0|    2|
|         Carnival|           1.0|   22|
+-----------------+--------------+-----+



In [17]:
data_ix.show(2)

+---------+-----------+---+------------------+----------+------+------+-----------------+----+--------------+
|Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|ix_cruise_line|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+--------------+
|  Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|          16.0|
|    Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|          16.0|
+---------+-----------+---+------------------+----------+------+------+-----------------+----+--------------+
only showing top 2 rows



In [18]:
data_ix.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'ix_cruise_line']

In [19]:
assembler = VectorAssembler(inputCols=['Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'ix_cruise_line'],
outputCol='features')

In [20]:
output = assembler.transform(data_ix)
output.select('features').show(truncate=False)

+--------------------------------------------------+
|features                                          |
+--------------------------------------------------+
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|
|[26.0,47.262,14.86,7.22,7.43,31.8,1.0]            |
|[11.0,110.0,29.74,9.53,14.88,36.99,1.0]           |
|[17.0,101.353,26.42,8.92,13.21,38.36,1.0]         |
|[22.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[15.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[23.0,70.367,20.56,8.55,10.22,34.23,1.0]          |
|[19.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[6.0,110.23899999999999,37.0,9.51,14.87,29.79,1.0]|
|[10.0,110.0,29.74,9.51,14.87,36.99,1.0]           |
|[28.0,46.052,14.52,7.27,7.26,31.72,1.0]           |
|[18.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[17.0,70.367,20.52,8.55,10.2,34.29,1.0]           |
|[11.0,86.0,21.24,9.63,10.62,40.49,1.0]            |
|[8.0,110.0,29.74,9.51,14.87,36.99,1.0]       

In [21]:
final_data = output.select('features','crew')
final_data.show(2, truncate=False)

+--------------------------------------------------+----+
|features                                          |crew|
+--------------------------------------------------+----+
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|3.55|
|[6.0,30.276999999999997,6.94,5.94,3.55,42.64,16.0]|3.55|
+--------------------------------------------------+----+
only showing top 2 rows



In [22]:
train_data, test_data = final_data.randomSplit([0.7,0.3])

In [23]:
lr = LinearRegression(featuresCol='features', labelCol='crew', predictionCol='prediction')
lr_model = lr.fit(train_data)

In [24]:
print("Coefficients: %s" % str(lr_model.coefficients))
print("Intercept: %s" % str(lr_model.intercept))

Coefficients: [-0.02225861872396819,0.008326593574292208,-0.1488188399787325,0.4003307967041108,0.871540220715352,-0.006137370413617217,0.05989035381810061]
Intercept: -0.6642817781497415


In [25]:
trainingSummary = lr_model.summary
print("numIterations: %d" % trainingSummary.totalIterations)
print("objectiveHistory: %s" % str(trainingSummary.objectiveHistory))
trainingSummary.residuals.show(5)
print("RMSE: %f" % trainingSummary.rootMeanSquaredError)
print("r2: %f" % trainingSummary.r2)

numIterations: 1
objectiveHistory: [0.0]
+--------------------+
|           residuals|
+--------------------+
|-0.05577648015344394|
|   0.430995721122021|
| -1.0399489363177796|
| -1.2450234602111312|
| 0.24220817278043327|
+--------------------+
only showing top 5 rows

RMSE: 1.019509
r2: 0.922510


In [26]:
final_data.describe().show()  #RMSE of ~1 on 7.79 mean and 3.5 std

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|              158|
|   mean|7.794177215189873|
| stddev|3.503486564627034|
|    min|             0.59|
|    max|             21.0|
+-------+-----------------+



In [28]:
lr_predictions = lr_model.transform(test_data)
lr_predictions.select("prediction","crew","features").show(5)

lr_evaluator1 = RegressionEvaluator(predictionCol="prediction", \
                 labelCol="crew",metricName='r2')
lr_evaluator2 = RegressionEvaluator(predictionCol="prediction", \
                 labelCol="crew",metricName='rmse')
print(f"R Squared (R2) on test data = {lr_evaluator1.evaluate(lr_predictions)}")
print(f"RMSE on test data = {lr_evaluator2.evaluate(lr_predictions)}")

+------------------+-----+--------------------+
|        prediction| crew|            features|
+------------------+-----+--------------------+
| 9.392437827978158|  8.0|[5.0,86.0,21.04,9...|
|12.099793650764916| 12.2|[5.0,115.0,35.74,...|
|13.299296863106182|13.13|[5.0,133.5,39.59,...|
|15.165939305726127| 13.6|[5.0,160.0,36.34,...|
|4.5899489363177794| 3.55|[6.0,30.276999999...|
+------------------+-----+--------------------+
only showing top 5 rows

R Squared (R2) on test data = 0.9418884254617262
RMSE on test data = 0.7249610894384666


In [29]:
data.describe().show()

+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|
|   mean| Infinity|       null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|
| stddev|      NaN|       null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|
|    min|Adventure|    Azamara|   

In [30]:
data.select(corr('crew', 'passengers')).show() # High Correlation

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [31]:
data.select(corr('crew', 'cabins')).show() # High Correlation

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+



In [32]:
spark.stop()