## Get a list of all files

In [1]:
! hadoop fs -ls /tmp/parsed-data/

In [2]:
import os
import subprocess
from pyspark.sql.functions import col
from pyspark import SQLContext
from pyspark.sql.types import StringType, StructField, StructType

sqlc = SQLContext(sc)
files = subprocess.check_output('hadoop fs -ls /tmp/raw-data'.split()).strip().split('\n')
files = [file for file in map(lambda path: path.split()[-1], files) if file.endswith('.csv') and 'yellow' in file]
print(files[:10])

['/tmp/raw-data/yellow_tripdata_2009-01.csv', '/tmp/raw-data/yellow_tripdata_2009-02.csv', '/tmp/raw-data/yellow_tripdata_2009-03.csv', '/tmp/raw-data/yellow_tripdata_2009-04.csv', '/tmp/raw-data/yellow_tripdata_2009-05.csv', '/tmp/raw-data/yellow_tripdata_2009-06.csv', '/tmp/raw-data/yellow_tripdata_2009-07.csv', '/tmp/raw-data/yellow_tripdata_2009-08.csv', '/tmp/raw-data/yellow_tripdata_2009-09.csv', '/tmp/raw-data/yellow_tripdata_2009-10.csv']


## Create a list of RDDs, where each is the file in memory

In [3]:
data_files = [sc.textFile('hdfs://' + file) for file in files]
data_file = data_files[1]
print(data_file.name())
data_file.take(3)

hdfs:///tmp/raw-data/yellow_tripdata_2009-02.csv


[u'vendor_name,Trip_Pickup_DateTime,Trip_Dropoff_DateTime,Passenger_Count,Trip_Distance,Start_Lon,Start_Lat,Rate_Code,store_and_forward,End_Lon,End_Lat,Payment_Type,Fare_Amt,surcharge,mta_tax,Tip_Amt,Tolls_Amt,Total_Amt',
 u'',
 u'DDS,2009-02-03 08:25:00,2009-02-03 08:33:39,1,1.6000000000000001,-73.992767999999998,40.758324999999999,,,-73.994709999999998,40.739722999999998,CASH,6.9000000000000004,0,,0,0,6.9000000000000004']

## Split each line into fields

In [4]:
data_file = data_file.map(lambda line: line.split(','))
data_file.take(2)

[[u'vendor_name',
  u'Trip_Pickup_DateTime',
  u'Trip_Dropoff_DateTime',
  u'Passenger_Count',
  u'Trip_Distance',
  u'Start_Lon',
  u'Start_Lat',
  u'Rate_Code',
  u'store_and_forward',
  u'End_Lon',
  u'End_Lat',
  u'Payment_Type',
  u'Fare_Amt',
  u'surcharge',
  u'mta_tax',
  u'Tip_Amt',
  u'Tolls_Amt',
  u'Total_Amt'],
 [u'']]

## Construct dataframe

In [5]:
from pyspark.sql.types import StructField, StructType, StringType

field_names = ' '.join(data_file.take(1)[0])
column_count = len(field_names.split(' '))
data_file = data_file.filter(lambda line: len(line) == column_count and ' '.join(line) != field_names)

fields = [StructField(field_name, StringType(), True) for field_name in field_names.split(' ')]
schema = StructType(fields)
df = data_file.toDF(schema=schema)

df.select(*['Trip_Pickup_DateTime', 'Passenger_Count']).show(3)

+--------------------+---------------+
|Trip_Pickup_DateTime|Passenger_Count|
+--------------------+---------------+
| 2009-02-03 08:25:00|              1|
| 2009-02-28 00:26:00|              5|
| 2009-02-22 00:39:23|              1|
+--------------------+---------------+
only showing top 3 rows



