# Firearm Background Checks and Mortality Rates
## Dana Laufer

In [1]:
import pandas as pd
import numpy as np
import altair as alt

In [2]:
pd.set_option("display.max_columns", None)

In [3]:
# Load in the data
background_data = pd.read_csv("firearms.csv")
mortality_data = pd.read_csv('mortality.csv')

# Data Processing and Transforming

In [4]:
background_data.head()

Unnamed: 0,month,state,permit,permit_recheck,handgun,long_gun,other,multiple,admin,prepawn_handgun,prepawn_long_gun,prepawn_other,redemption_handgun,redemption_long_gun,redemption_other,returned_handgun,returned_long_gun,returned_other,rentals_handgun,rentals_long_gun,private_sale_handgun,private_sale_long_gun,private_sale_other,return_to_seller_handgun,return_to_seller_long_gun,return_to_seller_other,totals
0,2020-05,Alabama,28064.0,534.0,32394.0,13899.0,1468.0,1016,0.0,21.0,17.0,5.0,1853.0,1162.0,13.0,2.0,0.0,0.0,0.0,0.0,38.0,23.0,8.0,1.0,1.0,0.0,80519
1,2020-05,Alaska,55.0,218.0,4469.0,3196.0,429.0,219,0.0,1.0,2.0,0.0,131.0,99.0,0.0,35.0,14.0,0.0,0.0,0.0,8.0,5.0,2.0,0.0,0.0,0.0,8883
2,2020-05,Arizona,4737.0,577.0,28089.0,11334.0,2053.0,2402,0.0,9.0,8.0,3.0,1089.0,654.0,8.0,101.0,12.0,1.0,0.0,0.0,10.0,10.0,1.0,0.0,0.0,0.0,51098
3,2020-05,Arkansas,3868.0,603.0,11572.0,7005.0,528.0,453,4.0,17.0,12.0,1.0,824.0,995.0,2.0,0.0,0.0,0.0,0.0,0.0,14.0,15.0,4.0,0.0,0.0,0.0,25917
4,2020-05,California,23948.0,0.0,50811.0,28550.0,6499.0,0,0.0,0.0,0.0,0.0,575.0,437.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,110823


In [5]:
mortality_data.head()

Unnamed: 0,State,Year,Deaths,Population,Crude Rate
0,Alabama,1999.0,790,4430141,17.8
1,Alabama,2000.0,766,4447100,17.2
2,Alabama,2001.0,737,4467634,16.5
3,Alabama,2002.0,724,4480089,16.2
4,Alabama,2003.0,765,4503491,17.0


In [6]:
mortality_data.columns = mortality_data.columns.str.lower()
mortality_data.dropna(inplace = True)
mortality_data.loc[:, 'year'] = mortality_data.loc[:, 'year'].astype(int)
mortality_data.loc[:, 'year'] = mortality_data.loc[:, 'year'].astype(str)

In [7]:
background_data.fillna(0, inplace = True)
background_data.rename(columns = {"month": "date"}, inplace = True)

# Split the date into month and year
background_data[['year','month']] = background_data.date.str.split(pat = "-", expand=True)

# Reorder the columns
cols = [background_data.columns[-1]] + [col for col in background_data if col != background_data.columns[-1]]
cols = [cols[-1]] + [col for col in cols if col != cols[-1]]
background_data = background_data[cols]

# Estimate sales
background_data.loc[:, 'sales'] = background_data.loc[:, ['handgun', 'long_gun', 'other', 'multiple', 'multiple']].sum(axis = 1)

In [8]:
# Mortality data only has years, so transform background data 
background_data_per_year = background_data.groupby(['year', 'state'], as_index = False).sum()

# Mortality data only has data from 1999 to 2018
data_combined = background_data_per_year.query('year != "2020" & year != "2019" & year != "1998"')
data_combined = data_combined.merge(mortality_data)

data_combined_year = data_combined.groupby('year', as_index = False).sum()

In [9]:
# Make new variables with population parameter
data_combined.loc[:, 'total_rate'] = data_combined['totals']*100000/data_combined['population']
data_combined.loc[:, 'sale_rate'] = data_combined['sales']*100000/data_combined['population']
data_combined

