In [1]:
from pyspark.sql import SparkSession

In [3]:
# parquet 압축 코덱 선택, defalut:"snappy"
MAX_MEMORY = "5g"
spark = SparkSession.builder.appName("taxi-fare-prediction")\
            .config("spark.executor.memory", MAX_MEMORY)\
            .config("spark.driver.memory", MAX_MEMORY)\
            .config("spark.sql.parquet.compression.codec", None)\
            .getOrCreate()

22/04/08 00:31:38 WARN Utils: Your hostname, devkhk-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 172.30.1.27 instead (on interface en0)
22/04/08 00:31:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/08 00:31:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
trip_files = "/Users/devkhk/Documents/data-engineering-study/data/trips/*"

In [5]:
trips_df = spark.read.csv(f"file:///{trip_files}", inferSchema=True, header=True)

                                                                                

In [6]:
trips_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [7]:
trips_df.createOrReplaceTempView("trips")

In [52]:
query = """
SELECT
    passenger_count,
    PULocationID as pickup_location_id,
    DOLocationID as dropoff_location_id,
    trip_distance,
    HOUR(tpep_pickup_datetime) as pickup_time,
    DATE_FORMAT(TO_DATE(tpep_pickup_datetime), 'EEEE') as day_of_week,
    total_amount
FROM
    trips
WHERE
    total_amount < 5000
    AND total_amount > 0
    AND trip_distance > 0
    AND trip_distance < 500
    AND passenger_count < 4
    AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
    AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
"""

data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')

In [53]:
data_df.show()

+---------------+------------------+-------------------+-------------+-----------+-----------+------------+
|passenger_count|pickup_location_id|dropoff_location_id|trip_distance|pickup_time|day_of_week|total_amount|
+---------------+------------------+-------------------+-------------+-----------+-----------+------------+
|              0|               138|                265|         16.5|          0|     Monday|       70.07|
|              1|                68|                264|         1.13|          0|     Monday|       11.16|
|              1|               239|                262|         2.68|          0|     Monday|       18.59|
|              1|               186|                 91|         12.4|          0|     Monday|        43.8|
|              2|               132|                265|          9.7|          0|     Monday|        32.3|
|              1|               138|                141|          9.3|          0|     Monday|       43.67|
|              1|           

In [54]:
data_df.printSchema()

