In [19]:
from pyspark.sql import SparkSession

from pyspark.sql import functions as F
from pyspark.sql import types as T

In [2]:
spark = SparkSession.builder.appName('regression_ship').getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/02/22 18:21:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/02/22 18:21:31 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/02/22 18:21:31 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [5]:
df = spark.read.csv(
    path='data/cruise_ship_info.txt',
    inferSchema=True,
    header=True)
print(df.count(), len(df.columns))

158 9


In [7]:
df.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [8]:
df.show(5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 5 rows



In [23]:
df.groupBy(['Cruise_line']).count().sort(F.desc('count')).show()

+-----------------+-----+
|      Cruise_line|count|
+-----------------+-----+
|  Royal_Caribbean|   23|
|         Carnival|   22|
|         Princess|   17|
| Holland_American|   14|
|        Norwegian|   13|
|            Costa|   11|
|        Celebrity|   10|
|              MSC|    8|
|              P&O|    6|
|             Star|    6|
|Regent_Seven_Seas|    5|
|        Silversea|    4|
|         Seabourn|    3|
|         Windstar|    3|
|           Cunard|    3|
|          Oceania|    3|
|          Crystal|    2|
|          Azamara|    2|
|           Disney|    2|
|           Orient|    1|
+-----------------+-----+



### Index String

In [24]:
from pyspark.ml.feature import StringIndexer

In [26]:
indexer = StringIndexer(inputCol='Cruise_line', outputCol='cruise_line_idx')
indexed = indexer.fit(df).transform(df)
indexed.show(5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|cruise_line_idx|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|            1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|            1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|            1.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+

In [27]:
indexed.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'cruise_line_idx']

### Assemble Vector

In [29]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [35]:
assembler = VectorAssembler(inputCols=[
    'Age',
    'Tonnage',
    'passengers',
    'length',
    'cabins',
    'passenger_density',
    'crew',
    'cruise_line_idx'],
    outputCol='features'
)
output = assembler.transform(indexed)
output.show(5)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|cruise_line_idx|            features|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+---------------+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|[6.0,30.276999999...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|           16.0|[6.0,30.276999999...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|            1.0|[26.0,47.262,14.8...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|            1.0|[11.0,110.0,29.74...|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8

In [36]:
final_data = output.select(['features', 'crew'])
train_data, test_data = final_data.randomSplit([0.7,0.3])
print(train_data.count())
print(test_data.count())

112
46


In [39]:
train_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               112|
|   mean| 8.088392857142868|
| stddev|3.3872629875720524|
|    min|              0.59|
|    max|              19.1|
+-------+------------------+



### LR

In [40]:
from pyspark.ml.regression import LinearRegression

In [42]:
train_data.columns

['features', 'crew']

In [43]:
ship_lr = LinearRegression(labelCol='crew')
trained_ship_model = ship_lr.fit(train_data)
ship_results = trained_ship_model.evaluate(test_data)

23/02/22 18:38:21 WARN Instrumentation: [76d552e5] regParam is zero, which might cause numerical instability and overfitting.


In [51]:
train_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|               112|
|   mean| 8.088392857142868|
| stddev|3.3872629875720524|
|    min|              0.59|
|    max|              19.1|
+-------+------------------+



In [45]:
print(ship_results.rootMeanSquaredError)
print(ship_results.r2)
print(ship_results.meanSquaredError)

1.6935818469962603e-14
1.0
2.868219472475265e-28


In [56]:
df.show(20)

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

### Prove

In [57]:
from pyspark.sql.functions import corr

In [68]:
df.select(corr('crew', 'Tonnage')).show()

+-------------------+
|corr(crew, Tonnage)|
+-------------------+
|  0.927568811544939|
+-------------------+

