In [1]:
import pandas as pd
import jax.numpy as jnp
import evofr as ef
from numpyro.infer.autoguide import AutoDelta, AutoMultivariateNormal
from pathlib import Path
import os
from jax import vmap
import jax.numpy as jnp
from jax.nn import softmax
import matplotlib.pyplot as plt
import numpy as np
from pandas.tseries.offsets import Day, BDay



In [4]:
os.chdir('/Users/eabousam/freq_dynamics/ncov-forecasting-fit') 
cwd = os.getcwd()

'/Users/eabousam/freq_dynamics/ncov-forecasting-fit'

## I. Specification of analysis period (Observation dates) and geographical settings

In [5]:
#Getting sequences data from commit dates
dates = ['2022-01-01', '2022-01-15','2022-02-01','2022-02-15','2022-03-01','2022-03-15',
         '2022-04-01','2022-04-15','2022-05-01','2022-05-15','2022-06-01','2022-06-15',
         '2022-07-01','2022-07-15','2022-08-01','2022-08-15','2022-09-01','2022-09-15',
         '2022-10-01','2022-10-15','2022-11-01','2022-11-15','2022-12-01','2022-12-15']

#specifying location to run model on
locations = ["USA", "United Kingdom", "Brazil","Australia","South Africa", "Japan"]

## II. Assigning farameters and forecasting period

In [6]:
#Parameters and forecasting period
seed_L = 14
forecast_L = 14
forecast_new = 30
ps = [0.95, 0.8, 0.5]
# Get delays
v_names = ['Delta', 
           'Omicron 21L', 
           'Omicron 21K', 
           'Omicron 22A', 
           'Omicron 22B', 
           'Omicron 22C',
           'Omicron 22D',
           'Omicron 22E',
           'Omicron 23A',
           'other']

gen = ef.pad_delays(
    [ef.discretise_gamma(mn=4.4, std=1.2), # Delta
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21L
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21K #3.1 std 1.2 
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22A
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22B
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22C
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22D
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22E
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 23A
     ef.discretise_gamma(mn=4.4, std=1.2)] # Other
    )

delays = ef.pad_delays([ef.discretise_lognorm(mn=3.1, std=1.0)])  

basis_fn = ef.Spline(order = 4, k = 10)

## III. FGA, GARW, MLR, and Piantham models definition

In [8]:
# Defining models
#dict for models
model_type = dict()