root
 |-- passenger_count: integer (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_time: integer (nullable = true)
 |-- day_of_week: string (nullable = true)
 |-- total_amount: double (nullable = true)



In [55]:
# 분석할 데이터 테이블이 준비 되었으면 학습용, 테스트용 데이터로 나눈다.
train_df, test_df = data_df.randomSplit([.8, .2], seed=1)

In [31]:
# *학습용 데이터를 꺼내 쓸 수 있도록 저장하기
data_dir = "/Users/devkhk/Documents/data-engineering-study/data/"

In [32]:
train_df.write.format("parquet").save(f"{data_dir}/train/")
test_df.write.format("parquet").save(f"{data_dir}/test/")

                                                                                

In [35]:
# 저장된 parquet 불러오기
train_df = spark.read.parquet(f"{data_dir}/train/")

In [36]:
test_df = spark.read.parquet(f"{data_dir}/test/")

In [56]:
train_df.printSchema()

root
 |-- passenger_count: integer (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_time: integer (nullable = true)
 |-- day_of_week: string (nullable = true)
 |-- total_amount: double (nullable = true)



In [75]:
# one hot Encoding : Wednesday -> 4 -> [0,0,0,1,0,0,0]

from pyspark.ml.feature import OneHotEncoder, StringIndexer
cat_feats = [
    "pickup_location_id",
    "dropoff_location_id",
    "day_of_week"
]

stages = []

for c in cat_feats:
    cat_indexer = StringIndexer(inputCol=c, outputCol= c + "_idx").setHandleInvalid("keep")
    onehot_encoder = OneHotEncoder(inputCols=[cat_indexer.getOutputCol()], outputCols=[c + "_onehot"])
    stages += [cat_indexer, onehot_encoder]

In [76]:
stages

[StringIndexer_62739c332208,
 OneHotEncoder_a5d8b7e8a003,
 StringIndexer_cdf76a70e9c2,
 OneHotEncoder_2b9708de83da,
 StringIndexer_beabeb2d762c,
 OneHotEncoder_db23d22fd30a]

In [77]:
from pyspark.ml.feature import VectorAssembler, StandardScaler

num_feats = [
    "passenger_count",
    "trip_distance",
    "pickup_time"
]

for n in num_feats:
    num_assembler = VectorAssembler(inputCols=[n], outputCol= n + "_vector")
    num_scalar = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol= n + "_scaled")
    stages += [num_assembler, num_scalar]

In [78]:
stages

[StringIndexer_62739c332208,
 OneHotEncoder_a5d8b7e8a003,
 StringIndexer_cdf76a70e9c2,
 OneHotEncoder_2b9708de83da,
 StringIndexer_beabeb2d762c,
 OneHotEncoder_db23d22fd30a,
 VectorAssembler_67b7c1afca8d,
 StandardScaler_e62b93e63109,
 VectorAssembler_909c188d7d37,
 StandardScaler_2ea98e00b5cc,
 VectorAssembler_e60e636f94ff,
 StandardScaler_3ff998f87394]

In [79]:
assembler_input = [c + "_onehot" for c in cat_feats] + [n + "_scaled" for n in num_feats]
assembler_input

['pickup_location_id_onehot',
 'dropoff_location_id_onehot',
 'day_of_week_onehot',
 'passenger_count_scaled',
 'trip_distance_scaled',
 'pickup_time_scaled']

In [80]:
assembler = VectorAssembler(inputCols=assembler_input, outputCol="feature_vector")
stages += [assembler]

In [81]:
stages

[StringIndexer_62739c332208,
 OneHotEncoder_a5d8b7e8a003,
 StringIndexer_cdf76a70e9c2,
 OneHotEncoder_2b9708de83da,
 StringIndexer_beabeb2d762c,
 OneHotEncoder_db23d22fd30a,
 VectorAssembler_67b7c1afca8d,
 StandardScaler_e62b93e63109,
 VectorAssembler_909c188d7d37,
 StandardScaler_2ea98e00b5cc,
 VectorAssembler_e60e636f94ff,
 StandardScaler_3ff998f87394,
 VectorAssembler_a9fb3e6a3d41]

In [82]:
from pyspark.ml import Pipeline

In [83]:
transform_stages = stages
pipeline = Pipeline(stages=transform_stages)
fitted_transformer = pipeline.fit(train_df)

                                                                                

In [84]:
vtrain_df = fitted_transformer.transform(train_df)

In [66]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(
        maxIter=50,
        solver="normal",
        labelCol="total_amount",
        featuresCol="feature_vector"
)

In [85]:
model = lr.fit(vtrain_df)

22/04/08 01:23:39 WARN Instrumentation: [b2a48522] regParam is zero, which might cause numerical instability and overfitting.
22/04/08 01:23:59 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/04/08 01:23:59 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
22/04/08 01:24:13 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
22/04/08 01:24:13 WARN Instrumentation: [b2a48522] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.
22/04/08 01:24:13 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
22/04/08 01:24:13 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
                                                                                

In [86]:
vtest_df = fitted_transformer.transform(test_df)

In [88]:
predictions = model.transform(vtest_df)

In [89]:
predictions.cache()

DataFrame[passenger_count: int, pickup_location_id: int, dropoff_location_id: int, trip_distance: double, pickup_time: int, day_of_week: string, total_amount: double, pickup_location_id_idx: double, pickup_location_id_onehot: vector, dropoff_location_id_idx: double, dropoff_location_id_onehot: vector, day_of_week_idx: double, day_of_week_onehot: vector, passenger_count_vector: vector, passenger_count_scaled: vector, trip_distance_vector: vector, trip_distance_scaled: vector, pickup_time_vector: vector, pickup_time_scaled: vector, feature_vector: vector, prediction: double]

In [93]:
predictions.select(["day_of_week","trip_distance", "total_amount", "prediction"]).show()

+-----------+-------------+------------+------------------+
|day_of_week|trip_distance|total_amount|        prediction|
+-----------+-------------+------------+------------------+
|    Tuesday|          1.0|       10.55|12.695522792729275|
|   Saturday|          1.7|        13.3| 14.45055801477692|
|     Friday|          4.1|        21.3|21.108271361254214|
|     Sunday|         11.5|        41.3| 40.87993984204375|
|   Saturday|          1.7|       14.15| 13.90693298261399|
|  Wednesday|          0.7|         5.8|  9.62248222618894|
|  Wednesday|          5.0|        24.3|21.147909957146926|
|   Thursday|          1.5|         8.8| 9.969750900636763|
|     Monday|         13.4|       66.35| 62.65097273030503|
|     Monday|         15.0|       70.67| 66.37330532523579|
|  Wednesday|         14.2|       85.65| 89.80098581271078|
|  Wednesday|          0.1|        55.3|12.483544948638677|
|    Tuesday|          3.9|       21.95|23.136774384461823|
|   Thursday|          4.7|        27.8|

In [94]:
model.summary.rootMeanSquaredError

5.6485201652667625

In [95]:
model.summary.r2

0.80849012500813

In [96]:
spark.stop()