In [None]:
import os
import pandas as pd
from pandas.tseries.offsets import Day, BDay
import evofr as ef
import numpy as np
from jax import vmap
from jax.nn import softmax
import jax.numpy as jnp
import matplotlib.pyplot as plt

## I. Specification of analysis period (Observation dates) and geographical settings

In [None]:
import yaml

with open("../config.yaml", 'r') as config:
    config = yaml.safe_load(config)
    
dates = config["main"]["estimation_dates"]
locations = config["main"]["locations"]
models = config["main"]["models"]

## II. Assigning parameters and forecasting period

In [None]:
#Parameters and forecasting period
seed_L = 14
forecast_L = 14
forecast_new = 30
ps = [0.95, 0.8, 0.5]

# Get delays
v_names = ['Delta', 
           'Omicron 21L', 
           'Omicron 21K', 
           'Omicron 22A', 
           'Omicron 22B', 
           'Omicron 22C',
           'Omicron 22D',
           'Omicron 22E',
           'Omicron 23A',
           'other']

gen = ef.pad_delays(
    [ef.discretise_gamma(mn=4.4, std=1.2), # Delta
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21L
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 21K #3.1 std 1.2 
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22A
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22B
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22C
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22D
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 22E
     ef.discretise_gamma(mn=3.1, std=1.2), # Omicron 23A
     ef.discretise_gamma(mn=4.4, std=1.2)] # Other
    )

delays = ef.pad_delays([ef.discretise_lognorm(mn=3.1, std=1.0)])  

basis_fn = ef.Spline(order=4, k=10)

## III. FGA, GARW, MLR, and Piantham models definition

In [None]:
# Defining models
model_type = dict()


#Fixed Growth Advantage model for variants
model_type['FGA'] = ef.RenewalModel(gen, delays, seed_L, forecast_new,
                RLik = ef.FixedGA(0.1), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

#Varying Growth Advantage Random Walk Model
model_type['GARW'] = ef.RenewalModel(gen, delays, seed_L, forecast_new,
                RLik = ef.GARW(0.1,0.01, prior_family='Normal'), #Likelihood on effective reproduction number (GARW depend on R and gen time)
                CLik = ef.ZINegBinomCases(0.05), #Case Likelihood
                SLik = ef.DirMultinomialSeq(100), #Sequence Likelihood
                v_names = v_names,
                basis_fn = basis_fn)

#Multinomial Logistic regression model
model_type['MLR'] = ef.MultinomialLogisticRegression(tau=4.2)

#Piantham model
model_type['Piantham'] = ef.models.PianthamModel(gen = ef.discretise_gamma(mn=3.1, std=1.2), forecast_L = forecast_new+14)

# defining inference method
#svi = ef.InferMAP(iters = 50_000, lr = 4e-3)
svi = ef.InferFullRank(iters = 50_000, lr = 4e-3, num_samples=500)

## IV. Naive Model specification

In [None]:
def naive_forecast(seq_count_date, pivot, period=60):
    
    # Defining forecast and nowcast dates from pivot date
    forecast_dates = pd.to_datetime(pd.unique(pd.date_range(start = pivot, periods=period, freq ='D') + Day(1))).astype(str)
    nowcast_dates = pd.to_datetime(pd.unique(pd.date_range(end = pivot, periods=period, freq ='D'))).astype(str)

    # Defining prediction period for nowcasting and forecasting
    pred_dates = nowcast_dates.union(forecast_dates)
  
    # Computing frequency of variants for each location
    seq_count_date['total_seq'] = seq_count_date.groupby(['date', 'location'])['sequences'].transform('sum')
    seq_count_date['freq'] = seq_count_date['sequences']/seq_count_date['total_seq']
    
    # Adding pred_dates to date column for each location and variant
    sc_s = []
    for d in pred_dates:
        # Defining recent_dates (7) from pivot date for each date in pred_dates
        recent_dates = pd.Series(pd.to_datetime(seq_count_date[seq_count_date.date < d].date).unique()).nlargest(n=7).astype(str)

        # Computing the mean frequency for recent dates
        seq_count_mean = seq_count_date[seq_count_date.date.isin(recent_dates)].groupby(["variant", "location"])["freq"].mean().reset_index()
    
        sc_ = seq_count_mean.copy()
        
        # Adding dates column
        sc_["date"] = d
        sc_s.append(sc_)

    sc = pd.concat(sc_s).sort_values(by=["location", "variant", "date"])
    
    # Adding nowcast and forecast columns
    sc['median_freq_nowcast'] = sc['freq']
    sc['median_freq_forecast'] = sc['freq']
    
    # Matching dates for nowcast and forecast
    sc.loc[sc.date.isin(forecast_dates),'median_freq_nowcast'] = np.nan
    sc.loc[sc.date.isin(nowcast_dates),'median_freq_forecast'] = np.nan
    return sc.reset_index(drop=True)
    
for pivot in dates:
    seq_count_date = pd.read_csv(f"../data/time_stamped/{pivot}/seq_counts_{pivot}.tsv", sep="\t")
    naive_pred = naive_forecast(seq_count_date, pivot)
    
    # Create files for estimates for each country and pivot date
    for location in locations:
        filepath = f'../estimates/naive/{location}/'
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        naive_pred[naive_pred.location == location].to_csv(filepath + f"freq_full_{pivot}.tsv", sep="\t", index = False)

## V. Helper functions to export variant frequencies and growth advantages

In [None]:
#export posterior frequencies forecast and no forecast
def save_freq(samples, variant_data, ps, forecast_date, name, filepath):
    # Get nowcast frequencies
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    nowcast_dates = variant_data.dates[-60:]    
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    
    # Get forecasted frequencies
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    
    # Merge and export file
    freq_merged = pd.concat([freq_now, freq_fr])
    freq_merged = freq_merged.rename(columns = {'median_freq':'median_freq_nowcast'}, inplace = False)
    freq_merged.to_csv(f'{filepath}/freq_full_{forecast_date}.tsv', sep="\t", index = False)
    return None

In [None]:
#export posterior frequencies forecast and nowcast
def save_mlr_freq(samples, variant_data, ps, forecast_date, name, filepath):
    
    # Get nowcast frequencies
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    nowcast_dates = variant_data.dates[-60:]
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    
    # Get forecasted frequencies
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    
    # Merge and export file
    freq_merged = pd.concat([freq_now, freq_fr])
    freq_merged = freq_merged.rename(columns = {'median_freq':'median_freq_nowcast'}, inplace = False)
    freq_merged.to_csv(f'{filepath}/freq_full_{forecast_date}.tsv', sep="\t", index = False)
    return None

In [None]:
def forecast_frequencies(samples, mlr, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]    
    X = mlr.make_ols_feature(start=last_T, stop=last_T + forecast_L)
    
    # Matrix multiplication by sample
    beta = jnp.array(samples["beta"])
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)

