### Note on what preprocessing should be done
Refer Inference and TO DO in the `01_initial_exploratory_analysis.ipynb` File.
1) Date format conversion
2) Age Column Cleaning
3) Removing unwanted Columns
4) Check for cleaning on `state`, `city_or_county` and `address`
5) Major cleaning reguired for the fields - `gun_stolen`, `gun_type`, `participant_age`, `participant_age_group`, `participant_gender`, `participant_status` and `participant_type`.
6) Clean Text Data - `incident_characterstics` and `notes`
7) Change data types too

And generate visualizations after cleaning too!

Leave encoding out!

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
# Importing packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
%matplotlib inline
plt.style.use('bmh')

In [5]:
# self created packages
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from scripts.visualizations import Visualization

In [6]:
# pyspark packages
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, sum, desc, explode, split, year, month, dayofweek, length, initcap, trim, lower, 
    regexp_extract, regexp_replace, max, explode, count, when)
from pyspark.sql.types import (
    StructType, StructField, IntegerType, StringType,
    FloatType, BooleanType, DateType, DoubleType)

### Setting Spark Session and Loading Data

In [8]:
spark = SparkSession.builder \
    .appName("MIS548 Project PreProcessing") \
    .getOrCreate()

spark

24/10/22 00:21:35 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [9]:
# creating data schema
ip_data_schema = StructType([
    StructField("incident_id", IntegerType(), True),
    StructField("date", DateType(), True),
    StructField("state", StringType(), True),
    StructField("city_or_county", StringType(), True),
    StructField("address", StringType(), True),
    StructField("n_killed", IntegerType(), True),
    StructField("n_injured", IntegerType(), True),
    StructField("incident_url", StringType(), True),
    StructField("source_url", StringType(), True),
    StructField("incident_url_fields_missing", BooleanType(), True),
    StructField("congressional_district", IntegerType(), True),
    StructField("gun_stolen", StringType(), True),
    StructField("gun_type", StringType(), True),
    StructField("incident_characteristics", StringType(), True),
    StructField("latitude", DoubleType(), True),
    StructField("location_description", StringType(), True),
    StructField("longitude", DoubleType(), True),
    StructField("n_guns_involved", IntegerType(), True),
    StructField("notes", StringType(), True),
    StructField("participant_age", StringType(), True),
    StructField("participant_age_group", StringType(), True),
    StructField("participant_gender", StringType(), True),
    StructField("participant_name", StringType(), True),
    StructField("participant_relationship", StringType(), True),
    StructField("participant_status", StringType(), True),
    StructField("participant_type", StringType(), True),
    StructField("sources", StringType(), True),
    StructField("state_house_district", IntegerType(), True),
    StructField("state_senate_district", IntegerType(), True)
])

In [10]:
ip_data = spark.read.option("header", "True") \
                .option("inferSchema", "True") \
                .option("quote", '"') \
                .option("escape", '"') \
                .option("sep", ",") \
                .option("ignoreLeadingWhiteSpace", "True") \
                .option("ignoreTrailingWhiteSpace", "True") \
                .option("multiLine", "True") \
                .option("mode", "PERMISSIVE") \
                .csv("../data/gun-violence-data_01-2013_03-2018.csv", schema = ip_data_schema)

In [11]:
print(f"Number of records in the data : {ip_data.count()}")
print(f"Number of columns: {len(ip_data.columns)}")

[Stage 0:>                                                          (0 + 1) / 1]

Number of records in the data : 239677
Number of columns: 29


                                                                                

In [12]:
ip_data.printSchema()

