In [25]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as f

In [26]:
spark = SparkSession.builder \
    .master("local[*]")\
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "1g") \
    .getOrCreate()

In [27]:
from pipeline_oriented_analytics.pipe import Pipe, IF
from pipeline_oriented_analytics.transformer import *
from typing import List, Dict
from pipeline_oriented_analytics.dataframe import CsvDataFrame, ParquetDataFrame
from pipeline_oriented_analytics import Phase

#phase = Phase.train
phase = Phase.predict

variables = ('pickup_datetime', 'pickup_longitude', 'pickup_latitude', 'dropoff_longitude', 'dropoff_latitude')
lables = ('trip_duration', )
column_names = {'pickup_longitude': 'pickup_lon', 'pickup_latitude': 'pickup_lat', 'dropoff_longitude': 'dropoff_lon', 'dropoff_latitude': 'dropoff_lat', 'trip_duration': 'duration_sec'}
variable_types = {'pickup_datetime': 'timestamp', 'pickup_lon': 'double', 'pickup_lat': 'double', 'dropoff_lon': 'double', 'dropoff_lat': 'double'}
label_types = {'duration': 'int'}

if phase.is_predict():
    columns = list(variables)
    column_types = variable_types
    data_path = '../data/raw/test.csv'
else: 
    columns = list(variables + lables)
    column_types = {**variable_types, **label_types}
    data_path = '../data/raw/train.csv'

df = Pipe([
    #Sample(0.0001),
    SelectColumns(columns),
    RenameColumns(column_names),
    NormalizeColumnTypes(column_types),
    CellId(6, 'pickup_lat', 'pickup_lon', 'pickup_cell_6'),
    CellId(6, 'dropoff_lat', 'dropoff_lon', 'dropoff_cell_6'),
    CellId(14, 'pickup_lat', 'pickup_lon', 'pickup_cell_14'),
    CellId(14, 'dropoff_lat', 'dropoff_lon', 'dropoff_cell_14'),
    Join(['pickup_cell_14', 'dropoff_cell_14'], Join.Method.left, ParquetDataFrame('../data/processed/distance_matrix', spark)),
    DropColumns(['pickup_lat', 'pickup_lon', 'dropoff_lon', 'dropoff_lat', 'pickup_cell_14', 'dropoff_cell_14']),
    SaveToParquet(f'../data/processed/{phase.name}/inputs'),
]).transform(CsvDataFrame(data_path, spark))

print(f'Saved {df.count()} rows of {phase.name} inputs')

Saved 625134 rows
