In [1]:
"""
# Running PySpark at Google Colab
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-3.0.0-preview2/spark-3.0.0-preview2-bin-hadoop2.7.tgz
!tar -xvf spark-3.0.0-preview2-bin-hadoop2.7.tgz
!pip install -q findspark
"""

'\n# Running PySpark at Google Colab\n!apt-get install openjdk-8-jdk-headless -qq > /dev/null\n!wget -q https://downloads.apache.org/spark/spark-3.0.0-preview2/spark-3.0.0-preview2-bin-hadoop2.7.tgz\n!tar -xvf spark-3.0.0-preview2-bin-hadoop2.7.tgz\n!pip install -q findspark\n'

In [0]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.0.0-preview2-bin-hadoop2.7"

import findspark
findspark.init()
from pyspark.sql import SparkSession

In [0]:
spark = SparkSession.builder.appName('hyundai_cruise').getOrCreate()

In [0]:
from pyspark.ml.regression import LinearRegression

In [0]:
data = spark.read.csv('/content/drive/My Drive/Colab Notebooks/cruise_ship_info.csv', inferSchema=True, header=True)

In [6]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [7]:
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [8]:
data.groupBy('Cruise_line').count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|          Azamara|    2|
|         Carnival|   22|
|        Celebrity|   10|
|            Costa|   11|
|          Crystal|    2|
|           Cunard|    3|
|           Disney|    2|
| Holland_American|   14|
|              MSC|    8|
|        Norwegian|   13|
|          Oceania|    3|
|           Orient|    1|
|              P&O|    6|
|         Princess|   17|
|Regent_Seven_Seas|    5|
|  Royal_Caribbean|   23|
|         Seabourn|    3|
|        Silversea|    4|
|             Star|    6|
|         Windstar|    3|
+-----------------+-----+



In [0]:
# Make Cruise_line column indexed
from pyspark.ml.feature import StringIndexer

In [0]:
StringIndexer = StringIndexer(inputCol="Cruise_line", outputCol="indexed", handleInvalid="error", stringOrderType='frequencyDesc')
model = StringIndexer.fit(data)
td = model.transform(data)

In [11]:
td.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|indexed|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|   16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|   16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|    1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|    1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|    1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|    1.0|
|    Elation|   Carnival| 15|            70.367|     20

In [0]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [0]:
assembler = VectorAssembler(inputCols=['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density', 'indexed'], outputCol='features')

In [0]:
output = assembler.transform(td)

In [15]:
output.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)
 |-- indexed: double (nullable = false)
 |-- features: vector (nullable = true)



In [0]:
final_data = output.select('features', 'crew')

In [17]:
final_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [0]:
train_data, test_data = final_data.randomSplit([0.7, 0.3]) #.describe().show()

In [19]:
train_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               117|
|   mean| 7.882649572649582|
| stddev|3.5404376782720357|
|    min|              0.59|
|    max|              21.0|
+-------+------------------+



In [20]:
test_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|               41|
|   mean|7.541707317073172|
| stddev|3.426110697597952|
|    min|              1.6|
|    max|             13.6|
+-------+-----------------+



In [0]:
lr = LinearRegression(labelCol='crew')

In [0]:
lr_model = lr.fit(train_data)

In [0]:
test_results = lr_model.evaluate(test_data)

In [24]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| -1.3924169315005983|
|  -0.071510113814659|
|-0.46691603203013976|
| 0.25193442420935597|
| -0.3362474605862342|
| -0.2581228388939705|
| -0.5995611940684267|
| -0.4187012931301819|
| -0.6278660036247175|
|  -0.335283255025427|
|  0.8411007313431327|
|  -0.207507225148591|
|  0.7285492502155329|
| -1.1847932896085744|
|  -1.180597155485545|
|-0.03776773108837...|
| -0.5871158704094128|
| -1.4661800330015184|
| 0.29478159521451275|
| 0.19437261562378927|
+--------------------+
only showing top 20 rows



In [25]:
test_results.rootMeanSquaredError

0.5941209788018735

In [26]:
test_results.r2

0.9691772872221183

In [0]:
unlabeled_data = test_data.select('features')

In [28]:
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[5.0,86.0,21.04,9...|
|[5.0,133.5,39.59,...|
|[6.0,112.0,38.0,9...|
|[6.0,113.0,37.82,...|
|[7.0,158.0,43.7,1...|
|[9.0,59.058,17.0,...|
|[9.0,110.0,29.74,...|
|[10.0,90.09,25.01...|
|[10.0,105.0,27.2,...|
|[11.0,86.0,21.24,...|
|[11.0,108.977,26....|
|[12.0,25.0,3.88,5...|
|[12.0,77.104,20.0...|
|[12.0,88.5,21.24,...|
|[12.0,138.0,31.14...|
|[13.0,30.27699999...|
|[13.0,61.0,13.8,7...|
|[13.0,63.0,14.4,7...|
|[13.0,101.509,27....|
|[14.0,30.27699999...|
+--------------------+
only showing top 20 rows



In [0]:
prediction = lr_model.transform(unlabeled_data)

In [30]:
prediction.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[5.0,86.0,21.04,9...| 9.392416931500598|
|[5.0,133.5,39.59,...| 13.20151011381466|
|[6.0,112.0,38.0,9...| 11.36691603203014|
|[6.0,113.0,37.82,...|11.748065575790644|
|[7.0,158.0,43.7,1...|13.936247460586234|
|[9.0,59.058,17.0,...| 7.658122838893971|
|[9.0,110.0,29.74,...|12.199561194068426|
|[10.0,90.09,25.01...| 8.998701293130182|
|[10.0,105.0,27.2,...|11.307866003624717|
|[11.0,86.0,21.24,...| 9.635283255025428|
|[11.0,108.977,26....|11.158899268656867|
|[12.0,25.0,3.88,5...| 3.077507225148591|
|[12.0,77.104,20.0...| 8.861450749784467|
|[12.0,88.5,21.24,...|10.484793289608575|
|[12.0,138.0,31.14...|13.030597155485545|
|[13.0,30.27699999...| 4.037767731088377|
|[13.0,61.0,13.8,7...| 6.587115870409413|
|[13.0,63.0,14.4,7...| 6.776180033001518|
|[13.0,101.509,27....|11.205218404785487|
|[14.0,30.27699999...|3.5356273843762107|
+--------------------+------------

In [0]:
from pyspark.sql.functions import corr

In [32]:
data.describe().show()

+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|summary|Ship_name|Cruise_line|               Age|           Tonnage|       passengers|           length|            cabins|passenger_density|             crew|
+-------+---------+-----------+------------------+------------------+-----------------+-----------------+------------------+-----------------+-----------------+
|  count|      158|        158|               158|               158|              158|              158|               158|              158|              158|
|   mean| Infinity|       null|15.689873417721518| 71.28467088607599|18.45740506329114|8.130632911392404| 8.830000000000005|39.90094936708861|7.794177215189873|
| stddev|      NaN|       null| 7.615691058751413|37.229540025907866|9.677094775143416|1.793473548054825|4.4714172221480615| 8.63921711391542|3.503486564627034|
|    min|Adventure|    Azamara|   

In [33]:
data.select(corr('crew', 'passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [34]:
data.select(corr('crew', 'cabins')).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