root
 |-- incident_id: integer (nullable = true)
 |-- date: date (nullable = true)
 |-- state: string (nullable = true)
 |-- city_or_county: string (nullable = true)
 |-- address: string (nullable = true)
 |-- n_killed: integer (nullable = true)
 |-- n_injured: integer (nullable = true)
 |-- incident_url: string (nullable = true)
 |-- source_url: string (nullable = true)
 |-- incident_url_fields_missing: boolean (nullable = true)
 |-- congressional_district: integer (nullable = true)
 |-- gun_stolen: string (nullable = true)
 |-- gun_type: string (nullable = true)
 |-- incident_characteristics: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- location_description: string (nullable = true)
 |-- longitude: double (nullable = true)
 |-- n_guns_involved: integer (nullable = true)
 |-- notes: string (nullable = true)
 |-- participant_age: string (nullable = true)
 |-- participant_age_group: string (nullable = true)
 |-- participant_gender: string (nullable = true)
 |-- 

### Preprocessing

#### Duplicate Check

In [15]:
def check_duplicates_except(df, column_to_exclude=""):
    """
    Check for duplicate rows in a DataFrame, excluding a specified column.

    Parameters:
    df (DataFrame): The input DataFrame to check for duplicates.
    column_to_exclude (str): The column to exclude from the duplicate check.

    Returns:
    DataFrame: A DataFrame containing the duplicate rows and their counts.
    """
    columns_to_check = [col for col in df.columns if col != column_to_exclude]
    
    df_duplicates = df.groupBy(columns_to_check).count().filter("count > 1")
    
    return df_duplicates

In [16]:
ip_data_dup_chk = check_duplicates_except(ip_data)

print(f"Number of Duplicate Rows: {ip_data_dup_chk.count()}")

24/10/22 00:21:38 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 3:>                                                          (0 + 1) / 1]

Number of Duplicate Rows: 0


                                                                                

In [17]:
# Drop duplicates if there are any
# ip_data = ip_data.dropDuplicates()

# print("DataFrame after dropping duplicates:")
# ip_data.count()

#### Null Values Check

In [19]:
def get_null_counts(df):
    """
    Get counts and percentages of null values in each column of a DataFrame.

    Parameters:
    df (DataFrame): The input DataFrame to analyze.

    Returns:
    DataFrame: A DataFrame containing column names, null counts, and null percentages.
    """
    total_rows = df.count()
    
    null_counts = df.select([sum(col(c).isNull().cast('int')).alias(c) for c in df.columns])

    narrow_null_counts = null_counts.selectExpr(
                                    f"'{null_counts.columns[0]}' as column_name",
                                    f"{null_counts.columns[0]} as null_count",
                                    f"({null_counts.columns[0]} / {total_rows} * 100) as null_percentage")

    for c in null_counts.columns[1:]:
        next_col = null_counts.selectExpr(f"'{c}' as column_name", 
                                          f"{c} as null_count",
                                          f"({c} / {total_rows} * 100) as null_percentage")
        narrow_null_counts = narrow_null_counts.union(next_col)
    
    narrow_null_counts = narrow_null_counts.orderBy(desc("null_count"))
    
    return narrow_null_counts

In [20]:
narrow_null_counts = get_null_counts(ip_data)
narrow_null_counts.show(n=29, truncate=False)

                                                                                

+---------------------------+----------+-------------------+
|column_name                |null_count|null_percentage    |
+---------------------------+----------+-------------------+
|participant_relationship   |223903    |93.4186425898188   |
|location_description       |197588    |82.43928286819344  |
|participant_name           |122253    |51.00739745574252  |
|gun_stolen                 |99498     |41.51337007722894  |
|gun_type                   |99451     |41.493760352474375 |
|n_guns_involved            |99451     |41.493760352474375 |
|participant_age            |92298     |38.509327136104005 |
|notes                      |81017     |33.80257596682201  |
|participant_age_group      |42119     |17.573233977394576 |
|state_house_district       |38772     |16.17677123795775  |
|participant_gender         |36362     |15.171251309053435 |
|state_senate_district      |32335     |13.49107340295481  |
|participant_status         |27626     |11.526345873821851 |
|participant_type       

There are some columns which do not add much significance to our analysis. We are dropping those out to aid in the processing  speed.

Might drop participant_age_group, state_house_district, state_senate_district, participant_name Later