Unnamed: 0,year,state,permit,permit_recheck,handgun,long_gun,other,multiple,admin,prepawn_handgun,prepawn_long_gun,prepawn_other,redemption_handgun,redemption_long_gun,redemption_other,returned_handgun,returned_long_gun,returned_other,rentals_handgun,rentals_long_gun,private_sale_handgun,private_sale_long_gun,private_sale_other,return_to_seller_handgun,return_to_seller_long_gun,return_to_seller_other,totals,sales,deaths,population,crude rate,total_rate,sale_rate
0,1999,Alabama,0.0,0.0,94544.0,149017.0,0.0,3195,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,246756,249951.0,790,4430141,17.8,5569.935584,5642.055185
1,1999,Alaska,2.0,0.0,14339.0,27790.0,0.0,942,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,43073,44013.0,88,624779,14.1,6894.117760,7044.570960
2,1999,Arizona,20503.0,0.0,78103.0,71365.0,0.0,3453,124.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,173548,156374.0,822,5023823,16.4,3454.500686,3112.649470
3,1999,Arkansas,4271.0,0.0,50523.0,126875.0,0.0,2946,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,184616,183290.0,388,2651860,14.6,6961.755145,6911.752506
4,1999,California,101132.0,0.0,371893.0,410119.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,883144,782012.0,3054,33499204,9.1,2636.313388,2334.419648
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1015,2018,Virginia,18450.0,305.0,253744.0,168859.0,34835.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,472.0,88.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,476760,457438.0,1035,8517685,12.2,5597.295509,5370.449835
1016,2018,Washington,177654.0,736.0,202826.0,148345.0,33596.0,9382,24.0,162.0,158.0,33.0,14738.0,12222.0,177.0,5619.0,1370.0,136.0,0.0,0.0,10377.0,8549.0,859.0,179.0,153.0,6.0,627301,403531.0,809,7535591,10.7,8324.509650,5355.001353
1017,2018,West Virginia,52481.0,0.0,75539.0,76310.0,4213.0,4892,84.0,155.0,174.0,1.0,12545.0,14418.0,21.0,131.0,10.0,17.0,0.0,0.0,368.0,284.0,28.0,2.0,4.0,1.0,241678,165846.0,343,1805832,19.0,13383.194007,9183.910796
1018,2018,Wisconsin,131991.0,2.0,155441.0,145652.0,12339.0,583,1.0,0.0,13.0,86.0,1645.0,3019.0,64.0,1103.0,335.0,26.0,0.0,0.0,0.0,220.0,0.0,0.0,0.0,0.0,452520,314598.0,598,5813568,10.3,7783.860101,5411.444400


In [10]:
# Make new variables divided by population
data_combined_year.loc[:, 'total_rate'] = data_combined_year['totals']*100000/data_combined_year['population']
data_combined_year.loc[:, 'sale_rate'] = data_combined_year['sales']*100000/data_combined_year['population']
data_combined_year

Unnamed: 0,year,permit,permit_recheck,handgun,long_gun,other,multiple,admin,prepawn_handgun,prepawn_long_gun,prepawn_other,redemption_handgun,redemption_long_gun,redemption_other,returned_handgun,returned_long_gun,returned_other,rentals_handgun,rentals_long_gun,private_sale_handgun,private_sale_long_gun,private_sale_other,return_to_seller_handgun,return_to_seller_long_gun,return_to_seller_other,totals,sales,deaths,population,crude rate,total_rate,sale_rate
0,1999,1037700.0,0.0,2532530.0,5216322.0,0.0,103669,148115.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9038336,7956190.0,28874,279040168,576.9,3239.080619,2851.270502
1,2000,1227814.0,0.0,2187598.0,4778762.0,0.0,95681,131648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8421503,7157722.0,28663,281421906,560.2,2992.483108,2543.413234
2,2001,1408338.0,0.0,2161178.0,4941987.0,0.0,96984,100126.0,1274.0,3266.0,0.0,26999.0,72322.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8812474,7297133.0,29573,284968955,572.1,3092.432999,2560.676478
3,2002,1363211.0,0.0,1838245.0,4407867.0,0.0,92892,76776.0,5082.0,11080.0,0.0,171051.0,392377.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8358581,6431896.0,30242,287625193,590.4,2906.067063,2236.207452
4,2003,1403496.0,0.0,1844608.0,4381438.0,0.0,99034,69946.0,5146.0,8656.0,0.0,181927.0,399702.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8393953,6424114.0,30136,290107933,576.2,2893.389682,2214.387567
5,2004,1345672.0,0.0,1984078.0,4505326.0,0.0,101737,51559.0,3943.0,6459.0,0.0,182539.0,390427.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8571740,6692878.0,29569,292805298,555.2,2927.45386,2285.77763
6,2005,1350193.0,0.0,2234190.0,4582777.0,0.0,109549,13158.0,3151.0,5782.0,0.0,183116.0,375703.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8857619,7036065.0,30694,295516599,571.2,2997.333832,2380.937323
7,2006,2036569.0,0.0,2434092.0,4787879.0,0.0,129899,41792.0,3146.0,5917.0,0.0,186385.0,356847.0,0.0,467.0,14.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9983007,7481769.0,30896,298379912,561.0,3345.736961,2507.464041
8,2007,3077655.0,0.0,2628665.0,4568878.0,0.0,324193,27318.0,2901.0,5504.0,0.0,180790.0,336230.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,11152134,7845929.0,31224,301231207,562.8,3702.18415,2604.620244
9,2008,3697854.0,0.0,3325522.0,4906513.0,0.0,186891,13883.0,2330.0,4205.0,0.0,199205.0,348096.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,12684499,8605817.0,31593,304093966,577.9,4171.24324,2829.98611


