In [3]:
from pyspark.sql import SparkSession
MAX_MEMORY = '8g'
spark = SparkSession.builder.appName('taxi-fare-prediction_2nd')\
                    .config('spark.driver.memory', MAX_MEMORY)\
                    .config('spark.executor.memory', MAX_MEMORY)\
                    .getOrCreate()

In [25]:
import os
cwd = os.getcwd()
trip_data_path = os.path.join(cwd, 'learning_spark_data', 'trips', '*.csv')
file_path = f"file:///{trip_data_path.replace(os.sep, '/')}"
trip_df = spark.read.csv(file_path, inferSchema=True, header=True)
trip_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [26]:
trip_df.createOrReplaceTempView('trips')

In [27]:
query = '''
    SELECT 
        passenger_count,
        PULocationID as pickup_location_id,
        DOLocationID as dropoff_location_id,
        trip_distance,
        HOUR(tpep_pickup_datetime) as pickup_time,
        DATE_FORMAT(TO_DATE(tpep_pickup_datetime), 'EEEE') AS day_of_week,
        total_amount
    FROM
        trips
    WHERE
        total_amount < 5000
        AND total_amount > 0
        AND trip_distance > 0
        AND trip_distance < 500
        AND passenger_count < 4
        AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
        AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
'''

In [28]:
data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')

In [29]:
spark.sql('select * from data limit 5').show()

+---------------+------------------+-------------------+-------------+-----------+-----------+------------+
|passenger_count|pickup_location_id|dropoff_location_id|trip_distance|pickup_time|day_of_week|total_amount|
+---------------+------------------+-------------------+-------------+-----------+-----------+------------+
|              0|               138|                265|         16.5|          0|     Monday|       70.07|
|              1|                68|                264|         1.13|          0|     Monday|       11.16|
|              1|               239|                262|         2.68|          0|     Monday|       18.59|
|              1|               186|                 91|         12.4|          0|     Monday|        43.8|
|              2|               132|                265|          9.7|          0|     Monday|        32.3|
+---------------+------------------+-------------------+-------------+-----------+-----------+------------+



In [30]:
train_df, test_df = data_df.randomSplit([0.8, 0.2], seed=42)

# 파이프라인 생성
- 전처리 과정을 각 스테이지로 정의해서 쌓아줌
- [범주형] StringIndexer + onehotencoding > 'pickup_location_id', 'dropoff_location_id', 'day_of_week'
- [수치형] Vectorassembler, StandardScaler > 'passenger_count', 'trip_distance', 'pickup_time'

In [34]:
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder, StandardScaler

In [32]:
stages = []

In [33]:
cat_features = ['pickup_location_id', 'dropoff_location_id', 'day_of_week']

for cat in cat_features:
    cat_index = StringIndexer(inputCol=cat, outputCol=cat+'_idx').setHandleInvalid('keep')
    onehot_encode = OneHotEncoder(inputCols=[cat_index.getOutputCol()], outputCols=[cat+'_onehot'])
    stages += [cat_index, onehot_encode]  # col list
stages

[StringIndexer_6e318c54a75f,
 OneHotEncoder_a8819829b0e0,
 StringIndexer_b372654022cd,
 OneHotEncoder_42124236f939,
 StringIndexer_e16732074e59,
 OneHotEncoder_fb61272d411c]

In [38]:
# vectorassembler
from pyspark.ml.feature import VectorAssembler
num_features = ['passenger_count', 'trip_distance', 'pickup_time']

for num in num_features:
    num_assembler = VectorAssembler(
        inputCols=[num],
        outputCol=num+'_vector'
    )
    num_scaler = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol=num+'_scaled')
    stages += [num_assembler, num_scaler]
stages

[StringIndexer_6e318c54a75f,
 OneHotEncoder_a8819829b0e0,
 StringIndexer_b372654022cd,
 OneHotEncoder_42124236f939,
 StringIndexer_e16732074e59,
 OneHotEncoder_fb61272d411c,
 VectorAssembler_0fc409cf09ce,
 StandardScaler_552ff29606ac,
 VectorAssembler_eb142e822879,
 StandardScaler_ea56b1aec20c,
 VectorAssembler_3e90929562ae,
 StandardScaler_0d563182b800]