In [22]:
trivial_columns = ["participant_relationship", "location_description", "sources", "source_url", 
                   "incident_url", "incident_url_fields_missing", "participant_name", "state_house_district",
                   "state_senate_district"]

In [23]:
ip_data = ip_data.drop(*trivial_columns)

For missing data, our plan is to impute the data. But some models such as Decision Trees, Random Forest, XGBoost do cater for missing data.

My plan is the use different sets of data and verify the performance. Let's see how it goes. So I would do the imputation after all the necessary preprocessing is done.

#### New Date Features

In [26]:
ip_data = ip_data.withColumn("year", year("date")) \
                .withColumn("month", month("date")) \
                .withColumn("day_of_week", dayofweek("date"))

In [27]:
ip_data.select("date", "year", "month", "day_of_week").show(5)

+----------+----+-----+-----------+
|      date|year|month|day_of_week|
+----------+----+-----+-----------+
|2013-01-01|2013|    1|          3|
|2013-01-01|2013|    1|          3|
|2013-01-01|2013|    1|          3|
|2013-01-05|2013|    1|          7|
|2013-01-07|2013|    1|          2|
+----------+----+-----+-----------+
only showing top 5 rows



#### Text Columms

First I will focus on the columns such as `state`, `city_or_county` and `address`.

In [29]:
# checking if there are any abbreviated state names in the data

abbreviated_states = ip_data.filter(length("state") == 2)
abbreviated_count = abbreviated_states.count()
print(f"Number of abbreviated state entries: {abbreviated_count}")

Number of abbreviated state entries: 0


                                                                                

In [30]:
def count_special_characters(ip_data, columns):
    """
    Check for special characters in specified columns and count the occurrences.
    
    Args:
        ip_data (DataFrame): Input DataFrame.
        columns (list): List of column names to check for special characters.

    Returns:
        DataFrame: A DataFrame with counts of special characters for each specified column.
    """
    special_char_pattern = r"[^a-zA-Z0-9\s,'_()-]"

    ip_data_with_special_chars = ip_data.select(
        *columns,
        *[
            (regexp_extract(col_name, special_char_pattern, 0) != "").alias(f"{col_name}_has_special_chars")
            for col_name in columns
        ]
    )

    count_true_values = ip_data_with_special_chars.agg(
        *[
            sum(col(f"{col_name}_has_special_chars").cast("int")).alias(f"count_{col_name}_special_chars")
            for col_name in columns
        ]
    )

    return count_true_values

In [31]:
columns_to_check = ["state", "city_or_county", "address"]
count_result = count_special_characters(ip_data, columns_to_check)

count_result.show()

[Stage 136:>                                                        (0 + 1) / 1]

+-------------------------+----------------------------------+---------------------------+
|count_state_special_chars|count_city_or_county_special_chars|count_address_special_chars|
+-------------------------+----------------------------------+---------------------------+
|                        0|                                58|                      15483|
+-------------------------+----------------------------------+---------------------------+



                                                                                

In [32]:
# checking county count in the data
county_count = ip_data.filter(lower(col("city_or_county")).contains("county")).count()

print(f"Count of entries with 'county' in 'city_or_county': {county_count}")

Count of entries with 'county' in 'city_or_county': 6331


In [33]:
ip_data = ip_data.withColumn("city_or_county", trim(col("city_or_county")))

Finalized transformation related to these columns are:
1) `state` : Making sure to trim extra spaces and also capitalizing the first letter of each word.
2) `city_or_state` : Making sure to trim extra spaces, replacing the special characters.
3) `address` : Making sure to trim extra spaces, replacing the special characters. And mapping `Street` to `St` an other common abbreviations.

In [35]:
street_type_mapping = {
    "Street": "St",
    "Avenue": "Ave",
    "Road": "Rd",
    "Boulevard": "Blvd",
    "Lane": "Ln",
    "Drive": "Dr",
    "Circle": "Cir",
    "Court": "Ct",
    "Terrace": "Ter",
    "Place": "Pl",
    "Highway": "Hwy",
}

