In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as f

In [2]:
spark = SparkSession.builder \
    .master("local[*]")\
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "1g") \
    .getOrCreate()

In [3]:
from pipeline_oriented_analytics.dataframe import CsvDataFrame, ParquetDataFrame
from pipeline_oriented_analytics import Phase

features_df = ParquetDataFrame(f'../data/processed/{Phase.train.name.lower()}/features', spark)
test_data_frac = 0.1
test_features_df, train_features_df = features_df.randomSplit([test_data_frac, 1-test_data_frac])

In [9]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import DecisionTreeRegressor

label_col = 'duration_min'
model = Pipeline(stages=[
    StringIndexer(inputCol='pickup_cell_8', handleInvalid='keep', outputCol='pickup_cell_8_idx'),
    StringIndexer(inputCol='dropoff_cell_8', handleInvalid='keep', outputCol='dropoff_cell_8_idx'),
    VectorAssembler(inputCols=['pickup_cell_8_idx', 'dropoff_cell_8_idx', 'distance', 'month', 'day_of_month', 
                               'day_of_week', 'hour', 'requests_pickup_cell', 'requests_dropoff_cell'], outputCol="features"),
    DecisionTreeRegressor(maxDepth=7, featuresCol='features', labelCol=label_col, maxBins=100)
]).fit(train_features_df)

In [13]:
model_path = '../model/trip_duration_min'
print(f'Saving model to {model_path}')
model.write().overwrite().save(model_path)
print(f'Model saved...')

Saving model to ../model/trip_duration_min
Model saved...


In [16]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import PipelineModel

model = PipelineModel.load(model_path)
predictions_df = model.transform(test_features_df)
mae_cv = RegressionEvaluator(labelCol=label_col, metricName='mae').evaluate(predictions_df)
print(f'Mean absolutre error: {mae_cv}')

Best model MAE: 6.872564895418289


In [20]:
predictions_df.groupby().agg(f.mean(label_col)).show()

+------------------+
| avg(duration_min)|
+------------------+
|16.010745581160823|
+------------------+