#Fixed Growth Advantage model for variants
model_type['FGA'] = ef.RenewalModel(gen, delays, seed_L, forecast_new,
                RLik = ef.FixedGA(0.1), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

#Varying Growth Advantage Random Walk Model
model_type['GARW'] = ef.RenewalModel(gen, delays, seed_L, forecast_new,
                RLik = ef.GARW(0.1,0.01, prior_family='Normal'), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

#Multinomial Logistic regression model
model_type['MLR'] = ef.MultinomialLogisticRegression(tau=4.2)

#Piantham model
model_type['Piantham'] = ef.models.PianthamModel(gen = ef.discretise_gamma(mn=3.1, std=1.2), forecast_L = forecast_new+14)

# defining inference method
svi = ef.InferMAP(iters = 50_000, lr = 4e-3)

## IV. Naive Model specification

In [16]:
def naive_forecast(seq_count_date, pivot, period=60):
    
    #defining forecast and nowcast dates from pivot date
    forecast_dates = pd.to_datetime(pd.unique(pd.date_range(start = pivot, periods=period, freq ='D') + Day(1))).astype(str)
    #print(forecast_dates)
    nowcast_dates = pd.to_datetime(pd.unique(pd.date_range(end = pivot, periods=period, freq ='D'))).astype(str)
    #print(forecast_dates)
    #defining prediction period for nowcasting and forecasting
    

    #pred_dates = pd.concat((nowcast_dates, forecast_dates))
    pred_dates = nowcast_dates.union(forecast_dates)
    #print(pred_dates)
  
    #computing the frequency of variants for each location
    seq_count_date['total_seq'] = seq_count_date.groupby(['date', 'location'])['sequences'].transform('sum')
    seq_count_date['freq'] = seq_count_date['sequences']/seq_count_date['total_seq']
    
    

    
    #Adding pred_dates to date column for each location and variant
    sc_s = []
    for d in pred_dates:
        #defining recent_dates (7) from pivot date for each date in pred_dates
        #recent_dates = pd.to_datetime(pd.unique(pd.date_range(end = d, periods=7, freq ='D') - Day(7))).astype(str) #that have available seq data
        recent_dates = pd.Series(pd.to_datetime(seq_count_date[seq_count_date.date < d].date).unique()).nlargest(n=7).astype(str)
        #print(d,recent_dates)
        #print(d, recent_dates_n)
        #Computing the mean frequency for recent dates
        seq_count_mean = seq_count_date[seq_count_date.date.isin(recent_dates)].groupby(["variant", "location"])["freq"].mean().reset_index()
    
        sc_ = seq_count_mean.copy()
        #adding dates column
        sc_["date"] = d
        sc_s.append(sc_)

    sc = pd.concat(sc_s).sort_values(by=["location", "variant", "date"])
    #adding nowcast and forecast columns
    sc['median_freq_nowcast'] = sc['freq']
    sc['median_freq_forecast'] = sc['freq']
    #matching dates for nowcast and forecast
    sc.loc[sc.date.isin(forecast_dates),'median_freq_nowcast'] = np.nan
    sc.loc[sc.date.isin(nowcast_dates),'median_freq_forecast'] = np.nan

    return sc.reset_index(drop=True)


for pivot in dates:
    seq_count_date = pd.read_csv(f"data/time_stamped/{pivot}/seq_counts_{pivot}.tsv", sep="\t")
    naive_pred = naive_forecast(seq_count_date, pivot)
    print(naive_pred)
    #create files for estimates for each country and pivot date
    for location in locations:
        filepath = f'plot-est2/cast_estimates_full_dummy/{location}/'
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        naive_pred[naive_pred.location == location].to_csv(filepath+f"freq_full_{pivot}.csv", index = False)
        
        

     variant        location      freq        date  median_freq_nowcast  \
0      Delta       Australia  1.000000  2021-11-03                  1.0   
1      Delta       Australia  1.000000  2021-11-04                  1.0   
2      Delta       Australia  1.000000  2021-11-05                  1.0   
3      Delta       Australia  1.000000  2021-11-06                  1.0   
4      Delta       Australia  1.000000  2021-11-07                  1.0   
...      ...             ...       ...         ...                  ...   
1857   other  United Kingdom  0.001226  2022-02-26                  NaN   
1858   other  United Kingdom  0.001226  2022-02-27                  NaN   
1859   other  United Kingdom  0.001226  2022-02-28                  NaN   
1860   other  United Kingdom  0.001226  2022-03-01                  NaN   
1861   other  United Kingdom  0.001226  2022-03-02                  NaN   

      median_freq_forecast  
0                      NaN  
1                      NaN  
2           

## V. Helper functions to export variant frequencies and growth advantages

In [10]:
#export posterior frequencies forecast and no forecast
def save_freq(samples, variant_data, ps, forecast_date, name, filepath):
    #only need last 14 days
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    #Get the freq dates nowcast period
    #nowcast_dates = variant_data.dates[-14:]
    nowcast_dates = variant_data.dates[-60:]
    
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    freq_merged = pd.concat([freq_now, freq_fr])
    #freq_now = freq_now.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)
    
    #rename intervals
    freq_merged = freq_merged.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)

    freq_merged.to_csv(f'{filepath}/freq_full_{forecast_date}.csv', index = False)
    #freq_now.to_csv(f'{filepath}/freq_nowcast_{forecast_date}.csv', index = False)
    #freq_fr.to_csv(f'{filepath}/freq_forecast_{forecast_date}.csv', index = False)
  

In [11]:
#export posterior frequencies forecast and nowcast
def save_mlr_freq(samples, variant_data, ps, forecast_date, name, filepath):

    
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    #Get the freq dates nowcast period
    nowcast_dates = variant_data.dates[-60:]
    
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    freq_merged = pd.concat([freq_now, freq_fr])
    #rename intervals
    freq_merged = freq_merged.rename(columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, inplace = False)

    freq_merged.to_csv(f'{filepath}/freq_full_{forecast_date}.csv', index = False)

In [12]:
def forecast_frequencies(samples, mlr, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]
    
    X = mlr.make_ols_feature(start=last_T, stop=last_T + forecast_L)
    # Posterior beta
    beta = jnp.array(samples["beta"])
    
    # Matrix multiplication by sample
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)

In [13]:
#Growth Advantage

def get_gr_adv(samples, variant_data, ps, name, date, site, filepath):

    growth_adv = pd.DataFrame(ef.posterior.get_site_by_variant(samples, variant_data, ps, name=name, site=site, forecast=False))

    #get growth advantage relative to BA.1 Omicron 21K
    
    
    growth_adv.to_csv(f'{filepath}/full_growth_advantages_{date}.csv',index = False, header = True)
    
    return None


