In [2]:
import pandas as pd
import jax.numpy as jnp
import evofr as ef
from numpyro.infer.autoguide import AutoDelta, AutoMultivariateNormal
from pathlib import Path
import os
from jax import vmap
import jax.numpy as jnp
from jax.nn import softmax

In [3]:
cwd = os.getcwd()
cwd


'/Users/eabousam/freq_dynamics/ncov-forecasting-fit/notebooks'

In [4]:
os.chdir('/Users/eabousam/freq_dynamics/ncov-forecasting-fit')

'/Users/eslamabousamra/rt_frq_dyn_datasets/ncov-forecasting-fit'

In [5]:
#Getting sequences data from commit dates
dates = ['2022-04-15','2022-04-22','2022-04-29','2022-05-06',
         '2022-05-13','2022-05-20','2022-05-27','2022-06-03',
         '2022-06-10','2022-06-17','2022-06-24','2022-06-30']

#specifying location to run model on
locations = ["USA", "United Kingdom", "Brazil","Australia","South Africa", "Japan"]



In [6]:
# Defining model
seed_L = 14
forecast_L = 14
ps = [0.95, 0.8, 0.5]
# Get delays
v_names = ['Delta', 
           'Omicron 21L', 
           'Omicron 21K', 
           'Omicron 22A', 
           'Omicron 22B', 
           'Omicron 22C', 
           'other']

gen = ef.pad_delays(
    [ef.discretise_gamma(mn=4.4, std=1.2), # Delta
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21L
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21K #3.1 std 1.2 
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22A
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22B
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22C
     ef.discretise_gamma(mn=4.4, std=1.2)] # Other
    )

delays = ef.pad_delays([ef.discretise_lognorm(mn=3.1, std=1.0)])  

basis_fn = ef.Spline(order = 4, k = 10)
#dict for models
model_type = dict()


#Fixed Growth Advantage model for variants
model_type['FGA'] = ef.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                RLik = ef.FixedGA(0.1), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

#model_type['GARW'] = ef.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
#                RLik = ef.GARW(0.1,0.1), #Likelihood on effective reproduction number (GARW depend on R and gen time)
#                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
#                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
#                v_names = v_names,
#                basis_fn = basis_fn)

#Varying Growth Advantage Random Walk Model
model_type['GARW'] = ef.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                RLik = ef.GARW(0.1,0.01, prior_family='Normal'), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

#Multinomial Logistic regression model
model_type['MLR'] = ef.MultinomialLogisticRegression(tau=4.2)


#Piantham model
model_type['Piantham'] = ef.models.PianthamModel(gen = ef.discretise_gamma(mn=3.1, std=1.2))



# defining inference method
svi_fullrank = ef.InferFullRank(iters = 50_000, lr = 4e-3, num_samples=1000)

In [21]:
#export posterior frequencies forecast and no forecast
def save_freq(samples, variant_data, ps, forecast_date, name, filepath):
    #only need last 14 days
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    #Get the freq dates nowcast period
    nowcast_dates = variant_data.dates[-14:]
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    freq_merged = pd.concat([freq_now, freq_fr])
    #freq_now = freq_now.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)
    
    #rename intervals
    freq_merged = freq_merged.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)

    freq_merged.to_csv(f'{filepath}/freq_full_{forecast_date}.csv', index = False)
    #freq_now.to_csv(f'{filepath}/freq_nowcast_{forecast_date}.csv', index = False)
    #freq_fr.to_csv(f'{filepath}/freq_forecast_{forecast_date}.csv', index = False)
  

In [22]:
  #export posterior frequencies forecast and no forecast
def save_mlr_freq(samples, variant_data, ps, forecast_date, name, filepath):

    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))

    #rename intervals
    freq_fr = freq_fr.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)

    freq_fr.to_csv(f'{filepath}/freq_full_{forecast_date}.csv', index = False)

In [23]:
def forecast_frequencies(samples, mlr, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]
    
    X = mlr.make_ols_feature(start=last_T, stop=last_T + forecast_L)
    # Posterior beta
    beta = jnp.array(samples["beta"])
    
    # Matrix multiplication by sample
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)

In [24]:
def get_gr_adv(samples, variant_data, ps, name, date, site, filepath):

    growth_adv = pd.DataFrame(ef.posterior.get_site_by_variant(samples, variant_data, ps, name=name, site=site, forecast=False))

    
    growth_adv.to_csv(f'{filepath}/full_growth_advantages_{date}.csv',index = False, header = True)
    
    return None


In [25]:
def get_mlr_gr_adv(samples, variant_data, ps, name,date, filepath):
    

    growth_adv = pd.DataFrame(ef.posterior.get_growth_advantage(samples, variant_data, ps, name=name, rel_to="other"))
    growth_adv.to_csv(f'{filepath}/full_growth_advantages_{date}.csv',index = False, header=True)

    return None

In [1]:
models = ['Piantham','GARW','MLR', 'FGA']
#'GARW',,'MLR'

#function to input models 
def model_run(model):
    for location in locations:
        for date in dates:
            #Output file path
            filepath = f"./plot-est/cast_estimates_full_{model}/{location}"    
            if not os.path.exists(filepath):
                os.makedirs(filepath)
            #read sequences
            raw_seq = pd.read_csv(f"data/time_stamped/{date}/seq_counts_{date}.tsv", sep="\t")
            raw_cases = pd.read_csv(f"data/time_stamped/{date}/case_counts_{date}.tsv", sep="\t")
            raw_cases = raw_cases[raw_cases.location == location]
            raw_seq = raw_seq[raw_seq.location == location]
            if len(raw_cases)==0:
                continue

            #defining variant data (freq and seq)
            #if model in ('Piantham','MLR'):
            if model == 'MLR':
                variant_data = ef.VariantFrequencies(raw_seq)
                posterior = svi_fullrank.fit(model_type[model], variant_data)
                posterior.samples["freq_forecast"] = forecast_frequencies(posterior.samples, model_type[model], 28)
                save_mlr_freq(posterior.samples, variant_data, ps, date, location, filepath)
                samples = posterior.samples
                #get_mlr_gr_adv(samples, variant_data, ps, location, date, filepath)
                
            if model == 'Piantham':
                variant_data = ef.VariantFrequencies(raw_seq)
                posterior = svi_fullrank.fit(model_type[model], variant_data)
                #samples = posterior.samples
                save_mlr_freq(posterior.samples, variant_data, ps, date, location, filepath)
                #samples = posterior.samples
                #get_mlr_gr_adv(samples, variant_data, ps, location, date, filepath)
            if model == 'FGA':
                variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq)
                posterior = svi_fullrank.fit(model_type[model], variant_data)
                #saving posterior frequencies
                samples = posterior.samples
                save_freq(posterior.samples, variant_data, ps, date, location, filepath)
                #saving growth advantages
                #get_mlr_gr_adv(samples, variant_data, ps, location, date, filepath)      
            else:
                variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq)
                posterior = svi_fullrank.fit(model_type[model], variant_data)
                #saving posterior frequencies
                samples = posterior.samples
                save_freq(posterior.samples, variant_data, ps, date, location, filepath)
                #saving growth advantages
                #get_gr_adv(samples, variant_data, ps, location, date, "ga", filepath)

                
                

                
            
    return None
 
for model in models:
    model_run(model)

NameError: name 'locations' is not defined