In [1]:
import evofr as ef
import os
import pandas as pd

from jax import vmap
import jax.numpy as jnp
from jax.nn import softmax

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
OBS_DATES = ['2022-01-01', '2022-02-01','2022-03-01',
         '2022-04-01','2022-05-01','2022-06-01',
         '2022-07-01','2022-08-01','2022-09-01',
         '2022-10-01','2022-11-01','2022-12-01']
LOCATIONS = ["United Kingdom"]
PS = [.5]

In [3]:
MODELS = {'MLR': ef.MultinomialLogisticRegression(tau=4.2)}
inference_method = ef.InferMAP(iters=30_000, lr=4e-2)

In [4]:
# Helper functions for forecasting and saving MLR
def forecast_frequencies(samples, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]
    
    X = MODELS["MLR"].make_ols_feature(start=last_T, stop=last_T + forecast_L)
    # Posterior beta
    beta = jnp.array(samples["beta"])
    
    # Matrix multiplication by sample
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)

def forecast_from_obs_date(posterior, obs_date, days_from_obs):
    # Compute obs_date + days_from_obs
    n_days_to_present = (pd.to_datetime(obs_date) - posterior.data.dates[-1]).days
    n_days_to_forecast = n_days_to_present + days_from_obs

    # Forecast
    posterior.samples["freq_forecast"] = forecast_frequencies(posterior.samples, n_days_to_forecast)
    return None

def save_mlr_freq(samples, variant_data, ps, obs_date, thres, name, filepath):
    
    freq_now = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = False))
    nowcast_dates = variant_data.dates[-60:]
    
    freq_now = freq_now[freq_now['date'].isin(nowcast_dates)]
    freq_fr = pd.DataFrame(ef.get_freq(samples, variant_data, ps, name=name, forecast = True))
    
    freq_merged = pd.concat([freq_now, freq_fr])
    freq_merged = freq_merged.rename(
        columns = {'median_freq':'median_freq_nowcast','freq_upper_95':'freq_nowcast_upper_95'}, 
        inplace = False)

    freq_merged.to_csv(f'{filepath}/frequencies_{obs_date}_{thres}.tsv', index = False, sep="\t")
    return None

In [5]:
def model_run(model, location):
    for date in OBS_DATES:
        # Output file path
        filepath = f"../estimates/down_scaled/thresholding/{model}/{location}"    
        if not os.path.exists(filepath):
            os.makedirs(filepath)
            
        # Prep data
        raw_seq = pd.read_csv(f"../data/down_scaled/thresholding/{date}/seq_counts_{date}.tsv", sep="\t")
        raw_seq = raw_seq[raw_seq.location == location]
        threses = pd.unique(raw_seq["thres"])
        for thres in threses:
            raw_seq_thres = raw_seq[raw_seq.thres == thres].copy()
            if len(raw_seq_thres)==0:
                continue

            # Defining data for model run
            variant_data = ef.VariantFrequencies(raw_seq_thres, pivot="Omicron 21L")
            posterior = inference_method.fit(MODELS[model], variant_data)

            # Forecast past observation date
            forecast_from_obs_date(posterior, date, 30)

            # Save frequencies
            save_mlr_freq(posterior.samples, variant_data, PS, date, int(thres), location, filepath)
    return None

for location in LOCATIONS:
    model_run("MLR",  location)

