In [1]:
from pyspark.sql import SparkSession

In [2]:
MAX_MEMORY="5g"
spark = SparkSession.builder.appName("taxi_fare_pred")\
        .config("spark.executor.memory", MAX_MEMORY)\
        .config("spark.driver.memory", MAX_MEMORY)\
        .getOrCreate()

In [3]:
trip_files = "/Users/ji/data-engineering/01-spark/data/trips/*"

In [6]:
trips_df = spark.read.parquet(f"file:///{trip_files}", inferSchema=True, header=True)

In [7]:
trips_df.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)



In [8]:
trips_df.createOrReplaceTempView("trips")

In [11]:
query="""
SELECT
    trip_distance,
    total_amount
FROM
    trips
WHERE
    total_amount<5000
    AND total_amount > 0
    AND trip_distance > 0
    AND trip_distance < 500
    AND passenger_count < 4
    AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
    AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
"""

In [12]:
data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')

In [13]:
data_df.show()

[Stage 2:>                                                          (0 + 1) / 1]

+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
|          8.4|       35.15|
|          0.9|         8.8|
|          3.4|        15.3|
|         1.96|       13.39|
|         0.77|        9.54|
|         3.65|       15.36|
|          8.9|       43.67|
|         2.98|        13.3|
|          8.9|        29.3|
|         7.48|        31.3|
|        12.14|       45.24|
|         11.8|        33.3|
|         1.44|         8.8|
|         1.65|       15.38|
|         8.16|       34.56|
|          7.4|       34.92|
|        13.65|        44.8|
|         3.14|       16.56|
|         15.5|       48.18|
|         20.1|       66.35|
+-------------+------------+
only showing top 20 rows



                                                                                

In [14]:
data_df.describe().show()



+-------+------------------+------------------+
|summary|     trip_distance|      total_amount|
+-------+------------------+------------------+
|  count|          13083045|          13083045|
|   mean|2.8808531003293565| 17.96854399035855|
| stddev|3.8197071180448856|12.972389388491917|
|    min|              0.01|              0.01|
|    max|             475.5|            4973.3|
+-------+------------------+------------------+



                                                                                

In [15]:
train_df, test_df = data_df.randomSplit([0.8, 0.2], seed=1)

In [16]:
print(train_df.count())
print(test_df.count())

                                                                                

10466766




2616279


                                                                                

In [17]:
from pyspark.ml.feature import VectorAssembler

In [25]:
vassembler = VectorAssembler(inputCols=["trip_distance"], outputCol="features")

In [26]:
vtrain_df = vassembler.transform(train_df)

In [27]:
vtrain_df.show()

[Stage 13:>                                                         (0 + 1) / 1]

+-------------+------------+--------+
|trip_distance|total_amount|features|
+-------------+------------+--------+
|         0.01|        2.81|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
|         0.01|         3.3|  [0.01]|
+-------------+------------+--------+
only showing top 20 rows



                                                                                

In [28]:
from pyspark.ml.regression import LinearRegression

In [29]:
lr = LinearRegression(
    maxIter=50,
    labelCol="total_amount",
    featuresCol="features"
)

In [30]:
model = lr.fit(vtrain_df)

22/06/13 02:10:09 WARN Instrumentation: [a643d971] regParam is zero, which might cause numerical instability and overfitting.
22/06/13 02:10:14 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/06/13 02:10:14 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
22/06/13 02:10:32 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
                                                                                

In [31]:
vtest_df = vassembler.transform(test_df)

In [33]:
pred = model.transform(vtest_df)

In [34]:
pred.show()

[Stage 18:>                                                         (0 + 1) / 1]

+-------------+------------+--------+-----------------+
|trip_distance|total_amount|features|       prediction|
+-------------+------------+--------+-----------------+
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982713937|
|         0.01|         3.3|  [0.01]|9.387592982

                                                                                

In [36]:
model.summary.rootMeanSquaredError

6.26219770991078

In [37]:
model.summary.r2

0.7680086135590561

In [38]:
from pyspark.sql.types import DoubleType
dist = [1.1,5.5,10.5,30.0]
dist_df = spark.createDataFrame(dist, DoubleType()).toDF('trip_distance')

In [39]:
dist_df.show()

[Stage 19:>                                                         (0 + 1) / 1]

+-------------+
|trip_distance|
+-------------+
|          1.1|
|          5.5|
|         10.5|
|         30.0|
+-------------+



                                                                                

In [40]:
vdist_df = vassembler.transform(dist_df)

In [41]:
vdist_df.show()

+-------------+--------+
|trip_distance|features|
+-------------+--------+
|          1.1|   [1.1]|
|          5.5|   [5.5]|
|         10.5|  [10.5]|
|         30.0|  [30.0]|
+-------------+--------+



In [42]:
model.transform(vdist_df).show()

+-------------+--------+------------------+
|trip_distance|features|        prediction|
+-------------+--------+------------------+
|          1.1|   [1.1]|12.645981386029183|
|          5.5|   [5.5]|25.799108885650362|
|         10.5|  [10.5]| 40.74584468067442|
|         30.0|  [30.0]| 99.03811428126826|
+-------------+--------+------------------+

