In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import jax.numpy as jnp

In [None]:
# Replacing state initials with names
# Thanks: https://gist.github.com/JeffPaine/3083347
abv_to_full = {
    'AK': 'Alaska',
    'AL': 'Alabama',
    'AR': 'Arkansas',
    'AZ': 'Arizona',
    'CA': 'California',
    'CO': 'Colorado',
    'CT': 'Connecticut',
    'DC': 'District of Columbia',
    'DE': 'Delaware',
    'FL': 'Florida',
    'GA': 'Georgia',
    'HI': 'Hawaii',
    'IA': 'Iowa',
    'ID': 'Idaho',
    'IL': 'Illinois',
    'IN': 'Indiana',
    'KS': 'Kansas',
    'KY': 'Kentucky',
    'LA': 'Louisiana',
    'MA': 'Massachusetts',
    'MD': 'Maryland',
    'ME': 'Maine',
    'MI': 'Michigan',
    'MN': 'Minnesota',
    'MO': 'Missouri',
    'MS': 'Mississippi',
    'MT': 'Montana',
    'NC': 'North Carolina',
    'ND': 'North Dakota',
    'NE': 'Nebraska',
    'NH': 'New Hampshire',
    'NJ': 'New Jersey',
    'NM': 'New Mexico',
    'NV': 'Nevada',
    'NY': 'New York',
    'OH': 'Ohio',
    'OK': 'Oklahoma',
    'OR': 'Oregon',
    'PA': 'Pennsylvania',
    'RI': 'Rhode Island',
    'SC': 'South Carolina',
    'SD': 'South Dakota',
    'TN': 'Tennessee',
    'TX': 'Texas',
    'UT': 'Utah',
    'VA': 'Virginia',
    'VT': 'Vermont',
    'WA': 'Washington',
    'WI': 'Wisconsin',
    'WV': 'West Virginia',
    'WY': 'Wyoming'
}




In [None]:
#https://data.cdc.gov/Vaccinations/COVID-19-Vaccinations-in-the-United-States-Jurisdi/unsk-b7fc
raw_vaccination = pd.read_csv("../../../Downloads/CDC-Vaccination.csv")

In [None]:
raw_vaccination

In [None]:
keep_cols = ["Date", "Location", "Series_Complete_Pop_Pct"]
vaccination = raw_vaccination[keep_cols]
vaccination = vaccination.rename(columns={"Date": "date", "Location":"location"})
vaccination["date"]=pd.to_datetime(vaccination.date)
vaccination = vaccination.sort_values("date")
vaccination = vaccination.replace({"location":abv_to_full})

In [None]:
vaccination

In [None]:
#https://data.cdc.gov/Case-Surveillance/United-States-COVID-19-Cases-and-Deaths-by-State-o/9mfq-cb36
raw_cumcases = pd.read_csv("../../../Downloads/CDC-Cases-Deaths.csv")

In [None]:
state_pop = pd.read_csv("../../../Downloads/state-population-sizes.tsv", sep="\t", header=None)
state_pop = state_pop.rename(columns={0: "location", 1: "pop_size"})

In [None]:
keep_cols = ["submission_date", "state", "tot_cases"]
cumcases = raw_cumcases[keep_cols]
cumcases = cumcases.rename(columns={"submission_date": "date", "state":"location"})
cumcases["date"]=pd.to_datetime(cumcases.date)
cumcases = cumcases.sort_values("date")
cumcases = cumcases.replace({"location":abv_to_full})
cumcases = pd.merge(cumcases, state_pop, on="location")
cumcases["frac_cases"] = cumcases["tot_cases"] / cumcases["pop_size"]

In [None]:
df = pd.merge(vaccination, cumcases, on=['date',"location"])
df['Series_Complete_Pop_Pct'] = df['Series_Complete_Pop_Pct'].fillna(0) / 100

# Really want to normalize by population

# Normalizing cases in each region
#grouper = df.groupby('location')['tot_cases']                                                                             
#maxes = grouper.transform('max')                                                                                   
#mins = grouper.transform('min') 

#df = df.assign(rel_cases =(df.tot_cases - mins)/(maxes - mins))                                                       

In [None]:
plt.scatter(df.Series_Complete_Pop_Pct, df.frac_cases)

In [None]:
df # Need median R column

In [None]:
R_df = pd.read_csv("../estimates/variants-us/variants-us_Rt-combined-GARW.tsv", sep="\t")[["date", "location", "variant", "median_R"]]
R_df["date"]=pd.to_datetime(R_df.date)

freq_df = pd.read_csv("../estimates/variants-us/variants-us_freq-combined-GARW.tsv", sep="\t")[["date", "location", "variant", "median_freq"]]
freq_df["date"]=pd.to_datetime(freq_df.date)

In [None]:
merged_df = pd.merge(pd.merge(freq_df, R_df), df,  on=['date',"location"])
merged_df = merged_df[merged_df.median_R < 10]
merged_df = merged_df[merged_df.median_freq > 0.01]
merged_df = merged_df[merged_df["Series_Complete_Pop_Pct"] > 0]

In [None]:
merged_df.to_csv("regression_analysis_df.tsv", sep="\t", index=False)

In [None]:
grouped = merged_df[merged_df["Series_Complete_Pop_Pct"] > 0].groupby('variant')

ncols=3
nrows = int(jnp.ceil(grouped.ngroups/ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12,4), sharey=True)
for (key, ax) in zip(grouped.groups.keys(), axes.flatten()):
    grouped.get_group(key).plot(ax=ax, kind="scatter",x="Series_Complete_Pop_Pct", y="median_R")
ax.legend()
plt.show()

In [None]:
grouped = merged_df[merged_df["Series_Complete_Pop_Pct"] > 0].groupby('variant')