In [6]:
# header for 07-12 months in 2016 require a  pre-defined column
bad_2016_cols = ['vendorid',
                 'tpep_pickup_datetime',
                 'tpep_dropoff_datetime',
                 'passenger_count',
                 'trip_distance',
                 'pickup_longitude',
                 'pickup_latitude',
                 'ratecodeid',
                 'store_and_fwd_flag',
                 'dropoff_longitude',
                 'dropoff_latitude',
                 'payment_type',
                 'fare_amount',
                 'extra',
                 'mta_tax',
                 'tip_amount',
                 'tolls_amount',
                 'improvement_surcharge',
                 'total_amount']

# Final columns we're interested in
columns = [
            'pickup_datetime',
            'dropoff_datetime',
            'tip_amount',
            'fare_amount',
            'total_amount',
            'vendor_id',
            'passenger_count',
            'trip_distance',
            'payment_type',
            'tolls_amount',
        ]

# Some files require some hard renaming of column names
hard_renames = {
    'vendor_name': 'vendor_id',
    'total_amt': 'total_amount',
    'tolls_amt': 'tolls_amount',
    'fare_amt': 'fare_amount',
    'tip_amt': 'tip_amount',
    'trip_pickup_datetime': 'pickup_datetime',
    'trip_dropoff_datetime': 'dropoff_datetime'          
}

## Combine all into cleaning function

In [7]:
import sys

for i, data_file in enumerate(data_files):
    
    year, month = data_file.name()[-11:-4].split('-')
    sys.stdout.write('\rYear: {} Month: {}'.format(year, month))
    
    # Split data_file lines into fields
    data_file = data_file.map(lambda line: line.split(','))
    
    # Define field_names, column_count varies if this is late 2016
    field_names = ' '.join([c.strip() for c in data_file.take(1)[0] if c.strip()])
    if year != '2016' and month not in ['07', '08', '09', '10', '11', '12']:
        column_count = len(field_names.split(' ')) 
    else:
        column_count = len(bad_2016_cols)
        
    # Remove the header and any row which doesn't match length of column_count / header
    data_file = data_file.filter(lambda line: len(line) == column_count and ' '.join(line) != field_names)
    
    # check to see if field_names need to be changed, needed above to kick out wrong header
    if year == '2016' and month in ['07', '08', '09', '10', '11', '12']:
        field_names = ' '.join(bad_2016_cols)

    # Construct schema and convert to dataframe
    fields = [StructField(field_name, StringType(), True) for field_name in field_names.split(' ')]
    schema = StructType(fields)
    df = sqlc.createDataFrame(data_file, schema=schema)
    
    # Rename all columns to .lower() and hard renaming mapping
    df = df.select([col(col_name).alias(col_name.lower() 
                                    if col_name.lower() not in hard_renames.keys() else 
                                    hard_renames[col_name.lower()]
                                   )
                for col_name in df.columns]
              )
    
    # Rename all columns to not have _ or 'tpep'
    df = df.select([col(col_name).alias(col_name.replace('_', '').replace('tpep', '')) for col_name in df.columns])
    
    # Now select all columns in dataframe which match columns mapping
    df = df.select([col(col_name.replace('_', '')).alias(col_name) for col_name in columns])
    
    if not i:
        final_df = df
    else:
        final_df = final_df.unionAll(df)

print()
print(df.columns)
df.select(*columns[:3]).show(3)

Year: 2017 Month: 06()
['pickup_datetime', 'dropoff_datetime', 'tip_amount', 'fare_amount', 'total_amount', 'vendor_id', 'passenger_count', 'trip_distance', 'payment_type', 'tolls_amount']
+-------------------+-------------------+----------+
|    pickup_datetime|   dropoff_datetime|tip_amount|
+-------------------+-------------------+----------+
|2017-06-08 07:52:31|2017-06-08 08:01:32|      1.86|
|2017-06-08 08:08:18|2017-06-08 08:14:00|      2.34|
|2017-06-08 08:16:49|2017-06-08 15:43:22|         0|
+-------------------+-------------------+----------+
only showing top 3 rows



In [None]:
final_df.write.save('hdfs:///tmp/parsed-data/data', format='csv', header=True, sep=',')