In [15]:
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import tensorflow as tf

In [16]:
data = pd.read_csv("dataset/USAccidents20162021/US_Accidents_Dec19.csv")

In [20]:
data

Unnamed: 0,ID,Source,TMC,Severity,Start_Time,End_Time,Start_Lat,Start_Lng,End_Lat,End_Lng,...,Roundabout,Station,Stop,Traffic_Calming,Traffic_Signal,Turning_Loop,Sunrise_Sunset,Civil_Twilight,Nautical_Twilight,Astronomical_Twilight
0,A-1,MapQuest,201.0,3,2016-02-08 05:46:00,2016-02-08 11:00:00,39.865147,-84.058723,,,...,False,False,False,False,False,False,Night,Night,Night,Night
1,A-2,MapQuest,201.0,2,2016-02-08 06:07:59,2016-02-08 06:37:59,39.928059,-82.831184,,,...,False,False,False,False,False,False,Night,Night,Night,Day
2,A-3,MapQuest,201.0,2,2016-02-08 06:49:27,2016-02-08 07:19:27,39.063148,-84.032608,,,...,False,False,False,False,True,False,Night,Night,Day,Day
3,A-4,MapQuest,201.0,3,2016-02-08 07:23:34,2016-02-08 07:53:34,39.747753,-84.205582,,,...,False,False,False,False,False,False,Night,Day,Day,Day
4,A-5,MapQuest,201.0,2,2016-02-08 07:39:07,2016-02-08 08:09:07,39.627781,-84.188354,,,...,False,False,False,False,True,False,Day,Day,Day,Day
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2974330,A-2974354,Bing,,2,2019-08-23 18:03:25,2019-08-23 18:32:01,34.002480,-117.379360,33.99888,-117.37094,...,False,False,False,False,False,False,Day,Day,Day,Day
2974331,A-2974355,Bing,,2,2019-08-23 19:11:30,2019-08-23 19:38:23,32.766960,-117.148060,32.76555,-117.15363,...,False,False,False,False,False,False,Day,Day,Day,Day
2974332,A-2974356,Bing,,2,2019-08-23 19:00:21,2019-08-23 19:28:49,33.775450,-117.847790,33.77740,-117.85727,...,False,False,False,False,False,False,Day,Day,Day,Day
2974333,A-2974357,Bing,,2,2019-08-23 19:00:21,2019-08-23 19:29:42,33.992460,-118.403020,33.98311,-118.39565,...,False,False,False,False,False,False,Day,Day,Day,Day


In [19]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2974335 entries, 0 to 2974334
Data columns (total 49 columns):
 #   Column                 Dtype  
---  ------                 -----  
 0   ID                     object 
 1   Source                 object 
 2   TMC                    float64
 3   Severity               int64  
 4   Start_Time             object 
 5   End_Time               object 
 6   Start_Lat              float64
 7   Start_Lng              float64
 8   End_Lat                float64
 9   End_Lng                float64
 10  Distance(mi)           float64
 11  Description            object 
 12  Number                 float64
 13  Street                 object 
 14  Side                   object 
 15  City                   object 
 16  County                 object 
 17  State                  object 
 18  Zipcode                object 
 19  Country                object 
 20  Timezone               object 
 21  Airport_Code           object 
 22  Weather_Timestamp 

In [21]:
data.isna().mean()

ID                       0.000000e+00
Source                   0.000000e+00
TMC                      2.447845e-01
Severity                 0.000000e+00
Start_Time               0.000000e+00
End_Time                 0.000000e+00
Start_Lat                0.000000e+00
Start_Lng                0.000000e+00
End_Lat                  7.552155e-01
End_Lng                  7.552155e-01
Distance(mi)             0.000000e+00
Description              3.362096e-07
Number                   6.447172e-01
Street                   0.000000e+00
Side                     0.000000e+00
City                     2.790540e-05
County                   0.000000e+00
State                    0.000000e+00
Zipcode                  2.958645e-04
Country                  0.000000e+00
Timezone                 1.063431e-03
Airport_Code             1.913369e-03
Weather_Timestamp        1.234057e-02
Temperature(F)           1.884892e-02
Wind_Chill(F)            6.228696e-01
Humidity(%)              1.989453e-02
Pressure(in)

