In [32]:
from pyspark.sql import SparkSession

In [33]:
spark = SparkSession.builder.appName('shipping_crew').getOrCreate()

In [34]:
from pyspark.ml.regression import LinearRegression

In [35]:
df = spark.read.csv("/home/marcelo/Documents/shipping_crew/cruise_ship_info.csv", header=True, inferSchema=True)

In [36]:
df.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [37]:
from pyspark.ml.feature import StringIndexer

In [39]:
indexer = StringIndexer(inputCol='Cruise_line', outputCol='Cruise_line_index')
indexed = indexer.fit(df).transform(df)
indexed.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|Cruise_line_index|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-----------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|             16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|             16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|              1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|              1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|              1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|       

In [40]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import MinMaxScaler

In [41]:
assembler = VectorAssembler(
    inputCols=['Cruise_line_index', 'Age', 'Tonnage', 'passengers', 'length', 'cabins', 'passenger_density'],
    outputCol='features'
)

data_features = assembler.transform(indexed)

# Scale

scaler = MinMaxScaler(inputCol='features', outputCol='scaled_features')
scaler_model = scaler.fit(data_features.select('features'))
df = scaler_model.transform(data_features)

#Removing NAN

data = df.dropna(how='any').dropDuplicates()


data.select('crew','scaled_features').show()

+-----+--------------------+
| crew|     scaled_features|
+-----+--------------------+
|  9.2|[0.05263157894736...|
|  9.3|[0.05263157894736...|
| 12.2|[0.42105263157894...|
|  8.5|[0.42105263157894...|
| 8.58|[0.0,0.1363636363...|
|13.13|[0.36842105263157...|
| 9.99|[0.31578947368421...|
| 21.0|[0.0,0.0,1.0,0.99...|
|  5.2|[0.42105263157894...|
|  9.2|[0.05263157894736...|
|  3.8|[0.21052631578947...|
| 9.45|[0.94736842105263...|
| 3.73|[0.10526315789473...|
| 5.88|[0.15789473684210...|
|12.38|[0.10526315789473...|
| 8.58|[0.31578947368421...|
| 13.6|[0.0,0.0227272727...|
| 6.44|[0.15789473684210...|
| 5.88|[0.15789473684210...|
|  5.2|[0.10526315789473...|
+-----+--------------------+
only showing top 20 rows



In [42]:
lr = LinearRegression(featuresCol='scaled_features', labelCol='crew', predictionCol='prediction')

In [43]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [44]:
model_train = lr.fit(train_data)

23/01/02 11:55:15 WARN Instrumentation: [c57193ca] regParam is zero, which might cause numerical instability and overfitting.


In [45]:
test_results = model_train.evaluate(test_data)

In [46]:
rmse = test_results.rootMeanSquaredError
mse = test_results.meanSquaredError
r2 = test_results.r2

In [47]:
print(rmse)
print(mse)
print(r2)

0.7884601968924153
0.6216694820836263
0.946326104889263