def replace_street_types(address):
    """
    Replace full street type names in an address with their abbreviations.

    Parameters:
    address (str): The address string to modify.

    Returns:
    str: The modified address with street types replaced by abbreviations.
    """
    for full, abbr in street_type_mapping.items():
        address = regexp_replace(address, f"\\b{full}\\b", abbr)
    return address


In [36]:
ip_data = ip_data \
    .withColumn("state", initcap(trim(col("state")))) \
    .withColumn("city_or_county", regexp_replace(trim(col("city_or_county")), "[^a-zA-Z0-9\s,'_()-]", "")) \
    .withColumn("address", trim(regexp_replace(col("address"), "[^a-zA-Z0-9\s,'_()-]", ""))) \
    .withColumn("address", replace_street_types(col("address")))

#### Cleaning the wrongly formatted Data

Major cleaning reguired for the fields - `gun_stolen`, `gun_type`, `participant_age`, `participant_age_group`, `participant_gender`, `participant_status` and `participant_type`

These columns are having data in the form of `0::val1||1::val2`.

Need to figure out a way to handle this type of data!

First I will get the maximum number of `||` present in these columns to get an idea of how many values are present.

In [38]:
columns_to_check = [
    "gun_stolen", "gun_type", "participant_age", "participant_age_group", 
    "participant_gender", "participant_status", "participant_type"
]

In [39]:
def max_delimiters_count(df, col_name):
    """
    Calculate the maximum count of delimiters (specifically '||') in a specified column.

    Parameters:
    df (DataFrame): The input DataFrame to analyze.
    col_name (str): The name of the column to count delimiters in.

    Returns:
    int: The maximum count of delimiters found in the specified column.
    """
    delimiter_count_col = (length(col(col_name)) - length(regexp_replace(col(col_name), r"\|\|", "")))
    max_count = df.select(delimiter_count_col.alias(f"{col_name}_delimiter_count")) \
                  .agg(max(f"{col_name}_delimiter_count")).collect()[0][0]
    
    return max_count

In [40]:
for col_name in columns_to_check:
    max_count = max_delimiters_count(ip_data, col_name)
    print(f"Max number of '||' in {col_name}: {int(max_count)}")

Max number of '||' in gun_stolen: 798
Max number of '||' in gun_type: 798
Max number of '||' in participant_age: 130
Max number of '||' in participant_age_group: 204
Max number of '||' in participant_gender: 154
Max number of '||' in participant_status: 204
Max number of '||' in participant_type: 204


                                                                                

In [41]:
def max_delimiters_data(df, col_name):
    """
    Retrieve rows with the maximum count of delimiters (specifically '||') in a specified column.

    Parameters:
    df (DataFrame): The input DataFrame to analyze.
    col_name (str): The name of the column to count delimiters in.

    Returns:
    DataFrame: A DataFrame containing rows with the maximum delimiter count, including 'incident_id', 
                the specified column, and the delimiter count.
    """
    max_count = max_delimiters_count(df, col_name)
    
    result = df.withColumn(f"{col_name}_delimiter_count", 
                           (length(col(col_name)) - length(regexp_replace(col(col_name), r"\|\|", "")))) \
                .filter(col(f"{col_name}_delimiter_count") == max_count) \
                .select("incident_id", col_name, f"{col_name}_delimiter_count")
    
    return result

In [42]:
for col_name in columns_to_check:
    result_df = max_delimiters_data(ip_data, col_name)
    print(f"Incident IDs for max '||' in {col_name}:")
    result_df.show(1) 

Incident IDs for max '||' in gun_stolen:
+-----------+--------------------+--------------------------+
|incident_id|          gun_stolen|gun_stolen_delimiter_count|
+-----------+--------------------+--------------------------+
|     338106|0::Unknown||1::Un...|                       798|
+-----------+--------------------+--------------------------+
only showing top 1 row



                                                                                

Incident IDs for max '||' in gun_type:
+-----------+--------------------+------------------------+
|incident_id|            gun_type|gun_type_delimiter_count|
+-----------+--------------------+------------------------+
|     338106|0::Unknown||1::Un...|                     798|
+-----------+--------------------+------------------------+
only showing top 1 row

