In [78]:
import pandas as pd
import jax.numpy as jnp
import evofr as ef
from numpyro.infer.autoguide import AutoDelta, AutoMultivariateNormal
from pathlib import Path
import os
from jax import vmap
import jax.numpy as jnp
from jax.nn import softmax

In [79]:
cwd = os.getcwd()
cwd

'/Users/eslamabousamra/rt_frq_dyn_datasets/ncov-forecasting-fit'

In [9]:
os.chdir('/Users/eslamabousamra/rt_frq_dyn_datasets/ncov-forecasting-fit')

In [80]:
#Getting sequences data from commit dates
dates = ['2022-03-21','2022-05-17','2022-06-22','2022-03-03','2022-05-28','2022-06-14',
         '2022-02-04','2022-02-08']

#'2022-06-22','2022-03-03','2022-05-28','2022-06-14'
#'2022-02-04','2022-02-08','2022-02-18','2022-02-23',
#         '2022-02-28','2022-03-03','2022-03-08','2022-03-15',
#         '2022-03-21','2022-03-25','2022-04-07','2022-04-14','2022-04-27'
#         ,'2022-05-06','2022-05-17','2022-05-20','2022-05-28','2022-06-09'
#         ,'2022-06-14',

#specifying location to run model on
locations = ["USA", "United Kingdom"]



In [81]:
# Defining model
seed_L = 14
forecast_L = 14
ps = [0.95, 0.8, 0.5]
# Get delays
v_names = ['Delta', 
           'Omicron 21L', 
           'Omicron 21K', 
           'Omicron 22A', 
           'Omicron 22B', 
           'Omicron 22C', 
           'other']

gen = ef.pad_delays(
    [ef.discretise_gamma(mn=4.4, std=1.2), # Delta
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21L
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21K #3.1 std 1.2 
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22A
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22B
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22C
     ef.discretise_gamma(mn=4.4, std=1.2)] # Other
    )

delays = ef.pad_delays([ef.discretise_lognorm(mn=3.1, std=1.0)])  

basis_fn = ef.Spline(order = 4, k = 10)
#dict for models
model_type = dict()

model_type['FGA'] = ef.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                RLik = ef.FixedGA(0.1), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

model_type['GARW'] = ef.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                RLik = ef.GARW(0.1,0.1), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)
model_type['GARW-N'] = ef.RenewalModel(gen, delays, seed_L, forecast_L, k=10,
                RLik = ef.GARW(0.1,0.01, prior_family='Normal'), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)
model_type['MLR'] = ef.MultinomialLogisticRegression(tau=4.2)

# defining inference method
svi_fullrank = ef.InferFullRank(iters = 50_000, lr = 4e-3, num_samples=1000)

In [82]:
#export posterior frequencies forecast and no forecast
def save_freq(samples, variant_data, ps, forecast_date, name, filepath):
    #only need last 14 days
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    #Get the freq dates nowcast period
    nowcast_dates = variant_data.dates[-14:]
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    freq_merged = pd.concat([freq_now, freq_fr])
    #freq_now = freq_now.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)
    
    #rename intervals
    freq_merged = freq_merged.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)

    freq_merged.to_csv(f'{filepath}/freq_full_{forecast_date}.csv', index = False)
    #freq_now.to_csv(f'{filepath}/freq_nowcast_{forecast_date}.csv', index = False)
    #freq_fr.to_csv(f'{filepath}/freq_forecast_{forecast_date}.csv', index = False)
    

In [83]:
def forecast_frequencies(samples, mlr, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]
    X = mlr.make_ols_feature(start=last_T, stop=last_T + forecast_L)
    
    # Posterior beta
    beta = jnp.array(samples["beta"])
    
    # Matrix multiplication by sample
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)



In [84]:
models = ['MLR','GARW-N','FGA','GARW']


#function to input models 
def model_run(model):
    for location in locations:
        for date in dates:
            #Output file path
            filepath = f"./plot-est/cast_estimates_full_{model}/{location}"    
            if not os.path.exists(filepath):
                os.makedirs(filepath)

            #read sequences
            raw_seq = pd.read_csv(f"data/{date}/seq-counts_{date}.tsv", sep="\t")
            raw_cases = pd.read_csv(f"data/{date}/case-counts_{date}.tsv", sep="\t")
            raw_cases = raw_cases[raw_cases.location == location]
            raw_seq = raw_seq[raw_seq.location == location]
            if len(raw_cases)==0:
                continue

            #defining variant data (freq and seq)
            if model == 'MLR':
                variant_data = ef.VariantFrequencies(raw_seq)
                posterior = svi_fullrank.fit(model_type[model], variant_data)
                posterior.samples["freq_forecast"] = forecast_frequencies(posterior.samples, model_type[model], 14)
            else:
                variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq)
                posterior = svi_fullrank.fit(model_type[model], variant_data)
            #running the model
            save_freq(posterior.samples, variant_data, ps, date, location, filepath)
            
    return None
 
for model in models:
    model_run(model)
        
    