In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from modules.my_spark_regression import *
from modules.my_pyspark import *
from modules.my_drawer import MyDrawer

In [3]:
spark = MyPySpark(session=True, sql=True)
drawer = MyDrawer()

# 3. Xây dựng model

## 3.1. Chuẩn bị & chuẩn hóa dữ liệu, xác định input, output

* Đọc dữ liệu

In [4]:
file_path = r'./data/flights.csv'

In [5]:
data = spark.readFile(file_path)

In [6]:
data.printSchema()

root
 |-- mon: integer (nullable = true)
 |-- dom: integer (nullable = true)
 |-- dow: integer (nullable = true)
 |-- carrier: string (nullable = true)
 |-- flight: integer (nullable = true)
 |-- org: string (nullable = true)
 |-- mile: integer (nullable = true)
 |-- depart: double (nullable = true)
 |-- duration: integer (nullable = true)
 |-- delay: string (nullable = true)



In [7]:
data.head()

Row(mon=11, dom=20, dow=6, carrier='US', flight=19, org='JFK', mile=2153, depart=9.48, duration=351, delay='NA')

* Xác định input và output

In [8]:
data.columns

['mon',
 'dom',
 'dow',
 'carrier',
 'flight',
 'org',
 'mile',
 'depart',
 'duration',
 'delay']

In [9]:
input_features = [
    'mile'
]

In [10]:
assembler = VectorAssembler(inputCols=input_features, outputCol='features')

In [11]:
data_pre = assembler.transform(data)

In [12]:
data_pre.select('features').show(2, False)

+--------+
|features|
+--------+
|[2153.0]|
|[316.0] |
+--------+
only showing top 2 rows



In [13]:
final_data = data_pre.select('features', 'duration')

In [14]:
final_data.show()

+--------+--------+
|features|duration|
+--------+--------+
|[2153.0]|     351|
| [316.0]|      82|
| [337.0]|      82|
|[1236.0]|     195|
| [258.0]|      65|
| [550.0]|     102|
| [733.0]|     135|
|[1440.0]|     232|
|[1829.0]|     250|
| [158.0]|      60|
|[1464.0]|     210|
| [978.0]|     160|
| [719.0]|     151|
|[1745.0]|     264|
|[1097.0]|     190|
| [967.0]|     158|
|[1735.0]|     265|
| [802.0]|     160|
| [948.0]|     160|
| [944.0]|     166|
+--------+--------+
only showing top 20 rows



## 3.2. Chuẩn bị train/test data

In [15]:
train_data, test_data = final_data.randomSplit((0.8, 0.2))

In [16]:
train_data.describe().show()

+-------+------------------+
|summary|          duration|
+-------+------------------+
|  count|             39957|
|   mean|151.94096153364868|
| stddev| 87.09705092087508|
|    min|                30|
|    max|               560|
+-------+------------------+



In [17]:
test_data.describe().show()

+-------+------------------+
|summary|          duration|
+-------+------------------+
|  count|             10043|
|   mean|151.06900328587076|
| stddev| 86.83879557617304|
|    min|                31|
|    max|               560|
+-------+------------------+



> * Dữ liệu train và test gần như tương đương, ko có sự chênh lệch cao về mặt thống kê

## 3.3. Xây dựng model

* Tạo model Linear Regression

In [18]:
lr = LinearRegression(featuresCol='features', labelCol='duration', predictionCol='prediction')

* Fit model với data và gán model cho một biến nào đó

In [19]:
lrModel = lr.fit(train_data)

* In ra coefficients và intercept

In [20]:
lrModel.coefficients, lrModel.intercept

(DenseVector([0.1216]), 44.440062415664244)

## 3.4. Đánh giá model vs test data

In [21]:
test_results = lrModel.evaluate(test_data)

* Đánh giá phần dư

In [22]:
test_results.residuals.show()

+-------------------+
|          residuals|
+-------------------+
| -9.586394736421823|
| -6.586394736421823|
| -6.586394736421823|
| -6.586394736421823|
| -6.586394736421823|
| -5.586394736421823|
| -5.586394736421823|
| -5.586394736421823|
| -4.586394736421823|
| -4.586394736421823|
| -4.586394736421823|
| -4.586394736421823|
| -4.586394736421823|
|-3.5863947364218234|
|-3.5863947364218234|
|-3.5863947364218234|
|-2.5863947364218234|
|-1.5863947364218234|
|-1.5863947364218234|
|-1.5863947364218234|
+-------------------+
only showing top 20 rows



* Đánh giá RMSE

In [23]:
test_results.rootMeanSquaredError

17.003770681934345

In [33]:
from pyspark.ml.evaluation import RegressionEvaluator

In [34]:
RegressionEvaluator(labelCol='duration').evaluate(test_model)

17.003770681934345

* Đánh giá mean squared error

In [24]:
test_results.meanSquaredError

289.12821740380997

* Đánh giá $R^2$

In [25]:
test_results.r2

0.9616552296008894

## 3.5. Đánh giá model vs test data

In [26]:
test_model = lrModel.transform(test_data)
test_model.select('prediction', 'duration').show()

+-----------------+--------+
|       prediction|duration|
+-----------------+--------+
|52.58639473642182|      43|
|52.58639473642182|      46|
|52.58639473642182|      46|
|52.58639473642182|      46|
|52.58639473642182|      46|
|52.58639473642182|      47|
|52.58639473642182|      47|
|52.58639473642182|      47|
|52.58639473642182|      48|
|52.58639473642182|      48|
|52.58639473642182|      48|
|52.58639473642182|      48|
|52.58639473642182|      48|
|52.58639473642182|      49|
|52.58639473642182|      49|
|52.58639473642182|      49|
|52.58639473642182|      50|
|52.58639473642182|      51|
|52.58639473642182|      51|
|52.58639473642182|      51|
+-----------------+--------+
only showing top 20 rows



## 3.6. Lưu trữ & tải model

* Lưu model

In [27]:
file_path1 = r'./data/lrModel_flight'

In [28]:
lrModel.save(file_path1)

* Tải model

In [29]:
from pyspark.ml.regression import LinearRegressionModel

In [30]:
lrModel2 = LinearRegressionModel.load(file_path1)

## 3.7. Dự đoán dữ liệu mới

In [31]:
unlabeled_data = test_data.select('features')
preditions = lrModel2.transform(unlabeled_data)

In [32]:
preditions.show()

+--------+-----------------+
|features|       prediction|
+--------+-----------------+
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
|  [67.0]|52.58639473642182|
+--------+-----------------+
only showing top 20 rows