# Exploratory Analysis

In [11]:
alt.Chart(data_combined).mark_boxplot().encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('totals', title = 'Total Background Checks')
).properties(
    title = "Total Background Checks by Year"
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
)


In [12]:
alt.Chart(data_combined).mark_line().encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('totals', title = 'Total Background Checks'),
    color = 'state'
).properties(
    title = "Total Background Checks by Year"
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
).configure_legend(
    labelFontSize = 8,
    titleOpacity = 0,
    labelLimit = 500
)


In [13]:
alt.Chart(data_combined).mark_line().encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('sales', title = 'Estimated Firearm Sales'),
    color = 'state'
).properties(
    title = "Estimated Firearm Sales by Year and State"
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
).configure_legend(
    labelFontSize = 8,
    titleOpacity = 0,
    labelLimit = 500
)


In [14]:
alt.Chart(data_combined).mark_line().encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('sale_rate', title = 'Firearm Sales per 100,000'),
    color = 'state'
).properties(
    title = "Estimated Firearm Sales by Year and State per 100,000"
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
).configure_legend(
    labelFontSize = 8,
    titleOpacity = 0,
    labelLimit = 500
)


In [15]:
chart1 = alt.Chart(data_combined_year).mark_line(color = '#d62728').encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('deaths', title = "Number of Deaths from Firearms", axis = alt.Axis(titleColor = '#d62728'))
)

chart2 = alt.Chart(data_combined_year).mark_line(color = '#1f77b4', clip = True).encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('sales', title = 'Firearm Sales', axis = alt.Axis(titleColor = '#1f77b4'))
)

chart1 + chart2


alt.layer(chart1, chart2).resolve_scale(y = 'independent').properties(
    title = "Estimated Firearm Sales and Firearm Deaths per Year",
    width = 500
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
)

In [16]:
chart1 = alt.Chart(data_combined_year).mark_line(color = '#d62728').encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('crude rate', title = "Deaths from Firearms per 100,000", axis = alt.Axis(titleColor = '#d62728'))
)

chart2 = alt.Chart(data_combined_year).mark_line(color = '#1f77b4', clip = True).encode(
    x = alt.X('year', title = 'Year'),
    y = alt.Y('sale_rate', title = 'Estimated Firearm Sales per 100,000', axis = alt.Axis(titleColor = '#1f77b4'))
)

chart1 + chart2


alt.layer(chart1, chart2).resolve_scale(y = 'independent').properties(
    title = "Estimated Firearm Sales and Firearm Deaths per 100,000",
    width = 500
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
)

In [17]:
chart = alt.Chart(data_combined).mark_point().encode(
    x = 'sale_rate',
    y = 'crude rate'
)

chart + chart.transform_regression('sale_rate', 'crude rate').mark_line(color="red")

# Principal Component Analysis

In [18]:
# Create dataframe indexed by state, with columns for every year containing sale rates data 
data_states = data_combined.groupby(['year', 'state'], as_index = False).sum()
data_states.set_index('state', inplace = True)
data_states = data_states.pivot(columns = 'year', values = 'sale_rate')

