In [None]:
# pandas and plotting libraries for visualizations
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# module containing functions for manipulating pyspark dataframes
import pyspark.sql.functions as f

# class which will let us create spark objects
from pyspark.sql import SparkSession

# helper functions for intro class
from helpers import display, read_df

## [PySpark SQL docs](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html)
 - the main functions you'll need to manipulate data in pyspark dataframes are in this module

## [Data Dictionary](https://digital.cityofchicago.org/index.php/chicago-taxi-data-released/)
 - details about the dataset used, here filtered down to just 2016 data

## Create a Spark Session

In [None]:
spark = (
    SparkSession
    .builder
    .appName('data_exploration')
    .master('local[2]')
    .getOrCreate()
)

## Read in data file

In [None]:
df = read_df(spark, '../taxi_2016')

In [None]:
display(df)

In [None]:
df.columns

In [None]:
display(df, 10)

In [None]:
total_rows = df.count()
print(total_rows)

In [None]:
small = df.sample(fraction=.1)

In [None]:
small.count()

In [None]:
display(df.agg(f.countDistinct('taxi_id')))

In [None]:
display(df.agg((f.sum('trip_miles')/total_rows).alias('trip_miles')))

In [None]:
# df.agg(agg1, agg1) yes
# df.agg([agg1, agg2]) no
# df.agg(*[agg1, agg2]) yes

In [None]:
display(df.agg(*[(f.count(c)/total_rows).alias(c) for c in df.columns]))

In [None]:
trips_per_taxi = df.groupBy('taxi_id').count()

In [None]:
display(trips_per_taxi, 10)

In [None]:
plt.figure()
sns.distplot(trips_per_taxi.select('count').toPandas()).set_title('Trips Per Taxi');

In [None]:
distance_traveled_per_taxi = (
    df
    .groupBy('taxi_id')
    .agg(f.sum('trip_miles').alias('miles'))
)

In [None]:
display(distance_traveled_per_taxi, 10)

In [None]:
plt.figure()
(
    sns
    .distplot(distance_traveled_per_taxi.select('miles').toPandas())
    .set_title('Miles Traveled Per Taxi')
);

In [None]:
distance_traveled_per_taxi.where(f.col('miles') < 1000)


In [None]:
plt.figure()
(
    sns
    .distplot(distance_traveled_per_taxi.where('miles < 1000').select('miles').toPandas())
    .set_title('Miles Traveled Per Taxi (capped at 1000)')
);

In [None]:
display(distance_traveled_per_taxi.orderBy('miles', ascending=False), 10)

In [None]:
display(distance_traveled_per_taxi.orderBy(f.desc('miles')), 10)

## Exercises

In [None]:
# when do most trips occur? 

In [None]:
# what's the most common length for a trip in miles? in minutes?

In [None]:
# are there companies that only use cash or only use credit?

In [None]:
spark.stop()