In [22]:
null_columns = ['End_Lat', 'End_Lng', 'Number', 'Wind_Chill(F)', 'Precipitation(in)']
data = data.drop(null_columns, axis=1)

In [23]:
data.isnull().sum()

ID                            0
Source                        0
TMC                      728071
Severity                      0
Start_Time                    0
End_Time                      0
Start_Lat                     0
Start_Lng                     0
Distance(mi)                  0
Description                   1
Street                        0
Side                          0
City                         83
County                        0
State                         0
Zipcode                     880
Country                       0
Timezone                   3163
Airport_Code               5691
Weather_Timestamp         36705
Temperature(F)            56063
Humidity(%)               59173
Pressure(in)              48142
Visibility(mi)            65691
Wind_Direction            45101
Wind_Speed(mph)          440840
Weather_Condition         65932
Amenity                       0
Bump                          0
Crossing                      0
Give_Way                      0
Junction

In [24]:
data = data.dropna(axis=0).reset_index(drop=True)
print("total missing values", data.isnull().sum().sum())

total missing values 0


In [25]:
data

Unnamed: 0,ID,Source,TMC,Severity,Start_Time,End_Time,Start_Lat,Start_Lng,Distance(mi),Description,...,Roundabout,Station,Stop,Traffic_Calming,Traffic_Signal,Turning_Loop,Sunrise_Sunset,Civil_Twilight,Nautical_Twilight,Astronomical_Twilight
0,A-3,MapQuest,201.0,2,2016-02-08 06:49:27,2016-02-08 07:19:27,39.063148,-84.032608,0.01,Accident on OH-32 State Route 32 Westbound at ...,...,False,False,False,False,True,False,Night,Night,Day,Day
1,A-4,MapQuest,201.0,3,2016-02-08 07:23:34,2016-02-08 07:53:34,39.747753,-84.205582,0.01,Accident on I-75 Southbound at Exits 52 52B US...,...,False,False,False,False,False,False,Night,Day,Day,Day
2,A-5,MapQuest,201.0,2,2016-02-08 07:39:07,2016-02-08 08:09:07,39.627781,-84.188354,0.01,Accident on McEwen Rd at OH-725 Miamisburg Cen...,...,False,False,False,False,True,False,Day,Day,Day,Day
3,A-6,MapQuest,201.0,3,2016-02-08 07:44:26,2016-02-08 08:14:26,40.100590,-82.925194,0.01,Accident on I-270 Outerbelt Northbound near Ex...,...,False,False,False,False,False,False,Day,Day,Day,Day
4,A-7,MapQuest,201.0,2,2016-02-08 07:59:35,2016-02-08 08:29:35,39.758274,-84.230507,0.00,Accident on Oakridge Dr at Woodward Ave. Expec...,...,False,False,False,False,False,False,Day,Day,Day,Day
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1886973,A-2246283,MapQuest,201.0,3,2017-08-30 18:41:30,2017-08-30 19:11:07,34.495808,-118.623932,0.00,Accident on I-5 Southbound at Forest Rte-7N09 ...,...,False,False,False,False,False,False,Day,Day,Day,Day
1886974,A-2246284,MapQuest,201.0,3,2017-08-30 18:59:02,2017-08-30 19:27:57,34.031322,-118.433723,0.00,Left lane closed due to accident on I-10 at Na...,...,False,False,False,False,False,False,Day,Day,Day,Day
1886975,A-2246285,MapQuest,201.0,3,2017-08-30 18:57:52,2017-08-30 19:26:11,34.106785,-117.369102,0.00,Accident on Olive Ave at CA-66 Foothill Blvd.,...,False,False,False,False,False,False,Day,Day,Day,Day
1886976,A-2246286,MapQuest,201.0,3,2017-08-30 19:49:01,2017-08-30 20:18:00,33.924686,-118.103981,0.00,#1 lane blocked due to accident on I-605 North...,...,False,False,False,False,False,False,Night,Night,Day,Day