In [None]:
def get_time_varying_growth_advantage(samples, variant_data, ps, name, date, site, filepath):
    growth_adv = pd.DataFrame(ef.posterior.get_site_by_variant(samples, variant_data, ps, name=name, site=site, forecast=False))    
    growth_adv.to_csv(f'{filepath}/full_growth_advantages_{date}.tsv', sep="\t", index = False, header = True)
    return None

def get_fixed_growth_advantage(samples, variant_data, ps, name,date, filepath):
    growth_adv = pd.DataFrame(ef.posterior.get_growth_advantage(samples, variant_data, ps, name=name, rel_to="Omicron 21L"))
    growth_adv.to_csv(f'{filepath}/full_growth_advantages_{date}.tsv', sep="\t", index = False, header=True)
    return None

## VI. Running models and exporting results

In [None]:
def run_model(model, location, date, raw_seq, raw_cases):      
    # Model fitting and exporting files
    if model == 'MLR':
        # Fit models and forecast
        variant_data = ef.VariantFrequencies(raw_seq, pivot="Omicron 21L")
        posterior = svi.fit(model_type[model], variant_data)
        posterior.samples["freq_forecast"] = forecast_frequencies(posterior.samples, model_type[model], 74)

        # Save frequencies and growth advantages
        save_mlr_freq(posterior.samples, variant_data, ps, date, location, filepath)
        get_fixed_growth_advantage(posterior.samples, variant_data, ps, location, date, filepath)
                
    if model == 'Piantham':
        # Fit models and forecast
        variant_data = ef.VariantFrequencies(raw_seq, pivot="Omicron 21L")
        posterior = svi.fit(model_type[model], variant_data)
                
        # Save frequencies and growth advantages
        save_freq(posterior.samples, variant_data, ps, date, location, filepath)
        get_fixed_growth_advantage(posterior.samples, variant_data, ps, location, date, filepath)
    
    if model == 'FGA':
        # Fit models and forecast
        variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq, pivot="Omicron 21L")
        posterior = svi.fit(model_type[model], variant_data)
                
        # Save frequencies and growth advantages
        save_freq(posterior.samples, variant_data, ps, date, location, filepath)
        get_fixed_growth_advantage(posterior.samples, variant_data, ps, location, date, filepath)  
    
    if model == 'GARW':
        # Fit models and forecast
        variant_data = ef.CaseFrequencyData(raw_cases=raw_cases, raw_seq=raw_seq, pivot="Omicron 21L")
        posterior = svi.fit(model_type[model], variant_data)
                
        # Save frequencies and growth advantages
        save_freq(posterior.samples, variant_data, ps, date, location, filepath)
        get_time_varying_growth_advantage(posterior.samples, variant_data, ps, location, date, "ga", filepath)
    return None

for date in dates:
    for location in locations:            
        # Read data
        raw_seq = pd.read_csv(f"../data/time_stamped/{date}/seq_counts_{date}.tsv", sep="\t")
        raw_cases = pd.read_csv(f"../data/time_stamped/{date}/case_counts_{date}.tsv", sep="\t")
        
        # Filter data
        raw_cases = raw_cases[raw_cases.location == location]
        raw_seq = raw_seq[raw_seq.location == location]

        # Check if data present
        if len(raw_cases)==0 or len(raw_seq) == 0:
            continue
            
        for model in models:
            # Create output file path
            filepath = f"../estimates/{model}/{location}"    
            if not os.path.exists(filepath):
                os.makedirs(filepath)

            run_model(model, location, date, raw_seq, raw_cases)