## Gathering state-level features

The purpose of this section is to gather state-level features that may affect the degree to which a given state is suspectible or resistant to a virus such as the flu or Covid-19. Collecting these state-level characteristics can help us identify which features are responsible for the correlation in viral infection rates between states, and thus can also be used to quantify the correlation between states based on fundamental attributes of the states rather than just the raw wILI time series. 

In [2]:
import json
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)

The density of a state is a natural feature to include because the denser a location, the more easily a virus can spread (look no further than NYC right now). However, it wouldn't make sense to report the density of a state because, for example, the high population density in Manhattan shouldn't be influenced by the fact that upstate New York State has a massive amount of scarsely populated land. Instead, a more sensible measure is a weighted average of the densities of each county in a given state, where the weights are the fraction of the state population that lives in the given county.

In [3]:
# dataset that reports the land area in square miles of each county in the U.S.
land_df = pd.read_csv('land_area.csv')

# dataset that reports the population of each county in the U.S.
popn_df = pd.read_csv('population.csv')

In [4]:
land_df.head()

Unnamed: 0,Areaname,STCOU,LND010190F,LND010190D,LND010190N1,LND010190N2,LND010200F,LND010200D,LND010200N1,LND010200N2,LND110180F,LND110180D,LND110180N1,LND110180N2,LND110190F,LND110190D,LND110190N1,LND110190N2,LND110200F,LND110200D,LND110200N1,LND110200N2,LND110210F,LND110210D,LND110210N1,LND110210N2,LND210190F,LND210190D,LND210190N1,LND210190N2,LND210200F,LND210200D,LND210200N1,LND210200N2
0,UNITED STATES,0,0,3787425.08,0,0,0,3794083.06,0,0,0,3539289.16,0,0,0,3536341.73,0,0,0,3537438.44,0,0,0,3531905.43,0,0,0,251083.35,0,0,0,256644.62,0,0
1,ALABAMA,1000,0,52422.94,0,0,0,52419.02,0,0,0,50767.18,0,0,0,50750.23,0,0,0,50744.0,0,0,0,50645.33,0,0,0,1672.71,0,0,0,1675.01,0,0
2,"Autauga, AL",1001,0,604.49,0,0,0,604.45,0,0,0,597.04,0,0,0,596.01,0,0,0,595.97,0,0,0,594.44,0,0,0,8.48,0,0,0,8.48,0,0
3,"Baldwin, AL",1003,0,2027.08,0,0,0,2026.93,0,0,0,1589.42,0,0,0,1596.53,0,0,0,1596.35,0,0,0,1589.78,0,0,0,430.55,0,0,0,430.58,0,0
4,"Barbour, AL",1005,0,904.59,0,0,0,904.52,0,0,0,883.89,0,0,0,885.0,0,0,0,884.9,0,0,0,884.88,0,0,0,19.59,0,0,0,19.61,0,0


In [5]:
popn_df.head()

Unnamed: 0,Areaname,STCOU,PST045200F,PST045200D,PST045200N1,PST045200N2,PST045201F,PST045201D,PST045201N1,PST045201N2,PST045202F,PST045202D,PST045202N1,PST045202N2,PST045203F,PST045203D,PST045203N1,PST045203N2,PST045204F,PST045204D,PST045204N1,PST045204N2,PST045205F,PST045205D,PST045205N1,PST045205N2,PST045206F,PST045206D,PST045206N1,PST045206N2,PST045207F,PST045207D,PST045207N1,PST045207N2,PST045208F,PST045208D,PST045208N1,PST045208N2,PST045209F,PST045209D,PST045209N1,PST045209N2
0,UNITED STATES,0,0,282171957,0,0,0,285081556,0,0,0,287803914,0,0,0,290326418,0,0,0,293045739,0,0,0,295753151,0,0,0,298593212,0,0,0,301579895,0,0,0,304374846,0,0,0,307006550,0,0
1,ALABAMA,1000,0,4451849,0,0,0,4464034,0,0,0,4472420,0,0,0,4490591,0,0,0,4512190,0,0,0,4545049,0,0,0,4597688,0,0,0,4637904,0,0,0,4677464,0,0,0,4708708,0,0
2,"Autauga, AL",1001,0,43872,0,0,0,44434,0,0,0,45157,0,0,0,45762,0,0,0,46933,0,0,0,47870,0,0,0,49105,0,0,0,49834,0,0,0,50354,0,0,0,50756,0,0
3,"Baldwin, AL",1003,0,141358,0,0,0,144988,0,0,0,148141,0,0,0,151707,0,0,0,156573,0,0,0,162564,0,0,0,168516,0,0,0,172815,0,0,0,176212,0,0,0,179878,0,0
4,"Barbour, AL",1005,0,29035,0,0,0,29223,0,0,0,29289,0,0,0,29480,0,0,0,29458,0,0,0,29452,0,0,0,29556,0,0,0,29736,0,0,0,29836,0,0,0,29737,0,0


In [6]:
land_df = land_df[['Areaname', 'LND010190D']]
popn_df = popn_df[['Areaname', 'PST045200D']]

In [7]:
# limit analysis to Lower 48 states
lower_48 = ["AL", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", 
            "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD", 
            "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", 
            "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC", 
            "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"]

state_end = tuple(', ' + abbrev for abbrev in lower_48)

In [8]:
# ignore AL and HI
filtered_land_df = land_df[land_df.Areaname.str.endswith(state_end)]
filtered_popn_df = popn_df[land_df.Areaname.str.endswith(state_end)]

In [9]:
filtered_popn_df.shape

(3111, 2)

In [10]:
# There are 5 counties in Virginia that are included twice in both the land area and population datasets
# so we need to ignore the duplicated row
virginia_counties_df = filtered_land_df[filtered_land_df.Areaname.str.endswith(', VA')]
indices_to_delete = []
counties_set = set()
for index, row in virginia_counties_df.iterrows():
    county = row['Areaname']
    if county not in counties_set:
        counties_set.add(county)
    else:
        indices_to_delete.append(index)
        
filtered_land_df = filtered_land_df[~filtered_land_df.index.isin(indices_to_delete)]
filtered_popn_df = filtered_popn_df[~filtered_popn_df.index.isin(indices_to_delete)]

In [11]:
len(filtered_popn_df)

3106

In [12]:
# merge land area and population datasets
combined_df = pd.merge(filtered_land_df, filtered_popn_df, on='Areaname', how='inner')

In [13]:
# extract state from Areaname column
combined_df['state'] = combined_df.Areaname.str[-2:]
combined_df.head()

Unnamed: 0,Areaname,LND010190D,PST045200D,state
0,"Autauga, AL",604.49,43872,AL
1,"Baldwin, AL",2027.08,141358,AL
2,"Barbour, AL",904.59,29035,AL
3,"Bibb, AL",625.5,19936,AL
4,"Blount, AL",650.65,51181,AL


In [14]:
# rename column names
combined_df.rename(columns={'Areaname': 'county', 'LND010190D': 'area', 'PST045200D': 'popn'}, inplace=True)

In [15]:
# fill in missing value of land area of Broomfield, CO from Wikipedia page
combined_df.loc[combined_df.county == 'Broomfield, CO', 'area'] = 33.00

In [16]:
# calculate density of each county by dividing population by land area
combined_df['density'] = combined_df['popn'] / combined_df['area']

In [17]:
# calculate total population of each state accross all counties
state2pop = combined_df.groupby('state').agg({'popn': sum}).to_dict()['popn']
combined_df['state_popn'] = [state2pop[state] for state in combined_df.state]
combined_df.head()

Unnamed: 0,county,area,popn,state,density,state_popn
0,"Autauga, AL",604.49,43872,AL,72.576883,4451849
1,"Baldwin, AL",2027.08,141358,AL,69.734791,4451849
2,"Barbour, AL",904.59,29035,AL,32.097414,4451849
3,"Bibb, AL",625.5,19936,AL,31.872102,4451849
4,"Blount, AL",650.65,51181,AL,78.661339,4451849


In [18]:
# calculate density metric for each state by weighing the density of each population by the fraction of 
# the state population that lives in the given state
state2density_metric = (combined_df.groupby('state').
                        apply(lambda x: round(x['popn'] * (x['density'] ** 1) / x['state_popn'], 1))
                        .groupby('state').sum()).to_dict()

In [19]:
# sort states in order of decreasing density
sorted_density_metrics = sorted(list(state2density_metric.values()), reverse=True)
density_metric2state = {v: k for k, v in state2density_metric.items()}
ordered_density_metric2state = {x: density_metric2state[x] for x in sorted_density_metrics}

In [20]:
# create dataframe with this first state-level feature
state_stats_df = pd.DataFrame(ordered_density_metric2state.keys(), columns=['density_metric'], 
                              index=ordered_density_metric2state.values())

In [21]:
state_stats_df.head()

Unnamed: 0,density_metric
NY,10711.4
NJ,2789.6
PA,1957.6
IL,1761.9
MD,1737.6


In [22]:
# dataset that lists the average latitude of each state
latlong_df = pd.read_csv('statelatlong.csv')
latlong_df.head()

Unnamed: 0,State,Latitude,Longitude,City
0,AL,32.601011,-86.680736,Alabama
1,AK,61.302501,-158.77502,Alaska
2,AZ,34.168219,-111.930907,Arizona
3,AR,34.751928,-92.131378,Arkansas
4,CA,37.271875,-119.270415,California


In [23]:
# include this latitude value in the feature dataframe
state_stats_df1 = (pd.merge(state_stats_df, latlong_df[['Latitude', 'State']],
                           left_index=True, right_on='State').drop(columns=['State']))
state_stats_df1.index = ordered_density_metric2state.values()

In [24]:
# states in Lower 48 that are on either the Atlantic or Pacific Ocean. This can potentially be an important
# feature because tourists and immigrants usually fly into the country in a coastal location
coastal_states = set('ME NH MA RI CT NY NJ PA MD DE VA NC SC GA FL WA OR CA'.split())
state_stats_df1['is_coastal'] = [int(state in coastal_states) for state in state_stats_df.index]

