In [3]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as f

In [4]:
spark = SparkSession.builder \
    .master("local[*]")\
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "1g") \
    .getOrCreate()

In [7]:
from pyspark.ml import PipelineModel
from pipeline_oriented_analytics.transformer import *
from pipeline_oriented_analytics.dataframe import *
from typing import List, Dict

column_names = ['pickup_datetime', 'pickup_longitude', 'pickup_latitude', 'dropoff_longitude', 'dropoff_latitude']
column_new_names = {'pickup_longitude': 'pickup_lon', 'pickup_latitude': 'pickup_lat', 'dropoff_longitude': 'dropoff_lon', 'dropoff_latitude': 'dropoff_lat'}
column_types = {'pickup_lon': 'double', 'pickup_lat': 'double', 'dropoff_lon': 'double', 'dropoff_lat': 'double'}
level = 6
pickup_cell = f'pickup_cell_{level}'
dropoff_cell = f'dropoff_cell_{level}'

def PREPARE_TRIP_DATA(level: int, column_names: List[str], column_new_names: List[str], column_types: Dict[str, str]) -> PipelineModel: 
    return PipelineModel([
        SelectColumns(column_names),
        RenameColumns(column_new_names),
        NormalizeColumnTypes(column_types),
        CellId(level, 'pickup_lat', 'pickup_lon', pickup_cell),
        CellId(level, 'dropoff_lat', 'dropoff_lon', dropoff_cell)
    ])


df = PipelineModel([
    PREPARE_TRIP_DATA(level, column_names, column_new_names, column_types),
    Union(
        PREPARE_TRIP_DATA(level, column_names, column_new_names, column_types).transform((CsvDataFrame('../data/raw/test.csv', spark)))
    ),
    SelectColumns([pickup_cell, dropoff_cell]),
    DropDuplicates(),
    SphereDistance(pickup_cell, dropoff_cell),
    SaveToParquet('../data/processed/distance_matrix')
]).transform(CsvDataFrame('../data/raw/train.csv', spark)).cache()


df.show(2)
df.printSchema()
print(f'Distance entries: {df.count()}')
print(f'Avg distance {df.groupby().avg("distance").collect()[0].asDict()["avg(distance)"]}')
print(f'Zero distanxe trips count {df.where(f.col("distance") == 0).count()}')      


+--------------+---------------+--------+
|pickup_cell_14|dropoff_cell_14|distance|
+--------------+---------------+--------+
|      89c258fd|       89c259a9|    1.35|
|      89c258f5|       89c25855|    1.63|
+--------------+---------------+--------+
only showing top 2 rows

root
 |-- pickup_cell_14: string (nullable = true)
 |-- dropoff_cell_14: string (nullable = true)
 |-- distance: float (nullable = true)

Distance entries: 126668
Avg distance 8.024320740666472
Zero distanxe trips count 1621