In [19]:
# Center and align the data
means = np.mean(data_states, axis = 0)
std = np.std(data_states, axis = 0)
data_states_centered = data_states - means
data_states_aligned = data_states_centered/std
data_states_aligned

year,1999,2000,2001,2002,2003,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018
state,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
Alabama,1.018127,0.896363,0.881964,0.82614,0.867361,0.845369,0.851475,0.766697,0.761717,0.965955,1.024525,0.939291,0.852036,0.971253,1.108842,1.13603,1.488297,0.185819,-0.362141,-0.412079
Alaska,1.701905,1.620964,1.541836,1.485907,1.458296,1.638599,1.869583,1.875373,2.062545,2.406928,2.220899,2.545652,2.270227,2.154305,2.228341,2.466039,2.068098,2.007253,2.029792,2.058471
Arizona,-0.215051,-0.297718,-0.341349,-0.400668,-0.529468,-0.459511,-0.325039,-0.301983,-0.451183,-0.521947,-0.563676,-0.50749,-0.443096,-0.419118,-0.510924,-0.507219,-0.48422,-0.418416,-0.36422,-0.364758
Arkansas,1.637151,1.45481,1.255748,0.987212,0.969777,0.921855,0.885277,0.781861,0.623604,0.506898,0.485598,0.378494,0.303895,0.24156,0.184728,0.141488,0.202274,0.105207,0.162801,0.153124
California,-0.594466,-0.687758,-0.734368,-0.908109,-1.162761,-1.140553,-1.058131,-1.09209,-0.615756,-1.163301,-1.161478,-1.111451,-1.083941,-1.107511,-1.067393,-0.943634,-1.079366,-0.789553,-1.080363,-1.083756
Colorado,1.301888,2.051773,2.387886,0.861357,0.666578,0.743406,0.78696,0.665391,0.990322,1.391058,1.737577,1.29031,1.404457,0.786501,0.973452,0.821072,0.886892,1.054083,1.139775,1.085886
Connecticut,-1.07938,-1.611313,-1.283761,-0.820537,-0.813348,-0.943675,-0.947948,-0.958458,-0.851302,-0.719908,-0.611109,-0.580331,-0.678747,-0.666963,-0.653156,-0.352135,-0.126586,-0.302004,-0.839965,-0.851878
Delaware,-0.613309,-0.632114,-0.577096,-0.497231,-0.540511,-0.591138,-0.598814,-0.705329,-0.762263,-0.741774,-0.676411,-0.726647,-0.703117,-0.645669,-0.46694,-0.15695,-0.12921,-0.080885,-0.158302,-0.158377
District of Columbia,-1.730786,-1.610015,-1.621878,-1.818127,-1.833125,-1.834284,-1.823765,-1.790648,-1.785156,-1.723233,-1.776259,-1.7724,-1.790066,-1.827231,-1.880799,-1.826596,-1.855384,-1.984118,-1.90233,-1.824746
Florida,-0.910764,-0.822228,-0.729326,-0.747561,-0.738026,-0.718341,-0.59269,-0.545472,-0.511313,-0.457214,-0.392815,-0.360016,-0.358475,-0.342101,-0.335838,-0.286146,-0.236746,-0.125973,-0.055352,-0.126991


In [20]:
# Perform SVD
u, s, vt = np.linalg.svd(data_states_aligned, full_matrices = False)
first_2_pcs = (data_states_aligned @ vt.T).loc[:, [0,1]]
first_2_pcs.columns = ['pc1', 'pc2']
first_2_pcs = first_2_pcs.reset_index(drop = True)

In [21]:
state_names = list(data_states.index)
first_2_pcs['state'] = state_names

alt.Chart(first_2_pcs).mark_text().encode(
    x = 'pc1',
    y = 'pc2',
    text = 'state'
).properties(width=450).interactive()

In [22]:
brush = alt.selection(type='interval')

labels = alt.Chart(first_2_pcs).mark_text().encode(
 x = alt.X('pc1', title = "First Principal Component"),
 y = alt.Y('pc2', title = "Second Principal Component"),
 text = 'state',
 color = alt.condition(brush, alt.value('blue'), alt.value('grey'))
).properties(width=450).add_selection(brush)

