## Spark Context

In [1]:
from pyspark import SparkContext
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.functions import *
from pyspark.sql.types import * #to support struct type schema
from functools import *

spark = SparkSession.builder.appName("Taxi")\
        .config("spark.driver.memory", "25g")\
        .config("spark.driver.cores", "4")\
        .getOrCreate()
sc=spark.sparkContext


## Libraries

In [2]:
# from pyspark.rdd import portable_hash
# from pyspark.statcounter import StatCounter

import os
import json
import numpy as np
from datetime import datetime
from operator import itemgetter
#from itertools import chain, imap
from shapely.geometry import shape, Point
from matplotlib import pyplot as plt
%matplotlib inline

## Function to print RDD content

In [3]:
from  pprint import pprint
def title(s):
    pprint("---- %s -----" %s)    
    
def see(s, v):
    pprint("---- %s -----" %s)
    pprint(v)

## Reading the data

In [4]:
taxiRaw_Rdd = sc.textFile("data/yellow_tripdata_2015-12.csv")
#header = sc.parallelize(taxiRawAll.take(1))
#taxiRaw = taxiRaw.coalesce(1) #Makes 1 file as an output, since it reduced the # of partitions to 1
#header.union(taxiRaw).coalesce(1).saveAsTextFile("../../data/ch08-geospatial/trip_data_sample.csv")

In [5]:
see("taxiRaw_Rdd",taxiRaw_Rdd.take(2))

'---- taxiRaw_Rdd -----'
['VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RatecodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount',
 '2,2015-12-01 00:00:00,2015-12-01 '
 '00:05:16,5,.96,-73.979942321777344,40.765380859375,1,N,-73.96630859375,40.763088226318359,1,5.5,0.5,0.5,1,0,0.3,7.8']


In [6]:
taxiRaw_Rdd.count()

11460574

## Cleaning Data

### 1. Parse the needed columns

In [7]:
def parse(fields):
    
    pickupTime = datetime.strptime(fields[1], '%Y-%m-%d %H:%M:%S')
    Count = int(fields[3])
    PU_lat = float(fields[5]) 
    PU_lng = float(fields[6])
    DO_lat = float(fields[9]) 
    DO_lng = float(fields[10])

    return (pickupTime, PU_lat, PU_lng, DO_lat, DO_lng, Count)


taxiParsed_rdd = taxiRaw_Rdd\
        .map(lambda line: line.split(','))\
        .filter(lambda fields: len(fields) == 19 and fields[0] != "VendorID")\
        .map(parse)
taxiParsed_rdd.cache()

see("taxiParsed_rdd", taxiParsed_rdd.take(2))

'---- taxiParsed_rdd -----'
[(datetime.datetime(2015, 12, 1, 0, 0),
  -73.97994232177734,
  40.765380859375,
  -73.96630859375,
  40.76308822631836,
  5),
 (datetime.datetime(2015, 12, 1, 0, 0),
  -73.97233581542969,
  40.76237869262695,
  -73.9936294555664,
  40.74599838256836,
  2)]


In [8]:
taxiParsed_rdd.count()

11460573

### 2. Setting the boundaries and eliminating the outliers

In [9]:
min_lat = 40.477399
max_lat = 40.917577
min_lng = -74.259090
max_lng = -73.700272

boundaries = {'min_lat':min_lat,'max_lat':max_lat,'min_lng':min_lng,'max_lng':max_lng}


def boundary_limit(boundaries,row):
    return (row[1] >= min_lng and row[1] <= max_lng)and\
            (row[2] >= min_lat and row[2] <= max_lat)and\
            (row[3] >= min_lng and row[3] <= max_lng)and\
            (row[4] >= min_lat and row[4] <= max_lat)

In [10]:
taxiReady_rdd = taxiParsed_rdd.filter(lambda x:boundary_limit(boundaries,x))
taxiReady_rdd.persist()
taxiParsed_rdd.unpersist()
see("taxiReady_rdd", taxiReady_rdd.take(2))

