## 예측 모델 정확도 올리기 위해 다른 변수들을 추가해서 모델 학습

In [1]:
from pyspark.sql import SparkSession

In [2]:
MAX_MEMORY = "5g"
spark = SparkSession.builder.appName("taxi-fare-prediction-expands")\
                            .config("spark.executor.memory", MAX_MEMORY)\
                            .config("spark.driver.memory", MAX_MEMORY)\
                            .getOrCreate()

22/04/21 13:19:47 WARN Utils: Your hostname, devkhk-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 172.30.1.27 instead (on interface en0)
22/04/21 13:19:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/21 13:19:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [10]:
trips_dir = "/Users/devkhk/Documents/data-engineering-study/data/trips/*"
trips_df = spark.read.csv(f"file:///{trips_dir}", inferSchema=True, header=True)

                                                                                

In [11]:
trips_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [13]:
trips_df.createOrReplaceTempView("trips")

In [14]:
query = """
SELECT
    passenger_count,
    PULocationID as pickup_location_id,
    DOLocationID as dropoff_location_id,
    trip_distance,
    HOUR(tpep_pickup_datetime) as hour,
    DAYOFWEEK(TO_DATE(tpep_pickup_datetime)) as day_of_week,
    total_amount
FROM
    trips
WHERE
    total_amount < 5000
    AND total_amount > 0
    AND trip_distance > 0
    AND trip_distance < 500
    AND passenger_count < 4
    AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
    AND TO_DATE(tpep_pickup_datetime) < '2021-08-01'
"""

data_df = spark.sql(query)
data_df.createOrReplaceTempView('data')

In [15]:
data_df.show()

+---------------+------------------+-------------------+-------------+----+-----------+------------+
|passenger_count|pickup_location_id|dropoff_location_id|trip_distance|hour|day_of_week|total_amount|
+---------------+------------------+-------------------+-------------+----+-----------+------------+
|              0|               138|                265|         16.5|   0|          2|       70.07|
|              1|                68|                264|         1.13|   0|          2|       11.16|
|              1|               239|                262|         2.68|   0|          2|       18.59|
|              1|               186|                 91|         12.4|   0|          2|        43.8|
|              2|               132|                265|          9.7|   0|          2|        32.3|
|              1|               138|                141|          9.3|   0|          2|       43.67|
|              1|               138|                 50|         9.58|   0|          2|    

In [16]:
train_df, test_df = data_df.randomSplit([.8, .2], seed=1)

In [19]:
# parquet형태 저장
data_dir = "/Users/devkhk/Documents/data-engineering-study/data/"
train_df.write.parquet(path=f"{data_dir}train-review/",mode="overwrite")
test_df.write.parquet(f"{data_dir}test-review/", mode="overwrite")

                                                                                

In [20]:
del train_df, test_df

In [21]:
%whos