In [40]:
assembler_input = [cat+'_onehot' for cat in cat_features] + [num+'_scaled' for num in num_features]
assembler_input

['pickup_location_id_onehot',
 'dropoff_location_id_onehot',
 'day_of_week_onehot',
 'passenger_count_scaled',
 'trip_distance_scaled',
 'pickup_time_scaled']

In [41]:
assembler = VectorAssembler(inputCols=assembler_input, outputCol='feature_vector')
stages += [assembler]
stages

[StringIndexer_6e318c54a75f,
 OneHotEncoder_a8819829b0e0,
 StringIndexer_b372654022cd,
 OneHotEncoder_42124236f939,
 StringIndexer_e16732074e59,
 OneHotEncoder_fb61272d411c,
 VectorAssembler_0fc409cf09ce,
 StandardScaler_552ff29606ac,
 VectorAssembler_eb142e822879,
 StandardScaler_ea56b1aec20c,
 VectorAssembler_3e90929562ae,
 StandardScaler_0d563182b800,
 VectorAssembler_805631bd348e]

In [42]:
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=stages)
fitted_transform = pipeline.fit(train_df)
vtrain_df = fitted_transform.transform(train_df)
vtrain_df.printSchema()

root
 |-- passenger_count: integer (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_time: integer (nullable = true)
 |-- day_of_week: string (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- pickup_location_id_idx: double (nullable = false)
 |-- pickup_location_id_onehot: vector (nullable = true)
 |-- dropoff_location_id_idx: double (nullable = false)
 |-- dropoff_location_id_onehot: vector (nullable = true)
 |-- day_of_week_idx: double (nullable = false)
 |-- day_of_week_onehot: vector (nullable = true)
 |-- passenger_count_vector: vector (nullable = true)
 |-- passenger_count_scaled: vector (nullable = true)
 |-- trip_distance_vector: vector (nullable = true)
 |-- trip_distance_scaled: vector (nullable = true)
 |-- pickup_time_vector: vector (nullable = true)
 |-- pickup_time_scaled: vector (nullable = true)
 |-- feature_vector: vector (nul

In [43]:
vtrain_df.select('feature_vector').show(2)

+--------------------+
|      feature_vector|
+--------------------+
|(534,[62,312,528,...|
|(534,[62,281,527,...|
+--------------------+
only showing top 2 rows



In [44]:
from pyspark.ml.regression import LinearRegression
lr = LinearRegression(maxIter=50, solver='normal', labelCol='total_amount', featuresCol='feature_vector')

In [45]:
model = lr.fit(vtrain_df)

In [47]:
# 테스트데이터도 변환
vtest_df = fitted_transform.transform(test_df)
# 테스트데이터로 예측
pred = model.transform(vtest_df)

In [48]:
pred.cache()

DataFrame[passenger_count: int, pickup_location_id: int, dropoff_location_id: int, trip_distance: double, pickup_time: int, day_of_week: string, total_amount: double, pickup_location_id_idx: double, pickup_location_id_onehot: vector, dropoff_location_id_idx: double, dropoff_location_id_onehot: vector, day_of_week_idx: double, day_of_week_onehot: vector, passenger_count_vector: vector, passenger_count_scaled: vector, trip_distance_vector: vector, trip_distance_scaled: vector, pickup_time_vector: vector, pickup_time_scaled: vector, feature_vector: vector, prediction: double]

In [49]:
pred.select('total_amount', 'prediction').show(3)

+------------+------------------+
|total_amount|        prediction|
+------------+------------------+
|       12.35| 12.62900702706409|
|        11.8|14.466237679610787|
|        12.3|14.775495186138414|
+------------+------------------+
only showing top 3 rows



In [50]:
model.summary.r2, model.summary.rootMeanSquaredError

(0.7966956145005624, 5.863545681582043)

In [51]:
spark.stop()