'---- taxiReady_rdd -----'
[(datetime.datetime(2015, 12, 1, 0, 0),
  -73.97994232177734,
  40.765380859375,
  -73.96630859375,
  40.76308822631836,
  5),
 (datetime.datetime(2015, 12, 1, 0, 0),
  -73.97233581542969,
  40.76237869262695,
  -73.9936294555664,
  40.74599838256836,
  2)]


In [11]:
taxiReady_rdd.count()

11265525

In [12]:
taxiReady_df = taxiReady_rdd.toDF(["datetime","PU_lng","PU_lat","DO_lng","DO_lat","Count"])

In [13]:
taxiReady_df.printSchema()

root
 |-- datetime: timestamp (nullable = true)
 |-- PU_lng: double (nullable = true)
 |-- PU_lat: double (nullable = true)
 |-- DO_lng: double (nullable = true)
 |-- DO_lat: double (nullable = true)
 |-- Count: long (nullable = true)



In [14]:
taxiReady_df.show(3)

+-------------------+------------------+------------------+------------------+-----------------+-----+
|           datetime|            PU_lng|            PU_lat|            DO_lng|           DO_lat|Count|
+-------------------+------------------+------------------+------------------+-----------------+-----+
|2015-12-01 00:00:00|-73.97994232177734|   40.765380859375|   -73.96630859375|40.76308822631836|    5|
|2015-12-01 00:00:00|-73.97233581542969| 40.76237869262695| -73.9936294555664|40.74599838256836|    2|
|2015-12-01 00:00:00| -73.9688491821289|40.764530181884766|-73.97454833984375|40.79164123535156|    1|
+-------------------+------------------+------------------+------------------+-----------------+-----+
only showing top 3 rows



### 3. Save the clean file to CSV

In [15]:
taxiReady_df.coalesce(1).write.save("taxiReady_df.csv",format="csv")

In [16]:
# saving rdd to csv
"""
def toCSVLine(data):
    return ','.join(str(d) for d in data)
taxiReady_rdd.map(toCSVLine).saveAsTextFile('taxiReady_df.csv')
"""

"\ndef toCSVLine(data):\n    return ','.join(str(d) for d in data)\ntaxiReady_rdd.map(toCSVLine).saveAsTextFile('taxiReady_df.csv')\n"

## Preprocessing

### load cleaned data (optional):

In [26]:
customSchema = StructType([ \
    StructField("_c0", TimestampType(), True),\
    StructField("_c1", FloatType(), True),\
    StructField("_c2", FloatType(), True),\
    StructField("_c3", FloatType(), True),\
    StructField("_c4", FloatType(), True),\
    StructField("_c4", LongType(), True)])

taxiReady_df = spark.read.csv("taxiReady_df.csv",inferSchema=True)
taxiReady_df = taxiReady_df.withColumnRenamed('_c0', 'datetime')\
                           .withColumnRenamed('_c1', 'PU_lng')\
                           .withColumnRenamed('_c2', 'PU_lat')\
                           .withColumnRenamed('_c3', 'DO_lng')\
                           .withColumnRenamed('_c4', 'DO_lat')\
                           .withColumnRenamed('_c5', 'Count')

In [27]:
taxiReady_df.dtypes

[('datetime', 'timestamp'),
 ('PU_lng', 'double'),
 ('PU_lat', 'double'),
 ('DO_lng', 'double'),
 ('DO_lat', 'double'),
 ('Count', 'int')]

In [28]:
taxiReady_df.show(3)

+-------------------+------------------+------------------+------------------+-----------------+-----+
|           datetime|            PU_lng|            PU_lat|            DO_lng|           DO_lat|Count|
+-------------------+------------------+------------------+------------------+-----------------+-----+
|2015-12-01 00:00:00|-73.97994232177734|   40.765380859375|   -73.96630859375|40.76308822631836|    5|
|2015-12-01 00:00:00|-73.97233581542969| 40.76237869262695| -73.9936294555664|40.74599838256836|    2|
|2015-12-01 00:00:00| -73.9688491821289|40.764530181884766|-73.97454833984375|40.79164123535156|    1|
+-------------------+------------------+------------------+------------------+-----------------+-----+
only showing top 3 rows



### 2. Generating the fishnet

