# Exploratory Data Analysis - Final Traffic Accident Dataset

This notebook performs comprehensive exploratory data analysis on the final traffic accident dataset (`data/final/data.csv`).

At the end, we will create a new CSV file called `data/final/data_post_eda.csv` and filter out outliers.

## Install required packages

In [1]:
!pip install pandas
!pip install matplotlib
!pip install seaborn



In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import display
import warnings

warnings.filterwarnings('ignore')

print("Libraries imported successfully!")

## Variables

In [None]:
data_folder = "../../data"
final_data_folder = f"{data_folder}/final"

data_csv = f"{final_data_folder}/data.csv"
data_post_eda_csv = f"{final_data_folder}/data_post_eda.csv"

## Exploratory Data Analysis (EDA)

In [None]:
df = pd.read_csv(data_csv)

print(f"Dataset shape: {df.shape}")
print(f"Number of rows: {df.shape[0]}")
print(f"Number of columns: {df.shape[1]}")
print(f"Column names: {df.columns.tolist()}")
print(f"\nDataset info:")
display(df.info())

### Missing Values

Check for any missing values.

In [None]:
missing_values = df.isnull().sum()
missing_percent = (missing_values / len(df) * 100).round(2)

missing_df = pd.DataFrame({
    'Column': missing_values.index,
    'Missing_Count': missing_values.values,
    'Missing_Percent': missing_percent.values
})

missing_df = missing_df[missing_df['Missing_Count'] > 0].sort_values('Missing_Count', ascending=False)
print(f"\nColumns with missing values: {len(missing_df)}")

### Basic Statistical Summary

Use pandas to output statistics on both numerical & categorial columns.

In [None]:
df.describe()

In [None]:
df.describe(include='object')

### Outliers

Data collection happened between 2024-2025. However, when doing manual auditing, we noticed some news articles reporting accidents that have happened some years back.

These accidents could have happened under different conditions (different road conditions, traffic situations, etcetera) so we consider these accidents as outliers and will filter them out.

In [None]:
df["accident_datetime"] = pd.to_datetime(df["accident_datetime"])

accidents_by_year = df["accident_datetime"].dt.year.value_counts().sort_index()

print("Accidents by Year:")
display(accidents_by_year)

plt.figure(figsize=(12, 5))
accidents_by_year.plot(kind='bar', color='coral')
plt.title('Number of Accidents by Year')
plt.xlabel('Year')
plt.ylabel('Number of Accidents')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Let's filter out accidents that happened before 2024.

In [None]:
count_before = len(df)
df = df[df["accident_datetime"] >= "2024-01-01"]
count_after = len(df)

print(f"Removed {count_before - count_after} outliers")
print(f"Dataset shape after removing outliers: {df.shape}")

### Accident Severity Analysis

In [None]:
severity_counts = df['accident_severity'].value_counts()
print("Accident Severity Distribution:")
print(severity_counts)

print(f"\nTotal accidents: {severity_counts.sum()}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# bar plot
severity_counts.plot(kind='bar', ax=axes[0], color='steelblue')
axes[0].set_title('Accident Severity Distribution (Bar Chart)')
axes[0].set_xlabel('Severity')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=45)

# pie chart
axes[1].pie(severity_counts.values, labels=severity_counts.index, autopct='%1.1f%%', startangle=90)
axes[1].set_title('Accident Severity Distribution (Pie Chart)')

plt.tight_layout()
plt.show()

The distribution of the `accident_severity` target variable is highly imbalanced, with certain classes containing very few samples. Given the limited dataset size (219 instances), this results in insufficient per-class representation for reliable multi-class learning.

To tackle this issue, the problem is reformulated as a binary classification task, where we attempt to classify fatal vs non fatal accidents.

In [None]:
fatal_buckets = {"fatal"}
non_fatal_buckets = {"grievious", "serious", "not injured", "slight"}

df["is_fatal"] = df["accident_severity"].map(lambda x: 0 if x in non_fatal_buckets else 1)

