In [26]:
import findspark
findspark.init('/home/zeedlomo/spark-2.4.4-bin-hadoop2.7')

In [27]:
from pyspark.sql import SparkSession

In [28]:
spark = SparkSession.builder.appName('Project').getOrCreate()

In [29]:
from pyspark.ml.regression import LinearRegression

In [30]:
data = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)

In [31]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [32]:
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [37]:
data.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [33]:
from pyspark.ml.feature import StringIndexer

In [35]:
#assign numericals to the cruise_line categories
indexer = StringIndexer(inputCol='Cruise_line',outputCol='cruise_cat')
indexed = indexer.fit(data).transform(data)

In [36]:
indexed.head(1)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0)]

In [38]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [40]:
indexed.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'cruise_cat']

In [41]:
assembler = VectorAssembler(inputCols=['Age','Tonnage','passengers','length','cabins','passenger_density','cruise_cat'], outputCol='features')

In [42]:
output = assembler.transform(indexed)

In [43]:
output.head(1)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55, cruise_cat=16.0, features=DenseVector([6.0, 30.277, 6.94, 5.94, 3.55, 42.64, 16.0]))]

In [44]:
final_data = output.select('features', 'crew')
final_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [57]:
#tuple unpack to train-test split
train, test = final_data.randomSplit([0.7,0.3])

In [58]:
train.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               113|
|   mean|  7.88212389380532|
| stddev|3.5086549751312512|
|    min|               0.6|
|    max|              21.0|
+-------+------------------+



In [59]:
test.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                45|
|   mean| 7.573333333333335|
| stddev|3.5201401056827795|
|    min|              0.59|
|    max|             13.13|
+-------+------------------+



In [60]:
lr = LinearRegression(labelCol='crew')

In [61]:
lr_model = lr.fit(train)

In [62]:
test_results = lr_model.evaluate(test)

In [63]:
#difference between the predicted value and the actual value from the test data
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
|  0.4347767866719856|
| 0.09945942143159492|
|  -1.129254290460782|
| 0.30614485587423346|
| -0.5804741836781364|
|  0.4200517753102808|
| -0.6783181974787293|
| -0.5577573389877273|
| -0.9112122334170483|
| -0.3444707939430849|
|  0.6981632137422125|
| -0.3520069225145974|
|  0.8136168854780728|
| -1.2908527383121644|
|-0.09157272411432782|
|  -1.201986557338449|
|  0.7095216360874179|
|  0.7208800584326234|
|  -1.175583594752748|
|  3.1710201738963093|
+--------------------+
only showing top 20 rows



In [64]:
#Root Mean Squared Error (Most Popular - Shows how many units on average your model is off by)
test_results.rootMeanSquaredError

0.8354866723127298

In [65]:
#R Squared Value (displays how much variance the model explains - the closer to 1 the better)
test_results.r2

0.9423871988009469

In [71]:
from pyspark.sql.functions import corr

In [74]:
#Check correlations between target variable and some independent variables to inspect why model performs so well

data.select(corr('crew', 'passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [76]:
data.select(corr('crew', 'cabins')).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+



In [77]:
data.select(corr('crew', 'length')).show()

+------------------+
|corr(crew, length)|
+------------------+
|0.8958566271016579|
+------------------+



A lot of the features are highly correlated with the target variable, this explains why the model performs so well on a simple linear regression

In [66]:
#Predicting on unlabeled data

unlabeled_data = test.select('features')
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[5.0,115.0,35.74,...|
|[5.0,133.5,39.59,...|
|[6.0,90.0,20.0,9....|
|[8.0,77.499,19.5,...|
|[8.0,110.0,29.74,...|
|[9.0,81.0,21.44,9...|
|[9.0,90.09,25.01,...|
|[10.0,110.0,29.74...|
|[11.0,85.0,18.48,...|
|[11.0,86.0,21.24,...|
|[11.0,91.0,20.32,...|
|[11.0,91.62700000...|
|[11.0,108.977,26....|
|[11.0,138.0,31.14...|
|[12.0,25.0,3.88,5...|
|[12.0,88.5,21.24,...|
|[12.0,91.0,20.32,...|
|[13.0,91.0,20.32,...|
|[14.0,63.0,14.4,7...|
|[14.0,76.8,19.6,8...|
+--------------------+
only showing top 20 rows



In [67]:
predictions = lr_model.transform(unlabeled_data)

In [68]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[5.0,115.0,35.74,...|11.765223213328014|
|[5.0,133.5,39.59,...|13.030540578568406|
|[6.0,90.0,20.0,9....|10.129254290460782|
|[8.0,77.499,19.5,...| 8.693855144125767|
|[8.0,110.0,29.74,...|12.180474183678136|
|[9.0,81.0,21.44,9...|  9.57994822468972|
|[9.0,90.09,25.01,...| 9.368318197478729|
|[10.0,110.0,29.74...|12.157757338987727|
|[11.0,85.0,18.48,...| 8.911212233417048|
|[11.0,86.0,21.24,...| 9.644470793943086|
|[11.0,91.0,20.32,...| 9.291836786257788|
|[11.0,91.62700000...| 9.352006922514597|
|[11.0,108.977,26....|11.186383114521927|
|[11.0,138.0,31.14...|13.140852738312164|
|[12.0,25.0,3.88,5...| 2.961572724114328|
|[12.0,88.5,21.24,...| 10.50198655733845|
|[12.0,91.0,20.32,...| 9.280478363912582|
|[13.0,91.0,20.32,...| 9.269119941567377|
|[14.0,63.0,14.4,7...| 6.785583594752748|
|[14.0,76.8,19.6,8...|  8.82897982610369|
+--------------------+------------