Incident IDs for max '||' in participant_age:
+-----------+--------------------+-------------------------------+
|incident_id|     participant_age|participant_age_delimiter_count|
+-----------+--------------------+-------------------------------+
|     577157|0::34||1::23||2::...|                            130|
+-----------+--------------------+-------------------------------+

Incident IDs for max '||' in participant_age_group:


                                                                                

+-----------+---------------------+-------------------------------------+
|incident_id|participant_age_group|participant_age_group_delimiter_count|
+-----------+---------------------+-------------------------------------+
|     577157| 0::Adult 18+||1::...|                                  204|
+-----------+---------------------+-------------------------------------+



                                                                                

Incident IDs for max '||' in participant_gender:
+-----------+--------------------+----------------------------------+
|incident_id|  participant_gender|participant_gender_delimiter_count|
+-----------+--------------------+----------------------------------+
|     577157|0::Male||1::Male|...|                               154|
+-----------+--------------------+----------------------------------+

Incident IDs for max '||' in participant_status:


                                                                                

+-----------+--------------------+----------------------------------+
|incident_id|  participant_status|participant_status_delimiter_count|
+-----------+--------------------+----------------------------------+
|     577157|0::Killed||1::Kil...|                               204|
+-----------+--------------------+----------------------------------+

Incident IDs for max '||' in participant_type:
+-----------+--------------------+--------------------------------+
|incident_id|    participant_type|participant_type_delimiter_count|
+-----------+--------------------+--------------------------------+
|     577157|0::Victim||1::Vic...|                             204|
+-----------+--------------------+--------------------------------+



After looking at the `gun_stolen` and `gun_type` are having `Unknown` in their data rather than any meaning full values.
As far as rest of the columns go most of them are categorical data so we can creat columns for those and have a count of those values too apart from `participant_age` which we need to figure out a way to store.

And the data of `participant_` as prefix are related with theri indexing of the values I guess. So we will get the `n` values only where `n` is the lowest count of any of these columns.

##### Handle Unknown Values

We can remove these or replace with empty strings for all of these columns.

In [45]:
def clean_unknown_values(df, columns):
    """
    Clean 'n::Unknown', 'n:Unknown', and mixed patterns from specified columns,
    while preserving valid data.

    Parameters:
    df (DataFrame): The input DataFrame to clean.
    columns (list): List of column names to clean.

    Returns:
    DataFrame: The cleaned DataFrame with unknown patterns removed.
    """
    unknown_pattern = r"(\d+[:]{1,2}Unknown)(\|+)?"

    for col_name in columns:
        df = df.withColumn(
            col_name,
            regexp_replace(col(col_name), unknown_pattern, "")
        )

        df = df.withColumn(
            col_name,
            regexp_replace(col(col_name), r"\|\|+", "||")
        )

        df = df.withColumn(
            col_name,
            trim(regexp_replace(col(col_name), r"^\|\||\|\|$", ""))
        )

    return df

In [46]:
ip_data = clean_unknown_values(ip_data, columns_to_check)

In [47]:
for col_name in columns_to_check:
    max_count = max_delimiters_count(ip_data, col_name)
    print(f"Max number of '||' in {col_name}: {int(max_count)}")

Max number of '||' in gun_stolen: 138


                                                                                

Max number of '||' in gun_type: 170


                                                                                

Max number of '||' in participant_age: 130


                                                                                

Max number of '||' in participant_age_group: 204


                                                                                

Max number of '||' in participant_gender: 154


                                                                                

Max number of '||' in participant_status: 204


[Stage 209:>                                                        (0 + 1) / 1]

Max number of '||' in participant_type: 204


                                                                                

In [48]:
for col_name in columns_to_check:
    result_df = max_delimiters_data(ip_data, col_name)
    print(f"Incident IDs for max '||' in {col_name}:")
    result_df.show(1) 