In [29]:
min_lat = 40.477399
max_lat = 40.917577
min_lng = -74.259090
max_lng = -73.700272

boundaries = {'min_lat':min_lat,'max_lat':max_lat,'min_lng':min_lng,'max_lng':max_lng}

In [30]:
def fishnet(df,boundaries = boundaries, lat_split = 20, lng_split = 30):
    """
    boundaries: dictionary contains lat/lng min/max points
    lat_split: number of lat splits
    lng_split: number of lng splits
    """
    lat_step = (boundaries['max_lat'] - boundaries['min_lat']) / lat_split
    lng_step = (boundaries['max_lng'] - boundaries['min_lng']) / lng_split
    min_lat = boundaries['min_lat']
    min_lng = boundaries['min_lng']

    #return (df-min_lat)//lat_split
    return df.withColumn('Plat_grid', floor((df.PU_lat - min_lat)/lat_step))\
             .withColumn('Plng_grid', floor((df.PU_lng - min_lng)/lng_step))\
             .withColumn('Dlat_grid', floor((df.DO_lat - min_lat)/lat_step))\
             .withColumn('Dlng_grid', floor((df.DO_lng - min_lng)/lng_step))
    

In [64]:
taxi_grid_df = fishnet(taxiReady_df).cache()
taxi_grid_df.show(5)

+-------------------+------------------+------------------+------------------+------------------+-----+---------+---------+---------+---------+
|           datetime|            PU_lng|            PU_lat|            DO_lng|            DO_lat|Count|Plat_grid|Plng_grid|Dlat_grid|Dlng_grid|
+-------------------+------------------+------------------+------------------+------------------+-----+---------+---------+---------+---------+
|2015-12-01 00:00:00|-73.97994232177734|   40.765380859375|   -73.96630859375| 40.76308822631836|    5|       13|       14|       12|       15|
|2015-12-01 00:00:00|-73.97233581542969| 40.76237869262695| -73.9936294555664| 40.74599838256836|    2|       12|       15|       12|       14|
|2015-12-01 00:00:00| -73.9688491821289|40.764530181884766|-73.97454833984375| 40.79164123535156|    1|       13|       15|       14|       15|
|2015-12-01 00:00:01|-73.99393463134766| 40.74168395996094|-73.99766540527344|40.747467041015625|    1|       12|       14|       12|   

## Processing SQL Query

In [65]:
taxi_grid_df.createOrReplaceTempView("taxi_all")

In [66]:
def runQuery(sqlQuery):
    """
    Receives:SQL Query as string
    Returns first 10 rows of the result
    """
    
    spark.sql(sqlQuery).createOrReplaceTempView("out_table")

    title("Query first 10 Rows")
    spark.sql("SELECT * FROM out_table").show(10)

In [69]:
runQuery("SELECT datetime, Plat_grid, Plng_grid, Dlat_grid, Dlng_grid FROM taxi_all\
         WHERE datetime BETWEEN '2015-12-06 11:55:00' AND '2015-12-06 12:30:00'\
         AND Plat_grid = 15\
         AND Plng_grid = 17\
         ORDER BY datetime")

'---- Query first 10 Rows -----'
+-------------------+---------+---------+---------+---------+
|           datetime|Plat_grid|Plng_grid|Dlat_grid|Dlng_grid|
+-------------------+---------+---------+---------+---------+
|2015-12-06 11:55:34|       15|       17|       14|       15|
|2015-12-06 11:57:16|       15|       17|       14|       15|
|2015-12-06 12:05:12|       15|       17|       14|       16|
|2015-12-06 12:08:33|       15|       17|       13|       14|
|2015-12-06 12:11:27|       15|       17|        9|       18|
|2015-12-06 12:15:08|       15|       17|       12|       15|
|2015-12-06 12:17:38|       15|       17|       15|       16|
|2015-12-06 12:20:40|       15|       17|       10|       13|
|2015-12-06 12:25:39|       15|       17|       14|       17|
|2015-12-06 12:26:16|       15|       17|       15|       16|
+-------------------+---------+---------+---------+---------+
only showing top 10 rows