summarised_fatality = df["is_fatal"].value_counts()
print("Fatality Distribution:")
print(summarised_fatality)

### Temporal Analysis

In [None]:
df['accident_datetime'] = pd.to_datetime(df['accident_datetime'])
df['accident_year'] = df['accident_datetime'].dt.year
df['accident_month'] = df['accident_datetime'].dt.month
df['accident_day_of_week'] = df['accident_datetime'].dt.dayofweek
df['accident_hour'] = df['accident_datetime'].dt.hour

print("Temporal features extracted successfully!")

#### Accidents by Year

In [None]:
accidents_by_year_stats = (
    df
    .groupby("accident_year")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

accidents_by_year_stats.columns = ["non_fatal", "fatal"]

accidents_by_year_stats["total"] = accidents_by_year_stats["non_fatal"] + accidents_by_year_stats["fatal"]
accidents_by_year_stats["fatal_rate"] = accidents_by_year_stats["fatal"] / accidents_by_year_stats["total"]

print("Distribution of fatal vs. non-fatal accidents based on year of accident:")
print(accidents_by_year_stats)

accidents_by_year_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Year")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

#### Accidents by Month

In [None]:
accidents_by_month_stats = (
    df
    .groupby("accident_month")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

accidents_by_month_stats.columns = ["non_fatal", "fatal"]

accidents_by_month_stats["total"] = accidents_by_month_stats["non_fatal"] + accidents_by_month_stats["fatal"]
accidents_by_month_stats["fatal_rate"] = accidents_by_month_stats["fatal"] / accidents_by_month_stats["total"]

print("Distribution of fatal vs. non-fatal accidents based on month of accident:")
print(accidents_by_month_stats)

accidents_by_month_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Month")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

#### Accidents by Day Of Week

In [None]:
accidents_by_day_stats = (
    df
    .groupby("accident_day_of_week")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

accidents_by_day_stats.columns = ["non_fatal", "fatal"]

accidents_by_day_stats["total"] = accidents_by_day_stats["non_fatal"] + accidents_by_day_stats["fatal"]
accidents_by_day_stats["fatal_rate"] = accidents_by_day_stats["fatal"] / accidents_by_day_stats["total"]

print("Distribution of fatal vs. non-fatal accidents based on day of week of accident:")
print(accidents_by_day_stats)

accidents_by_day_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Day of the Week")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

#### Accidents by time of day

In [None]:
accident_time_category_stats = (
    df
    .groupby("accident_time_category")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

accident_time_category_stats.columns = ["non_fatal", "fatal"]

accident_time_category_stats["total"] = accident_time_category_stats["non_fatal"] + accident_time_category_stats["fatal"]
accident_time_category_stats["fatal_rate"] = accident_time_category_stats["fatal"] / accident_time_category_stats["total"]

print("Distribution of fatal vs. non-fatal accidents based on accident time of day:")
print(accident_time_category_stats)

accident_time_category_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Time Category")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

### Location Analysis

#### Top 15 cities with the most accidents

In [None]:
top_cities = df['city'].value_counts().head(15)
print("Top 15 Cities with Most Accidents:")
print(top_cities)

plt.figure(figsize=(12, 6))
top_cities.plot(kind='barh', color='mediumpurple')
plt.title('Top 15 Cities with Most Accidents')
plt.xlabel('Number of Accidents')
plt.ylabel('City')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

#### Top 3 cities with the most fatalities

In [None]:
fatal_df = df[df['is_fatal'] == 1]

top_3_fatal_cities = fatal_df['city'].value_counts().head(5)

print("Top 3 Cities with Most Fatal Accidents:")
print(top_3_fatal_cities)

plt.figure(figsize=(8, 4))
top_3_fatal_cities.plot(kind='barh')
plt.title('Top 3 Cities with Most Fatal Accidents')
plt.xlabel('Number of Fatal Accidents')
plt.ylabel('City')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

#### Analysis by region

In [None]:
region_category_stats = (
    df
    .groupby("region")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

region_category_stats.columns = ["non_fatal", "fatal"]

region_category_stats["total"] = region_category_stats["non_fatal"] + region_category_stats["fatal"]
region_category_stats["fatal_rate"] = region_category_stats["fatal"] / region_category_stats["total"]

print(region_category_stats)

region_category_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Region")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

#### Analysis by street type

In [None]:
street_type_category_stats = (
    df
    .groupby("street_type")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

street_type_category_stats.columns = ["non_fatal", "fatal"]

street_type_category_stats["total"] = street_type_category_stats["non_fatal"] + street_type_category_stats["fatal"]
street_type_category_stats["fatal_rate"] = street_type_category_stats["fatal"] / street_type_category_stats["total"]

print(street_type_category_stats)

street_type_category_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Street Type")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

### Driver Demographics Analysis

#### Gender of drivers

In [None]:
total_drivers = df['num_drivers_total'].sum()
print(f"Total drivers involved in accidents: {total_drivers}")

male_drivers = df['num_drivers_male'].sum()
female_drivers = df['num_drivers_female'].sum()
unknown_gender = df['num_drivers_gender_unknown'].sum()

print(f"\nDriver Gender Distribution:")
print(f"Male drivers: {male_drivers} ({male_drivers/total_drivers*100:.1f}%)")
print(f"Female drivers: {female_drivers} ({female_drivers/total_drivers*100:.1f}%)")
print(f"Unknown gender: {unknown_gender} ({unknown_gender/total_drivers*100:.1f}%)")

fig, ax = plt.subplots(figsize=(8, 8))
ax.pie([male_drivers, female_drivers, unknown_gender], 
       labels=['Male', 'Female', 'Unknown'], 
       autopct='%1.1f%%', 
       startangle=90,
       colors=['steelblue', 'pink', 'gray'])
ax.set_title('Driver Gender Distribution')
plt.show()

#### Age of drivers

In [None]:
numerical_gender_buckets = [
    "num_drivers_male",
    "num_drivers_female",
    "num_drivers_gender_unknown",
]

for col in numerical_gender_buckets:
    filtered = df[df[col] > 0]

    if filtered.empty:
        continue # nothing to report

    numeric_gender_stats = (
        filtered
        .groupby(col)["is_fatal"]
        .agg(
            accidents="count",
            fatal="sum",
            fatality_rate="mean"
        )
    )

    print(f"Stats for '{col}' column")
    print(numeric_gender_stats)

    numeric_gender_stats["fatality_rate"].plot(kind="bar")
    plt.title(f"Fatality Rate vs Gender of Drivers ({col})")
    plt.ylabel("Fatality rate")
    plt.tight_layout()
    plt.show()

In [None]:
driver_ages = {
    'Under 18': df['num_drivers_under_18'].sum(),
    '18-24': df['num_drivers_18_to_24'].sum(),
    '25-49': df['num_drivers_25_to_49'].sum(),
    '50-64': df['num_drivers_50_to_64'].sum(),
    '65+': df['num_drivers_65_plus'].sum(),
    'Unknown': df['num_drivers_age_unknown'].sum()
}

print("Driver Age Distribution:")
for age_group, count in driver_ages.items():
    print(f"{age_group}: {count:,} ({count/total_drivers*100:.1f}%)")

plt.figure(figsize=(12, 6))
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8', '#CCCCCC']
plt.bar(driver_ages.keys(), driver_ages.values(), color=colors)
plt.title('Distribution of Drivers involved in accidents by Age Group', fontsize=14, fontweight='bold')
plt.xlabel('Age Group', fontsize=12)
plt.ylabel('Number of Drivers', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
numerical_age_buckets = [
    "num_drivers_under_18",
    "num_drivers_18_to_24",
    "num_drivers_25_to_49",
    "num_drivers_50_to_64",
    "num_drivers_65_plus",
    "num_drivers_age_unknown",
]

for col in numerical_age_buckets:
    filtered = df[df[col] > 0]

    if filtered.empty:
        continue # nothing to report

    numeric_age_stats = (
        filtered
        .groupby(col)["is_fatal"]
        .agg(
            accidents="count",
            fatal="sum",
            fatality_rate="mean"
        )
    )

    print(f"Stats for '{col}' column")
    print(numeric_age_stats)

    numeric_age_stats["fatality_rate"].plot(kind="bar")
    plt.title(f"Fatality Rate vs Number of Drivers ({col})")
    plt.ylabel("Fatality rate")
    plt.tight_layout()
    plt.show()

### Vehicle Type Analysis

In [None]:
vehicle_types = {
    'Car': df['num_vehicle_car'].sum(),
    'Motorbike': df['num_vehicle_motorbike'].sum(),
    'Van': df['num_vehicle_van'].sum(),
    'Bus': df['num_vehicle_bus'].sum(),
    'Bicycle': df['num_vehicle_bicycle'].sum(),
    'Pedestrian': df['num_vehicle_pedestrian'].sum(),
    'Unknown': df['num_vehicle_unknown'].sum()
}

print("Vehicle Type Distribution:")
for vehicle, count in sorted(vehicle_types.items(), key=lambda x: x[1], reverse=True):
    total_vehicles = sum(vehicle_types.values())
    print(f"{vehicle}: {count} ({count/total_vehicles*100:.1f}%)")

plt.figure(figsize=(12, 6))
plt.bar(vehicle_types.keys(), vehicle_types.values(), color='mediumseagreen')
plt.title('Vehicle Type Distribution')
plt.xlabel('Vehicle Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
numerical_vehicle_buckets = [
    "num_vehicle_unknown",
    "num_vehicle_pedestrian",
    "num_vehicle_bicycle",
    "num_vehicle_motorbike",
    "num_vehicle_car",
    "num_vehicle_van",
    "num_vehicle_bus",
]

for col in numerical_vehicle_buckets:
    filtered = df[df[col] > 0]

    if filtered.empty:
        continue # nothing to report

    numeric_vehicle_stats = (
        filtered
        .groupby(col)["is_fatal"]
        .agg(
            accidents="count",
            fatal="sum",
            fatality_rate="mean"
        )
    )

    print(f"Stats: {col}")
    print(numeric_vehicle_stats)

    numeric_vehicle_stats["fatality_rate"].plot(kind="bar")
    plt.title(f"Fatality Rate vs Number of Vehicles ({col})")
    plt.ylabel("Fatality rate")
    plt.tight_layout()
    plt.show()

### Injuries Analysis

In [None]:
total_injured = df['total_injured'].sum()
print(f"Total people injured: {total_injured}")
print(f"Average injuries per accident: {df['total_injured'].mean():.2f}")

plt.figure(figsize=(12, 6))
df['total_injured'].value_counts().sort_index().plot(kind='bar', color='indianred')
plt.title('Distribution of Number of Injuries per Accident')
plt.xlabel('Number of Injured')
plt.ylabel('Number of Accidents')
plt.tight_layout()
plt.show()

### Weather Conditions Analysis

#### Simple Weather Condition Statistics

In [None]:
print("Weather Conditions Summary:")
print(f"Temperature (mean): Min={df['temperature_min'].min():.1f}째C, Max={df['temperature_max'].max():.1f}째C, Avg={df['temperature_mean'].mean():.1f}째C")
print(f"Precipitation (sum): Min={df['precipitation_sum'].min():.1f}mm, Max={df['precipitation_sum'].max():.1f}mm, Avg={df['precipitation_sum'].mean():.1f}mm")
print(f"Wind speed (max): Min={df['windspeed_max'].min():.1f}km/h, Max={df['windspeed_max'].max():.1f}km/h, Avg={df['windspeed_max'].mean():.1f}km/h")

#### Raining Statistics

In [None]:
rain_counts = df['is_raining'].value_counts()
print(f"\nAccidents during rain: {rain_counts.get(True, 0)} ({rain_counts.get(True, 0)/len(df)*100:.1f}%)")
print(f"Accidents without rain: {rain_counts.get(False, 0)} ({rain_counts.get(False, 0)/len(df)*100:.1f}%)")

#### Visualise Distributions

In [None]:
# Visualize weather impact
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Temperature distribution
axes[0, 0].hist(df['temperature_mean'], bins=30, color='orange', edgecolor='black')
axes[0, 0].set_title('Distribution of Mean Temperature')
axes[0, 0].set_xlabel('Temperature (째C)')
axes[0, 0].set_ylabel('Frequency')

# Precipitation distribution
axes[0, 1].hist(df['precipitation_sum'], bins=30, color='blue', edgecolor='black')
axes[0, 1].set_title('Distribution of Precipitation')
axes[0, 1].set_xlabel('Precipitation (mm)')
axes[0, 1].set_ylabel('Frequency')

# Wind speed distribution
axes[1, 0].hist(df['windspeed_max'], bins=30, color='green', edgecolor='black')
axes[1, 0].set_title('Distribution of Max Wind Speed')
axes[1, 0].set_xlabel('Wind Speed (km/h)')
axes[1, 0].set_ylabel('Frequency')

# Rain vs no rain
rain_counts.plot(kind='pie', ax=axes[1, 1], autopct='%1.1f%%', 
                 labels=['No Rain', 'Rain'], colors=['gold', 'skyblue'])
axes[1, 1].set_title('Accidents: Rain vs No Rain')
axes[1, 1].set_ylabel('')

plt.tight_layout()
plt.show()

#### Investigate temperature conditions between fatal vs. non-fatal

In [None]:
plt.figure(figsize=(8, 5))
sns.boxplot(x='is_fatal', y='temperature_mean', data=df)
plt.title('Boxplot of Mean Temperature by Fatality')
plt.xlabel('Fatality')
plt.ylabel('Max Temperature')
plt.show()

#### Investigate wind conditions between fatal vs. non-fatal

In [None]:
plt.figure(figsize=(8, 5))
sns.boxplot(x='is_fatal', y='windspeed_max', data=df)
plt.title('Boxplot of Windspeed by Fatality')
plt.xlabel('Fatality')
plt.ylabel('Windspeed')
plt.show()

#### Investigate rain conditions between fatal vs. non-fatal

In [None]:
plt.figure(figsize=(6,4))
sns.countplot(data=df, x='is_raining', hue='is_fatal')
plt.title('Fatality vs. Is Raining')
plt.xlabel('Is Raining')
plt.ylabel('Count')
plt.show()

Interestingly enough, raining seems to be inversely correleated with fatal accidents. It might be that drivers are more careful when it rains. Moreover, in Malta there is not a lot of rain throughout the year so fewer accidents in the rain happen.

### Traffic Level Analysis

In [None]:
traffic_level_stats = (
    df
    .groupby("traffic_level")["is_fatal"]
    .value_counts()
    .unstack(fill_value=0)
)

traffic_level_stats.columns = ["non_fatal", "fatal"]

traffic_level_stats["total"] = traffic_level_stats["non_fatal"] + traffic_level_stats["fatal"]
traffic_level_stats["fatal_rate"] = traffic_level_stats["fatal"] / traffic_level_stats["total"]

print(traffic_level_stats)

traffic_level_stats[["non_fatal", "fatal"]].plot(
    kind="bar",
    stacked=True
)

plt.ylabel("Number of Accidents")
plt.title("Fatal vs Non-fatal Accidents by Traffic Levels")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

### Special Days Analysis (Weekend, Holidays, School Days)

#### Weekend by Fatality

In [None]:
plt.figure(figsize=(6,4))
sns.countplot(data=df, x='is_weekend', hue='is_fatal')
plt.title('Fatality vs. Weekend')
plt.xlabel('Is Weekend')
plt.ylabel('Count')
plt.show()

#### Public Holiday by Fatality

In [None]:
plt.figure(figsize=(6,4))
sns.countplot(data=df, x='is_public_holiday_mt', hue='is_fatal')
plt.title('Fatality vs. Public Holiday')
plt.xlabel('Is Public Holiday')
plt.ylabel('Count')
plt.show()

#### School Holiday by Fatality

In [None]:
plt.figure(figsize=(6,4))
sns.countplot(data=df, x='is_school_holiday_mt', hue='is_fatal')
plt.title('Fatality vs. Is School Holiday')
plt.xlabel('Is School Holiday')
plt.ylabel('Count')
plt.show()

#### School Day by Fatality

In [None]:
plt.figure(figsize=(6,4))
sns.countplot(data=df, x='is_school_day_mt', hue='is_fatal')
plt.title('Fatality vs. Is School Day')
plt.xlabel('Is School Day')
plt.ylabel('Count')
plt.show()

#### Overall

In [None]:
weekend_counts = df['is_weekend'].value_counts()
holiday_counts = df['is_public_holiday_mt'].value_counts()
school_holiday_counts = df['is_school_holiday_mt'].value_counts()
school_day_counts = df['is_school_day_mt'].value_counts()

# Visualize special days
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Weekend
weekend_counts.plot(kind='pie', ax=axes[0, 0], autopct='%1.1f%%', 
                    labels=['Weekday', 'Weekend'], colors=['lightblue', 'salmon'])
axes[0, 0].set_title('Accidents: Weekend vs Weekday')
axes[0, 0].set_ylabel('')

# Public Holiday
holiday_counts.plot(kind='pie', ax=axes[0, 1], autopct='%1.1f%%', 
                    labels=['Not Holiday', 'Holiday'], colors=['lightgreen', 'orange'])
axes[0, 1].set_title('Accidents: Public Holiday vs Not')
axes[0, 1].set_ylabel('')

# School Holiday
school_holiday_counts.plot(kind='pie', ax=axes[1, 0], autopct='%1.1f%%', 
                           labels=['Not School Holiday', 'School Holiday'], colors=['yellow', 'purple'])
axes[1, 0].set_title('Accidents: School Holiday vs Not')
axes[1, 0].set_ylabel('')

# School Day
school_day_counts.plot(kind='pie', ax=axes[1, 1], autopct='%1.1f%%', 
                       labels=['Not School Day', 'School Day'], colors=['pink', 'teal'])
axes[1, 1].set_title('Accidents: School Day vs Not')
axes[1, 1].set_ylabel('')

plt.tight_layout()
plt.show()

### Correlation Analysis

In [None]:
corr_df = df[[
    "num_drivers_under_18",
    "num_drivers_18_to_24",
    "num_drivers_25_to_49",
    "num_drivers_50_to_64",
    "num_drivers_65_plus",
    "num_drivers_age_unknown",
    "num_drivers_male",
    "num_drivers_female",
    "num_drivers_gender_unknown",
    "num_drivers_total",
    "num_vehicle_unknown",
    "num_vehicle_pedestrian",
    "num_vehicle_bicycle",
    "num_vehicle_motorbike",
    "num_vehicle_car",
    "num_vehicle_van",
    "num_vehicle_bus",
    "is_weekend",
    "is_public_holiday_mt",
    "is_school_holiday_mt",
    "is_school_day_mt",
    "temperature_max",
    "temperature_min",
    "temperature_mean",
    "windspeed_max",
    "precipitation_sum",
    "is_raining",
    "traffic_level",
    "total_injured",
    "is_fatal",
]]

corr_df["traffic_level"] = df["traffic_level"].map(
    lambda x: 0 if x == 'LOW' else (1 if x == 'MODERATE' else 2)
)

print("DataFrame to check for correlation:")
corr_df

In [None]:
correlation_matrix = corr_df.corr()

plt.figure(figsize=(20, 16))
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', center=0, 
            square=True, linewidths=0.5, cbar_kws={"shrink": 0.8})
plt.title('Correlation Heatmap of all Features', fontsize=16)
plt.tight_layout()
plt.show()

## Persist DataFrame

In [None]:
print(f"\nSaving {len(df)} rows to {data_post_eda_csv}")
df.to_csv(data_post_eda_csv, index=False)