In [14]:
def get_mlr_gr_adv(samples, variant_data, ps, name,date, filepath):
    

    growth_adv = pd.DataFrame(ef.posterior.get_growth_advantage(samples, variant_data, ps, name=name, rel_to="Omicron 21L"))
    growth_adv.to_csv(f'{filepath}/full_growth_advantages_{date}.csv',index = False, header=True)

    return None

## VI. Running models and exporting results

In [15]:
models = ['GARW','MLR', 'Piantham', 'FGA']
#'Piantham','GARW',, 'FGA'

#function to input models 
def model_run(model):
    for location in locations:
        for date in dates:
            #Output file path
            filepath = f"./plot-est2/cast_estimates_full_{model}/{location}"    
            if not os.path.exists(filepath):
                os.makedirs(filepath)
            #read sequences
            raw_seq = pd.read_csv(f"data/time_stamped/{date}/seq_counts_{date}.tsv", sep="\t")
            raw_cases = pd.read_csv(f"data/time_stamped/{date}/case_counts_{date}.tsv", sep="\t")
            raw_cases = raw_cases[raw_cases.location == location]
            raw_seq = raw_seq[raw_seq.location == location]
            if len(raw_cases)==0:
                continue

            #defining variant data (freq and seq)
            if model == 'MLR':
                variant_data = ef.VariantFrequencies(raw_seq, pivot="Omicron 21L")
                posterior = svi.fit(model_type[model], variant_data)
                posterior.samples["freq_forecast"] = forecast_frequencies(posterior.samples, model_type[model], 74)
                samples = posterior.samples
                save_mlr_freq(posterior.samples, variant_data, ps, date, location, filepath)
                samples = posterior.samples
                get_mlr_gr_adv(samples, variant_data, ps, location, date, filepath)
                
            if model == 'Piantham':
                variant_data = ef.VariantFrequencies(raw_seq, pivot="Omicron 21L")
                posterior = svi.fit(model_type[model], variant_data)
                samples = posterior.samples
                save_freq(posterior.samples, variant_data, ps, date, location, filepath)
                get_mlr_gr_adv(samples, variant_data, ps, location, date, filepath)
            if model == 'FGA':
                variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq, pivot="Omicron 21L")
                posterior = svi.fit(model_type[model], variant_data)
                #saving posterior frequencies
                samples = posterior.samples
                save_freq(posterior.samples, variant_data, ps, date, location, filepath)
                #saving growth advantages
                get_mlr_gr_adv(samples, variant_data, ps, location, date, filepath)      
            if model == 'GARW':
                variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq, pivot="Omicron 21L")
                posterior = svi.fit(model_type[model], variant_data)
                #saving posterior frequencies
                samples = posterior.samples
                save_freq(posterior.samples, variant_data, ps, date, location, filepath)
                #saving growth advantages
                get_gr_adv(samples, variant_data, ps, location, date, "ga", filepath)

    return None

for model in models:
    model_run(model)



Misc Code

In [21]:
forecast_dates = pd.to_datetime(pd.unique(pd.date_range(start = date, periods=60, freq ='D') + Day(1))).astype(str)
#print(forecast_dates)
for date in dates:
    nowcast_dates = pd.to_datetime(pd.unique(pd.date_range(end = date, periods=60, freq ='D'))).astype(str)
    #print(nowcast_dates)

recent_dates = pd.to_datetime(pd.unique(pd.date_range(end = date, periods=7, freq ='D') - Day(14))).astype(str)
#print(recent_dates)
#defining prediction period for nowcasting and forecasting
pred_dates = pd.to_datetime(pd.unique(pd.date_range(start = date, periods=120, freq ='D') - Day(60))).astype(str)
#print(pred_dates)
for d in pred_dates:
    #defining recent_dates (7) from pivot date for each date in pred_dates
    for i in range(60):
        recent_dates = pd.to_datetime(pd.unique(pd.date_range(end = date, periods=7, freq ='D') - Day(i))).astype(str)
        #print(recent_dates)
    #Computing the mean frequency for recent dates
    seq_count_mean = seq_count_date[seq_count_date.date.isin(recent_dates)].groupby(["variant", "location"])["freq"].mean().reset_index()

test = pd.to_datetime(pd.unique(pd.date_range(start = date, periods= 60, freq ='D') - Day(14))).astype(str)
print(test)