Incident IDs for max '||' in gun_stolen:


                                                                                

+-----------+--------------------+--------------------------+
|incident_id|          gun_stolen|gun_stolen_delimiter_count|
+-----------+--------------------+--------------------------+
|     366742|0::Stolen||1::Sto...|                       138|
+-----------+--------------------+--------------------------+



                                                                                

Incident IDs for max '||' in gun_type:


                                                                                

+-----------+--------------------+------------------------+
|incident_id|            gun_type|gun_type_delimiter_count|
+-----------+--------------------+------------------------+
|     623687|0::Handgun||1::Ha...|                     170|
+-----------+--------------------+------------------------+

Incident IDs for max '||' in participant_age:


                                                                                

+-----------+--------------------+-------------------------------+
|incident_id|     participant_age|participant_age_delimiter_count|
+-----------+--------------------+-------------------------------+
|     577157|0::34||1::23||2::...|                            130|
+-----------+--------------------+-------------------------------+



                                                                                

Incident IDs for max '||' in participant_age_group:


                                                                                

+-----------+---------------------+-------------------------------------+
|incident_id|participant_age_group|participant_age_group_delimiter_count|
+-----------+---------------------+-------------------------------------+
|     577157| 0::Adult 18+||1::...|                                  204|
+-----------+---------------------+-------------------------------------+



                                                                                

Incident IDs for max '||' in participant_gender:


                                                                                

+-----------+--------------------+----------------------------------+
|incident_id|  participant_gender|participant_gender_delimiter_count|
+-----------+--------------------+----------------------------------+
|     577157|0::Male||1::Male|...|                               154|
+-----------+--------------------+----------------------------------+



                                                                                

Incident IDs for max '||' in participant_status:


                                                                                

+-----------+--------------------+----------------------------------+
|incident_id|  participant_status|participant_status_delimiter_count|
+-----------+--------------------+----------------------------------+
|     577157|0::Killed||1::Kil...|                               204|
+-----------+--------------------+----------------------------------+



                                                                                

Incident IDs for max '||' in participant_type:


[Stage 239:>                                                        (0 + 1) / 1]

+-----------+--------------------+--------------------------------+
|incident_id|    participant_type|participant_type_delimiter_count|
+-----------+--------------------+--------------------------------+
|     577157|0::Victim||1::Vic...|                             204|
+-----------+--------------------+--------------------------------+



                                                                                

In [49]:
def clean_and_encode_cols(df, col_name):
    """
    Cleans and performs one-hot frequency encoding on the specified column.

    Parameters:
    df (DataFrame): Input DataFrame with a delimited column.
    col_name (str): Name of the column to clean and encode.

    Returns:
    DataFrame: DataFrame with one-hot frequency encoded columns.
    """
    pivot_df = ip_data.withColumn("value", explode(split(col(col_name), r"\|{1,2}"))) \
                        .withColumn("value", regexp_replace(col("value"),  r"(:|::)", "")) \
                        .withColumn("value", regexp_replace(col("value"), r"\d+", "")) \
                        .withColumn("value", regexp_replace(col("value"), r"[\[\]{}()]", "")) \
                        .withColumn("value", regexp_replace(col("value"), r"\s*[-_.]\s*", " ")) \
                        .withColumn("value", when(col("value") == "", "Unknown").otherwise(col("value"))) \
                        .withColumn("value", when(col("value") == "Other", "Unknown").otherwise(col("value"))) \
                        .withColumn("value", lower(trim(col("value")))) \
                        .withColumn("value", regexp_replace(col("value"), r"\s+", "_")) \
                        .groupBy("incident_id", "value").agg(count("value").alias("frequency")) \
                        .groupBy("incident_id").pivot("value").agg(sum("frequency")).fillna(0)

    for pivot_col in pivot_df.columns[1:]:
        pivot_df = pivot_df.withColumnRenamed(pivot_col, f"{col_name}_{pivot_col}_freq")

    return pivot_df

In [50]:
guns_info = ["gun_stolen", "gun_type"]

