In [2]:
#initializing spark session in Jupyter Notebook
import findspark

In [3]:
findspark.init()

In [4]:
from pyspark.sql import SparkSession

In [20]:
spark = SparkSession.builder.appName('LRCon').getOrCreate()

In [21]:
#import dataset from a csv file
data = spark.read.csv('cruise_ship_info.csv', inferSchema = True, header = True)

In [23]:
#exploring the dataset a little bit
data.describe()

DataFrame[summary: string, Ship_name: string, Cruise_line: string, Age: string, Tonnage: string, passengers: string, length: string, cabins: string, passenger_density: string, crew: string]

In [24]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [27]:
for item in data.head(1)[0]:
    print(item)

Journey
Azamara
6
30.276999999999997
6.94
5.94
3.55
42.64
3.55


In [29]:
data.head(1)

[Row(Ship_name='Journey', Cruise_line='Azamara', Age=6, Tonnage=30.276999999999997, passengers=6.94, length=5.94, cabins=3.55, passenger_density=42.64, crew=3.55)]

In [32]:
data.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew']

In [33]:
#making the given dataset spark compatible
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [34]:
assembbler = VectorAssembler(inputCols = ['Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density'],
                             outputCol = 'feature')
output = assembbler.transform(data)

In [35]:
output.select('feature').show()

