In [1]:
import findspark
findspark.init('/home/duynguyen/spark-2.1.0-bin-hadoop2.7')

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('lr_example').getOrCreate()

In [3]:
data = spark.read.csv("cruise_ship_info.csv",inferSchema=True,header=True)

In [4]:
data.printSchema()

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



In [5]:
data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Elation|   Carnival| 15|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|
|    Fantasy|   Carnival| 23| 

In [7]:
from pyspark.ml.regression import LinearRegression

In [8]:
from pyspark.ml.linalg import Vectors

In [9]:
from pyspark.ml.feature import VectorAssembler

In [10]:
from pyspark.ml.feature import StringIndexer

In [11]:
indexer = StringIndexer(inputCol="Cruise_line",outputCol="categoryIndex")

In [12]:
new_data = indexer.fit(data).transform(data)

In [13]:
new_data.show()

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|categoryIndex|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+-------------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|         16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|         16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|          1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|          1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|          1.0|
|    Ecstasy|   Carnival| 22|            70.367|     20.52|  8.55|  10.2|            34.29| 9.2|          1.0|
|

In [14]:
new_data.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'categoryIndex']

In [15]:
assembler = VectorAssembler(inputCols=['Age','Tonnage','passengers','length','cabins','passenger_density','categoryIndex'],outputCol='features')

In [17]:
output = assembler.transform(new_data)

In [20]:
output.select("features").show()

