In [5]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.ml.regression import LinearRegression
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [2]:
spark = SparkSession.builder.appName('hyundai_test').getOrCreate()

22/02/15 11:15:22 WARN Utils: Your hostname, ganesh-pi resolves to a loopback address: 127.0.1.1; using 192.168.1.119 instead (on interface eth0)
22/02/15 11:15:22 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/02/15 11:15:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.csv('cruise_ship_info.csv', inferSchema=True, header=True)
df.printSchema()

[Stage 1:>                                                          (0 + 1) / 1]

root
 |-- Ship_name: string (nullable = true)
 |-- Cruise_line: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tonnage: double (nullable = true)
 |-- passengers: double (nullable = true)
 |-- length: double (nullable = true)
 |-- cabins: double (nullable = true)
 |-- passenger_density: double (nullable = true)
 |-- crew: double (nullable = true)



                                                                                

In [4]:
df.show(truncate=False)

                                                                                

+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|Ship_name  |Cruise_line|Age|Tonnage           |passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|Journey    |Azamara    |6  |30.276999999999997|6.94      |5.94  |3.55  |42.64            |3.55|
|Quest      |Azamara    |6  |30.276999999999997|6.94      |5.94  |3.55  |42.64            |3.55|
|Celebration|Carnival   |26 |47.262            |14.86     |7.22  |7.43  |31.8             |6.7 |
|Conquest   |Carnival   |11 |110.0             |29.74     |9.53  |14.88 |36.99            |19.1|
|Destiny    |Carnival   |17 |101.353           |26.42     |8.92  |13.21 |38.36            |10.0|
|Ecstasy    |Carnival   |22 |70.367            |20.52     |8.55  |10.2  |34.29            |9.2 |
|Elation    |Carnival   |15 |70.367            |20.52     |8.55  |10.2  |34.29            |9.2 |
|Fantasy    |Carnival   |23 |7

In [7]:
#Converting the Cruise_line column to string values
indexer = StringIndexer(inputCol='Cruise_line', outputCol='Cruise_line_index')
indexed = indexer.fit(df).transform(df)

In [8]:
indexed.select(['Cruise_line', 'Cruise_line_index']).show(truncate=False)

+-----------+-----------------+
|Cruise_line|Cruise_line_index|
+-----------+-----------------+
|Azamara    |16.0             |
|Azamara    |16.0             |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
|Carnival   |1.0              |
+-----------+-----------------+
only showing top 20 rows



In [13]:
output.columns

['Ship_name',
 'Cruise_line',
 'Age',
 'Tonnage',
 'passengers',
 'length',
 'cabins',
 'passenger_density',
 'crew',
 'Cruise_line_index',
 'features']

In [11]:
#combining all the numeric features into vectors
assembler = VectorAssembler(inputCols=['Cruise_line_index', 'Age', 'Tonnage', 'passengers',
                            'length', 'cabins', 'passenger_density'], outputCol='features')

output = assembler.transform(indexed)
# output.show()

In [15]:
train_data, test_data = output.select('features', 'crew').randomSplit([0.7, 0.3])
train_data.describe().show()

                                                                                

+-------+-----------------+
|summary|             crew|
+-------+-----------------+
|  count|              114|
|   mean|7.579473684210531|
| stddev|3.147236136002752|
|    min|             0.59|
|    max|             13.6|
+-------+-----------------+



In [16]:
lr = LinearRegression(labelCol='crew')
lr_model = lr.fit(train_data)

test_results = lr_model.evaluate(test_data)
print(test_results.rootMeanSquaredError, test_results.r2)

22/02/15 11:29:59 WARN Instrumentation: [242bbc20] regParam is zero, which might cause numerical instability and overfitting.
22/02/15 11:30:00 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/02/15 11:30:00 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
22/02/15 11:30:00 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
                                                                                

1.1822631818026546 0.9220459690901849


In [17]:
unlablelled_data = test_data.select('features')

predictions = lr_model.transform(unlablelled_data)
predictions.show(truncate=False)

+--------------------------------------------------+------------------+
|features                                          |prediction        |
+--------------------------------------------------+------------------+
|[0.0,4.0,220.0,54.0,11.82,27.0,40.74]             |20.602817543499373|
|[0.0,6.0,158.0,43.7,11.25,18.0,36.16]             |13.855356761097509|
|[0.0,7.0,158.0,43.7,11.12,18.0,36.16]             |13.793499202213347|
|[0.0,11.0,138.0,31.14,10.2,15.57,44.32]           |12.9266849018226  |
|[0.0,12.0,90.09,25.01,9.62,10.5,36.02]            |8.878219783156325 |
|[0.0,13.0,138.0,31.14,10.2,15.57,44.32]           |12.915201649603976|
|[0.0,17.0,70.0,20.76,8.67,9.02,33.72]             |7.585497580461531 |
|[0.0,18.0,70.0,18.0,8.67,9.0,38.89]               |7.932585220058812 |
|[1.0,8.0,110.0,29.74,9.51,14.87,36.99]            |11.924233202323432|
|[1.0,11.0,110.0,29.74,9.53,14.88,36.99]           |11.923379960330038|
|[1.0,17.0,101.353,26.42,8.92,13.21,38.36]         |10.660337121