In [1]:
import pandas as pd
import numpy as np
from jax import vmap
import jax.numpy as jnp
from jax.nn import softmax
import evofr as ef

In [2]:
# Getting data
raw_seq = pd.read_csv("../data/pango_location-variant-sequence-counts.tsv", sep="\t")
raw_seq = raw_seq[raw_seq.location == "USA"]
variant_frequencies = ef.VariantFrequencies(raw_seq)

_How do I set the baseline Pango lineage in the MLR model? For this analysis, I'd like all estimates to be relative to BA.2._

In [3]:
# Defining model
mlr = ef.MultinomialLogisticRegression(tau=4.2)

In [4]:
# Defining inference method
inference_method = ef.InferFullRank(iters=30_000, lr=0.01, num_samples=100)

In [5]:
# Fitting model
posterior = inference_method.fit(mlr, variant_frequencies)
samples = posterior.samples

In [6]:
def forecast_frequencies(samples, mlr, forecast_L):
    """
    Use posterior beta to forecast posterior frequenicies.
    """
    
    # Making feature matrix for forecasting
    last_T = samples["freq"].shape[1]
    X = mlr.make_ols_feature(start=last_T, stop=last_T + forecast_L)
    
    # Posterior beta
    beta = jnp.array(samples["beta"])
    
    # Matrix multiplication by sample
    dot_by_sample = vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
    logits = dot_by_sample(X, beta) # Logit frequencies by variant
    return softmax(logits, axis=-1)

forecast_L = 30
samples["freq_forecast"] = forecast_frequencies(samples, mlr, forecast_L)

In [7]:
ga = samples["ga"]

_This has growth advantages for 141 Pango lineages as expected. However, I don't know what order these are in. I can't inspect `posterior` to get a list of variant labels._

In [8]:
ga.shape

(100, 136)

In [9]:
type(ga)

jaxlib.xla_extension.DeviceArray

In [10]:
posterior.data.var_names[:10]

['B.1.1.529',
 'BA.1',
 'BA.1.1',
 'BA.1.1.1',
 'BA.1.1.10',
 'BA.1.1.14',
 'BA.1.1.16',
 'BA.1.1.18',
 'BA.1.1.2',
 'BA.1.15']

In [11]:
ga_df = pd.DataFrame(
    ef.posterior.get_growth_advantage(samples, posterior.data, ps=[0.8], name="USA", rel_to="BA.2")
)
ga_df.head(10)

Unnamed: 0,location,variant,median_ga,ga_upper_80,ga_lower_80
0,USA,B.1.1.529,1.6048825,1.6147666,1.595982
1,USA,BA.1,0.64536273,0.65045166,0.6396261
2,USA,BA.1.1,0.67618734,0.67786837,0.6743834
3,USA,BA.1.1.1,0.6812716,0.70418715,0.65972924
4,USA,BA.1.1.10,0.7644537,0.78079414,0.7476678
5,USA,BA.1.1.14,0.68550634,0.6990193,0.66925836
6,USA,BA.1.1.16,0.73289734,0.7463902,0.71903044
7,USA,BA.1.1.18,0.6426995,0.6492533,0.6380768
8,USA,BA.1.1.2,0.61856157,0.6465513,0.5933481
9,USA,BA.1.15,0.631348,0.63662994,0.62601155


In [12]:
ga_df.to_csv('growth_advantages.tsv', sep="\t")