# Spark data frames from CSV files: handling headers & column types

Start pyspark with IPython notebook with ``IPYTHON_OPTS="notebook" $SPARK_HOME/bin/pyspark`` from the command prompt

In [1]:
from pyspark.sql import SQLContext
from pyspark.sql.types import *  
sqlContext = SQLContext(sc)

Change the path in the command below to reflect the directory where you have saved the file ``nyctaxisub.csv``

In [2]:
taxiFile = sc.textFile("/user/root/analytic_out/part-r-00000")
taxiFile.count()

604

In [3]:
taxiFile.take(2)

[u'1\t90|13107|1061|16|163|1631|2536|41.83754|6.427445|51.73502',
 u'1\t91|13107|1148|17|170|1428|2634|43.5839|6.454062|49.96203']

Let's isolate the header, in order to eventually use it to get the field names:

In [4]:
header = taxiFile.first()
header

u'1\t90|13107|1061|16|163|1631|2536|41.83754|6.427445|51.73502'

We want to get rid of these double quotes around the field names, and then use the header to build the fields for our schema:

In [5]:
schemaString = header.replace('"','')
fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split('|')]
fields

[StructField(1	90|13107|1061|16|163|1631|2536|41.83754|6.427445|51.73502,StringType,true)]

How many elements there are in the header (i.e. how many data columns)?

In [6]:
len(fields)

1

OK, now let's modify the fields which should not be of type ``String``:

In [7]:
fields[2].dataType = TimestampType()
fields[3].dataType = FloatType()
fields[4].dataType = FloatType()
fields[7].dataType = IntegerType()
fields[8].dataType = TimestampType()
fields[9].dataType = FloatType()
fields[10].dataType = FloatType()
fields[11].dataType = IntegerType()
fields[13].dataType = FloatType()
fields[14].dataType = IntegerType()
fields

IndexError: list index out of range

Let's also get rid of the leading underscores in the first two field names (``_id`` and ``_rev``):

In [None]:
fields[0].name = 'id'
fields[1].name = 'rev'
fields

Now that we are satisfied with the data types, we can construct our schema, which we will use later below for building the dataframe:

In [None]:
schema = StructType(fields)
schema

Isolate the header and drop it off the actual data:

In [None]:
taxiHeader = taxiFile.filter(lambda l: "_id" in l)
taxiHeader.collect()

In [None]:
taxiNoHeader = taxiFile.subtract(taxiHeader)
taxiNoHeader.count()

We end up with 249,999 rows, as expected.

Before parsing the data, we have to import the necessary Python modules to handle ``datetimes``:

In [None]:
from datetime import *
from dateutil.parser import parse
# test it:
parse("2013-02-09 18:16:10")

We are now ready for our first attempt to parse the data with the correct types. We build a temporary RDD for this purpose - ``taxi_temp``:

In [None]:
taxi_temp = taxiNoHeader.map(lambda k: k.split(",")).map(lambda p: (p[0], p[1], parse(p[2].strip('"')), float(p[3]), float(p[4]) , p[5], p[6] , int(p[7]), parse(p[8].strip('"')), float(p[9]), float(p[10]), int(p[11]), p[12], float(p[13]), int(p[14]), p[15] ))
taxi_temp.top(2)

Finally, let's build our dataframe, using the ``taxi_temp`` RDD just produced and the ``schema`` variable computed above:

In [None]:
taxi_df = sqlContext.createDataFrame(taxi_temp, schema)
taxi_df.head(2)

We see that we still have quotes-within-quotes in our ``StringType`` variables. We make a second attempt, this time using Spark's ``rdd.toDF()`` method, in order to build the dataframe directly from ``taxiNoHeader`` RDD, without invoking the temporary ``taxi_temp`` RDD:

In [None]:
taxi_df = taxiNoHeader.map(lambda k: k.split(",")).map(lambda p: (p[0].strip('"'), p[1].strip('"'), parse(p[2].strip('"')), float(p[3]), float(p[4]) , p[5].strip('"'), p[6].strip('"') , int(p[7]), parse(p[8].strip('"')), float(p[9]), float(p[10]), int(p[11]), p[12].strip('"'), float(p[13]), int(p[14]), p[15].strip('"')) ).toDF(schema)
taxi_df.head(2)

Let's run some simple pandas-like queries. How many records per vendor are there in the dataset?

In [None]:
taxi_df.groupBy("vendor_id").count().show()

Recall that we have missing values in the field ``store_and_fwd_flag``. How many are they?

In [None]:
taxi_df.filter(taxi_df.store_and_fwd_flag == '').count()

OK, the number of missing values is dangerously close to the number of ``VTS`` vendor records. Is this a coincidence, or vendor ``VTS`` indeed tends not to log the subject variable?

In [None]:
taxi_df.filter(taxi_df.store_and_fwd_flag == '' and taxi_df.vendor_id == 'VTS').count()

Well, we have a finding! Indeed, all records coming from ``VTS`` vendor have missing value in the subject field...

``dtypes`` and ``printSchema()`` methods can be used to get information about the schema:

In [None]:
taxi_df.dtypes

In [None]:
taxi_df.printSchema()

We can run the SQL equivalent of the above pandas-like queries. First, we have to register the dataframe as a named temporary table, let's say ``taxi``:

In [None]:
taxi_df.registerTempTable("taxi")

In [None]:
sqlContext.sql("SELECT vendor_id, COUNT(*) FROM taxi GROUP BY vendor_id ").show()

In [None]:
sqlContext.sql("SELECT COUNT(*) FROM taxi WHERE store_and_fwd_flag = '' ").show()

In [None]:
sqlContext.sql("SELECT COUNT(*) FROM taxi WHERE vendor_id = 'VTS' AND store_and_fwd_flag = '' ").show()

Notice that, unlike standard SQL, table and column names are case sensitive, i.e. ``TAXI`` or ``vendor_ID`` in the queries will produce an error.

Let's change some column names to shorter versions:

In [None]:
taxi_df = taxi_df.withColumnRenamed('dropoff_longitude', 'dropoff_long').withColumnRenamed('dropoff_latitude', 'dropoff_lat').withColumnRenamed('pickup_latitude', 'pickup_lat').withColumnRenamed('pickup_longitude', 'pickup_long')

In [None]:
taxi_df.dtypes

Finally, let's make a row selection and store the results to a pandas dataframe:

In [None]:
import pandas as pd
taxi_CMT = taxi_df.filter("vendor_id = 'CMT' and store_and_fwd_flag != '' ").toPandas()

In [None]:
taxi_CMT.head()