final_data = ip_data

for col_name in guns_info:
    guns_one_hot_df = clean_and_encode_cols(final_data, col_name)
    final_data = final_data.join(guns_one_hot_df, on="incident_id", how="left")

                                                                                

For participants data columns we are not going to to for `participant_age` as it has so many values and since we have `participant_age_group` we will use that.

In [52]:
def clean_and_encode_particpant_cols(df, col_name):
    """
    Cleans and performs one-hot frequency encoding on the specified column.

    Parameters:
    df (DataFrame): Input DataFrame with a delimited column.
    col_name (str): Name of the column to clean and encode.

    Returns:
    DataFrame: DataFrame with one-hot frequency encoded columns.
    """
    pivot_df = ip_data.withColumn("value", explode(split(col(col_name), r"\|{1,2}"))) \
                        .withColumn("value", regexp_replace(col("value"),  r"(:|::)", "")) \
                        .withColumn("value", regexp_replace(col("value"), r"\+", "plus")) \
                        .withColumn("value", regexp_replace(col("value"), r"^\d+", "")) \
                        .withColumn("value", when(col("value") == "", "Unknown").otherwise(col("value"))) \
                        .withColumn("value", when(col("value") == "Other", "Unknown").otherwise(col("value"))) \
                        .withColumn("value", lower(trim(col("value")))) \
                        .withColumn("value", regexp_replace(col("value"), r"[\s-]+", "_")) \
                        .groupBy("incident_id", "value").agg(count("value").alias("frequency")) \
                        .groupBy("incident_id").pivot("value").agg(sum("frequency")).fillna(0)

    for pivot_col in pivot_df.columns[1:]:
        pivot_df = pivot_df.withColumnRenamed(pivot_col, f"{col_name}_{pivot_col}_freq")

    return pivot_df

In [53]:
participant_cols = ["participant_age_group", "participant_gender", "participant_status", "participant_type"]

In [54]:
cleaned_ip_data = final_data

for col_name in participant_cols:
    participants_one_hot_df = clean_and_encode_particpant_cols(cleaned_ip_data, col_name)
    cleaned_ip_data = cleaned_ip_data.join(participants_one_hot_df, on="incident_id", how="left")

                                                                                

Aggregating `participant_gender_male,_female_freq` to both male and female freq then dropping it, since only one occurance of it is there.

And combining and optimizing the `participant_status` columns by aggregating the counts based on the relationships. The goal is to simplify the representation of participant statuses while ensuring that the totals remain accurate.

For each main category, aggregate counts from the relevant columns into a single count. For example, if we have columns for `injured,_arrested`, you would add those counts to both `injured_freq` and `arrested_freq`.

And then dropping the unwanted columns

In [56]:
cleaned_ip_data = cleaned_ip_data.withColumn("participant_status_arrested_freq", 
                                             col("participant_status_arrested_freq") +
                                             col("participant_status_injured,_arrested_freq") +
                                             col("participant_status_killed,_arrested_freq") +
                                             col("participant_status_unharmed,_arrested_freq")) \
                                .withColumn("participant_status_injured_freq", 
                                            col("participant_status_injured_freq") +
                                            col("participant_status_injured,_arrested_freq") +
                                            col("participant_status_killed,_injured_freq") +
                                            col("participant_status_injured,_unharmed_freq") +
                                            col("participant_status_injured,_unharmed,_arrested_freq")) \
                                .withColumn("participant_status_killed_freq", 
                                            col("participant_status_killed_freq") +
                                            col("participant_status_killed,_injured_freq") +
                                            col("participant_status_killed,_unharmed_freq") +
                                            col("participant_status_killed,_arrested_freq") +
                                            col("participant_status_killed,_unharmed,_arrested_freq")) \
                                .withColumn("participant_status_unharmed_freq", 
                                            col("participant_status_unharmed_freq") +
                                            col("participant_status_injured,_unharmed_freq") +
                                            col("participant_status_killed,_unharmed_freq") +
                                            col("participant_status_unharmed,_arrested_freq") +
                                            col("participant_status_killed,_unharmed,_arrested_freq")) \
                                .withColumn("participant_gender_female_freq", 
                                            col("participant_gender_female_freq") +
                                            col("participant_gender_male,_female_freq")) \
                                .withColumn("participant_gender_male_freq", 
                                            col("participant_gender_male_freq") +
                                            col("participant_gender_male,_female_freq"))

