# Traffic accidents analysis using Apache Spark and Google Colab


https://smoosavi.org/datasets/us_accidents

## Instalation and configuration of PySpark

### Installing Java 8

In [0]:
!apt-get install openjdk-8-jdk-headless > /dev/null

### Installing PySpark

In [0]:
!wget -q http://www-eu.apache.org/dist/spark/spark-2.4.4/spark-2.4.4-bin-hadoop2.7.tgz
!tar xf spark-2.4.4-bin-hadoop2.7.tgz

In [0]:
!pip install -q findspark

In [0]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.4-bin-hadoop2.7"

### Pointing Colaboratory to Google Drive

In [0]:
from google.colab import drive
drive.mount("/content/gdrive")

### Getting files from Drive

In [0]:
CSV = "/content/gdrive/My Drive/Colab Datasets/US_Accidents_May19.csv"

### Starting a Spark session

In [0]:
import findspark
findspark.init("spark-2.4.4-bin-hadoop2.7")

from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

df = spark.read.csv(CSV, inferSchema=True, header=True)

## Basic inspection

### Dataframe basic inspection

**dtypes** Return df column names and data types.

In [0]:
df.dtypes

[('ID', 'string'),
 ('Source', 'string'),
 ('TMC', 'double'),
 ('Severity', 'int'),
 ('Start_Time', 'timestamp'),
 ('End_Time', 'timestamp'),
 ('Start_Lat', 'double'),
 ('Start_Lng', 'double'),
 ('End_Lat', 'double'),
 ('End_Lng', 'double'),
 ('Distance(mi)', 'double'),
 ('Description', 'string'),
 ('Number', 'double'),
 ('Street', 'string'),
 ('Side', 'string'),
 ('City', 'string'),
 ('County', 'string'),
 ('State', 'string'),
 ('Zipcode', 'string'),
 ('Country', 'string'),
 ('Timezone', 'string'),
 ('Airport_Code', 'string'),
 ('Weather_Timestamp', 'timestamp'),
 ('Temperature(F)', 'double'),
 ('Wind_Chill(F)', 'double'),
 ('Humidity(%)', 'double'),
 ('Pressure(in)', 'double'),
 ('Visibility(mi)', 'double'),
 ('Wind_Direction', 'string'),
 ('Wind_Speed(mph)', 'double'),
 ('Precipitation(in)', 'double'),
 ('Weather_Condition', 'string'),
 ('Amenity', 'boolean'),
 ('Bump', 'boolean'),
 ('Crossing', 'boolean'),
 ('Give_Way', 'boolean'),
 ('Junction', 'boolean'),
 ('No_Exit', 'boolean'),

**show** Displays the content of df.

In [0]:
df.show()

+----+--------+-----+--------+-------------------+-------------------+------------------+------------------+-------+-------+------------+--------------------+------+--------------------+----+------------+----------+-----+----------+-------+----------+------------+-------------------+--------------+-------------+-----------+------------+--------------+--------------+---------------+-----------------+-----------------+-------+-----+--------+--------+--------+-------+-------+----------+-------+-----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|  ID|  Source|  TMC|Severity|         Start_Time|           End_Time|         Start_Lat|         Start_Lng|End_Lat|End_Lng|Distance(mi)|         Description|Number|              Street|Side|        City|    County|State|   Zipcode|Country|  Timezone|Airport_Code|  Weather_Timestamp|Temperature(F)|Wind_Chill(F)|Humidity(%)|Pressure(in)|Visibility(mi)|Wind_Direction|Wind_Speed(mph)

**head** Return first n rows.

In [0]:
df.head(3)

[Row(ID='A-1', Source='MapQuest', TMC=201.0, Severity=3, Start_Time=datetime.datetime(2016, 2, 8, 5, 46), End_Time=datetime.datetime(2016, 2, 8, 11, 0), Start_Lat=39.865147, Start_Lng=-84.058723, End_Lat=None, End_Lng=None, Distance(mi)=0.01, Description='Right lane blocked due to accident on I-70 Eastbound at Exit 41 OH-235 State Route 4.', Number=None, Street='I-70 E', Side='R', City='Dayton', County='Montgomery', State='OH', Zipcode='45424', Country='US', Timezone='US/Eastern', Airport_Code='KFFO', Weather_Timestamp=datetime.datetime(2016, 2, 8, 5, 58), Temperature(F)=36.9, Wind_Chill(F)=None, Humidity(%)=91.0, Pressure(in)=29.68, Visibility(mi)=10.0, Wind_Direction='Calm', Wind_Speed(mph)=None, Precipitation(in)=0.02, Weather_Condition='Light Rain', Amenity=False, Bump=False, Crossing=False, Give_Way=False, Junction=False, No_Exit=False, Railway=False, Roundabout=False, Station=False, Stop=False, Traffic_Calming=False, Traffic_Signal=False, Turning_Loop=False, Sunrise_Sunset='Night

**first** Return first row.

In [0]:
df.first()

Row(ID='A-1', Source='MapQuest', TMC=201.0, Severity=3, Start_Time=datetime.datetime(2016, 2, 8, 5, 46), End_Time=datetime.datetime(2016, 2, 8, 11, 0), Start_Lat=39.865147, Start_Lng=-84.058723, End_Lat=None, End_Lng=None, Distance(mi)=0.01, Description='Right lane blocked due to accident on I-70 Eastbound at Exit 41 OH-235 State Route 4.', Number=None, Street='I-70 E', Side='R', City='Dayton', County='Montgomery', State='OH', Zipcode='45424', Country='US', Timezone='US/Eastern', Airport_Code='KFFO', Weather_Timestamp=datetime.datetime(2016, 2, 8, 5, 58), Temperature(F)=36.9, Wind_Chill(F)=None, Humidity(%)=91.0, Pressure(in)=29.68, Visibility(mi)=10.0, Wind_Direction='Calm', Wind_Speed(mph)=None, Precipitation(in)=0.02, Weather_Condition='Light Rain', Amenity=False, Bump=False, Crossing=False, Give_Way=False, Junction=False, No_Exit=False, Railway=False, Roundabout=False, Station=False, Stop=False, Traffic_Calming=False, Traffic_Signal=False, Turning_Loop=False, Sunrise_Sunset='Night'

**take** Return the first n rows.

In [0]:
df.take(5) 

[Row(ID='A-1', Source='MapQuest', TMC=201.0, Severity=3, Start_Time=datetime.datetime(2016, 2, 8, 5, 46), End_Time=datetime.datetime(2016, 2, 8, 11, 0), Start_Lat=39.865147, Start_Lng=-84.058723, End_Lat=None, End_Lng=None, Distance(mi)=0.01, Description='Right lane blocked due to accident on I-70 Eastbound at Exit 41 OH-235 State Route 4.', Number=None, Street='I-70 E', Side='R', City='Dayton', County='Montgomery', State='OH', Zipcode='45424', Country='US', Timezone='US/Eastern', Airport_Code='KFFO', Weather_Timestamp=datetime.datetime(2016, 2, 8, 5, 58), Temperature(F)=36.9, Wind_Chill(F)=None, Humidity(%)=91.0, Pressure(in)=29.68, Visibility(mi)=10.0, Wind_Direction='Calm', Wind_Speed(mph)=None, Precipitation(in)=0.02, Weather_Condition='Light Rain', Amenity=False, Bump=False, Crossing=False, Give_Way=False, Junction=False, No_Exit=False, Railway=False, Roundabout=False, Station=False, Stop=False, Traffic_Calming=False, Traffic_Signal=False, Turning_Loop=False, Sunrise_Sunset='Night

**schema** Return the schema of df.

In [0]:
df.schema

StructType(List(StructField(ID,StringType,true),StructField(Source,StringType,true),StructField(TMC,DoubleType,true),StructField(Severity,IntegerType,true),StructField(Start_Time,TimestampType,true),StructField(End_Time,TimestampType,true),StructField(Start_Lat,DoubleType,true),StructField(Start_Lng,DoubleType,true),StructField(End_Lat,DoubleType,true),StructField(End_Lng,DoubleType,true),StructField(Distance(mi),DoubleType,true),StructField(Description,StringType,true),StructField(Number,DoubleType,true),StructField(Street,StringType,true),StructField(Side,StringType,true),StructField(City,StringType,true),StructField(County,StringType,true),StructField(State,StringType,true),StructField(Zipcode,StringType,true),StructField(Country,StringType,true),StructField(Timezone,StringType,true),StructField(Airport_Code,StringType,true),StructField(Weather_Timestamp,TimestampType,true),StructField(Temperature(F),DoubleType,true),StructField(Wind_Chill(F),DoubleType,true),StructField(Humidity(%)

**describe** Compute summary statistics.

In [0]:
df.describe().show()

+-------+--------+-------------+------------------+------------------+------------------+------------------+-----------------+-------------------+------------------+--------------------+------------------+------------------+-------+----------+---------+-------+------------------+-------+----------+------------+------------------+-----------------+------------------+------------------+------------------+--------------+------------------+--------------------+-----------------+--------------+--------------+-----------------+---------------------+
|summary|      ID|       Source|               TMC|          Severity|         Start_Lat|         Start_Lng|          End_Lat|            End_Lng|      Distance(mi)|         Description|            Number|            Street|   Side|      City|   County|  State|           Zipcode|Country|  Timezone|Airport_Code|    Temperature(F)|    Wind_Chill(F)|       Humidity(%)|      Pressure(in)|    Visibility(mi)|Wind_Direction|   Wind_Speed(mph)|   Precipi

**columns** Return the columns of df.

In [0]:
df.columns

['ID',
 'Source',
 'TMC',
 'Severity',
 'Start_Time',
 'End_Time',
 'Start_Lat',
 'Start_Lng',
 'End_Lat',
 'End_Lng',
 'Distance(mi)',
 'Description',
 'Number',
 'Street',
 'Side',
 'City',
 'County',
 'State',
 'Zipcode',
 'Country',
 'Timezone',
 'Airport_Code',
 'Weather_Timestamp',
 'Temperature(F)',
 'Wind_Chill(F)',
 'Humidity(%)',
 'Pressure(in)',
 'Visibility(mi)',
 'Wind_Direction',
 'Wind_Speed(mph)',
 'Precipitation(in)',
 'Weather_Condition',
 'Amenity',
 'Bump',
 'Crossing',
 'Give_Way',
 'Junction',
 'No_Exit',
 'Railway',
 'Roundabout',
 'Station',
 'Stop',
 'Traffic_Calming',
 'Traffic_Signal',
 'Turning_Loop',
 'Sunrise_Sunset',
 'Civil_Twilight',
 'Nautical_Twilight',
 'Astronomical_Twilight']

**count** Count the number of rows in df.

In [0]:
df.count()

2243939

**distinct + count** Count the number of distinct rows indf.

In [0]:
df.distinct().count()

2243939

**printSchema** Print the schema of df.

In [0]:
df.printSchema()

root
 |-- ID: string (nullable = true)
 |-- Source: string (nullable = true)
 |-- TMC: double (nullable = true)
 |-- Severity: integer (nullable = true)
 |-- Start_Time: timestamp (nullable = true)
 |-- End_Time: timestamp (nullable = true)
 |-- Start_Lat: double (nullable = true)
 |-- Start_Lng: double (nullable = true)
 |-- End_Lat: double (nullable = true)
 |-- End_Lng: double (nullable = true)
 |-- Distance(mi): double (nullable = true)
 |-- Description: string (nullable = true)
 |-- Number: double (nullable = true)
 |-- Street: string (nullable = true)
 |-- Side: string (nullable = true)
 |-- City: string (nullable = true)
 |-- County: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Zipcode: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Timezone: string (nullable = true)
 |-- Airport_Code: string (nullable = true)
 |-- Weather_Timestamp: timestamp (nullable = true)
 |-- Temperature(F): double (nullable = true)
 |-- Wind_Chill(F): double (n

**explain** Print the (logical and physical) plans.

In [0]:
df.explain()

== Physical Plan ==
*(1) FileScan csv [ID#10,Source#11,TMC#12,Severity#13,Start_Time#14,End_Time#15,Start_Lat#16,Start_Lng#17,End_Lat#18,End_Lng#19,Distance(mi)#20,Description#21,Number#22,Street#23,Side#24,City#25,County#26,State#27,Zipcode#28,Country#29,Timezone#30,Airport_Code#31,Weather_Timestamp#32,Temperature(F)#33,... 25 more fields] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/content/gdrive/My Drive/Colab Datasets/US_Accidents_May19.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<ID:string,Source:string,TMC:double,Severity:int,Start_Time:timestamp,End_Time:timestamp,St...


### Dataframe queries


**select** Allows us to access the columns of the DataFrame.

In [0]:
df.select("ID", "Description").show()
df.select("ID", "Temperature(F)", ((df["Temperature(F)"]-32)*(5/9)).alias("Temperature(C)")).show()
df.select(df["Humidity(%)"] > 90).show()

+----+--------------------+
|  ID|         Description|
+----+--------------------+
| A-1|Right lane blocke...|
| A-2|Accident on Brice...|
| A-3|Accident on OH-32...|
| A-4|Accident on I-75 ...|
| A-5|Accident on McEwe...|
| A-6|Accident on I-270...|
| A-7|Accident on Oakri...|
| A-8|Accident on I-75 ...|
| A-9|Accident on Notre...|
|A-10|Right hand should...|
|A-11|Accident on I-270...|
|A-12|One lane blocked ...|
|A-13|Accident on Rever...|
|A-14|Accident on Salem...|
|A-15|Accident on OH-16...|
|A-16|Accident on Wayne...|
|A-17|Accident on James...|
|A-18|Accident on Delph...|
|A-19|Accident on Stewa...|
|A-20|Accident on Hillc...|
+----+--------------------+
only showing top 20 rows

+----+--------------+------------------+
|  ID|Temperature(F)|    Temperature(C)|
+----+--------------+------------------+
| A-1|          36.9|2.7222222222222214|
| A-2|          37.9|3.2777777777777772|
| A-3|          36.0|2.2222222222222223|
| A-4|          35.1| 1.722222222222223|
| A-5|         

**when** Returns a new column from a condition.

In [0]:
from pyspark.sql.functions import when

df.select("Precipitation(in)", when(df["Humidity(%)"] > 90, 1).otherwise(0)).show()

+-----------------+----------------------------------------------+
|Precipitation(in)|CASE WHEN (Humidity(%) > 90) THEN 1 ELSE 0 END|
+-----------------+----------------------------------------------+
|             0.02|                                             1|
|              0.0|                                             1|
|             null|                                             1|
|             null|                                             1|
|             null|                                             0|
|             0.03|                                             1|
|             null|                                             1|
|             null|                                             1|
|             null|                                             1|
|             0.02|                                             1|
|             null|                                             1|
|             0.02|                                           

**like** Works like a SQL LIKE clause.

In [0]:
df.select("City", df.City.like("Reynoldsburg")).show()

+------------+----------------------+
|        City|City LIKE Reynoldsburg|
+------------+----------------------+
|      Dayton|                 false|
|Reynoldsburg|                  true|
|Williamsburg|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
| Westerville|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
| Westerville|                 false|
|    Columbus|                 false|
|Reynoldsburg|                  true|
|      Dayton|                 false|
|      Dayton|                 false|
|    Columbus|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
|      Dayton|                 false|
+------------+----------------------+
only showing top 20 rows



**startswith + endswith** Both return a boolean column based on string match

In [0]:
df.select("City", df.City.startswith("Day"), df.City.endswith("ville")).show()

+------------+---------------------+---------------------+
|        City|startswith(City, Day)|endswith(City, ville)|
+------------+---------------------+---------------------+
|      Dayton|                 true|                false|
|Reynoldsburg|                false|                false|
|Williamsburg|                false|                false|
|      Dayton|                 true|                false|
|      Dayton|                 true|                false|
| Westerville|                false|                 true|
|      Dayton|                 true|                false|
|      Dayton|                 true|                false|
|      Dayton|                 true|                false|
| Westerville|                false|                 true|
|    Columbus|                false|                false|
|Reynoldsburg|                false|                false|
|      Dayton|                 true|                false|
|      Dayton|                 true|                fals

**substr** Returns a column where each row has as value the substring that matches withthe value of the specified column.

In [0]:
df.select(df.City.substr(1, 3).alias("City_1_3")).take(4)

[Row(City_1_3='Day'),
 Row(City_1_3='Rey'),
 Row(City_1_3='Wil'),
 Row(City_1_3='Day')]

**between** Returns a boolean column based on the specified condition.

In [0]:
df.select("Pressure(in)", df["Pressure(in)"].between(29.65, 29.67)).show()

+------------+-----------------------------------------------------+
|Pressure(in)|((Pressure(in) >= 29.65) AND (Pressure(in) <= 29.67))|
+------------+-----------------------------------------------------+
|       29.68|                                                false|
|       29.65|                                                 true|
|       29.67|                                                 true|
|       29.64|                                                false|
|       29.65|                                                 true|
|       29.63|                                                false|
|       29.66|                                                 true|
|       29.66|                                                 true|
|       29.67|                                                 true|
|       29.62|                                                false|
|       29.64|                                                false|
|       29.62|                    

**groupBy** Groups theDataFrameusing the specified columns.

In [0]:
df.groupBy("City").count().show()

+--------------+-----+
|          City|count|
+--------------+-----+
|       Palermo|   10|
|   Santa Paula|  261|
|        Osteen|   12|
|      Moreland|   29|
|      Bluffton|  374|
|        Grimes|   54|
|          Dows|   12|
|     Worcester|  344|
|     Rhinebeck|  284|
|       Hanover| 1088|
|West Sand Lake|   41|
|        Agawam|    9|
|         Leola|  176|
|  Harleysville|  616|
|      Rawlings|    8|
|         Tyler|  110|
|    Westampton|    6|
|        Nahant|    4|
|   Middlefield|   60|
|  Saint George|  398|
+--------------+-----+
only showing top 20 rows



**filter** Filters the rows of theDataFrameusing given condition.

In [0]:
df.filter(df["Precipitation(in)"] < 0.01).select("Precipitation(in)").show()

+-----------------+
|Precipitation(in)|
+-----------------+
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
|              0.0|
+-----------------+
only showing top 20 rows



**orderBy** Returns a ordered version of the specified DataFrame.

In [0]:
df.select("ID").orderBy("ID", ascending=False).show()
df.select("Severity", "Humidity(%)").orderBy(["Severity", "Humidity(%)"], ascending=[1,0]).show()

+--------+
|      ID|
+--------+
|A-999999|
|A-999998|
|A-999997|
|A-999996|
|A-999995|
|A-999994|
|A-999993|
|A-999992|
|A-999991|
|A-999990|
| A-99999|
|A-999989|
|A-999988|
|A-999987|
|A-999986|
|A-999985|
|A-999984|
|A-999983|
|A-999982|
|A-999981|
+--------+
only showing top 20 rows

+--------+-----------+
|Severity|Humidity(%)|
+--------+-----------+
|       0|       96.0|
|       0|       83.0|
|       0|       83.0|
|       0|       82.0|
|       0|       82.0|
|       0|       78.0|
|       0|       78.0|
|       0|       77.0|
|       0|       71.0|
|       0|       63.0|
|       0|       62.0|
|       0|       55.0|
|       0|       54.0|
|       0|       46.0|
|       0|       44.0|
|       0|       37.0|
|       0|       35.0|
|       1|      100.0|
|       1|      100.0|
|       1|      100.0|
+--------+-----------+
only showing top 20 rows



### Running SQL queries programatically

In [0]:
df.createOrReplaceTempView("accidents")

spark.sql("SELECT * FROM accidents WHERE Stop == true").show()

+------+--------+-----+--------+-------------------+-------------------+------------------+-------------------+-------+-------+------------+--------------------+-------+----------------+----+--------------+-----------+-----+----------+-------+----------+------------+-------------------+--------------+-------------+-----------+------------+--------------+--------------+---------------+-----------------+-----------------+-------+-----+--------+--------+--------+-------+-------+----------+-------+----+---------------+--------------+------------+--------------+--------------+-----------------+---------------------+
|    ID|  Source|  TMC|Severity|         Start_Time|           End_Time|         Start_Lat|          Start_Lng|End_Lat|End_Lng|Distance(mi)|         Description| Number|          Street|Side|          City|     County|State|   Zipcode|Country|  Timezone|Airport_Code|  Weather_Timestamp|Temperature(F)|Wind_Chill(F)|Humidity(%)|Pressure(in)|Visibility(mi)|Wind_Direction|Wind_Speed

## Preprocessing

### Variable selection
It has been decided to discard part of the variablesfrom the dataset since they have not been considered of great relevance and would only slow down the execution of the queries.

In [0]:
COLUMNS_TO_DROP = ["ID", "End_Time", "Start_Lat", "Start_Lng", 
                   "End_Lat", "End_Lng", "Description", "Number", "Street", 
                   "Side", "City", "County", "Zipcode", "Country", "Timezone",
                   "Airport_Code", "Weather_Timestamp", "Wind_Chill(F)",
                   "Amenity", "Bump", "Crossing", "Give_Way", "Junction", 
                   "No_Exit", "Railway", "Roundabout", "Station", "Stop", 
                   "Traffic_Calming", "Traffic_Signal", "Turning_Loop",
                   "Sunrise_Sunset", "Nautical_Twilight", 
                   "Astronomical_Twilight"]

df = df.drop(*COLUMNS_TO_DROP)
df.columns

['Source',
 'TMC',
 'Severity',
 'Start_Time',
 'Distance(mi)',
 'State',
 'Temperature(F)',
 'Humidity(%)',
 'Pressure(in)',
 'Visibility(mi)',
 'Wind_Direction',
 'Wind_Speed(mph)',
 'Precipitation(in)',
 'Weather_Condition',
 'Civil_Twilight']

### Nulls values management
It has been decided to replace the null values with a new category "Unknown".

In [0]:
df = df.na.fill("Unknown")
df.filter(df["Wind_Direction"] == "Unknown").select("Wind_Direction").show()

+--------------+
|Wind_Direction|
+--------------+
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
|       Unknown|
+--------------+
only showing top 20 rows



### Duplicates management
If we were looking for the construction of a machine learning model, it would be interesting to eliminate duplicate values, since it would be redundant information.

In [0]:
records = df.count()

filtered_df = df.dropDuplicates()
after_drop = filtered_df.count()

print("number of records: {}\nnumber of records after removing duplicates: {}\nnumber of duplicated records: {}".format(records, after_drop, records-after_drop))

number of records: 2243939
number of records after removing duplicates: 2212561
number of duplicated records: 31378


## Analysis

This section of the practical example aims to carry out a small statistical study of the data, generating bar and donut charts for those categorical variables and computing a small set of basic statistical values for those numerical ones.

In [0]:
import plotly.graph_objects as go

def plot_categorical_variable(dataframe, variable_name, variable_description):
  # In order to plot, we transform our Spark dataframe into a Pandas dataframe
  # https://datascience.stackexchange.com/questions/37880/plotting-in-pyspark
  dataframe = dataframe.select(variable_name).toPandas()

  labels = dataframe[variable_name].value_counts().index
  values = dataframe[variable_name].value_counts().values

  # Bars plot
  fig = go.Figure(go.Bar(x=labels, y=values, name=""), layout=go.Layout(
          title=go.layout.Title(text="Total number of "+variable_description), 
          xaxis=dict(type="category")
      ))
  fig.show()

  # Donut plot
  fig = go.Figure(go.Pie(labels=labels, values=values, hole=.5, name=""), layout=go.Layout(
          title=go.layout.Title(text="Percentage of "+variable_description)
      ))
  fig.show()

def compute_numerical_variable_statistics(dataframe, variable_name):
  dataframe.describe(variable_name).show()

### Categorical variables

#### Source variable

Indicates source of the accident report (i.e. the API which reported the accident.).

In [0]:
plot_categorical_variable(df, "Source", "patterns acquired from each source")

#### TMC variable

A traffic accident may have a Traffic Message Channel (TMC) code which provides more detailed description of the event.

In [0]:
plot_categorical_variable(df, "TMC", "accidents per TMC")

#### Severity variable

Shows the severity of the accident, a number between 1 and 4, where 1 indicates the least impact on traffic (i.e., short delay as a result of the accident) and 4 indicates a significant impact on traffic (i.e., long delay).

In [0]:
plot_categorical_variable(df, "Severity", "accidents per severity")

#### State variable
Shows the state in address field.

In [0]:
plot_categorical_variable(df, "State", "accidents per state")

#### Wind_Direction variable

Shows wind direction.

In [0]:
plot_categorical_variable(df, "Wind_Direction", "accidents per wind direction")

#### Weather_Condition variable

Shows the weather condition (rain, snow, thunderstorm, fog, etc.)

In [0]:
plot_categorical_variable(df, "Weather_Condition", "accidents per weather condition")

In [0]:
from pyspark.sql.functions import udf

def gather_weather(s):
  if "Thunderstorm" in s:
    s = "Thunderstorm"
  elif "Cloud" in s:
    s = "Cloudy"
  else:
    if "Mostly" in s:
      s = s.replace("Mostly", "").strip()
    if "Partly" in s:
      s = s.replace("Partly", "").strip()
    if "Light" in s:
      s = s.replace("Light", "").strip()
    if "Heavy" in s:
      s = s.replace("Heavy", "").strip()
    if "Blowing" in s:
      s = s.replace("Blowing", "").strip()
    if "Widespread" in s:
      s = s.replace("Widespread", "").strip()
    if "Showers" in s:
      s = s.replace("Showers", "").strip()
    if "Grains" in s:
      s = s.replace("Grains", "").strip()
    if "Whirls" in s:
      s = s.replace("Whirls", "").strip()
    if "Freezing" in s:
      s = s.replace("Freezing", "").strip()
    if "Shallow" in s:
      s = s.replace("Shallow", "").strip()
    if "Low" in s:
      s = s.replace("Low", "").strip()
    if "Drifting" in s:
      s = s.replace("Drifting", "").strip()
    if "Patches of" in s:
      s = s.replace("Patches of", "").strip()
    if "Small" in s:
      s = s.replace("Small", "").strip()

  return s

gather_weather_udf = udf(gather_weather)

df = df.withColumn("Weather_Condition", gather_weather_udf("Weather_Condition"))

plot_categorical_variable(df, "Weather_Condition", "accidents per weather condition")

#### Civil_Twilight variable

Shows the period of day (i.e. day or night) based on civil twilight.
(https://en.wikipedia.org/wiki/Twilight#Civil_twilight)

In [0]:
plot_categorical_variable(df, "Civil_Twilight", "accidents for each period of the day (based on civil twilight)")

#### Start_Time variable

In [0]:
import time

# Holidays 
# https://www.independent.co.uk/life-style/us-federal-holidays-2019-when-calendar-list-americans-how-many-thanksgiving-christmas-memorial-day-a8688996.html
holidays = [
            (1,1), (21,1), (2,2), (14,2), (18,2), (17,3), (15,4), (19,4), 
            (21,4), (5,5), (12,5), (27,5), (16,6), (4,7), (2,9), (14,10), 
            (31,10), (11,11), (28,11), (25,12), (29,12)
           ]

def is_holiday(datetime):
  return (datetime.day, datetime.month) in holidays

is_holiday_udf = udf(is_holiday)

df = df.withColumn("Holiday", is_holiday_udf("Start_Time"))

plot_categorical_variable(df, "Holiday", "accidents in holiday")

In [0]:
def day_parts(datetime):
  if datetime.hour < 6:
    return "Early Morning"
  elif 6 <= datetime.hour < 12:
    return "Morning"
  elif 12 <= datetime.hour < 19:
    return "Afternoon"
  elif 19 <= datetime.hour:
    return "Night"

day_parts_udf = udf(day_parts)

df = df.withColumn("Part of the day", day_parts_udf("Start_Time"))

plot_categorical_variable(df, "Part of the day", "accidents for each part of the day")

### Numerical variables

#### Correlation between numerical variables

In [0]:
correlation = df.drop("Source", "TMC", "Severity", "State",
                      "Wind_Direction", "Weather_Condition", 
                      "Civil_Twilight").toPandas().corr()
correlation

Unnamed: 0,Distance(mi),Temperature(F),Humidity(%),Pressure(in),Visibility(mi),Wind_Speed(mph),Precipitation(in)
Distance(mi),1.0,-0.051803,0.019102,-0.003422,-0.014238,0.016701,-0.001827
Temperature(F),-0.051803,1.0,-0.300127,-0.287857,0.168717,-0.050278,0.042814
Humidity(%),0.019102,-0.300127,1.0,-0.027111,-0.383285,-0.119372,-0.018104
Pressure(in),-0.003422,-0.287857,-0.027111,1.0,0.063578,-0.146369,0.052253
Visibility(mi),-0.014238,0.168717,-0.383285,0.063578,1.0,0.001902,-0.023521
Wind_Speed(mph),0.016701,-0.050278,-0.119372,-0.146369,0.001902,1.0,-0.010058
Precipitation(in),-0.001827,0.042814,-0.018104,0.052253,-0.023521,-0.010058,1.0


In [0]:
features = [col for col in correlation.columns]
fig = go.Figure(data=go.Heatmap(z=correlation, x=features, y=features), 
                layout = {
                  "title": "Correlation Map", 
                  "xaxis": {"ticks": ""}, 
                  "yaxis": {"ticks": ""}
                })
fig.show()

#### Distance variable
The length of the road extent affected by the accident.

In [0]:
compute_numerical_variable_statistics(df, "Distance(mi)")

+-------+------------------+
|summary|      Distance(mi)|
+-------+------------------+
|  count|           2243939|
|   mean|0.2879094992995438|
| stddev| 1.532340789546452|
|    min|               0.0|
|    max|333.63000488299997|
+-------+------------------+



#### Temperature(F) variable
Shows the temperature (in Fahrenheit).

In [0]:
compute_numerical_variable_statistics(df, "Temperature(F)")

+-------+------------------+
|summary|    Temperature(F)|
+-------+------------------+
|  count|           2181674|
|   mean| 61.23243940203674|
| stddev|19.146156587688665|
|    min|             -77.8|
|    max|             170.6|
+-------+------------------+



#### Humidity(%) variable

Shows the humidity (in percentage).

In [0]:
compute_numerical_variable_statistics(df, "Humidity(%)")

+-------+------------------+
|summary|       Humidity(%)|
+-------+------------------+
|  count|           2179472|
|   mean| 65.92758200151229|
| stddev|22.430132917679703|
|    min|               4.0|
|    max|             100.0|
+-------+------------------+



#### Pressure(in) variable

Shows the air pressure (in inches).

In [0]:
compute_numerical_variable_statistics(df, "Pressure(in)")

+-------+------------------+
|summary|      Pressure(in)|
+-------+------------------+
|  count|           2186659|
|   mean|30.037469468261452|
| stddev|0.2267242220509422|
|    min|               0.0|
|    max|             33.04|
+-------+------------------+



#### Visibility(mi) variable

Shows visibility (in miles).

In [0]:
compute_numerical_variable_statistics(df, "Visibility(mi)")

+-------+------------------+
|summary|    Visibility(mi)|
+-------+------------------+
|  count|           2172579|
|   mean|  9.12409624690295|
| stddev|2.9863589915097206|
|    min|               0.0|
|    max|             140.0|
+-------+------------------+



#### Wind_Speed(mph) variable

Shows wind speed (in miles per hour).

In [0]:
compute_numerical_variable_statistics(df, "Wind_Speed(mph)")

+-------+------------------+
|summary|   Wind_Speed(mph)|
+-------+------------------+
|  count|           1800985|
|   mean| 8.844042176922095|
| stddev|4.9732003142137735|
|    min|               1.2|
|    max|             822.8|
+-------+------------------+



#### Precipitation(in) variable

Shows precipitation amount in inches, if there is any.

In [0]:
compute_numerical_variable_statistics(df, "Precipitation(in)")

+-------+--------------------+
|summary|   Precipitation(in)|
+-------+--------------------+
|  count|              264473|
|   mean|0.060438645910929595|
| stddev|  0.4396975552863648|
|    min|                 0.0|
|    max|                10.8|
+-------+--------------------+

