In [78]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import countDistinct,corr

from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler,StringIndexer

from pyspark.ml.regression import LinearRegression

In [2]:
spark = SparkSession.builder.master("local").appName("RegressionShip").getOrCreate()

In [3]:
data = spark.read.csv("cruise_ship_info.csv",header=True,inferSchema=True)

In [4]:
data.show(n=5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 5 rows



In [5]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



# Target Variable: crew

In [9]:
data.groupBy("Cruise_line").count().show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|            Costa|   11|
|              P&O|    6|
|           Cunard|    3|
|Regent_Seven_Seas|    5|
|              MSC|    8|
|         Carnival|   22|
|          Crystal|    2|
|           Orient|    1|
|         Princess|   17|
|        Silversea|    4|
|         Seabourn|    3|
| Holland_American|   14|
|         Windstar|    3|
|           Disney|    2|
|        Norwegian|   13|
|          Oceania|    3|
|          Azamara|    2|
|        Celebrity|   10|
|             Star|    6|
|  Royal_Caribbean|   23|
+-----------------+-----+



In [7]:
#getting the distinct values in the column Cruise Line

data.agg(countDistinct(data['Cruise_Line']).alias("Count")).show()

+-----+
|Count|
+-----+
|   20|
+-----+



In [8]:
#getting the distinct values in the column Ship name
data.agg(countDistinct(data['Ship_name']).alias("Count")).show()

+-----+
|Count|
+-----+
|  138|
+-----+



# StringIndexer

Using string indexer to encode the string values in ship name and cruise line into integers

In [10]:
si = StringIndexer(inputCol="Cruise_line",outputCol="Cruise_line_category")

In [13]:
data_cat = si.fit(data).transform(data)

In [16]:
data_cat.select(['Cruise_line_category']).show()

+--------------------+
|Cruise_line_category|
+--------------------+
|                16.0|
|                16.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
|                 1.0|
+--------------------+
only showing top 20 rows



In [25]:
type(countDistinct(data_cat['Cruise_line_category']))

pyspark.sql.column.Column

In [42]:
#show the sorted distinct values of a column 
data_cat.orderBy(["Cruise_line_category"]).select("Cruise_line_category").distinct().show()

+--------------------+
|Cruise_line_category|
+--------------------+
|                 0.0|
|                 1.0|
|                 2.0|
|                 3.0|
|                 4.0|
|                 5.0|
|                 6.0|
|                 7.0|
|                 8.0|
|                 9.0|
|                10.0|
|                11.0|
|                12.0|
|                13.0|
|                14.0|
|                15.0|
|                16.0|
|                17.0|
|                18.0|
|                19.0|
+--------------------+



In [50]:
data_cat.show(n=5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_category|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|                16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|                16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|                 1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|                 1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|                 1.0|
+-----------+-----------+---+------------------+--------

In [64]:
assembler = VectorAssembler(
    inputCols=["Age",
              "Tonnage",
              "passengers",
              "length",
              "cabins",
              "passenger_density",
              "Cruise_line_category"],
    outputCol="features")

In [65]:
final_data = assembler.transform(data_cat)
final_data = final_data.select(["features","crew"])

In [66]:
train_data, test_data = final_data.randomSplit([0.8,0.2])

In [68]:
lr = LinearRegression(featuresCol='features',labelCol='crew',predictionCol='prediction')

In [70]:
linear_regression_model = lr.fit(train_data)

In [75]:
linear_regression_model.coefficients,linear_regression_model.intercept

(DenseVector([-0.0052, 0.0111, -0.1445, 0.3815, 0.8539, -0.007, 0.0464]),
 -0.8495173517553753)

In [76]:
test_results = linear_regression_model.evaluate(test_data)

In [77]:
print("RMSE: {}".format(test_results.rootMeanSquaredError))
print("MSE: {}".format(test_results.meanSquaredError))
print("R2: {}".format(test_results.r2))

RMSE: 0.6678459308837401
MSE: 0.4460181873979694
R2: 0.9547182375122512


In [79]:
data.select(corr("crew","passengers")).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [81]:
data.select(corr("crew","cabins")).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