In [57]:
drop_cols = ["participant_status_injured,_arrested_freq", "participant_status_killed,_arrested_freq",
             "participant_status_unharmed,_arrested_freq", "participant_status_killed,_injured_freq",
             "participant_status_injured,_unharmed_freq", "participant_status_injured,_unharmed,_arrested_freq",
             "participant_status_killed,_unharmed_freq", "participant_status_killed,_arrested_freq",
             "participant_gender_male,_female_freq", "participant_age_group", "participant_gender",
             "participant_status", "participant_type", "participant_status_killed,_unharmed,_arrested_freq"]

In [58]:
cleaned_ip_data = cleaned_ip_data.drop(*drop_cols)

#### Cleaning Text Columns

We have `incident_characteristics` and `notes`

Starting with `incident_characteristics` then will implement similar nlp pre-processing tasks on both `incident_characteristics` and `notes`

In [60]:
def clean_txt_data(df, column_names):
    """
    Clean the specified column of the DataFrame by replacing certain patterns and normalizing the text.

    Parameters:
    df (DataFrame): The input DataFrame.
    column_name (str): The name of the column to clean.

    Returns:
    DataFrame: The DataFrame with the cleaned column.
    """
    for column_name in column_names:
        df = df.withColumn(column_name,
                           regexp_replace(col(column_name), r"\|{1,2}", "; ")) \
               .withColumn(column_name,
                           regexp_replace(col(column_name), r"/", " ")) \
               .withColumn(column_name,
                           regexp_replace(col(column_name), r"[^\w\s;]", "")) \
               .withColumn(column_name,
                           lower(trim(regexp_replace(col(column_name), r"\s{2,}", " "))))
    
    return df

In [61]:
columns_to_clean = ["notes", "incident_characteristics"]
cleaned_ip_data = clean_txt_data(cleaned_ip_data, columns_to_clean)

In [62]:
cleaned_ip_data = cleaned_ip_data.withColumn("notes",
                                             regexp_replace(col("notes"), r'[\d\s\.]+$', '')) \
                                .withColumn("notes",
                                            regexp_replace(col("notes"), r'\s*;\s*$', '')) \
                                .withColumn("notes",
                                            regexp_replace(col("notes"), r'\byr\b', 'year')) \
                                .withColumn("notes",
                                            regexp_replace(col("notes"), r'\binured\b', 'injured')) \
                                .withColumn("notes",
                                            regexp_replace(col("notes"), r'\s+', ' ')) \
                                .withColumn("notes", trim(col("notes")))

In [63]:
cleaned_ip_data.printSchema()

root
 |-- incident_id: integer (nullable = true)
 |-- date: date (nullable = true)
 |-- state: string (nullable = true)
 |-- city_or_county: string (nullable = true)
 |-- address: string (nullable = true)
 |-- n_killed: integer (nullable = true)
 |-- n_injured: integer (nullable = true)
 |-- congressional_district: integer (nullable = true)
 |-- gun_stolen: string (nullable = true)
 |-- gun_type: string (nullable = true)
 |-- incident_characteristics: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- n_guns_involved: integer (nullable = true)
 |-- notes: string (nullable = true)
 |-- participant_age: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- gun_stolen_not_stolen_freq: long (nullable = true)
 |-- gun_stolen_stolen_freq: long (nullable = true)
 |-- gun_stolen_unknown_freq: long (nullable = true)
 |-- gun_type_ak_freq: lon

#### Handling `NULL` Values

In [66]:
# narrow_null_counts = get_null_counts(cleaned_ip_data)
# narrow_null_counts.show(n=50, truncate=False)