In [26]:
{column: len(data[column].unique()) for column in data.columns if data.dtypes[column] == 'object'}

{'ID': 1886978,
 'Source': 2,
 'Start_Time': 1844922,
 'End_Time': 1839202,
 'Description': 1191939,
 'Street': 118970,
 'Side': 3,
 'City': 9825,
 'County': 1550,
 'State': 49,
 'Zipcode': 289393,
 'Country': 1,
 'Timezone': 4,
 'Airport_Code': 1828,
 'Weather_Timestamp': 368398,
 'Wind_Direction': 23,
 'Weather_Condition': 113,
 'Sunrise_Sunset': 2,
 'Civil_Twilight': 2,
 'Nautical_Twilight': 2,
 'Astronomical_Twilight': 2}

In [27]:
unneeded_columns = ['ID', 'Description', 'Street', 'City', 'Zipcode', 'Country']

data = data.drop(unneeded_columns, axis = 1)
data

Unnamed: 0,Source,TMC,Severity,Start_Time,End_Time,Start_Lat,Start_Lng,Distance(mi),Side,County,...,Roundabout,Station,Stop,Traffic_Calming,Traffic_Signal,Turning_Loop,Sunrise_Sunset,Civil_Twilight,Nautical_Twilight,Astronomical_Twilight
0,MapQuest,201.0,2,2016-02-08 06:49:27,2016-02-08 07:19:27,39.063148,-84.032608,0.01,R,Clermont,...,False,False,False,False,True,False,Night,Night,Day,Day
1,MapQuest,201.0,3,2016-02-08 07:23:34,2016-02-08 07:53:34,39.747753,-84.205582,0.01,R,Montgomery,...,False,False,False,False,False,False,Night,Day,Day,Day
2,MapQuest,201.0,2,2016-02-08 07:39:07,2016-02-08 08:09:07,39.627781,-84.188354,0.01,R,Montgomery,...,False,False,False,False,True,False,Day,Day,Day,Day
3,MapQuest,201.0,3,2016-02-08 07:44:26,2016-02-08 08:14:26,40.100590,-82.925194,0.01,R,Franklin,...,False,False,False,False,False,False,Day,Day,Day,Day
4,MapQuest,201.0,2,2016-02-08 07:59:35,2016-02-08 08:29:35,39.758274,-84.230507,0.00,R,Montgomery,...,False,False,False,False,False,False,Day,Day,Day,Day
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1886973,MapQuest,201.0,3,2017-08-30 18:41:30,2017-08-30 19:11:07,34.495808,-118.623932,0.00,R,Los Angeles,...,False,False,False,False,False,False,Day,Day,Day,Day
1886974,MapQuest,201.0,3,2017-08-30 18:59:02,2017-08-30 19:27:57,34.031322,-118.433723,0.00,R,Los Angeles,...,False,False,False,False,False,False,Day,Day,Day,Day
1886975,MapQuest,201.0,3,2017-08-30 18:57:52,2017-08-30 19:26:11,34.106785,-117.369102,0.00,L,San Bernardino,...,False,False,False,False,False,False,Day,Day,Day,Day
1886976,MapQuest,201.0,3,2017-08-30 19:49:01,2017-08-30 20:18:00,33.924686,-118.103981,0.00,R,Los Angeles,...,False,False,False,False,False,False,Night,Night,Day,Day


In [28]:
def get_years(df, column):
    return df[column].apply(lambda date:date[0:4])

def get_months(df, column):
    return df[column].apply(lambda date: date[5:7])
data['Start_time_Month'] = get_months(data, 'Start_Time')
data['Start_Time_Year'] = get_years(data, 'Start_Time')

data['End_Time_Month'] = get_months(data, 'End_Time')
data['End_Time_Year'] = get_years(data, 'End_Time')