Variable       Type            Data/Info
----------------------------------------
MAX_MEMORY     str             5g
SparkSession   type            <class 'pyspark.sql.session.SparkSession'>
data_df        DataFrame       DataFrame[passenger_count<...>nt, total_amount: double]
data_dir       str             /Users/devkhk/Documents/d<...>a-engineering-study/data/
query          str             \nSELECT\n    passenger_c<...>atetime) < '2021-08-01'\n
spark          SparkSession    <pyspark.sql.session.Spar<...>object at 0x7ff2027e58b0>
trips_df       DataFrame       DataFrame[VendorID: int, <...>estion_surcharge: double]
trips_dir      str             /Users/devkhk/Documents/d<...>eering-study/data/trips/*


In [24]:
# parqeut 불러오기
train_df = spark.read.parquet(f"file:///{data_dir}train-review/")
test_df = spark.read.parquet(f"file:///{data_dir}test-review/")

In [27]:
train_df.show(5)
test_df.show(5)

+---------------+------------------+-------------------+-------------+----+-----------+------------+
|passenger_count|pickup_location_id|dropoff_location_id|trip_distance|hour|day_of_week|total_amount|
+---------------+------------------+-------------------+-------------+----+-----------+------------+
|              0|                 4|                  4|          0.1|   2|          1|         6.8|
|              0|                 4|                  4|          2.2|   2|          7|        15.3|
|              0|                 4|                 48|          2.8|  16|          7|        19.3|
|              0|                 4|                 79|          0.6|  14|          5|         8.3|
|              0|                 4|                 87|          2.7|  15|          6|        15.8|
+---------------+------------------+-------------------+-------------+----+-----------+------------+
only showing top 5 rows

+---------------+------------------+-------------------+----------

In [28]:
train_df.printSchema()

root
 |-- passenger_count: integer (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- hour: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- total_amount: double (nullable = true)



In [55]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.ml.feature import VectorAssembler, StandardScaler

In [56]:
# 카테고리 / Numberic Feature 분류 => pipeline stages 만들기
stages = []

cat_features = [
    "pickup_location_id",
    "dropoff_location_id",
    "hour",
    "day_of_week"
]

num_features = [
    "passenger_count",
    "trip_distance",
]

for c in cat_features:
    cat_indexer = StringIndexer(inputCol=c, outputCol= c + "_idx").setHandleInvalid("keep")
    one_encoder = OneHotEncoder(inputCol=cat_indexer.getOutputCol(), outputCol= c + "_onehot")
    stages += [cat_indexer, one_encoder]

for n in num_features:
    num_vector = VectorAssembler(inputCols=[n], outputCol= n + "_vector")
    num_std = StandardScaler(inputCol=num_vector.getOutputCol(), outputCol=n + "_std")
    stages += [num_vector, num_std]

In [57]:
stages

[StringIndexer_fc2486a1ca25,
 OneHotEncoder_cc5790cc8a13,
 StringIndexer_6e9b789b994f,
 OneHotEncoder_799eccab5615,
 StringIndexer_b78e0c5ab2d4,
 OneHotEncoder_97e28f3f9de4,
 StringIndexer_96bb03d0147f,
 OneHotEncoder_af4a1ee7ca13,
 VectorAssembler_d515f97b3d18,
 StandardScaler_7f68c742c379,
 VectorAssembler_bee950cfefe4,
 StandardScaler_16f1b098f881]

In [58]:
assembler = [c+"_onehot" for c in cat_features] + [n + "_std" for n in num_features]
assembler

['pickup_location_id_onehot',
 'dropoff_location_id_onehot',
 'hour_onehot',
 'day_of_week_onehot',
 'passenger_count_std',
 'trip_distance_std']

In [60]:
vassembler = VectorAssembler(inputCols=assembler, outputCol="features")
stages += [vassembler]

In [61]:
stages

[StringIndexer_fc2486a1ca25,
 OneHotEncoder_cc5790cc8a13,
 StringIndexer_6e9b789b994f,
 OneHotEncoder_799eccab5615,
 StringIndexer_b78e0c5ab2d4,
 OneHotEncoder_97e28f3f9de4,
 StringIndexer_96bb03d0147f,
 OneHotEncoder_af4a1ee7ca13,
 VectorAssembler_d515f97b3d18,
 StandardScaler_7f68c742c379,
 VectorAssembler_bee950cfefe4,
 StandardScaler_16f1b098f881,
 VectorAssembler_94ebbaede811]

In [62]:
from pyspark.ml.pipeline import Pipeline
pipeline = Pipeline(stages=stages)

In [63]:
fitted_transformer = pipeline.fit(train_df)

                                                                                

In [65]:
vtrain_df = fitted_transformer.transform(train_df)

In [68]:
vtrain_df.select(["features"]).show()

+--------------------+
|            features|
+--------------------+
|(556,[62,311,543,...|
|(556,[62,311,543,...|
|(556,[62,273,527,...|
|(556,[62,280,526,...|
|(556,[62,308,524,...|
|(556,[62,290,540,...|
|(556,[62,279,529,...|
|(556,[62,299,534,...|
|(556,[62,288,528,...|
|(556,[62,310,531,...|
|(556,[62,303,532,...|
|(556,[62,303,543,...|
|(556,[62,301,523,...|
|(556,[62,301,533,...|
|(556,[62,301,527,...|
|(556,[62,282,533,...|
|(556,[62,266,536,...|
|(556,[62,266,528,...|
|(556,[62,370,525,...|
|(556,[62,324,526,...|
+--------------------+
only showing top 20 rows



In [76]:
# train dataFrame 을 vector로 만드는 파이프라인을 통과 시켰으니 모델 학습을 추가한다.
from pyspark.ml.regression import LinearRegression
lr = LinearRegression(
    maxIter=30,
    regParam=0.01,
    labelCol="total_amount",
    featuresCol="features",
)

In [77]:
# vtrain_df를 모델 학습
model = lr.fit(vtrain_df)

22/04/21 16:04:50 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/04/21 16:04:50 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
22/04/21 16:05:00 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
                                                                                

In [78]:
vtest_df = fitted_transformer.transform(test_df)

In [80]:
predictions = model.transform(vtest_df)

In [82]:
predictions.printSchema()

root
 |-- passenger_count: integer (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- hour: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- pickup_location_id_idx: double (nullable = false)
 |-- pickup_location_id_onehot: vector (nullable = true)
 |-- dropoff_location_id_idx: double (nullable = false)
 |-- dropoff_location_id_onehot: vector (nullable = true)
 |-- hour_idx: double (nullable = false)
 |-- hour_onehot: vector (nullable = true)
 |-- day_of_week_idx: double (nullable = false)
 |-- day_of_week_onehot: vector (nullable = true)
 |-- passenger_count_vector: vector (nullable = true)
 |-- passenger_count_std: vector (nullable = true)
 |-- trip_distance_vector: vector (nullable = true)
 |-- trip_distance_std: vector (nullable = true)
 |-- features: vector (nullable = true)
 |-- prediction: dou

In [84]:
predictions.select(["day_of_week", "hour", "total_amount", "prediction"]).show()

+-----------+----+------------+------------------+
|day_of_week|hour|total_amount|        prediction|
+-----------+----+------------+------------------+
|          5|  20|        12.3|14.635430095883565|
|          7|  23|       23.15|19.958199770493934|
|          4|  16|        16.3|16.487046381443452|
|          5|  16|         5.8| 8.741823421141877|
|          5|  20|        65.3|45.438651128967834|
|          6|  13|        13.3|47.755725589684076|
|          6|   2|        17.8|49.465441940257406|
|          4|  17|        76.3| 68.78800790045243|
|          1|  17|        17.3| 20.00459903495035|
|          2|  17|        24.3| 27.73866726998494|
|          7|  16|       27.35| 26.02304078097364|
|          3|  19|       32.75|30.188940913571383|
|          4|  10|         8.8|12.376017430698168|
|          4|  14|        12.8|17.114979497261892|
|          2|  16|        15.8| 17.69950086110069|
|          6|  19|       20.75|20.560003900267276|
|          7|  19|        20.3|

In [85]:
model.summary.rootMeanSquaredError

5.622837560755986

In [86]:
model.summary.r2

0.8102252707480382