ncols=3
nrows = int(jnp.ceil(grouped.ngroups/ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12,4), sharey=True)
for (key, ax) in zip(grouped.groups.keys(), axes.flatten()):
    grouped.get_group(key).plot(ax=ax, kind="scatter",x="frac_cases", y="median_R")
ax.legend()
plt.show()

In [None]:
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive, SVI, Trace_ELBO
import jax.numpy as jnp
from jax import random, vmap
from jax.nn import normalize

In [None]:
# Create data
data = dict()

# Convert state to dummy
state, state_levels = pd.factorize(merged_df["location"])
data["state"] = state
data["N_state"] = len(state_levels)

variant, variant_levels = pd.factorize(merged_df["variant"])
data["variant"] = variant
data["N_variant"] = len(variant_levels)

# Get other columns
data["vaccination"] = jnp.array(merged_df["Series_Complete_Pop_Pct"]) 
data["cases"] = jnp.array(merged_df["frac_cases"])
data["R"] = jnp.array(merged_df["median_R"])

In [None]:
def RegressionModel(vaccination, cases, state, variant, N_state, N_variant, R):
    # State effects will be drawn from shared variant-specific prior
    with numpyro.plate("pool_by_variant", N_variant):
        mu_vaccination = numpyro.sample("mu_vaccination", dist.Normal(0.0, 1.0))
        sigma_vaccination = numpyro.sample("sigma_vaccination", dist.HalfNormal(1.0))
    
        mu_cases = numpyro.sample("mu_cases", dist.Normal(0.0, 1.0))
        sigma_cases = numpyro.sample("sigma_cases", dist.HalfNormal(1.0))
        
        #mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0.0, 1.0))
        #sigma_cases = numpyro.sample("sigma_alpha", dist.HalfNormal(1.0))        
        
        # Draw state effects for each variant
        with numpyro.plate("draw_by_state", N_state):
            beta_vaccination = numpyro.sample("beta_vaccination", dist.Normal(mu_vaccination, sigma_vaccination))
            beta_cases = numpyro.sample("beta_cases", dist.Normal(mu_cases, sigma_cases))
    
            # Unpooled state and variant intercept
            alpha = numpyro.sample("alpha", dist.Normal(0.0,1.0)) 

    # Compute expectation by variant and state
    EY = alpha[(state,variant)] + (beta_vaccination[(state,variant)]*vaccination) + (beta_cases[(state,variant)] * cases)   
    sigma_Y = numpyro.sample("sigma_Y", dist.HalfNormal(0.1))
    Y = numpyro.sample("observed_R", dist.LogNormal(EY, sigma_Y), obs = R)

In [None]:
# Run model
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(RegressionModel)
num_warmup = 500
num_samples = 500
#mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
#mcmc.run(rng_key_, **data)
#mcmc.print_summary()
#samples_1 = mcmc.get_samples()

In [None]:
from numpyro.infer.autoguide import AutoMultivariateNormal


# Run SVI
optimizer = numpyro.optim.Adam(step_size=0.0005)
guide = AutoMultivariateNormal(RegressionModel)
svi = SVI(RegressionModel, guide, optimizer, loss=Trace_ELBO())

svi_result = svi.run(random.PRNGKey(0), 10_000,**data)
params = svi_result.params

In [None]:
# get posterior samples

predictive = Predictive(guide, params=params, num_samples=1000)

samples = predictive(random.PRNGKey(1), data)

In [None]:
samples["beta_vaccination"].shape

In [None]:
jnp.median(samples["mu_vaccination"], axis=0)
#jnp.quantile(samples["beta_vaccination"], jnp.array([0.25, 0.75]), axis=0)[:,:,8]

In [None]:
# By variant plot vaccination effect

# x-axis is variant, y-axis is magntude

def plot_effect(beta, state_levels, variant_levels, colors, title=None):
    fig = plt.figure(figsize=(14, 10))
    
    n_state = len(state_levels)
    
    # Sort level of confidence  
    _lw = [1.5, 2.5, 3.5]
    
    # Top panel
    ax = fig.add_subplot(1,1,1)
    ax.axhline(y=0, lw=2,linestyle='dashed', color="k")
    
    beta_med = jnp.median(beta, axis=0)
    beta_q = jnp.quantile(beta, jnp.array([0.25, 0.75]), axis=0)
    print(beta_q.shape)
    for v, var in enumerate(variant_levels):
        ax.scatter([v] * n_state, beta_med[:,v], 
                    color=colors[v],
                    edgecolors="k",
                    s = 45,
                    zorder = 3)
        ax.fill_between([v] * n_state, beta_q[0, :, v], beta_q[1, :, v],
                    #fmt = 'none',
                    color = colors[v])
                    #elinewidth = _lw[0])
        
    # Adding variant labels
    ax.set_xticks(jnp.arange(0, len(variant_levels), 1))
    ax.set_xticklabels([v.replace("_", " ") for v in variant_levels],  rotation=0)
    
    ax.set_ylabel("Effect size")
    
    if title is not None:
        ax.set_title(title)

In [None]:
v_colors =["#2e5eaa", "#5adbff",  "#56e39f","#b4c5e4", "#f03a47",  "#f5bb00", "#9e4244","#9932CC", "#808080"] 
v_names = ['Alpha', 'Beta', 'Delta', 'Epsilon', 'Gamma', 'Iota', 'Mu', 'Omicron', 'other']
color_map = {v : c for c, v in zip(v_colors, v_names)}
colors = [color_map[v] for v in variant_levels]


In [None]:
plot_effect(samples["beta_vaccination"], state_levels, variant_levels, colors, title="Vaccination")

In [None]:
plot_effect(samples["beta_cases"], state_levels, variant_levels, colors, title="Fraction cummulative cases")

In [None]:
merged_df