data['Weather_Timestamp_Month'] = get_months(data, 'Weather_Timestamp')
data['Weather_Timestamp_Year'] = get_years(data, 'Weather_Timestamp')
data = data.drop(['Start_Time', 'End_Time', 'Weather_Timestamp'], axis=1)

In [29]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1886978 entries, 0 to 1886977
Data columns (total 41 columns):
 #   Column                   Dtype  
---  ------                   -----  
 0   Source                   object 
 1   TMC                      float64
 2   Severity                 int64  
 3   Start_Lat                float64
 4   Start_Lng                float64
 5   Distance(mi)             float64
 6   Side                     object 
 7   County                   object 
 8   State                    object 
 9   Timezone                 object 
 10  Airport_Code             object 
 11  Temperature(F)           float64
 12  Humidity(%)              float64
 13  Pressure(in)             float64
 14  Visibility(mi)           float64
 15  Wind_Direction           object 
 16  Wind_Speed(mph)          float64
 17  Weather_Condition        object 
 18  Amenity                  bool   
 19  Bump                     bool   
 20  Crossing                 bool   
 21  Give_Way

In [30]:
def onehot_encode(df, columns, prefixes):
    df = df.copy()
    for column, prefix in zip(columns, prefixes):
        dummies = pd.get_dummies(df[column], prefix=prefix)
        df = pd.concat([df, dummies], axis=1)
        df = df.drop(column, axis=1)
    return df
{column: len(data[column].unique()) for column in data.columns if data.dtypes[column] == 'object'}

{'Source': 2,
 'Side': 3,
 'County': 1550,
 'State': 49,
 'Timezone': 4,
 'Airport_Code': 1828,
 'Wind_Direction': 23,
 'Weather_Condition': 113,
 'Sunrise_Sunset': 2,
 'Civil_Twilight': 2,
 'Nautical_Twilight': 2,
 'Astronomical_Twilight': 2,
 'Start_time_Month': 12,
 'Start_Time_Year': 4,
 'End_Time_Month': 12,
 'End_Time_Year': 5,
 'Weather_Timestamp_Month': 12,
 'Weather_Timestamp_Year': 4}

In [31]:
data = data.drop(['Airport_Code', 'County'], axis =1 )
{column: len(data[column].unique()) for column in data.columns if data.dtypes[column] == 'object'}

{'Source': 2,
 'Side': 3,
 'State': 49,
 'Timezone': 4,
 'Wind_Direction': 23,
 'Weather_Condition': 113,
 'Sunrise_Sunset': 2,
 'Civil_Twilight': 2,
 'Nautical_Twilight': 2,
 'Astronomical_Twilight': 2,
 'Start_time_Month': 12,
 'Start_Time_Year': 4,
 'End_Time_Month': 12,
 'End_Time_Year': 5,
 'Weather_Timestamp_Month': 12,
 'Weather_Timestamp_Year': 4}

In [32]:
data = onehot_encode(
    data,
    columns=['Side',  'State', 'Timezone',  'Wind_Direction', 'Weather_Condition'],
    prefixes=['SI', 'CO', 'ST', 'TZ', 'AC', 'WD', 'WC']
)
data

