In [1]:
import pandas as pd
import numpy as np
from jax import vmap
import jax.numpy as jnp
from jax.nn import softmax
import evofr as ef

In [2]:
# Getting data
raw_seq = pd.read_csv("../data/pango_location-variant-sequence-counts.tsv", sep="\t")
raw_seq = raw_seq[raw_seq.location == "USA"]
variant_frequencies = ef.VariantFrequencies(raw_seq)

In [3]:
# Defining model
mlr = ef.MultinomialLogisticRegression(tau=4.2)

In [4]:
# Defining inference method
inference_method = ef.InferFullRank(iters=30_000, lr=0.01, num_samples=100)

In [5]:
# Fitting model
posterior = inference_method.fit(mlr, variant_frequencies)
samples = posterior.samples

In [6]:
def forecast_frequencies(samples, mlr, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]
    X = mlr.make_ols_feature(start=last_T, stop=last_T + forecast_L)
    
    # Posterior beta
    beta = jnp.array(samples["beta"])
    
    # Matrix multiplication by sample
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)

forecast_L = 30
samples["freq_forecast"] = forecast_frequencies(samples, mlr, forecast_L)

In [7]:
ga = samples["ga"]

In [8]:
ga.shape

(100, 163)

In [9]:
type(ga)

jaxlib.xla_extension.DeviceArray

In [10]:
posterior.data.var_names[:10]

['B.1.1.529',
 'BA.1',
 'BA.1.1',
 'BA.1.1.1',
 'BA.1.1.10',
 'BA.1.1.14',
 'BA.1.1.16',
 'BA.1.1.18',
 'BA.1.1.2',
 'BA.1.15']

In [11]:
ga_df = pd.DataFrame(
    ef.posterior.get_growth_advantage(samples, posterior.data, ps=[0.8], name="USA", rel_to="BA.2")
)
ga_df.head(10)

Unnamed: 0,location,variant,median_ga,ga_upper_80,ga_lower_80
0,USA,B.1.1.529,1.600136,1.6094162,1.5920101
1,USA,BA.1,0.6470126,0.65158707,0.64206755
2,USA,BA.1.1,0.676453,0.67779565,0.67521626
3,USA,BA.1.1.1,0.6751293,0.6891342,0.6581392
4,USA,BA.1.1.10,0.7596299,0.7697921,0.74802005
5,USA,BA.1.1.14,0.6809459,0.6963756,0.66183364
6,USA,BA.1.1.16,0.73752725,0.7538507,0.72013676
7,USA,BA.1.1.18,0.6395916,0.6443845,0.634317
8,USA,BA.1.1.2,0.61307806,0.6288856,0.5919038
9,USA,BA.1.15,0.6287744,0.6336167,0.6225041


In [12]:
ga_df.to_csv('growth_advantages.tsv', sep="\t")