# Base chart for data tables
ranked_text = alt.Chart(first_2_pcs).mark_text().encode(
    y = alt.Y('row_number:O', axis = None)
).transform_window(
    row_number='row_number()'
).transform_filter(
    brush
).transform_window(
    rank='rank(row_number)'
).transform_filter(
    alt.datum.rank<20
)

# Data Tables
origin = ranked_text.encode(text = 'state:N').properties(title = 'State', width=100)
                                                         
# Build chart
alt.hconcat(
    labels,origin
).properties(
    title = "Firearm Sales per 100,000 Principal Component Analysis"
)

In [23]:
df = pd.DataFrame({'v':vt[0, :], 'Column names':list(data_states.columns)})

alt.Chart(df).mark_bar().encode(
    x = 'Column names',
    y = alt.Y('v', title = ""),
    opacity=alt.value(0.7)
).configure_axis(
    labelFontSize=12,
    titleFontSize=14
).configure_axisX(
    labelAngle = 0
).properties(title = "Scalar Multiplied by Each Column for First PCA", width = 700)


In [24]:
df = pd.DataFrame({'v':vt[1, :], 'Column names':list(data_states.columns)})
alt.Chart(df).mark_bar().encode(
 x='Column names',
 y='v',
 opacity=alt.value(0.7)
 ).configure_axis(
 labelFontSize=12,
 titleFontSize=14
 ).configure_axisX(
 labelAngle = 0
 ).properties(width = 700)


In [25]:
explained_var = pd.DataFrame({
 'PC #': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'Fraction of Variance Explained' : s[range(0,10)]**2/sum(s**2)
})
# Draw your Altair visualization
alt.Chart(explained_var,
 title = "Variance Explained by Principal Components"
).mark_bar(size=30).encode(
 alt.X('PC #:O'),
 alt.Y('Fraction of Variance Explained:Q')
).configure_axisX(labelAngle=0).properties(width=500)


# Regression Model

In [26]:
X = data_combined[['sale_rate']]
Y = data_combined.loc[:, 'crude rate']

In [27]:
from sklearn import linear_model as lm
linear_model = lm.LinearRegression()

In [28]:
# Fit the linear model
linear_model.fit(X, Y)

# Generate predictions
Y_hat = linear_model.predict(X)

# Coefficients in the linear regression model
coefs = linear_model.coef_
intercept = linear_model.intercept_
print("Y intercept:", intercept, "\nSlope:", coefs)

Y intercept: 7.890055727489272 
Slope: [0.00095711]


In [29]:
regression = pd.DataFrame({'x': data_combined['sale_rate'], 'y': Y_hat})

chart1 = alt.Chart(data_combined).mark_point().encode(
    x = alt.X('sale_rate', title = 'Firearm Sales per 100,000'),
    y = alt.Y('crude rate', title = 'Firearm Deaths per 100,000')
)

chart2 = alt.Chart(regression).mark_line(color = 'red').encode(
    x = 'x',
    y = 'y'
)

alt.layer(chart1 + chart2).properties(
    title = "Estimated Firearm Sales and Firearm Deaths per 100,000",
    width = 500
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
)


In [30]:
residuals = pd.DataFrame({'True Deaths per 100,000': Y, 'Residuals': Y - Y_hat})

alt.Chart(residuals, title = "Residual Plot of Deaths by Firearm per 100,000").mark_point().encode(
    y = 'Residuals',
    x = 'True Deaths per 100,000'
).configure_axis(
 labelFontSize = 12,
 titleFontSize = 16
).configure_title(
 fontSize = 17
)

In [31]:
from sklearn.metrics import mean_squared_error as mse

rmse = np.sqrt(mse(Y, Y_hat))
print("Root mean squared error:", rmse)

Root mean squared error: 3.942136777472344


There's lots of sales rates very close to zero, so I'm going to get rid of these outliers and see if my model improves.

In [32]:
data_narrowed = data_combined.query('sale_rate > 150')

X_narrow = data_narrowed[['sale_rate']]
Y_narrow = data_narrowed.loc[:, 'crude rate']

data_combined