Unnamed: 0,Source,TMC,Severity,Start_Lat,Start_Lng,Distance(mi),Temperature(F),Humidity(%),Pressure(in),Visibility(mi),...,AC_Thunder / Windy,AC_Thunder / Wintry Mix / Windy,AC_Thunder in the Vicinity,AC_Thunderstorm,AC_Thunderstorms and Rain,AC_Thunderstorms and Snow,AC_Widespread Dust,AC_Widespread Dust / Windy,AC_Wintry Mix,AC_Wintry Mix / Windy
0,MapQuest,201.0,2,39.063148,-84.032608,0.01,36.0,100.0,29.67,10.0,...,0,0,0,0,0,0,0,0,0,0
1,MapQuest,201.0,3,39.747753,-84.205582,0.01,35.1,96.0,29.64,9.0,...,0,0,0,0,0,0,0,0,0,0
2,MapQuest,201.0,2,39.627781,-84.188354,0.01,36.0,89.0,29.65,6.0,...,0,0,0,0,0,0,0,0,0,0
3,MapQuest,201.0,3,40.100590,-82.925194,0.01,37.9,97.0,29.63,7.0,...,0,0,0,0,0,0,0,0,0,0
4,MapQuest,201.0,2,39.758274,-84.230507,0.00,34.0,100.0,29.66,7.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1886973,MapQuest,201.0,3,34.495808,-118.623932,0.00,100.0,18.0,28.85,10.0,...,0,0,0,0,0,0,0,0,0,0
1886974,MapQuest,201.0,3,34.031322,-118.433723,0.00,77.0,64.0,29.69,10.0,...,0,0,0,0,0,0,0,0,0,0
1886975,MapQuest,201.0,3,34.106785,-117.369102,0.00,102.2,16.0,29.73,6.0,...,0,0,0,0,0,0,0,0,0,0
1886976,MapQuest,201.0,3,33.924686,-118.103981,0.00,88.0,39.0,29.68,10.0,...,0,0,0,0,0,0,0,0,0,0


In [33]:
def get_binary_column(df, column):
    if column == 'Source':
        return df[column].apply(lambda x: 1 if x == 'MapQuest' else 0)
    else:
        return df[column].apply(lambda x: 1 if x == 'Day' else 0)
data['Sunrise_Sunset'] = get_binary_column(data, 'Sunrise_Sunset')
data['Civil_Twilight'] = get_binary_column(data, 'Civil_Twilight')
data['Nautical_Twilight'] = get_binary_column(data, 'Nautical_Twilight')
data['Astronomical_Twilight'] = get_binary_column(data, 'Astronomical_Twilight')
data

Unnamed: 0,Source,TMC,Severity,Start_Lat,Start_Lng,Distance(mi),Temperature(F),Humidity(%),Pressure(in),Visibility(mi),...,AC_Thunder / Windy,AC_Thunder / Wintry Mix / Windy,AC_Thunder in the Vicinity,AC_Thunderstorm,AC_Thunderstorms and Rain,AC_Thunderstorms and Snow,AC_Widespread Dust,AC_Widespread Dust / Windy,AC_Wintry Mix,AC_Wintry Mix / Windy
0,MapQuest,201.0,2,39.063148,-84.032608,0.01,36.0,100.0,29.67,10.0,...,0,0,0,0,0,0,0,0,0,0
1,MapQuest,201.0,3,39.747753,-84.205582,0.01,35.1,96.0,29.64,9.0,...,0,0,0,0,0,0,0,0,0,0
2,MapQuest,201.0,2,39.627781,-84.188354,0.01,36.0,89.0,29.65,6.0,...,0,0,0,0,0,0,0,0,0,0
3,MapQuest,201.0,3,40.100590,-82.925194,0.01,37.9,97.0,29.63,7.0,...,0,0,0,0,0,0,0,0,0,0
4,MapQuest,201.0,2,39.758274,-84.230507,0.00,34.0,100.0,29.66,7.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1886973,MapQuest,201.0,3,34.495808,-118.623932,0.00,100.0,18.0,28.85,10.0,...,0,0,0,0,0,0,0,0,0,0
1886974,MapQuest,201.0,3,34.031322,-118.433723,0.00,77.0,64.0,29.69,10.0,...,0,0,0,0,0,0,0,0,0,0
1886975,MapQuest,201.0,3,34.106785,-117.369102,0.00,102.2,16.0,29.73,6.0,...,0,0,0,0,0,0,0,0,0,0
1886976,MapQuest,201.0,3,33.924686,-118.103981,0.00,88.0,39.0,29.68,10.0,...,0,0,0,0,0,0,0,0,0,0


In [34]:
y = data['Severity'].copy()
X = data.drop('Severity', axis=1).copy()
y.unique()

array([2, 3, 1, 4])