A potentially important state-level feature is the number of airline passengers arriving in the state. As we've seen with Covid-19, clusters have started in particular locations because visiters have come into these places with the virus from foreigns countries. The most readily available source for this data are the 'List of airports in [state]' Wikipedia article for each state. Each of these pages contains the number of commerical passenger boardings in 2016 for each airport in the state. Although commerical passenger arrivals are not included, it's reasonable to assume that the number of boardings and arrivals are closely related to each other. The values in the dictionary below represents the sum of the number of commerical passenger arrivals for the major airports in each state. Note: the number of major airports variesby state (e.g. the only major airport in Massachusetts in Logan, there are no major airports in Delaware, and there are three major airports in Kentucky (Cincinatti, Louisville and Lexington). Finally, the number of annual boardings in each state in normalized by the population of the given state, as this metric represents the relative influence of air traffic on the given state.

In [25]:
state2passengers = {'NY': 50868391, 
                    'PA': 15285948 + 4670954 + 636916, 
                    'NJ': 19923009 + 589091,
                    'MD': 13371816,
                    'IL': round((83245472 / 2) + (22027737 / 2)),
                    'MA': 17759044,
                    'VA': 11470854 + 10596942 + 1777648 + 1602631,
                    'MO': 6793076 + 5391557 + 462126,
                    'CA': (39636042 + 25707101 + 10340164 + 5934639 + 5321603 + 5217242 
                           + 4969366 + 2104625 + 2077892 + 1386357 + 995801 + 761298),
                    'MI': 16847135 + 1334979 + 398508,
                    'CO': 28267394 + 657694,
                    'MN': 18123844,
                    'TX': 31283579 + 20062072 + 7554596 + 6285181 + 6095545 + 4179994 + 1414376,
                    'RI': 1803000,
                    'GA': 50501858 + 1056265,
                    'OH': 4083476 + 3567864 + 1019922 + 685553,
                    'CT': 2982194,
                    'IN': 4216766 + 360369 + 329957 + 204352,
                    'DE': 0,
                    'KY': 3269979 + 1631494 + 638316,
                    'FL': (20875813 + 20283541 + 14263270 + 9194994 + 4239261 + 3100624 + 2729129 
                           + 1321675 + 986766 + 915672 + 589860),
                    'NE': 2127387 + 162876,
                    'UT': 11143738,
                    'OR': 9071154,
                    'TN': 6338517 + 2016089 + 887103,
                    'LA': 5569705 + 364200,
                    'OK': 1796473 + 1342315,
                    'NC': 21511880 + 5401714 + 848261,
                    'KS': 781944,
                    'WA': 21887110 + 1570652,
                    'WI': 3496724 + 1043185 + 348026 + 314909,
                    'NH': 995403,
                    'AL': 1304467 + 527801 + 288209 + 173210,
                    'NM': 2341719,
                    'IA': 1216357 + 547786,
                    'AZ': 20896265 + 1594594 + 705731,
                    'SC': 1811695 + 991276 + 944849 + 553658,
                    'AR': 958824 + 673810,
                    'WV': 213412,
                    'ID': 1633507,
                    'NV': 22833267 + 1771864,
                    'ME': 886343 + 269013,
                    'MS': 491464 + 305157,
                    'VT': 593311,
                    'SD': 510105 + 272537,
                    'ND': 402976 + 273980 + 150634 + 132557 + 68829,
                    'MT': 553245 + 423213 + 381582 + 247816 + 176730 + 103239,
                    'WY': 342044 + 92805}

In [26]:
# population of each state according to the 2010 census
state2popn_2010 = {
        'AL': 4779736,
        'AR': 2915918,
        'AZ': 6392017,
        'CA': 37253956,
        'CO': 5029196,
        'CT': 3574097,
        'DE': 897934,
        'FL': 18801310,
        'GA': 9687653,
        'IA': 3046355,
        'ID': 1567582,
        'IL': 12830632,
        'IN': 6483802,
        'KS': 2853118,
        'KY': 4339367,
        'LA': 4533372,
        'MA': 6547629,
        'MD': 5773552,
        'ME': 1328361,
        'MI': 9883640,
        'MN': 5303925,
        'MO': 5988927,
        'MS': 2967297,
        'MT': 989415,
        'NC': 9535483,
        'ND': 672591,
        'NE': 1826341,
        'NH': 1316470,
        'NJ': 8791894,
        'NM': 2059179,
        'NV': 2700551,
        'NY': 19378102,
        'OH': 11536504,
        'OK': 3751351,
        'OR': 3831074,
        'PA': 12702379,
        'RI': 1052567,
        'SC': 4625364,
        'SD': 814180,
        'TN': 6346105,
        'TX': 25145561,
        'UT': 2763885,
        'VA': 8001024,
        'VT': 625741,
        'WA': 6724540,
        'WI': 5686986,
        'WV': 1852994,
        'WY': 563626
}

In [27]:
state_stats_df1['airport_boardings'] = [state2passengers[state] / state2popn_2010[state]
                                        for state in state_stats_df.index]

In [28]:
state_stats_df1.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings
NY,10711.4,40.705626,1,2.625045
NJ,2789.6,40.143006,1,2.33307
PA,1957.6,40.994593,1,1.621257
IL,1761.9,39.739318,0,4.102417
MD,1737.6,38.806352,1,2.316047


In [29]:
abbrev2state = {
        'AK': 'Alaska',
        'AL': 'Alabama',
        'AR': 'Arkansas',
        'AS': 'American Samoa',
        'AZ': 'Arizona',
        'CA': 'California',
        'CO': 'Colorado',
        'CT': 'Connecticut',
        'DC': 'District of Columbia',
        'DE': 'Delaware',
        'FL': 'Florida',
        'GA': 'Georgia',
        'GU': 'Guam',
        'HI': 'Hawaii',
        'IA': 'Iowa',
        'ID': 'Idaho',
        'IL': 'Illinois',
        'IN': 'Indiana',
        'KS': 'Kansas',
        'KY': 'Kentucky',
        'LA': 'Louisiana',
        'MA': 'Massachusetts',
        'MD': 'Maryland',
        'ME': 'Maine',
        'MI': 'Michigan',
        'MN': 'Minnesota',
        'MO': 'Missouri',
        'MP': 'Northern Mariana Islands',
        'MS': 'Mississippi',
        'MT': 'Montana',
        'NA': 'National',
        'NC': 'North Carolina',
        'ND': 'North Dakota',
        'NE': 'Nebraska',
        'NH': 'New Hampshire',
        'NJ': 'New Jersey',
        'NM': 'New Mexico',
        'NV': 'Nevada',
        'NY': 'New York',
        'OH': 'Ohio',
        'OK': 'Oklahoma',
        'OR': 'Oregon',
        'PA': 'Pennsylvania',
        'PR': 'Puerto Rico',
        'RI': 'Rhode Island',
        'SC': 'South Carolina',
        'SD': 'South Dakota',
        'TN': 'Tennessee',
        'TX': 'Texas',
        'UT': 'Utah',
        'VA': 'Virginia',
        'VI': 'Virgin Islands',
        'VT': 'Vermont',
        'WA': 'Washington',
        'WI': 'Wisconsin',
        'WV': 'West Virginia',
        'WY': 'Wyoming'
}

state2abbrev = {v: k for k, v in abbrev2state.items()}

In [30]:
# dataframe that reports the fraction of each state's population that falls into a set of age categories
age_df = pd.read_csv('age.csv')

In [31]:
# merge age dataframe with dataframe that contains the rest of the features
age_df['Location'] = [state2abbrev[state] for state in age_df.Location]
state_stats_df2 = (pd.merge(state_stats_df1, age_df, left_index=True, right_on='Location')
                  .drop(columns=['Location']))
state_stats_df2.index = ordered_density_metric2state.values()

In [32]:
state_stats_df2.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+
NY,10711.4,40.705626,1,2.625045,0.22,0.09,0.13,0.26,0.14,0.16
NJ,2789.6,40.143006,1,2.33307,0.23,0.08,0.11,0.27,0.14,0.16
PA,1957.6,40.994593,1,1.621257,0.22,0.08,0.12,0.25,0.14,0.18
IL,1761.9,39.739318,0,4.102417,0.24,0.09,0.12,0.26,0.13,0.15
MD,1737.6,38.806352,1,2.316047,0.23,0.08,0.12,0.27,0.14,0.15


In [33]:
# dataset that reports the average temperature of each state during each of the four seasons of the year
temps_df = pd.read_csv('temps.csv')

In [34]:
temps_df['State'] = [state2abbrev[state] for state in temps_df.State]

In [35]:
# merge temperature dataframe with dataframe that contains the rest of the features
state_stats_df3 = (pd.merge(state_stats_df2, temps_df, left_index=True, right_on='State')
                  .drop(columns=['State']))
state_stats_df3.index = ordered_density_metric2state.values()

In [36]:
state_stats_df3.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+,spring,summer,fall,winter
NY,10711.4,40.705626,1,2.625045,0.22,0.09,0.13,0.26,0.14,0.16,43.6,66.5,48.1,23.3
NJ,2789.6,40.143006,1,2.33307,0.23,0.08,0.11,0.27,0.14,0.16,50.6,72.2,54.8,33.0
PA,1957.6,40.994593,1,1.621257,0.22,0.08,0.12,0.25,0.14,0.18,47.4,68.6,50.9,28.4
IL,1761.9,39.739318,0,4.102417,0.24,0.09,0.12,0.26,0.13,0.15,51.6,73.4,53.8,28.3
MD,1737.6,38.806352,1,2.316047,0.23,0.08,0.12,0.27,0.14,0.15,52.8,73.3,56.1,34.7


It's possible that state-level political policies have an impact on the proliferation of virus infections. The Cook Partisan Voting Index taken from Wikipedia assigns a number to each state that indicates how strongly the state leads toward the Republican or Democratic Party based on recent state and federal elections. In our convention, a positive value signifies leaning Republican, while a negative value signifies leading Democratic.

In [37]:
state2partisan_score = {
        'AL': 14,
        'AR': 15,
        'AZ': 5,
        'CA': -12,
        'CO': 1,
        'CT': -6,
        'DE': -6,
        'FL': 2,
        'GA': 5,
        'IA': 3,
        'ID': 19,
        'IL': -7,
        'IN': 9,
        'KS': 13,
        'KY': 15,
        'LA': 11,
        'MA': -12,
        'MD': -12,
        'ME': -3,
        'MI': -1,
        'MN': -1,
        'MO': 9,
        'MS': 9,
        'MT': 11,
        'NC': 3,
        'ND': 17,
        'NE': 14,
        'NH': 0,
        'NJ': -7,
        'NM': -3,
        'NV': -1,
        'NY': -12,
        'OH': 3,
        'OK': 20,
        'OR': -5,
        'PA': 0,
        'RI': -10,
        'SC': 8,
        'SD': 15,
        'TN': 14,
        'TX': 8,
        'UT': 20,
        'VA': -1,
        'VT': -15,
        'WA': -7,
        'WI': 0,
        'WV': 19,
        'WY': 25
}

In [38]:
state_stats_df3['partisan_score'] = [state2partisan_score[state] for state in state_stats_df3.index]

The following dataset was taken from a Stat139 problem set last semester and contains a range of socioeconomic, demographic and health indicators. These include:

Cancer: prevalence of cancer per 100,000 individuals

Hispanic: percent of adults that are hispanic

Minority: percent of adults that are nonwhite

Female: percent of adults that are female

Income: median income

Nodegree: percent of adults who have not completed high school

Bachelor: percent of adults with a bachelor’s degree

Inactive: percent of adults who do not exercise in their leisure time

Obesity: percent of individuals with BMI > 30

Cancer: prevalence of cancer per 100,000 individuals

We're not considering unemployment rate, as these rates are likely no longer accurate for many states.

Just as with the density metric, the state-level value for each of these features is determined by calculating a weighted average of the measurements for each county, where the weights are the fraction of the state population that lives in the given county.

In [39]:
county_metrics_df = pd.read_csv('county_metrics.csv')

In [40]:
county_metrics_df.head()

Unnamed: 0,state,fipscode,county,population,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,density,cancer
0,Colorado,8117,Summit County,27239,15.173,4.918,45.996,2.5,68352,5.4,48.1,8.1,13.1,46.0,46.2
1,Colorado,8037,Eagle County,53653,30.04,5.169,47.231,3.1,76661,10.1,47.3,9.4,11.8,31.0,47.1
2,Idaho,16067,Minidoka County,19226,34.07,5.611,49.318,3.7,46332,24.1,11.8,18.3,34.2,80.0,61.8
3,Colorado,8113,San Miguel County,7558,10.154,4.747,46.808,3.7,59603,4.7,54.4,12.4,16.7,5.7,62.6
4,Utah,49051,Wasatch County,21600,13.244,4.125,48.812,3.4,65207,9.5,34.4,13.9,23.0,257.8,68.3


In [41]:
county_metrics_df['state'] = [state2abbrev[state] for state in county_metrics_df.state]

In [42]:
county_metrics_df = county_metrics_df[county_metrics_df.state.isin(lower_48)]
county_metrics_df.head()

Unnamed: 0,state,fipscode,county,population,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,density,cancer
0,CO,8117,Summit County,27239,15.173,4.918,45.996,2.5,68352,5.4,48.1,8.1,13.1,46.0,46.2
1,CO,8037,Eagle County,53653,30.04,5.169,47.231,3.1,76661,10.1,47.3,9.4,11.8,31.0,47.1
2,ID,16067,Minidoka County,19226,34.07,5.611,49.318,3.7,46332,24.1,11.8,18.3,34.2,80.0,61.8
3,CO,8113,San Miguel County,7558,10.154,4.747,46.808,3.7,59603,4.7,54.4,12.4,16.7,5.7,62.6
4,UT,49051,Wasatch County,21600,13.244,4.125,48.812,3.4,65207,9.5,34.4,13.9,23.0,257.8,68.3


In [43]:
state2pop_ = county_metrics_df.groupby('state').agg({'population': sum}).to_dict()['population']
county_metrics_df['state_popn'] = [state2pop_[state] for state in county_metrics_df.state]

In [44]:
county_metrics_df.head()

Unnamed: 0,state,fipscode,county,population,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,density,cancer,state_popn
0,CO,8117,Summit County,27239,15.173,4.918,45.996,2.5,68352,5.4,48.1,8.1,13.1,46.0,46.2,5022460
1,CO,8037,Eagle County,53653,30.04,5.169,47.231,3.1,76661,10.1,47.3,9.4,11.8,31.0,47.1,5022460
2,ID,16067,Minidoka County,19226,34.07,5.611,49.318,3.7,46332,24.1,11.8,18.3,34.2,80.0,61.8,1351143
3,CO,8113,San Miguel County,7558,10.154,4.747,46.808,3.7,59603,4.7,54.4,12.4,16.7,5.7,62.6,5022460
4,UT,49051,Wasatch County,21600,13.244,4.125,48.812,3.4,65207,9.5,34.4,13.9,23.0,257.8,68.3,2481585


In [45]:
metrics = ['hispanic', 'minority', 'female', 'unemployed', 'income', 'nodegree', 'bachelor', 'inactivity',
          'obesity', 'cancer']

for metric in metrics:
    state2metric = (county_metrics_df.groupby('state').
                    apply(lambda x: round((x['population'] * x[metric]) / x['state_popn'], 3))
                    .groupby('state').sum()).to_dict()
    
    denom = 1000 if metric == 'income' else 1
    state_stats_df3[metric] = [state2metric[state] / denom for state in state_stats_df3.index]

The more people travel between states, the more closely related the states should be in terms of rate of virus infections. The Census Bureau Journey to Work datset reports the number of people that commute from any given county in the county to any other county in the country. This means we can aggregate these county to county commuting flows to determine the number of people that commute between any two states. From this data, we can create a symmetric matrix where the $i,j$ and $j,i$ elements represent the number of people that commute from state $i$ to state $j$ plus the number of people that commute from state $j$ to state $i$. However, just as with the number of annual boardings in each state, the final value of the number of people who commute between two states in normalized by the popualation of the given state. This means that this commuting matrix is no longer symmetric because the populations of state $i$ and state $j$ are different. 

In [46]:
commuting_df_complete = pd.read_csv('commuting.csv')

In [47]:
commuting_df_complete.columns

Index(['State FIPS Code', 'County FIPS Code', 'State Name', 'County Name',
       'State FIPS Code.1', 'County FIPS Code.1', 'State Name.1',
       'County Name.1', 'Workers in Commuting Flow', ' Margin of Error',
       'Unnamed: 10', 'Unnamed: 11', 'Unnamed: 12', 'Unnamed: 13'],
      dtype='object')

In [48]:
commuting_df = commuting_df_complete[['State Name', 'State Name.1', 'Workers in Commuting Flow']]

In [49]:
commuting_df.rename(columns={'State Name': 'home_state', 
                             'State Name.1': 'work_state', 
                             'Workers in Commuting Flow': 'commuters'}, 
                   inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  errors=errors,


In [50]:
lower_48_full_name = [abbrev2state[abbrev] for abbrev in lower_48]
commuting_df = commuting_df[commuting_df.work_state.isin(lower_48_full_name)]

In [51]:
commuting_df['home_state'] = [state2abbrev[state] for state in commuting_df.home_state]
commuting_df['work_state'] = [state2abbrev[state] for state in commuting_df.work_state]

In [52]:
commuting_df.head(10)

Unnamed: 0,home_state,work_state,commuters
0,AL,AL,8828
1,AL,AL,22
2,AL,AL,7
3,AL,AL,309
4,AL,AL,17
5,AL,AL,11
6,AL,AL,210
7,AL,AL,2244
8,AL,AL,27
9,AL,AL,35


In [53]:
commuting_df['commuters'] = commuting_df['commuters'].apply(lambda x: int(''.join([y for y in x if y.isdigit()])))

In [54]:
commuting_groupby_df = (commuting_df.groupby(['work_state', 'home_state'], as_index=False)
                       .agg({'commuters': 'sum'}))

In [55]:
# calculate the number of commuters between two states for all pairs of states
for work_state in state_stats_df3.index:
    vals = []
    for home_state in state_stats_df3.index:
        try:
            num1 = int((commuting_groupby_df[(commuting_groupby_df.work_state == work_state)
                       & (commuting_groupby_df.home_state == home_state)].commuters))
            num2 = int((commuting_groupby_df[(commuting_groupby_df.work_state == home_state)
                       & (commuting_groupby_df.home_state == work_state)].commuters))
            num = num1 + num2
            
            num /= state2popn_2010[work_state]
            
        except TypeError:
            num = 0

        vals.append(num)

    state_stats_df3[work_state + '_dest'] = vals

In [56]:
state_stats_df3.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+,spring,summer,fall,winter,partisan_score,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,cancer,NY_dest,NJ_dest,PA_dest,IL_dest,MD_dest,MA_dest,VA_dest,CA_dest,RI_dest,MI_dest,TX_dest,MO_dest,MN_dest,CT_dest,GA_dest,OH_dest,CO_dest,DE_dest,FL_dest,IN_dest,UT_dest,KY_dest,NE_dest,TN_dest,OR_dest,LA_dest,NC_dest,OK_dest,WA_dest,KS_dest,WI_dest,NH_dest,AZ_dest,SC_dest,AL_dest,IA_dest,NM_dest,WV_dest,NV_dest,AR_dest,ID_dest,ME_dest,MS_dest,VT_dest,SD_dest,ND_dest,MT_dest,WY_dest
NY,10711.4,40.705626,1,2.625045,0.22,0.09,0.13,0.26,0.14,0.16,43.6,66.5,48.1,23.3,-12,18.737,29.756,51.436,5.347,61.622797,14.719,33.425,24.046,24.675,200.165,0.911538,0.061864,0.004793,0.000216,0.000722,0.001985,0.000495,0.000199,0.001692,0.000196,0.000148,0.000147,0.000207,0.033303,0.000321,0.000261,0.000237,0.002078,0.00057,0.000153,0.000171,0.000122,0.000156,0.000232,9.8e-05,0.000264,0.000368,8.9e-05,0.000126,0.000117,0.000155,0.001047,0.000143,0.000251,0.000122,0.000119,8.8e-05,0.000168,0.000233,9.9e-05,8e-05,0.000857,0.000131,0.012476,8.1e-05,0.0,0.000105,0.000209
NJ,2789.6,40.143006,1,2.33307,0.23,0.08,0.11,0.27,0.14,0.16,50.6,72.2,54.8,33.0,-7,19.372,27.143,51.199,5.676,72.829584,11.59,36.24,22.909,25.672,200.319,0.028068,0.823014,0.019686,0.000128,0.000793,0.000447,0.000315,7.2e-05,0.000472,8.9e-05,9.9e-05,0.000112,7.9e-05,0.001903,0.000142,0.000171,0.000154,0.016296,0.000245,8.1e-05,6.7e-05,9.1e-05,6.5e-05,9.8e-05,5.2e-05,8e-05,0.000205,5.9e-05,6.7e-05,6e-05,6.9e-05,0.000368,0.000132,0.000183,5.1e-05,1.5e-05,2.5e-05,0.000137,8.1e-05,4.5e-05,4.1e-05,0.000325,3.4e-05,0.000566,0.0,0.0,0.000114,0.0
PA,1957.6,40.994593,1,1.621257,0.22,0.08,0.12,0.25,0.14,0.18,47.4,68.6,50.9,28.4,0,6.767,17.371,51.067,5.199,54.9607,11.11,28.104,23.155,29.331,230.835,0.003142,0.028442,0.877457,0.000199,0.012778,0.000454,0.000731,7.9e-05,0.000352,0.000216,0.000143,0.000188,0.000158,0.000587,0.000222,0.002901,0.000164,0.069599,0.000248,0.000159,0.000125,0.000229,5.8e-05,0.000202,5.3e-05,0.00018,0.000314,0.000203,5.8e-05,0.000106,0.000117,0.000366,0.000156,0.00025,0.000151,9.1e-05,0.0001,0.011997,0.000114,0.000207,6.5e-05,0.000299,0.000101,0.000294,8e-05,0.000137,0.000137,0.000232
IL,1761.9,39.739318,0,4.102417,0.24,0.09,0.12,0.26,0.13,0.15,51.6,73.4,53.8,28.3,-7,16.918,22.729,50.902,5.963,59.266824,12.397,32.078,21.199,27.34,225.471,0.000143,0.000187,0.000201,0.900515,0.00015,0.000173,0.000197,0.000141,0.000231,0.000784,0.000194,0.018036,0.000492,0.00021,0.000285,0.000342,0.000297,0.000204,0.000304,0.015667,0.000116,0.001515,0.000491,0.000354,9.2e-05,0.000159,0.000204,0.000159,0.000138,0.000288,0.011498,0.000137,0.000223,0.000114,0.000131,0.018256,7.1e-05,0.000108,0.000336,0.000304,0.00019,9.6e-05,0.000181,0.000142,0.000141,0.000346,0.000173,0.000213
MD,1737.6,38.806352,1,2.316047,0.23,0.08,0.12,0.27,0.14,0.15,52.8,73.3,56.1,34.7,-12,9.407,40.312,51.537,5.266,75.296062,11.089,37.129,21.545,28.94,225.486,0.000215,0.000521,0.005808,6.8e-05,0.845245,0.000221,0.024575,3.8e-05,0.000227,6.7e-05,7.6e-05,9.7e-05,7.7e-05,0.00014,0.000167,0.000103,6.1e-05,0.045597,0.000184,7.6e-05,4.1e-05,7.4e-05,4.3e-05,8.7e-05,3.7e-05,6.8e-05,0.000281,1.5e-05,7.6e-05,4.5e-05,5.4e-05,0.000109,2.5e-05,0.000143,6.5e-05,2.3e-05,1.6e-05,0.015275,5e-05,4.2e-05,0.0,0.00014,5.9e-05,8.3e-05,0.0,0.000113,1.1e-05,0.0


States that are in close proximity may be similarly affected by viruses. Therefore, we include a column for each state in the design matrix that denotes whether that given states borders each of the other states.

In [57]:
# dictionary that maps each state in the Lower 48 to the states that directly border it or are not contiguous
# but are very close (e.g. NJ and CT)
state2neighbors = {'AL': {'AL', 'MS', 'TN', 'FL', 'GA', 'NC', 'SC'},
                  'GA': {'GA', 'TN', 'FL', 'AL', 'SC', 'NC', 'MS'},
                  'FL': {'FL', 'GA', 'AL', 'MS', 'SC'},
                  'MS': {'MS', 'AL', 'TN', 'FL', 'LA', 'AR', 'GA'},
                  'LA': {'LA', 'TX', 'AR', 'MS', 'OK', 'AL'},
                  'SC': {'SC', 'FL', 'GA', 'NC', 'TN'},
                  'NC': {'NC', 'SC', 'GA', 'TN', 'VA', 'KY'},
                  'AR': {'AR', 'LA', 'TX', 'MS', 'TN', 'OK', 'MO', 'KY'},
                  'VA': {'VA', 'NC', 'KY', 'WV', 'TN', 'DC', 'MD', 'DE'},
                  'MD': {'MD', 'DC', 'VA', 'WV', 'DE', 'NJ', 'PA'},
                  'DE': {'DE', 'MD', 'DC', 'NJ', 'PA'},
                  'NJ': {'NJ', 'DE', 'MD', 'PA', 'NY', 'NJ', 'CT'},
                  'NY': {'NY', 'NJ', 'PA', 'CT', 'MA', 'VT'},
                  'CT': {'CT', 'NY', 'RI', 'MA', 'NJ'},
                  'RI': {'RI', 'CT', 'MA'},
                  'MA': {'MA', 'CT', 'RI', 'NH', 'VT', 'NY'},
                  'NH': {'NH', 'VT', 'ME', 'MA'},
                  'ME': {'ME', 'NH', 'MA', 'VT'},
                  'VT': {'VT', 'NH', 'NY', 'MA'},
                  'PA': {'PA', 'NY', 'NJ', 'MD', 'WV', 'OH', 'DE'},
                  'WV': {'WV', 'DC', 'MD', 'PA', 'OH', 'KY', 'VA'},
                  'OH': {'OH', 'PA', 'WV', 'MI', 'IN', 'KY'},
                  'MI': {'MI', 'OH', 'WI', 'IN', 'IL'},
                  'KY': {'KY', 'WV', 'OH', 'IN', 'IL', 'MO', 'TN', 'VA', 'AR', 'NC'},
                  'TN': {'TN', 'KY', 'VA', 'NC', 'SC', 'GA', 'AL', 'MS', 'AR', 'MO', 'IL'},
                  'IN': {'IN', 'KY', 'OH', 'MI', 'IL', 'WI'},
                  'IL': {'IL', 'IN', 'MI', 'WI', 'IA', 'MO', 'KY', 'TN'},
                  'WI': {'WI', 'IL', 'MN', 'MI', 'IA'},
                  'MN': {'MN', 'MI', 'WI', 'IA', 'ND', 'SD', 'NE', 'IL'},
                  'IA': {'IA', 'WI', 'MN', 'IL', 'MO', 'KS', 'NE', 'SD'},
                  'MO': {'MO', 'IA', 'IL', 'KY', 'TN', 'AR', 'OK', 'KS', 'NE'},
                  'ND': {'ND', 'SD', 'MN', 'MT', 'WY'},
                  'SD': {'SD', 'ND', 'MN', 'IA', 'NE', 'MT', 'WY'},
                  'NE': {'NE', 'SD', 'IA', 'MO', 'KS', 'WY', 'CO'},
                  'KS': {'KS', 'NE', 'IA', 'MO', 'AR', 'OK', 'CO', 'TX', 'NM'},
                  'OK': {'OK', 'KS', 'MO', 'AR', 'TX', 'NM', 'CO', 'LA'},
                  'TX': {'TX', 'LA', 'AR', 'OK', 'NM', 'CO'},
                  'MT': {'MT', 'ND', 'SD', 'WY', 'ID'},
                  'WY': {'WY', 'MT', 'ND', 'SD', 'NE', 'CO', 'UT', 'ID'},
                  'CO': {'CO', 'WY', 'NE', 'KS', 'OK', 'TX', 'NM', 'UT', 'AZ'},
                  'NM': {'NM', 'CO', 'KS', 'OK', 'TX', 'AZ', 'UT'},
                  'ID': {'ID', 'MT', 'WY', 'UT', 'NV', 'WA', 'OR'},
                  'UT': {'UT', 'ID', 'WY', 'CO', 'NM', 'AZ', 'NV'},
                  'AZ': {'AZ', 'NM', 'CO', 'UT', 'NV', 'CA'},
                  'WA': {'WA', 'ID', 'OR'},
                  'OR': {'OR', 'WA', 'ID', 'NV', 'CA'},
                  'NV': {'NV', 'ID', 'OR', 'UT', 'AZ', 'CA'},
                  'CA': {'CA', 'OR', 'NV', 'AZ'}
                 }

In [58]:
for neighboring_state in state_stats_df3.index:
    states = [int(neighboring_state in state2neighbors[state]) for state in state_stats_df3.index]
    state_stats_df3[neighboring_state + '_is_neighbor'] = states  

In [59]:
state_stats_df3.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+,spring,summer,fall,winter,partisan_score,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,cancer,NY_dest,NJ_dest,PA_dest,IL_dest,MD_dest,MA_dest,VA_dest,CA_dest,RI_dest,MI_dest,TX_dest,MO_dest,MN_dest,CT_dest,GA_dest,OH_dest,CO_dest,DE_dest,FL_dest,IN_dest,UT_dest,KY_dest,NE_dest,TN_dest,OR_dest,LA_dest,NC_dest,OK_dest,WA_dest,KS_dest,WI_dest,NH_dest,AZ_dest,SC_dest,AL_dest,IA_dest,NM_dest,WV_dest,NV_dest,AR_dest,ID_dest,ME_dest,MS_dest,VT_dest,SD_dest,ND_dest,MT_dest,WY_dest,NY_is_neighbor,NJ_is_neighbor,PA_is_neighbor,IL_is_neighbor,MD_is_neighbor,MA_is_neighbor,VA_is_neighbor,CA_is_neighbor,RI_is_neighbor,MI_is_neighbor,TX_is_neighbor,MO_is_neighbor,MN_is_neighbor,CT_is_neighbor,GA_is_neighbor,OH_is_neighbor,CO_is_neighbor,DE_is_neighbor,FL_is_neighbor,IN_is_neighbor,UT_is_neighbor,KY_is_neighbor,NE_is_neighbor,TN_is_neighbor,OR_is_neighbor,LA_is_neighbor,NC_is_neighbor,OK_is_neighbor,WA_is_neighbor,KS_is_neighbor,WI_is_neighbor,NH_is_neighbor,AZ_is_neighbor,SC_is_neighbor,AL_is_neighbor,IA_is_neighbor,NM_is_neighbor,WV_is_neighbor,NV_is_neighbor,AR_is_neighbor,ID_is_neighbor,ME_is_neighbor,MS_is_neighbor,VT_is_neighbor,SD_is_neighbor,ND_is_neighbor,MT_is_neighbor,WY_is_neighbor
NY,10711.4,40.705626,1,2.625045,0.22,0.09,0.13,0.26,0.14,0.16,43.6,66.5,48.1,23.3,-12,18.737,29.756,51.436,5.347,61.622797,14.719,33.425,24.046,24.675,200.165,0.911538,0.061864,0.004793,0.000216,0.000722,0.001985,0.000495,0.000199,0.001692,0.000196,0.000148,0.000147,0.000207,0.033303,0.000321,0.000261,0.000237,0.002078,0.00057,0.000153,0.000171,0.000122,0.000156,0.000232,9.8e-05,0.000264,0.000368,8.9e-05,0.000126,0.000117,0.000155,0.001047,0.000143,0.000251,0.000122,0.000119,8.8e-05,0.000168,0.000233,9.9e-05,8e-05,0.000857,0.000131,0.012476,8.1e-05,0.0,0.000105,0.000209,1,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0
NJ,2789.6,40.143006,1,2.33307,0.23,0.08,0.11,0.27,0.14,0.16,50.6,72.2,54.8,33.0,-7,19.372,27.143,51.199,5.676,72.829584,11.59,36.24,22.909,25.672,200.319,0.028068,0.823014,0.019686,0.000128,0.000793,0.000447,0.000315,7.2e-05,0.000472,8.9e-05,9.9e-05,0.000112,7.9e-05,0.001903,0.000142,0.000171,0.000154,0.016296,0.000245,8.1e-05,6.7e-05,9.1e-05,6.5e-05,9.8e-05,5.2e-05,8e-05,0.000205,5.9e-05,6.7e-05,6e-05,6.9e-05,0.000368,0.000132,0.000183,5.1e-05,1.5e-05,2.5e-05,0.000137,8.1e-05,4.5e-05,4.1e-05,0.000325,3.4e-05,0.000566,0.0,0.0,0.000114,0.0,1,1,1,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
PA,1957.6,40.994593,1,1.621257,0.22,0.08,0.12,0.25,0.14,0.18,47.4,68.6,50.9,28.4,0,6.767,17.371,51.067,5.199,54.9607,11.11,28.104,23.155,29.331,230.835,0.003142,0.028442,0.877457,0.000199,0.012778,0.000454,0.000731,7.9e-05,0.000352,0.000216,0.000143,0.000188,0.000158,0.000587,0.000222,0.002901,0.000164,0.069599,0.000248,0.000159,0.000125,0.000229,5.8e-05,0.000202,5.3e-05,0.00018,0.000314,0.000203,5.8e-05,0.000106,0.000117,0.000366,0.000156,0.00025,0.000151,9.1e-05,0.0001,0.011997,0.000114,0.000207,6.5e-05,0.000299,0.000101,0.000294,8e-05,0.000137,0.000137,0.000232,1,1,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
IL,1761.9,39.739318,0,4.102417,0.24,0.09,0.12,0.26,0.13,0.15,51.6,73.4,53.8,28.3,-7,16.918,22.729,50.902,5.963,59.266824,12.397,32.078,21.199,27.34,225.471,0.000143,0.000187,0.000201,0.900515,0.00015,0.000173,0.000197,0.000141,0.000231,0.000784,0.000194,0.018036,0.000492,0.00021,0.000285,0.000342,0.000297,0.000204,0.000304,0.015667,0.000116,0.001515,0.000491,0.000354,9.2e-05,0.000159,0.000204,0.000159,0.000138,0.000288,0.011498,0.000137,0.000223,0.000114,0.000131,0.018256,7.1e-05,0.000108,0.000336,0.000304,0.00019,9.6e-05,0.000181,0.000142,0.000141,0.000346,0.000173,0.000213,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
MD,1737.6,38.806352,1,2.316047,0.23,0.08,0.12,0.27,0.14,0.15,52.8,73.3,56.1,34.7,-12,9.407,40.312,51.537,5.266,75.296062,11.089,37.129,21.545,28.94,225.486,0.000215,0.000521,0.005808,6.8e-05,0.845245,0.000221,0.024575,3.8e-05,0.000227,6.7e-05,7.6e-05,9.7e-05,7.7e-05,0.00014,0.000167,0.000103,6.1e-05,0.045597,0.000184,7.6e-05,4.1e-05,7.4e-05,4.3e-05,8.7e-05,3.7e-05,6.8e-05,0.000281,1.5e-05,7.6e-05,4.5e-05,5.4e-05,0.000109,2.5e-05,0.000143,6.5e-05,2.3e-05,1.6e-05,0.015275,5e-05,4.2e-05,0.0,0.00014,5.9e-05,8.3e-05,0.0,0.000113,1.1e-05,0.0,0,1,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0


The proportion of each state that is vaccinated may affect the number of people who are infected with the flu. Therefore, we include information on the adult and child vaccination rate for each state.

In [60]:
flu_df = pd.read_csv('flu.csv')

In [61]:
flu_df['State'] = [state2abbrev[state] for state in flu_df.State]

In [62]:
state_stats_df4 = (pd.merge(state_stats_df3, flu_df, left_index=True, right_on='State').drop(columns=['State']))
state_stats_df4.index = state_stats_df3.index

In [63]:
state_stats_df4.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+,spring,summer,fall,winter,partisan_score,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,cancer,NY_dest,NJ_dest,PA_dest,IL_dest,MD_dest,MA_dest,VA_dest,CA_dest,RI_dest,MI_dest,TX_dest,MO_dest,MN_dest,CT_dest,GA_dest,OH_dest,CO_dest,DE_dest,FL_dest,IN_dest,UT_dest,KY_dest,NE_dest,TN_dest,OR_dest,LA_dest,NC_dest,OK_dest,WA_dest,KS_dest,WI_dest,NH_dest,AZ_dest,SC_dest,AL_dest,IA_dest,NM_dest,WV_dest,NV_dest,AR_dest,ID_dest,ME_dest,MS_dest,VT_dest,SD_dest,ND_dest,MT_dest,WY_dest,NY_is_neighbor,NJ_is_neighbor,PA_is_neighbor,IL_is_neighbor,MD_is_neighbor,MA_is_neighbor,VA_is_neighbor,CA_is_neighbor,RI_is_neighbor,MI_is_neighbor,TX_is_neighbor,MO_is_neighbor,MN_is_neighbor,CT_is_neighbor,GA_is_neighbor,OH_is_neighbor,CO_is_neighbor,DE_is_neighbor,FL_is_neighbor,IN_is_neighbor,UT_is_neighbor,KY_is_neighbor,NE_is_neighbor,TN_is_neighbor,OR_is_neighbor,LA_is_neighbor,NC_is_neighbor,OK_is_neighbor,WA_is_neighbor,KS_is_neighbor,WI_is_neighbor,NH_is_neighbor,AZ_is_neighbor,SC_is_neighbor,AL_is_neighbor,IA_is_neighbor,NM_is_neighbor,WV_is_neighbor,NV_is_neighbor,AR_is_neighbor,ID_is_neighbor,ME_is_neighbor,MS_is_neighbor,VT_is_neighbor,SD_is_neighbor,ND_is_neighbor,MT_is_neighbor,WY_is_neighbor,overall_vacc_rate,child_vacc_rate
NY,10711.4,40.705626,1,2.625045,0.22,0.09,0.13,0.26,0.14,0.16,43.6,66.5,48.1,23.3,-12,18.737,29.756,51.436,5.347,61.622797,14.719,33.425,24.046,24.675,200.165,0.911538,0.061864,0.004793,0.000216,0.000722,0.001985,0.000495,0.000199,0.001692,0.000196,0.000148,0.000147,0.000207,0.033303,0.000321,0.000261,0.000237,0.002078,0.00057,0.000153,0.000171,0.000122,0.000156,0.000232,9.8e-05,0.000264,0.000368,8.9e-05,0.000126,0.000117,0.000155,0.001047,0.000143,0.000251,0.000122,0.000119,8.8e-05,0.000168,0.000233,9.9e-05,8e-05,0.000857,0.000131,0.012476,8.1e-05,0.0,0.000105,0.000209,1,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,81.7,69.6
NJ,2789.6,40.143006,1,2.33307,0.23,0.08,0.11,0.27,0.14,0.16,50.6,72.2,54.8,33.0,-7,19.372,27.143,51.199,5.676,72.829584,11.59,36.24,22.909,25.672,200.319,0.028068,0.823014,0.019686,0.000128,0.000793,0.000447,0.000315,7.2e-05,0.000472,8.9e-05,9.9e-05,0.000112,7.9e-05,0.001903,0.000142,0.000171,0.000154,0.016296,0.000245,8.1e-05,6.7e-05,9.1e-05,6.5e-05,9.8e-05,5.2e-05,8e-05,0.000205,5.9e-05,6.7e-05,6e-05,6.9e-05,0.000368,0.000132,0.000183,5.1e-05,1.5e-05,2.5e-05,0.000137,8.1e-05,4.5e-05,4.1e-05,0.000325,3.4e-05,0.000566,0.0,0.0,0.000114,0.0,1,1,1,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,79.4,72.8
PA,1957.6,40.994593,1,1.621257,0.22,0.08,0.12,0.25,0.14,0.18,47.4,68.6,50.9,28.4,0,6.767,17.371,51.067,5.199,54.9607,11.11,28.104,23.155,29.331,230.835,0.003142,0.028442,0.877457,0.000199,0.012778,0.000454,0.000731,7.9e-05,0.000352,0.000216,0.000143,0.000188,0.000158,0.000587,0.000222,0.002901,0.000164,0.069599,0.000248,0.000159,0.000125,0.000229,5.8e-05,0.000202,5.3e-05,0.00018,0.000314,0.000203,5.8e-05,0.000106,0.000117,0.000366,0.000156,0.00025,0.000151,9.1e-05,0.0001,0.011997,0.000114,0.000207,6.5e-05,0.000299,0.000101,0.000294,8e-05,0.000137,0.000137,0.000232,1,1,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,82.5,69.7
IL,1761.9,39.739318,0,4.102417,0.24,0.09,0.12,0.26,0.13,0.15,51.6,73.4,53.8,28.3,-7,16.918,22.729,50.902,5.963,59.266824,12.397,32.078,21.199,27.34,225.471,0.000143,0.000187,0.000201,0.900515,0.00015,0.000173,0.000197,0.000141,0.000231,0.000784,0.000194,0.018036,0.000492,0.00021,0.000285,0.000342,0.000297,0.000204,0.000304,0.015667,0.000116,0.001515,0.000491,0.000354,9.2e-05,0.000159,0.000204,0.000159,0.000138,0.000288,0.011498,0.000137,0.000223,0.000114,0.000131,0.018256,7.1e-05,0.000108,0.000336,0.000304,0.00019,9.6e-05,0.000181,0.000142,0.000141,0.000346,0.000173,0.000213,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,83.0,60.1
MD,1737.6,38.806352,1,2.316047,0.23,0.08,0.12,0.27,0.14,0.15,52.8,73.3,56.1,34.7,-12,9.407,40.312,51.537,5.266,75.296062,11.089,37.129,21.545,28.94,225.486,0.000215,0.000521,0.005808,6.8e-05,0.845245,0.000221,0.024575,3.8e-05,0.000227,6.7e-05,7.6e-05,9.7e-05,7.7e-05,0.00014,0.000167,0.000103,6.1e-05,0.045597,0.000184,7.6e-05,4.1e-05,7.4e-05,4.3e-05,8.7e-05,3.7e-05,6.8e-05,0.000281,1.5e-05,7.6e-05,4.5e-05,5.4e-05,0.000109,2.5e-05,0.000143,6.5e-05,2.3e-05,1.6e-05,0.015275,5e-05,4.2e-05,0.0,0.00014,5.9e-05,8.3e-05,0.0,0.000113,1.1e-05,0.0,0,1,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,81.9,74.5


Smoking may also affect suspectibility to viruses such as the flu and Covid-19, so we include a feature that reports the fraction of adults who smoke in each state.

In [64]:
state2smoking_rate = {
        'AL': 20.9,
        'AR': 22.3,
        'AZ': 15.6,
        'CA': 11.3,
        'CO': 14.6,
        'CT': 12.7,
        'DE': 17.0,
        'FL': 16.1,
        'GA': 17.5,
        'IA': 17.1,
        'ID': 14.3,
        'IL': 15.5,
        'IN': 21.8,
        'KS': 17.4,
        'KY': 24.6,
        'LA': 23.1,
        'MA': 13.7,
        'MD': 13.8,
        'ME': 17.3,
        'MI': 19.3,
        'MN': 14.5,
        'MO': 20.8,
        'MS': 22.2,
        'MT': 17.2,
        'NC': 17.2,
        'ND': 18.3,
        'NE': 15.4,
        'NH': 15.7,
        'NJ': 13.7,
        'NM': 17.5,
        'NV': 17.6,
        'NY': 14.1,
        'OH': 21.1,
        'OK': 20.1,
        'OR': 16.1,
        'PA': 18.7,
        'RI': 14.9,
        'SC': 18.8,
        'SD': 19.3,
        'TN': 22.6,
        'TX': 15.7,
        'UT': 8.9,
        'VA': 16.4,
        'VT': 15.8,
        'WA': 13.5,
        'WI': 16,
        'WV': 26,
        'WY': 18.7
}

In [65]:
state_stats_df4['smoking_rate'] = [state2smoking_rate[state] / 100 for state in state_stats_df4.index]

In [66]:
state_stats_df4.to_csv('state_stats.csv', index_label=False)

## Bayesian Model

### Motivation

Before describing the model, it's important to first discuss the motivation behind it in the first place. The wILI time series clearly show that the states are affected differently by the flu. Therefore, we wanted to determine whether there are any state-level features that account for the disrepencies between the states. If we could identify these particular features, then we'd also be able to figure out which states are intrinsically linked based on their attributes. 

This information would then allow us to transfer this knowledge about the flu to Covid-19. Because both the flu and Covid are viruses, we'd expect some of the underlying risk factors of flu to generalize to Covid as well. We could then take one of two routes: first, we could assess if the interstate correlations discovered from the flu data apply in the case of Covid by comparing the number of Covid cases among different states. And second, we could assume that the flu relationships apply in the case of Covid and use these insights to look deeper than just the raw Covid numbers. For example, if the flu analysis reveals that two states share many similar characteristics, and one of these states has more Covid cases per 1000 people but also has more testing, then we may believe that the second state has more case of Covid than are reported. Alternatively, we can identify states that, based on their characteristics (e.g. high density, high obesity rate), are more susceptible to a major spike in Covid cases and thus should take additional precautions when opening up their states. 

### Model Formulation

If the state wILI rates are correlated with each other, then we should, in theory, be able to predict the wILI rate in a given state and for a given week from the wILI rates of all the other states for the same week. Because correlated states may have similar flu trajectories but have different raw wILI rates, it's more robust to predict the weekly percent change in wILI rather than the absolute change in wILI. This means that we want to predict the trend in the number of flu cases for each state based on the trends of all the other states at the same time. 

The big question is obviously how to use the percent change in the wILI rate of every other state to predict the percent change in the wILI rate for a single state. Because some states are more closely correlated with a given state than others, it makes sense to predict the percent change for a given state to be a weighted average of the percent changes of the other weeks, where the weights should ideally be proportional to the underlying correlation between the two states. For example, if we were trying to predict the trend in New York, we'd take into account the trend of every other state (except for Alaska and Hawaii), but the influence of each of these states on our overall prediction for New York would vary (e.g. the influence of New Jersey and Connecticut may be high, while the influenced of Idaho and Nebraska may be low).

Converting this into formal notation, let's define $\delta_i$ to be the percent change in the wILI rate between two consecutive weeks for state $i$, and define $\alpha_{ij}$ to be the weight coefficient of state $j$ on state $i$. We predict each $\delta_i$ as:

$$ \delta_i \sim N\left(\frac{\sum_{j=1}^{48}\alpha_{ij}\delta_jI(j \neq i)}{\sum_{j=1}^{48}\alpha_{ij}I(j \neq i)}, {\sigma_{i}}^2\right)$$

where ${\sigma_{i}}^2$ is a state-specific variance. Intuitively, the lower the value of ${\sigma}^2$ for a given state, the more the variation in the state's wILI trend can be explained by the wILI trends of the other states, and vice versa. 

Next, we want to link the $\alpha_{ij}$ weights to the features associated with each state such that states with more similar characteristics and high rates of interstate travel have higher $\alpha_{ij}$ and $\alpha_{ji}$ values and vice versa. Additionally, we only want a few of the $\alpha_{ij}$s corresponding to state $i$ to be large, and the rest to be small (in a similar spirit to regularization). We can accomplish both of these features as follows: first, each $\alpha_{ij}$ is modelled as being distributed according to an exponential distribution with a scale (i.e. inverse rate) parameter of $\lambda_{ij}$. Because an exponential distribution is right skewed and has most of its mass near zero, this ensures that most of the $\alpha_{ij}$ that are drawn from exponential distributions will take on relatively small values, while only a few will take on relatively large values. Next, we link the scale parameter ($\lambda_{ij}$) of this exponential distribution to the state-level features by setting the log of $\lambda_{ij}$ equal to the linear predictor function (taking the log is necessary to map the domain of the scale parameter (all positive real numbers) to the domain of the linear prediction function (the entire real line)). 

Translating this into formal notation:

$$ \alpha_{ij} \sim Expo(\lambda_{ij})$$

$$ log(\lambda_{ij}) = \beta_0 + \beta_1X_1 + ... + \beta_kX_k$$

In this case the linear predictor function is a little different that usual. Two of the predictors (normalized number of commuters between states $i$ and $j$ and the indicator of whether state $j$ borders state $i$) are included in the usual form of $\beta_iX_i$, where a unit increase in $X_i$ corresponds to a $\beta_i$ increase in the linear predictor. However, the rest of the predictors are state-level features such as obesity rate and density. This means that we don't care about the raw values of these features; instead, we only care about the difference between the values for state $i$ and state $j$. Therefore, each of the predictors is defined to be $|X_i - X_j|$, such that the predictor value is 0 when the two states have the same feature value, and increases as the difference between the two states grows. 

Finally, because this is a Bayesian model, we need to define a prior distribution for the model parameters, which in this case are the $\beta$ coefficient associated with each predictor variable and the ${\sigma}^2$ parameter associated with each state. Because we have no substantial prior domain knowledge, we placed uninformative priors on these parameters. Putting all of these components together produces the following generative model:

$$ \delta_i \sim N\left(\frac{\sum_{j=1}^{48}\alpha_{ij}\delta_jI(j \neq i)}{\sum_{j=1}^{48}\alpha_{ij}I(j \neq i)}, {\sigma_{i}}^2\right)$$

$$ \sigma_{i}^{2} \sim Inv-Gamma(2, 2)$$

$$ \alpha_{ij} \sim Expo(\lambda_{ij})$$

$$ log(\lambda_{ij}) = \beta_0 + \beta_1X_1 + ... + \beta_kX_k$$

$$ \beta_i \sim N(0, 5^2) $$

Performing inference for this model yields the posterior distribution of the $\beta$s and the ${\sigma}^2$, but we only really care about the $\beta$s. Because the exponential distribution is parameterized by a scale parameter rather than the usual rate parameter, the expected value of the distribution is equal to the scale parameter. This means that a larger $\lambda_{ij}$ value corresponds, on average, to a higher $\alpha_{ij}$ coefficient, and because the linear predictor function is defined to be the log of $\lambda_{ij}$, this in turn means that a larger linear predictor corresponds, on average, to a higher $\alpha_{ij}$ coefficient. For the two predictors that are not differences between the two given states, this means that a positive $\beta$ coefficent indicates that a unit increase in the predictor value produces a stronger correlation between the two given states and vice versa. On the other hand, for the rest of the predictors that are included as differences between certain features of the two states, a strong correlation between two given states is signified by a negative $\beta$ coefficient. This is the case because the predictor value represents the absolute differences between the features of the states, so a larger predictor value corresponds to a larger discrepancy between the states. Thus, the corresponding $\beta$ coefficient can be interpreted as a penalty parameter, such that states that are less similar in terms of the given feature are less correlated with each other (assuming the $\beta$ coefficient value is negative). 

Overall, the model provides us with two interpretative results. First, the $\beta$ coefficients indicate which features contribute to the correlation between the wILI time series of different states. And second, the $\beta$ coefficients tell us about the $\alpha_{ij}$ weights, which, in turn, inform us about which states are highly correlated with each other based on the fundamental characteristics of the states. 

Finally, one major advantage of this model is that the observations (i.e. the percent change in the wILI rate for a given week) are independent of each other conditioned on the percent changes of the other states for the same week. This means that unlike in a classic time seris model, the past wILI rates of a state are irrelevant to predicting the percent change in the wILI rate at any given time. This greatly simplifies things, as it's much easier to deal with independent observations than it is to handle observations that are correlated with previous observations.

In [67]:
import numpy as np
import pandas as pd
import pymc3 as pm 
pd.set_option('display.max_columns', None)
import matplotlib.pyplot as plt
import theano
theano.config.gcc.cxxflags = "-Wno-c++11-narrowing"

In [68]:
predictor_df = pd.read_csv('state_stats.csv')
predictor_df.drop(index='FL', inplace=True, errors='ignore')
flu_percent_change_df = pd.read_csv('flu_percent_change_imputed_48.csv')
week_nums = flu_percent_change_df.week_num
flu_percent_change_df.drop(columns='week_num', inplace=True)

In [69]:
flu_percent_change_df = flu_percent_change_df[predictor_df.index]

In [70]:
flu_percent_change_df.head()

Unnamed: 0,NY,NJ,PA,IL,MD,MA,VA,CA,RI,MI,TX,MO,MN,CT,GA,OH,CO,DE,IN,UT,KY,NE,TN,OR,LA,NC,OK,WA,KS,WI,NH,AZ,SC,AL,IA,NM,WV,NV,AR,ID,ME,MS,VT,SD,ND,MT,WY
0,0.405867,0.070409,0.294947,0.131695,0.590957,0.468467,0.081422,0.101601,-0.142493,-0.129229,0.013946,-0.155423,-0.058293,2.040475,0.0674,0.584302,-0.04853,2.680899,0.258545,-0.450393,0.567004,0.243452,0.818616,0.131938,-0.102285,0.009597,0.535902,1.039189,0.305868,0.251192,0.217754,0.11148,0.715931,-0.03503,0.529058,-0.078995,-0.004839,-0.018298,-0.031462,2.51373,0.215616,0.020946,-0.080357,-0.280169,0.075379,3.875146,-0.303818
1,-0.020577,-0.0929,-0.032551,0.079386,-0.307034,-0.161128,-0.008961,0.041377,0.800318,-0.103884,0.003238,0.063802,-0.037191,0.571465,0.174786,0.447786,0.279022,-0.085538,-0.105324,0.072137,-0.302815,0.210473,0.000116,-0.223095,0.186536,-0.148031,1.15963,-0.13045,-0.16925,1.042205,2.175391,0.271257,-0.491289,-0.05716,-0.52727,-0.147003,-0.039605,-0.078059,-0.237244,-0.034184,-0.030153,0.003588,0.09165,-0.003491,-0.213978,0.027138,0.002671
2,-0.008671,0.242309,-0.181483,-0.06925,-0.12837,0.052249,0.035554,-0.144643,-0.110708,0.757665,0.050378,-0.030485,0.223643,-0.108553,0.186326,-0.37062,0.131387,0.36635,-0.005582,0.715811,-0.038706,-0.378559,-0.196,0.550529,-0.206277,0.118257,-0.193036,0.056078,0.121565,-0.149577,-0.391046,-0.067719,0.074273,0.1721,-0.549816,0.406076,0.186668,0.421523,-0.195573,-0.09095,-0.197594,0.05589,-0.025826,0.013006,-0.402198,-0.804222,0.109795
3,0.469039,-0.345198,0.374768,0.021734,0.146858,0.03662,-0.002797,0.315925,0.643313,-0.062965,0.07127,0.216697,0.263575,0.189484,0.112533,1.524287,0.068209,-0.517188,0.449453,-0.137704,0.929797,-0.816918,0.314856,0.017457,0.456719,0.271463,-0.025306,-0.176671,-0.105529,0.085107,0.646422,0.334074,0.064433,0.244766,0.958139,0.010552,0.015538,-0.175751,1.635755,0.427453,-0.478527,0.065965,-0.132396,-0.026804,1.103971,0.113451,0.155681
4,-0.042857,0.204689,0.04268,0.04339,-0.022196,-0.097843,0.115514,0.088124,-0.099674,-0.06576,0.065136,0.04801,0.017769,-0.194537,0.284966,-0.377241,-0.260048,2.574204,-0.107092,0.018472,-0.230197,0.084329,0.124235,0.01361,-0.257559,-0.266281,0.58097,0.357173,0.506357,-0.057146,0.034113,-0.202738,0.14774,0.221805,0.166899,-0.000937,0.124406,-0.544448,0.098213,-0.182187,1.020853,0.171425,-0.100346,-0.046688,0.062152,0.084329,-0.182304


In [71]:
predictor_df.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+,spring,summer,fall,winter,partisan_score,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,cancer,NY_dest,NJ_dest,PA_dest,IL_dest,MD_dest,MA_dest,VA_dest,CA_dest,RI_dest,MI_dest,TX_dest,MO_dest,MN_dest,CT_dest,GA_dest,OH_dest,CO_dest,DE_dest,FL_dest,IN_dest,UT_dest,KY_dest,NE_dest,TN_dest,OR_dest,LA_dest,NC_dest,OK_dest,WA_dest,KS_dest,WI_dest,NH_dest,AZ_dest,SC_dest,AL_dest,IA_dest,NM_dest,WV_dest,NV_dest,AR_dest,ID_dest,ME_dest,MS_dest,VT_dest,SD_dest,ND_dest,MT_dest,WY_dest,NY_is_neighbor,NJ_is_neighbor,PA_is_neighbor,IL_is_neighbor,MD_is_neighbor,MA_is_neighbor,VA_is_neighbor,CA_is_neighbor,RI_is_neighbor,MI_is_neighbor,TX_is_neighbor,MO_is_neighbor,MN_is_neighbor,CT_is_neighbor,GA_is_neighbor,OH_is_neighbor,CO_is_neighbor,DE_is_neighbor,FL_is_neighbor,IN_is_neighbor,UT_is_neighbor,KY_is_neighbor,NE_is_neighbor,TN_is_neighbor,OR_is_neighbor,LA_is_neighbor,NC_is_neighbor,OK_is_neighbor,WA_is_neighbor,KS_is_neighbor,WI_is_neighbor,NH_is_neighbor,AZ_is_neighbor,SC_is_neighbor,AL_is_neighbor,IA_is_neighbor,NM_is_neighbor,WV_is_neighbor,NV_is_neighbor,AR_is_neighbor,ID_is_neighbor,ME_is_neighbor,MS_is_neighbor,VT_is_neighbor,SD_is_neighbor,ND_is_neighbor,MT_is_neighbor,WY_is_neighbor,overall_vacc_rate,child_vacc_rate,smoking_rate
NY,10711.4,40.705626,1,2.625045,0.22,0.09,0.13,0.26,0.14,0.16,43.6,66.5,48.1,23.3,-12,18.737,29.756,51.436,5.347,61.622797,14.719,33.425,24.046,24.675,200.165,0.911538,0.061864,0.004793,0.000216,0.000722,0.001985,0.000495,0.000199,0.001692,0.000196,0.000148,0.000147,0.000207,0.033303,0.000321,0.000261,0.000237,0.002078,0.00057,0.000153,0.000171,0.000122,0.000156,0.000232,9.8e-05,0.000264,0.000368,8.9e-05,0.000126,0.000117,0.000155,0.001047,0.000143,0.000251,0.000122,0.000119,8.8e-05,0.000168,0.000233,9.9e-05,8e-05,0.000857,0.000131,0.012476,8.1e-05,0.0,0.000105,0.000209,1,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,81.7,69.6,0.141
NJ,2789.6,40.143006,1,2.33307,0.23,0.08,0.11,0.27,0.14,0.16,50.6,72.2,54.8,33.0,-7,19.372,27.143,51.199,5.676,72.829584,11.59,36.24,22.909,25.672,200.319,0.028068,0.823014,0.019686,0.000128,0.000793,0.000447,0.000315,7.2e-05,0.000472,8.9e-05,9.9e-05,0.000112,7.9e-05,0.001903,0.000142,0.000171,0.000154,0.016296,0.000245,8.1e-05,6.7e-05,9.1e-05,6.5e-05,9.8e-05,5.2e-05,8e-05,0.000205,5.9e-05,6.7e-05,6e-05,6.9e-05,0.000368,0.000132,0.000183,5.1e-05,1.5e-05,2.5e-05,0.000137,8.1e-05,4.5e-05,4.1e-05,0.000325,3.4e-05,0.000566,0.0,0.0,0.000114,0.0,1,1,1,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,79.4,72.8,0.137
PA,1957.6,40.994593,1,1.621257,0.22,0.08,0.12,0.25,0.14,0.18,47.4,68.6,50.9,28.4,0,6.767,17.371,51.067,5.199,54.9607,11.11,28.104,23.155,29.331,230.835,0.003142,0.028442,0.877457,0.000199,0.012778,0.000454,0.000731,7.9e-05,0.000352,0.000216,0.000143,0.000188,0.000158,0.000587,0.000222,0.002901,0.000164,0.069599,0.000248,0.000159,0.000125,0.000229,5.8e-05,0.000202,5.3e-05,0.00018,0.000314,0.000203,5.8e-05,0.000106,0.000117,0.000366,0.000156,0.00025,0.000151,9.1e-05,0.0001,0.011997,0.000114,0.000207,6.5e-05,0.000299,0.000101,0.000294,8e-05,0.000137,0.000137,0.000232,1,1,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,82.5,69.7,0.187
IL,1761.9,39.739318,0,4.102417,0.24,0.09,0.12,0.26,0.13,0.15,51.6,73.4,53.8,28.3,-7,16.918,22.729,50.902,5.963,59.266824,12.397,32.078,21.199,27.34,225.471,0.000143,0.000187,0.000201,0.900515,0.00015,0.000173,0.000197,0.000141,0.000231,0.000784,0.000194,0.018036,0.000492,0.00021,0.000285,0.000342,0.000297,0.000204,0.000304,0.015667,0.000116,0.001515,0.000491,0.000354,9.2e-05,0.000159,0.000204,0.000159,0.000138,0.000288,0.011498,0.000137,0.000223,0.000114,0.000131,0.018256,7.1e-05,0.000108,0.000336,0.000304,0.00019,9.6e-05,0.000181,0.000142,0.000141,0.000346,0.000173,0.000213,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,83.0,60.1,0.155
MD,1737.6,38.806352,1,2.316047,0.23,0.08,0.12,0.27,0.14,0.15,52.8,73.3,56.1,34.7,-12,9.407,40.312,51.537,5.266,75.296062,11.089,37.129,21.545,28.94,225.486,0.000215,0.000521,0.005808,6.8e-05,0.845245,0.000221,0.024575,3.8e-05,0.000227,6.7e-05,7.6e-05,9.7e-05,7.7e-05,0.00014,0.000167,0.000103,6.1e-05,0.045597,0.000184,7.6e-05,4.1e-05,7.4e-05,4.3e-05,8.7e-05,3.7e-05,6.8e-05,0.000281,1.5e-05,7.6e-05,4.5e-05,5.4e-05,0.000109,2.5e-05,0.000143,6.5e-05,2.3e-05,1.6e-05,0.015275,5e-05,4.2e-05,0.0,0.00014,5.9e-05,8.3e-05,0.0,0.000113,1.1e-05,0.0,0,1,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,81.9,74.5,0.138


In [72]:
# predictors that are compared between states
comparison_predictors = ['density_metric', 'Latitude', 'is_coastal', 'airport_boardings', 'Children 0-18', 
                          'Adults 19-25', 'Adults 26-34', 'Adults 35-54', 'Adults 55-64', '65+', 
                         'partisan_score', 'hispanic', 'minority', 'female', 
                         'income', 'nodegree', 'bachelor', 'inactivity', 'obesity', 'cancer',
                         'overall_vacc_rate', 'child_vacc_rate', 'smoking_rate']
season_predictors = ['spring', 'fall', 'winter']

# predictors that are not compared between states
no_comparison_predictors = ['commuters', 'is_neighbor']

An important preprocessing step is to standardize each of the predictors (except for `is_coastal` and `is_neighbor` as these variables only take on the values 0 and 1. This ensures that the $\beta$ coefficients associated with each predictor are all on the same scale and thus are easily comparable to each other. Additionally, ensuring the the $\beta$ parameters lie in a similar range may help with the MCMC sampling. 

In [73]:
predictors_to_standardize = [x for x in comparison_predictors if x != 'is_coastal'] + season_predictors

# there are no observations during the summer so we don't need the summer weather predictor
predictor_df_standardized = predictor_df.drop(columns='summer')
for predictor in predictors_to_standardize:
    data = predictor_df_standardized[predictor]
    mean = np.mean(data)
    std = np.std(data)
    predictor_df_standardized[predictor] = [(x - mean) / std for x in data]

commute_columns = [column for column in predictor_df_standardized if column.endswith('_dest')]
commute_vals = predictor_df_standardized[commute_columns].to_numpy().flatten()
commute_mean = np.mean(commute_vals)
commute_std = np.std(commute_vals)

for commute_column in commute_columns:
    predictor_df_standardized[commute_column] = [(x - commute_mean) / commute_std 
                                                 for x in predictor_df_standardized[commute_column]]
    
comparison_preds_df = predictor_df_standardized[comparison_predictors + season_predictors]

In [74]:
predictor_df_standardized.head()

Unnamed: 0,density_metric,Latitude,is_coastal,airport_boardings,Children 0-18,Adults 19-25,Adults 26-34,Adults 35-54,Adults 55-64,65+,spring,fall,winter,partisan_score,hispanic,minority,female,unemployed,income,nodegree,bachelor,inactivity,obesity,cancer,NY_dest,NJ_dest,PA_dest,IL_dest,MD_dest,MA_dest,VA_dest,CA_dest,RI_dest,MI_dest,TX_dest,MO_dest,MN_dest,CT_dest,GA_dest,OH_dest,CO_dest,DE_dest,FL_dest,IN_dest,UT_dest,KY_dest,NE_dest,TN_dest,OR_dest,LA_dest,NC_dest,OK_dest,WA_dest,KS_dest,WI_dest,NH_dest,AZ_dest,SC_dest,AL_dest,IA_dest,NM_dest,WV_dest,NV_dest,AR_dest,ID_dest,ME_dest,MS_dest,VT_dest,SD_dest,ND_dest,MT_dest,WY_dest,NY_is_neighbor,NJ_is_neighbor,PA_is_neighbor,IL_is_neighbor,MD_is_neighbor,MA_is_neighbor,VA_is_neighbor,CA_is_neighbor,RI_is_neighbor,MI_is_neighbor,TX_is_neighbor,MO_is_neighbor,MN_is_neighbor,CT_is_neighbor,GA_is_neighbor,OH_is_neighbor,CO_is_neighbor,DE_is_neighbor,FL_is_neighbor,IN_is_neighbor,UT_is_neighbor,KY_is_neighbor,NE_is_neighbor,TN_is_neighbor,OR_is_neighbor,LA_is_neighbor,NC_is_neighbor,OK_is_neighbor,WA_is_neighbor,KS_is_neighbor,WI_is_neighbor,NH_is_neighbor,AZ_is_neighbor,SC_is_neighbor,AL_is_neighbor,IA_is_neighbor,NM_is_neighbor,WV_is_neighbor,NV_is_neighbor,AR_is_neighbor,ID_is_neighbor,ME_is_neighbor,MS_is_neighbor,VT_is_neighbor,SD_is_neighbor,ND_is_neighbor,MT_is_neighbor,WY_is_neighbor,overall_vacc_rate,child_vacc_rate,smoking_rate
NY,6.293909,0.28086,1,0.380592,-0.797652,0.391965,1.00752,0.711035,0.420201,-0.288009,-0.982757,-0.737546,-0.833857,-1.542134,0.716146,1.089934,1.176226,5.347,0.830535,0.778156,0.980192,0.350434,-1.211948,-0.534187,6.950922,0.329633,-0.115107,-0.150772,-0.146829,-0.136991,-0.1486,-0.15091,-0.139272,-0.150927,-0.151303,-0.151309,-0.150841,0.107066,-0.149959,-0.150421,-0.150608,-0.136264,-0.148012,-0.151267,-0.151124,-0.151508,-0.151242,-0.150648,-0.151697,-0.1504,-0.149592,-0.151764,-0.151474,-0.151543,-0.151246,-0.144295,-0.151343,-0.150498,-0.151509,-0.151529,-0.151769,-0.15115,-0.150643,-0.151683,-0.151831,-0.145776,-0.151433,-0.055232,-0.151826,-0.152458,-0.151639,-0.150826,1,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,-0.099426,0.987184,-0.945082
NJ,1.24809,0.150602,1,0.208113,-0.334817,-1.025139,-1.302404,1.723722,0.420201,-0.288009,-0.023899,0.235039,0.148711,-1.05764,0.77793,0.818879,0.823605,5.676,2.136029,-0.203394,1.553998,0.06693,-0.920232,-0.529721,0.066268,6.261077,0.000948,-0.151457,-0.146279,-0.148978,-0.149999,-0.151899,-0.148778,-0.151761,-0.151685,-0.151582,-0.151839,-0.137631,-0.151352,-0.151129,-0.151254,-0.025465,-0.150552,-0.151824,-0.151936,-0.151748,-0.15195,-0.151691,-0.152049,-0.151834,-0.150863,-0.151996,-0.151935,-0.151988,-0.151923,-0.149593,-0.151427,-0.151032,-0.152061,-0.152337,-0.152265,-0.151389,-0.151826,-0.152105,-0.152139,-0.149923,-0.152192,-0.148049,-0.152458,-0.152458,-0.151568,-0.152458,1,1,1,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,-0.703241,1.429462,-1.060531
PA,0.718145,0.347762,1,-0.212377,-0.797652,-1.025139,-0.147442,-0.301651,0.420201,0.840027,-0.462234,-0.331092,-0.317249,-0.379348,-0.448517,-0.194802,0.627209,5.199,0.054458,-0.353967,-0.104435,0.128269,0.150368,0.355264,-0.127974,0.06918,6.685338,-0.150909,-0.052883,-0.148923,-0.146758,-0.151846,-0.149718,-0.150777,-0.15134,-0.15099,-0.151226,-0.147881,-0.150724,-0.129847,-0.151179,0.389907,-0.150527,-0.151221,-0.151485,-0.150676,-0.152005,-0.150882,-0.152043,-0.151058,-0.150014,-0.150873,-0.152008,-0.15163,-0.151549,-0.149604,-0.151243,-0.15051,-0.151277,-0.151747,-0.151678,-0.058965,-0.151569,-0.150841,-0.151951,-0.150129,-0.151667,-0.150166,-0.151836,-0.151392,-0.151386,-0.150646,1,1,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0.110597,1.001005,0.382577
IL,0.593493,0.05714,0,1.253322,0.128018,0.391965,-0.147442,0.711035,-0.43847,-0.852027,0.113081,0.089877,-0.327379,-1.05764,0.53916,0.361001,0.381713,5.963,0.556084,0.049757,0.705621,-0.359447,-0.432186,0.199704,-0.151341,-0.150998,-0.150894,6.865024,-0.151287,-0.151107,-0.150923,-0.151362,-0.150659,-0.146348,-0.150947,-0.011908,-0.148624,-0.150818,-0.150233,-0.149795,-0.150146,-0.150869,-0.150091,-0.03037,-0.151555,-0.140648,-0.148635,-0.149702,-0.15174,-0.151222,-0.150869,-0.151215,-0.151386,-0.150213,-0.062859,-0.151392,-0.150718,-0.151566,-0.151437,-0.010197,-0.151901,-0.151617,-0.14984,-0.150092,-0.150976,-0.151713,-0.151045,-0.151349,-0.151357,-0.149758,-0.151111,-0.150799,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0.241861,-0.325827,-0.541012
MD,0.578015,-0.158861,1,0.198057,-0.334817,-1.025139,-0.147442,1.723722,0.420201,-0.852027,0.277457,0.423749,0.320913,-1.542134,-0.191649,2.184942,1.326499,5.266,2.423352,-0.360555,1.735211,-0.273174,0.035964,0.200139,-0.150781,-0.1484,-0.107198,-0.151931,6.434313,-0.150735,0.039047,-0.152158,-0.150688,-0.151936,-0.151867,-0.151699,-0.151854,-0.151363,-0.151153,-0.151658,-0.151982,0.202867,-0.151027,-0.151863,-0.152142,-0.151885,-0.152121,-0.151782,-0.152171,-0.15193,-0.15027,-0.152337,-0.151868,-0.152108,-0.15204,-0.151605,-0.152263,-0.151342,-0.151954,-0.152281,-0.152333,-0.033421,-0.152068,-0.152132,-0.152458,-0.151366,-0.151998,-0.15181,-0.152458,-0.151577,-0.152371,-0.152458,0,1,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,-0.04692,1.664422,-1.031669


In [75]:
# determine season from week of the year
def get_season(week):
    if week >= 52 or week < 13:
        return np.array([0, 0, 1])
    if 13 <= week < 26:
        return np.array([1, 0, 0])
    if 39 <= week < 52:
        return np.array([0, 1, 0])
    raise ValueError

In [76]:
predictor_num = len(comparison_predictors) + len(season_predictors) + len(no_comparison_predictors)
state_num = flu_percent_change_df.shape[1]
comparison_preds_num = len(comparison_predictors)
obs_num = len(flu_percent_change_df)

In [77]:
# indicate which season each observation fall into 
season_indictor_array = np.zeros((obs_num, state_num - 1, len(season_predictors)))
for i, week_num in enumerate(week_nums[1:]):
    season_indictor_array[i, :, :] = np.repeat(get_season(week_num)[np.newaxis, :], state_num - 1, axis=0)

`Y_target` is a 1D array that contains the percent change of each state for each week of the time series that is included in the analysis. This is the variable we want to predict for each observation. Because there are 47 states (Lower 48 except for Florida) and 217 observations for each state, this array has a length of $47*217=10199$. 

`Y_state_idx` is a 1D array of the same length as `Y_target` that represents the specific state associated with each `Y_target` value. Therefore, it takes on values between 0 and 46. This is necessary to pick out the variance parameter corresponding to the given state. 

`X` is a 3D design matrix. The first axis has a length equal to the total number of observations (10199). The second axis has a length of 46, which represents the $47-1=46$ other states from which we're trying to predict the final state. And the first axis has a length of 29, which contain the 28 predictors in addition to an intercept term, which is simply the value of 1. Therefore, this `X` matrix contains all the predictors for each state for each observation.

`X_flu` is a 2D array. The first axis has a length equal to the total number of observations (10199), while the second axis has a length of 46 and represents the percent change in wILI rate for all the $47-1=46$ other states from which we're trying to predict the final state. Therefore, this array is contains all the $\delta_jI(j \neq i)$ values for each observation.


In [78]:
Y_target = np.zeros(state_num * obs_num)
X = np.zeros((Y_target.shape[0], state_num - 1, predictor_num + 1))
Y_state_idx = np.zeros(Y_target.shape[0], dtype=int)
X_flu = np.zeros((Y_target.shape[0], state_num - 1))
X.shape

(10199, 46, 29)

In [79]:
for idx, state in enumerate(predictor_df_standardized.index):
    
    # response variable
    Y_target[obs_num * idx: obs_num * idx + obs_num] = flu_percent_change_df[state]
    
    # percent change of other states
    X_flu[obs_num * idx: obs_num * idx + obs_num, :] = flu_percent_change_df.drop(columns=state).to_numpy()
    
    # index of response state
    Y_state_idx[obs_num * idx: obs_num * idx + obs_num] = [idx] * obs_num
    
    state_comparison_preds = np.array(comparison_preds_df.loc[state])
    
    constant_design_matrix = np.zeros((X.shape[1], X.shape[2]))
    constant_design_matrix[:, 0] = np.ones(state_num - 1)
    
    # two predictors that aren't differences between states: neighboring state and number of commuters
    other_states_preds_df = predictor_df_standardized.drop(index=state)
    not_difference_matrix = other_states_preds_df[[state + '_is_neighbor', state + '_dest']].to_numpy()
    constant_design_matrix[:, 1: 1 + len(no_comparison_predictors)] = not_difference_matrix
    
    # the rest of the predictors are differences between two states
    other_states_comparison_preds_array = comparison_preds_df.drop(index=state).to_numpy()
    difference_matrix = abs((other_states_comparison_preds_array - state_comparison_preds) ** 1)
    constant_design_matrix[:, 1 + len(no_comparison_predictors):] = difference_matrix
    
    constant_design_matrix_3D = np.repeat(constant_design_matrix[np.newaxis, :, :], repeats=obs_num, axis=0)
    
    # pick out appropriate season and set the rest of the temperature predictors to zero
    constant_design_matrix_3D[:, :, -len(season_predictors):] *= season_indictor_array 
    
    X[obs_num * idx: obs_num * idx + obs_num, :, :] = constant_design_matrix_3D 

In [80]:
# randomly shuffle the observations 
np.random.seed(109)
indices = np.arange(len(Y_target))
np.random.shuffle(indices)
Y_target_random = Y_target[indices]
X_flu_random = X_flu[indices]
X_random = X[indices]
Y_state_idx_random = Y_state_idx[indices]

In [81]:
model = pm.Model()

with model:
    # define prior distribution for beta parameters 
    beta = pm.Normal('beta', mu=0, sigma=5, shape=predictor_num + 1)
    
    # define prior distribution for state-specific variance parameter
    sigma_sq = pm.InverseGamma('sigma_sq', alpha=2, beta=2, shape=state_num)
    
    # calculate the linear predictor for each state by multipling the 3D X design matrix with the vector
    # of beta parameters
    nu = pm.Deterministic('nu', pm.math.dot(X_random, beta))
    
    # calculate the lambda parameter for each state by exponentiating the linear predictor
    lambda_ = pm.Deterministic('lambda', pm.math.exp(nu))
    
    # sample an alpha random variable for each state from an exponential distribution with the 
    # corresponding rate parameter
    alpha = pm.Exponential('alpha', lam=1/lambda_, shape=(X_random.shape[0], state_num - 1))
    
    # calculate the mean of each response variable by taking the dot product between the alpha vector
    # and the vector of the percent change in the wILI rates of the other 46 states and dividing by the 
    # sum of the alpha weights
    mu = pm.Deterministic('mu', pm.math.sum(alpha * X_flu_random, axis=1) / pm.math.sum(alpha, axis=1))
    
    # define the response variable to be normally distributed about the mean and with a standard deviation that
    # is the square root of the variance parameter associated with the given state
    Y_obs = pm.Normal('Y_obs', mu=mu, sigma=pm.math.sqrt(sigma_sq[Y_state_idx_random]), observed=Y_target_random)

Just as we did in HW3, it's important to check whether the generative model is correctly specified. This can be done by hardcoding the  values for the parameters, generating response variables from these parameters and then trying to infer the parameters using MCMC.

In [82]:
# hardcode values for the beta and sigma parameters
np.random.seed(109)
betas = np.random.normal(0, 0.5, predictor_num + 1)
true_sigmas = abs(np.random.normal(0, 0.05, state_num))

In [83]:
# normalize the beta parameters so that they all have a similar effect on the linear predictor
normalize = ([1, 0.5, np.mean(list(predictor_df['NY_dest'])[1:])] + 
             [np.mean(predictor_df[pred]) for pred in comparison_predictors + season_predictors])

normalize_array = np.array(normalize)
true_betas = betas / (normalize_array)

In [84]:
# simulate Y_target values from these hardcoded parameters
np.random.seed(209)
betas_vec = true_betas.reshape(-1, 1)
Y_target_sim = np.zeros(X_random.shape[0])

for i in range(X_random.shape[0]):
    design_mat = X_random[i, :, :]
    nu = (design_mat @ betas_vec).flatten()
    lambda_ = np.exp(nu)
    alpha = np.random.exponential(scale=lambda_)
    mu = np.dot(alpha, X_flu_random[i, :]) / np.sum(alpha)
    Y_target_sim[i] = np.random.normal(mu, true_sigmas[Y_state_idx_random[i]])

In [85]:
# infer hardcode parameters
sim_model = pm.Model()

with sim_model:
    beta = pm.Normal('beta', mu=0, sigma=100, shape=predictor_num + 1)
    sigma_sq = pm.InverseGamma('sigma_sq', alpha=0.1, beta=0.1, shape=state_num)
    nu = pm.Deterministic('nu', pm.math.dot(X_random, beta))
    lambda_ = pm.Deterministic('lambda', pm.math.exp(nu))
    alpha = pm.Exponential('alpha', lam=1/lambda_, shape=(X_random.shape[0], state_num - 1))
    mu = pm.Deterministic('mu', pm.math.sum(alpha * X_flu_random, axis=1) / pm.math.sum(alpha, axis=1))
    
    Y_obs = pm.Normal('Y_obs', mu=mu, sigma=pm.math.sqrt(sigma_sq[Y_state_idx_random]), observed=Y_target_sim)
    
    trace_sim = pm.sample(500, tune=500, cores=4, init='adapt_diag')

In [915]:
with sim_model:
    pm.save_trace(trace_sim, 'sim.trace') 

In [86]:
with sim_model:
    trace_sim = pm.load_trace('sim.trace') 

In [None]:
sim_trace_df = pm.summary(trace_sim)

In [927]:
param_df = sim_trace_df.head(76).copy()

In [120]:
sim_trace_betas_df = pm.summary(trace_sim, var_names=[f'beta[{i}]' for i in range(predictor_num + 1)])
sim_trace_betas_df

KeyError: "['beta[0]' 'beta[1]' 'beta[2]' 'beta[3]' 'beta[4]' 'beta[5]' 'beta[6]'\n 'beta[7]' 'beta[8]' 'beta[9]' 'beta[10]' 'beta[11]' 'beta[12]' 'beta[13]'\n 'beta[14]' 'beta[15]' 'beta[16]' 'beta[17]' 'beta[18]' 'beta[19]'\n 'beta[20]' 'beta[21]' 'beta[22]' 'beta[23]' 'beta[24]' 'beta[25]'\n 'beta[26]' 'beta[27]' 'beta[28]'] var names are not present in dataset"

In [122]:
[f'beta[{i}]' for i in range(predictor_num + 1)]

['beta[0]',
 'beta[1]',
 'beta[2]',
 'beta[3]',
 'beta[4]',
 'beta[5]',
 'beta[6]',
 'beta[7]',
 'beta[8]',
 'beta[9]',
 'beta[10]',
 'beta[11]',
 'beta[12]',
 'beta[13]',
 'beta[14]',
 'beta[15]',
 'beta[16]',
 'beta[17]',
 'beta[18]',
 'beta[19]',
 'beta[20]',
 'beta[21]',
 'beta[22]',
 'beta[23]',
 'beta[24]',
 'beta[25]',
 'beta[26]',
 'beta[27]',
 'beta[28]']

In [None]:
sim_trace_betas_df.to_csv('sim_trace.csv', index_label=False)

In [938]:
param_df['true'] = ([format(beta, 'f') for beta in list(true_betas)] + 
                    [format(sigma ** 2, 'f') for sigma in list(true_sigmas)])

The sampling took a whopping 13 hours to sample just 500 times for each chain (with a 500 burn-in sample). However, the results confirm that the model was correctly specified, as the majority of the true $\beta$ values lie within the corresponding 94% credible interval. Therefore, performance inference for the actual data should yield reliable results.

However, carrying out inference on this synthetic data reveals several issues. First, many of the r_hat values are significantly larger than 1.0, which means that more than 500 samples are needed for the chains to converge to the posterior distribution. And second, the fact that the sampling took so long may indicate that the uninformative priors are too flat and make it difficult for the NUTS sampler to sample points from the true posterior distribution. To address these issues, the number of samples is increased from 500 to 1000 and a semi-informative prior is placed on the $\beta$ and $\sigma^2$ parameters ($N(0, 25)$ for each of the $\beta$s and $Inv-Gamma(2, 2)$ for each $\sigma^2$. 

In [951]:
param_df.iloc[:30]

Unnamed: 0,mean,sd,hpd_3%,hpd_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat,true
beta[0],0.666,0.091,0.457,0.82,0.034,0.025,7.0,7.0,8.0,16.0,1.43,-0.093867
beta[1],2.012,0.118,1.76,2.167,0.04,0.03,8.0,8.0,9.0,30.0,1.38,2.440637
beta[2],-255.124,53.736,-339.328,-158.694,22.04,17.085,6.0,6.0,6.0,12.0,1.91,-523.855855
beta[3],0.0,0.0,0.0,0.0,0.0,0.0,9.0,9.0,9.0,54.0,1.36,0.00037
beta[4],0.018,0.008,0.003,0.034,0.002,0.002,15.0,15.0,15.0,53.0,1.23,0.009313
beta[5],1.302,0.047,1.218,1.387,0.009,0.006,28.0,28.0,28.0,64.0,1.11,1.344916
beta[6],0.08,0.011,0.063,0.102,0.002,0.001,43.0,43.0,43.0,78.0,1.07,0.065428
beta[7],-1.056,1.388,-3.834,1.427,0.213,0.159,43.0,39.0,39.0,125.0,1.1,-0.156122
beta[8],-6.395,2.455,-10.656,-1.572,0.463,0.331,28.0,28.0,29.0,66.0,1.13,-1.018456
beta[9],-8.997,2.452,-13.586,-4.507,0.635,0.458,15.0,15.0,15.0,86.0,1.19,-3.457978


In [None]:
pm.traceplot(trace_sim, var_names=[f'beta[{i}]' for i in range(29)])

Unfortunately I ran into major issues running MCMC for the actual data. A burn-in of 500 and a sample of 1000 should have taken around 18 hours to finish. However, the first time I ran it, it was 80% complete after 14 hours and then my screen saver didn't turn off and the notebook shut down. I then tried running in a second time, and this time it again was 80% done after another 14 hours and then encountered a memory failure issue that terminated the notebook. Therefore, the third time I only asked for 500 samples, even though I knew this likely wouldn't be large enough for the sampler to converge. It took 14 hours to run but finished successfully. Even so, the model was so unwieldy that it took an additional three hours just to save the model and create a summary dataframe. 

In [17]:
with model:
    trace = pm.sample(500, tune=500, cores=4, init='adapt_diag')

Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, sigma_sq, beta]
Sampling 4 chains, 0 divergences: 100%|██████████| 4000/4000 [13:50:34<00:00, 12.46s/draws]  
The acceptance probability does not match the target. It is 0.6705019417755961, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.6168183334333946, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.6555152878868541, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.


In [18]:
with model:
    pm.save_trace(trace, 'real_trace') 

In [None]:
with model:
    trace = pm.load_trace('real_trace')

Unfortunately, most of the r_hat values of the $\beta$ coefficients are extremely inflated (the average r_hat value is just under 2.0). This means that the sampler hasn't come close to converging and means that it's pointless to try to interpret the sign or the magnitude of the coefficients. At this point, we ran out of time. However, if we had more time, we'd randomly select a subset of the observations and get more samples for these observations, as it's better to have trustworthy results on less data than it is to have unreliable results on the entire datset. 

In [19]:
trace_df = pm.summary(trace)

In [110]:
trace_betas_df = trace_df.loc[[f'beta[{i}]' for i in range(predictor_num + 1)]]

In [119]:
trace_betas_df

Unnamed: 0,mean,sd,hpd_3%,hpd_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
beta[0],2.409,0.157,2.065,2.649,0.056,0.042,8.0,8.0,8.0,30.0,1.47
beta[1],0.465,0.165,0.152,0.798,0.061,0.045,7.0,7.0,8.0,25.0,1.5
beta[2],-1.239,1.28,-4.177,0.681,0.571,0.435,5.0,5.0,6.0,11.0,1.74
beta[3],-0.027,0.039,-0.08,0.059,0.018,0.013,5.0,5.0,5.0,11.0,2.23
beta[4],-0.299,0.078,-0.447,-0.16,0.026,0.019,9.0,9.0,9.0,29.0,1.35
beta[5],-0.271,0.07,-0.392,-0.143,0.029,0.022,6.0,6.0,6.0,22.0,1.88
beta[6],-0.031,0.059,-0.153,0.078,0.027,0.021,5.0,5.0,5.0,12.0,2.57
beta[7],0.094,0.179,-0.259,0.397,0.085,0.065,4.0,4.0,5.0,13.0,2.7
beta[8],-0.083,0.078,-0.255,0.016,0.037,0.028,5.0,5.0,6.0,16.0,1.92
beta[9],0.008,0.058,-0.076,0.161,0.026,0.02,5.0,5.0,5.0,11.0,2.22


In [28]:
print('mean r_hat value:', 
      round(np.mean(trace_df.loc[[f'beta[{i}]' for i in range(predictor_num + 1)]]['r_hat']), 3))

mean r_hat value: 1.975


While the results of the inference were unreliable, it's still worthwhile to discuss what the next steps would have been in the analysis. First, we would check the sign and 94% credible interval of each of the $\beta$ coefficients to see if the majority of them make intuitive sense (i.e. negative coefficients for the difference predictors and positive coefficients for the non-difference predictors.) Next, we would evaluate the predictive power of the model and test the model assumptions at the same time. This could be done by first calculating the predictive power of a baseline naive model where the average of all the other states is used to predict for the percent change in the final state (in other words, where the weights associated with each state are the same). Because the likelihood function is modelled as a normal distribution, the optimal loss function is the mean squared error. The predictions would be performed for each state separately. 

After calculating the MSE for the naive model, we'd evaluate the Bayesian model as follows: first, we'd sample hundreds of times from the posterior distribution of each of the $\beta$ coefficients. Then, for each sample, we'd work our way up the model (i.e. sample an $\alpha$ for each state) and calculate the mean of the prediction. We'd then plot the residuals by subtracting the predicted percent change from the true percent change. Calculating the average of the square of the residuals would give us the MSE, which we'd compare to the baseline model to see if this model has any increased predictive power. Meanwhile, we'd plot these residuals to assess the assumption that the observations are normally distributed about the weighted average of the percent change of each of the other states. If this is the case, then we'd expect the distribution to being normally distributed around 0.0. Finally, we can calculate the variance of the residuals for each state and compare this sample variance to the posterior distribution of $\sigma^2$ for each state to check if they are consistent with each other. 

In [None]:
# code to assess predictions of naive model where all states are weighed equally
for idx, state in enumerate(flu_percent_change_df.columns):
    Y_target_state = Y_target[obs_num * idx: obs_num * idx + obs_num]
    X_flu_state = X_flu[obs_num * idx: obs_num * idx + obs_num, :]
    X_flu_state_mean = np.apply_along_axis(np.mean, 0, X_flu_state)
    residuals = Y_target_state - X_flu_state_mean
    plt.figure(figsize=(20, 10))
    plt.hist(residuals, bins=20, density=True)
    plt.xlabel('residuals')
    plt.ylabel('density')
    plt.main(f'{state} -- MSE: {np.mean(residuals ** 2)}, STD: {np.std(residuals)}')
    plt.show()

In [None]:
# code to assess predictions of Bayesian model
Y_pred_mean = np.zeros(X.shape[0])

for idx, state in enumerate(flu_percent_change_df.columns):
    Y_target_state = Y_target[obs_num * idx: obs_num * idx + obs_num]
    design_mat_state = X[obs_num * idx: obs_num * idx + obs_num, :, :]
    residuals = np.array(nsim * obs_num)
    for i in range(len(beta_sample)):
        nu = design_mat_state @ betas_vec
        lambda_ = np.exp(nu)
        alpha = np.random.exponential(scale=lambda_)
        mu = np.sum(alpha * X_flu[obs_num * idx: obs_num * idx + obs_num, :], axis=1) / np.sum(alpha, axis=1)
        residuals[obs_num * i: obs_num * i + obs_num] = Y_target_state - mu
    
    plt.figure(figsize=(20, 10))
    plt.hist(residuals, bins=20, density=True)
    plt.xlabel('residuals')
    plt.ylabel('density')
    plt.main(f'{state} -- MSE: {np.mean(residuals ** 2)}, STD: {np.std(residuals)}')
    plt.show()