Unnamed: 0,year,state,permit,permit_recheck,handgun,long_gun,other,multiple,admin,prepawn_handgun,prepawn_long_gun,prepawn_other,redemption_handgun,redemption_long_gun,redemption_other,returned_handgun,returned_long_gun,returned_other,rentals_handgun,rentals_long_gun,private_sale_handgun,private_sale_long_gun,private_sale_other,return_to_seller_handgun,return_to_seller_long_gun,return_to_seller_other,totals,sales,deaths,population,crude rate,total_rate,sale_rate
0,1999,Alabama,0.0,0.0,94544.0,149017.0,0.0,3195,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,246756,249951.0,790,4430141,17.8,5569.935584,5642.055185
1,1999,Alaska,2.0,0.0,14339.0,27790.0,0.0,942,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,43073,44013.0,88,624779,14.1,6894.117760,7044.570960
2,1999,Arizona,20503.0,0.0,78103.0,71365.0,0.0,3453,124.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,173548,156374.0,822,5023823,16.4,3454.500686,3112.649470
3,1999,Arkansas,4271.0,0.0,50523.0,126875.0,0.0,2946,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,184616,183290.0,388,2651860,14.6,6961.755145,6911.752506
4,1999,California,101132.0,0.0,371893.0,410119.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,883144,782012.0,3054,33499204,9.1,2636.313388,2334.419648
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1015,2018,Virginia,18450.0,305.0,253744.0,168859.0,34835.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,472.0,88.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,476760,457438.0,1035,8517685,12.2,5597.295509,5370.449835
1016,2018,Washington,177654.0,736.0,202826.0,148345.0,33596.0,9382,24.0,162.0,158.0,33.0,14738.0,12222.0,177.0,5619.0,1370.0,136.0,0.0,0.0,10377.0,8549.0,859.0,179.0,153.0,6.0,627301,403531.0,809,7535591,10.7,8324.509650,5355.001353
1017,2018,West Virginia,52481.0,0.0,75539.0,76310.0,4213.0,4892,84.0,155.0,174.0,1.0,12545.0,14418.0,21.0,131.0,10.0,17.0,0.0,0.0,368.0,284.0,28.0,2.0,4.0,1.0,241678,165846.0,343,1805832,19.0,13383.194007,9183.910796
1018,2018,Wisconsin,131991.0,2.0,155441.0,145652.0,12339.0,583,1.0,0.0,13.0,86.0,1645.0,3019.0,64.0,1103.0,335.0,26.0,0.0,0.0,0.0,220.0,0.0,0.0,0.0,0.0,452520,314598.0,598,5813568,10.3,7783.860101,5411.444400


In [33]:
# Fit the linear model
linear_model.fit(X_narrow, Y_narrow)

# Generate predictions
Y_hat_narrow = linear_model.predict(X_narrow)

# Coefficients in the linear regression model
coefs_narrow = linear_model.coef_
intercept_narrow = linear_model.intercept_
print("Y intercept:", intercept_narrow, "\nSlope:", coefs_narrow)

Y intercept: 7.169716097053719 
Slope: [0.00108975]


In [34]:
regression_narrow = pd.DataFrame({'x': data_narrowed['sale_rate'], 'y': Y_hat_narrow})

chart1 = alt.Chart(data_narrowed).mark_point().encode(
    x = alt.X('sale_rate', title = 'Firearm Sales per 100,000'),
    y = alt.Y('crude rate', title = 'Firearm Deaths per 100,000')
)

chart2 = alt.Chart(regression_narrow).mark_line(color = 'red').encode(
    x = 'x',
    y = 'y'
)

alt.layer(chart1 + chart2).properties(
    title = "Estimated Firearm Sales and Firearm Deaths per 100,000",
    width = 500
).configure_title(fontSize = 18).configure_axis(
    titleFontSize = 16,
    labelFontSize = 14
)


In [35]:
residuals = pd.DataFrame({'True Deaths per 100,000': Y, 'Residuals': Y - Y_hat})

alt.Chart(residuals, title = "Residual Plot of Deaths by Firearm per 100,000").mark_point().encode(
    y = 'Residuals',
    x = 'True Deaths per 100,000'
).configure_axis(
 labelFontSize = 12,
 titleFontSize = 16
).configure_title(
 fontSize = 17
)

In [36]:
rmse_narrow = np.sqrt(mse(Y_narrow, Y_hat_narrow))
print("Root mean squared error:", rmse_narrow)

Root mean squared error: 3.3390103136704723