+--------------------+
|             feature|
+--------------------+
|[6.0,30.276999999...|
|[6.0,30.276999999...|
|[26.0,47.262,14.8...|
|[11.0,110.0,29.74...|
|[17.0,101.353,26....|
|[22.0,70.367,20.5...|
|[15.0,70.367,20.5...|
|[23.0,70.367,20.5...|
|[19.0,70.367,20.5...|
|[6.0,110.23899999...|
|[10.0,110.0,29.74...|
|[28.0,46.052,14.5...|
|[18.0,70.367,20.5...|
|[17.0,70.367,20.5...|
|[11.0,86.0,21.24,...|
|[8.0,110.0,29.74,...|
|[9.0,88.5,21.24,9...|
|[15.0,70.367,20.5...|
|[12.0,88.5,21.24,...|
|[20.0,70.367,20.5...|
+--------------------+
only showing top 20 rows



In [36]:
output.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|             feature|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+--------------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|[6.0,30.276999999...|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|[6.0,30.276999999...|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|[26.0,47.262,14.8...|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|[11.0,110.0,29.74...|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|[17.0,101.353,26....|
|    Ecstasy|   Carnival| 22|            70.367|     20.

In [38]:
#therefore the final data needed for building the predictive model is
final_data = output.select('feature', 'crew')
final_data.show()
#where 'crew' column denotes the label data

+--------------------+----+
|             feature|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [39]:
#now that we have our final spark compatible data, perform the customery train test split
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [41]:
train_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|              106|
|   mean|8.008113207547183|
| stddev|3.390815638135448|
|    min|             0.59|
|    max|             21.0|
+-------+-----------------+



In [42]:
test_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                52|
|   mean| 7.358076923076923|
| stddev|3.7182417510776538|
|    min|              0.59|
|    max|              13.6|
+-------+------------------+



In [43]:
#initiate the linear regression model
from pyspark.ml.regression import LinearRegression

In [46]:
lrModel = LinearRegression(featuresCol = 'feature', labelCol = 'crew', predictionCol = 'predictions')

In [47]:
trained_Model = lrModel.fit(train_data)

In [48]:
trained_Model.coefficients

DenseVector([-0.0022, 0.0243, -0.1438, 0.3018, 0.7787, 0.0023])

In [51]:
trained_Model.featuresCol

Param(parent='LinearRegression_b49961f78894', name='featuresCol', doc='features column name')

In [53]:
results = trained_Model.evaluate(test_data)

In [55]:
results.r2

0.9469733579129627

In [56]:
results.rootMeanSquaredError

0.8479458177803906

In [57]:
#mimicking an unlabeled data situation
unlabeled_data = test_data.select('feature')

In [58]:
unlabeled_data.show()

+--------------------+
|             feature|
+--------------------+
|[5.0,133.5,39.59,...|
|[5.0,160.0,36.34,...|
|[6.0,30.276999999...|
|[6.0,30.276999999...|
|[6.0,158.0,43.7,1...|
|[7.0,158.0,43.7,1...|
|[8.0,91.0,22.44,9...|
|[8.0,110.0,29.74,...|
|[9.0,81.0,21.44,9...|
|[9.0,85.0,19.68,9...|
|[9.0,90.09,25.01,...|
|[10.0,58.825,15.6...|
|[10.0,77.0,20.16,...|
|[10.0,105.0,27.2,...|
|[10.0,110.0,29.74...|
|[10.0,138.0,31.14...|
|[11.0,90.09,25.01...|
|[11.0,138.0,31.14...|
|[12.0,2.329,0.94,...|
|[12.0,25.0,3.88,5...|
+--------------------+
only showing top 20 rows



In [66]:
predictions = trained_Model.transform(unlabeled_data)

In [60]:
#what are all available columns in these predictions?
predictions.columns

['feature', 'predictions']

In [62]:
#visualizing the predictions
predictions.show()

+--------------------+------------------+
|             feature|       predictions|
+--------------------+------------------+
|[5.0,133.5,39.59,...| 13.12244490803336|
|[5.0,160.0,36.34,...|15.716528820444779|
|[6.0,30.276999999...|3.8387081479411327|
|[6.0,30.276999999...|3.8387081479411327|
|[6.0,158.0,43.7,1...|14.496123616704763|
|[7.0,158.0,43.7,1...|14.454724166496687|
|[8.0,91.0,22.44,9...|10.168547594244865|
|[8.0,110.0,29.74,...|12.372192303370396|
|[9.0,81.0,21.44,9...| 9.538602050278547|
|[9.0,85.0,19.68,9...| 9.258471062477161|
|[9.0,90.09,25.01,...| 9.537137014986083|
|[10.0,58.825,15.6...|7.1508463746579425|
|[10.0,77.0,20.16,...| 8.672739726789604|
|[10.0,105.0,27.2,...| 11.41115433169151|
|[10.0,110.0,29.74...|12.367873255668789|
|[10.0,138.0,31.14...|13.617295732731039|
|[11.0,90.09,25.01...| 9.190204461069879|
|[11.0,138.0,31.14...|13.615136208880235|
|[12.0,2.329,0.94,...|0.6543658945219479|
|[12.0,25.0,3.88,5...| 2.943702470502175|
+--------------------+------------

In [69]:
#comparing against the real values listed in the test data
test_data.show()

+--------------------+-----+
|             feature| crew|
+--------------------+-----+
|[5.0,133.5,39.59,...|13.13|
|[5.0,160.0,36.34,...| 13.6|
|[6.0,30.276999999...| 3.55|
|[6.0,30.276999999...| 3.55|
|[6.0,158.0,43.7,1...| 13.6|
|[7.0,158.0,43.7,1...| 13.6|
|[8.0,91.0,22.44,9...| 11.0|
|[8.0,110.0,29.74,...| 11.6|
|[9.0,81.0,21.44,9...| 10.0|
|[9.0,85.0,19.68,9...| 8.69|
|[9.0,90.09,25.01,...| 8.69|
|[10.0,58.825,15.6...|  7.0|
|[10.0,77.0,20.16,...|  9.0|
|[10.0,105.0,27.2,...|10.68|
|[10.0,110.0,29.74...| 11.6|
|[10.0,138.0,31.14...|11.85|
|[11.0,90.09,25.01...| 8.48|
|[11.0,138.0,31.14...|11.85|
|[12.0,2.329,0.94,...|  0.6|
|[12.0,25.0,3.88,5...| 2.87|
+--------------------+-----+
only showing top 20 rows