+--------------------+
|            features|
+--------------------+
|[6.0,30.276999999...|
|[6.0,30.276999999...|
|[26.0,47.262,14.8...|
|[11.0,110.0,29.74...|
|[17.0,101.353,26....|
|[22.0,70.367,20.5...|
|[15.0,70.367,20.5...|
|[23.0,70.367,20.5...|
|[19.0,70.367,20.5...|
|[6.0,110.23899999...|
|[10.0,110.0,29.74...|
|[28.0,46.052,14.5...|
|[18.0,70.367,20.5...|
|[17.0,70.367,20.5...|
|[11.0,86.0,21.24,...|
|[8.0,110.0,29.74,...|
|[9.0,88.5,21.24,9...|
|[15.0,70.367,20.5...|
|[12.0,88.5,21.24,...|
|[20.0,70.367,20.5...|
+--------------------+
only showing top 20 rows



In [21]:
final_data =  output.select("features","crew")

In [22]:
final_data.show()

+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
|[22.0,70.367,20.5...| 9.2|
|[15.0,70.367,20.5...| 9.2|
|[23.0,70.367,20.5...| 9.2|
|[19.0,70.367,20.5...| 9.2|
|[6.0,110.23899999...|11.5|
|[10.0,110.0,29.74...|11.6|
|[28.0,46.052,14.5...| 6.6|
|[18.0,70.367,20.5...| 9.2|
|[17.0,70.367,20.5...| 9.2|
|[11.0,86.0,21.24,...| 9.3|
|[8.0,110.0,29.74,...|11.6|
|[9.0,88.5,21.24,9...|10.3|
|[15.0,70.367,20.5...| 9.2|
|[12.0,88.5,21.24,...| 9.3|
|[20.0,70.367,20.5...| 9.2|
+--------------------+----+
only showing top 20 rows



In [23]:
train_data,test_data = final_data.randomSplit([0.7,0.3])

In [24]:
train_data.describe().show()

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|              110|
|   mean|7.889181818181828|
| stddev| 3.44161769404447|
|    min|             0.59|
|    max|             19.1|
+-------+-----------------+



In [26]:
test_data.describe().show()

+-------+------------------+
|summary|              crew|
+-------+------------------+
|  count|                48|
|   mean| 7.576458333333331|
| stddev|3.6691343344424445|
|    min|               0.6|
|    max|              21.0|
+-------+------------------+



In [27]:
#create a linear regression mobel object
lr = LinearRegression(labelCol='crew')

In [28]:
lrModel = lr.fit(train_data)

In [31]:
test_result = lrModel.evaluate(test_data)

In [33]:
test_result.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| -0.3988543788512722|
| -1.4357698583406773|
|  -1.427685260999107|
| -0.3250412921401882|
|  -0.705454308318723|
| 0.35118391727397125|
| -0.5905616479268758|
|-0.29961948198183563|
|  1.8118039847920464|
| 0.05586905692642219|
|-0.28047962449251784|
| -1.3150354799896853|
|  0.8470982363407682|
| -0.3627095952926762|
|  -0.244477571137649|
|  -1.297265450789844|
| 0.08647987478509478|
| 0.09923572566857164|
| 0.07188503383031808|
| 0.17389903690387554|
+--------------------+
only showing top 20 rows



In [34]:
unlabeled_data = test_data.select('features')

In [35]:
predictions = lrModel.transform(unlabeled_data)

In [36]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[4.0,220.0,54.0,1...|21.398854378851272|
|[6.0,90.0,20.0,9....|10.435769858340677|
|[7.0,89.6,25.5,9....|11.297685260999106|
|[7.0,158.0,43.7,1...|13.925041292140188|
|[8.0,110.0,29.74,...|12.305454308318723|
|[9.0,81.0,21.44,9...| 9.648816082726029|
|[9.0,90.09,25.01,...| 9.280561647926875|
|[9.0,116.0,26.0,9...|11.299619481981836|
|[10.0,46.0,7.0,6....|2.6581960152079533|
|[10.0,77.0,20.16,...| 8.944130943073578|
|[10.0,90.09,25.01...| 8.860479624492518|
|[10.0,138.0,31.14...|13.165035479989685|
|[11.0,90.0,22.4,9...|10.152901763659232|
|[11.0,90.09,25.01...| 8.842709595292677|
|[11.0,91.62700000...| 9.244477571137649|
|[11.0,138.0,31.14...|13.147265450789844|
|[12.0,2.329,0.94,...|0.5135201252149052|
|[12.0,50.0,7.0,7....|4.3507642743314285|
|[12.0,108.865,27....|10.928114966169682|
|[13.0,25.0,3.82,5...|2.7761009630961246|
+--------------------+------------

In [37]:
test_data.show()

+--------------------+-----+
|            features| crew|
+--------------------+-----+
|[4.0,220.0,54.0,1...| 21.0|
|[6.0,90.0,20.0,9....|  9.0|
|[7.0,89.6,25.5,9....| 9.87|
|[7.0,158.0,43.7,1...| 13.6|
|[8.0,110.0,29.74,...| 11.6|
|[9.0,81.0,21.44,9...| 10.0|
|[9.0,90.09,25.01,...| 8.69|
|[9.0,116.0,26.0,9...| 11.0|
|[10.0,46.0,7.0,6....| 4.47|
|[10.0,77.0,20.16,...|  9.0|
|[10.0,90.09,25.01...| 8.58|
|[10.0,138.0,31.14...|11.85|
|[11.0,90.0,22.4,9...| 11.0|
|[11.0,90.09,25.01...| 8.48|
|[11.0,91.62700000...|  9.0|
|[11.0,138.0,31.14...|11.85|
|[12.0,2.329,0.94,...|  0.6|
|[12.0,50.0,7.0,7....| 4.45|
|[12.0,108.865,27....| 11.0|
|[13.0,25.0,3.82,5...| 2.95|
+--------------------+-----+
only showing top 20 rows



In [38]:
from pyspark.sql.functions import corr

In [39]:
test_result.r2

0.96439590437483

In [40]:
test_result.rootMeanSquaredError

0.6850809670893258

In [41]:
test_result.meanSquaredError

0.4693359314680458

In [43]:
data.select(corr('crew','passengers')).show()

+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+



In [44]:
data.select(corr('crew','cabins')).show()

+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

