# Plot N to S ratio along time

In [None]:
import pickle
import pandas as pd
import torch
import glob
from pyrocov import mutrans, pangolin, stats
import matplotlib.pyplot as plt
import numpy as np
from pyrocov.sarscov2 import GENE_TO_POSITION, GENE_STRUCTURE, aa_mutation_to_position
import datetime

In [None]:
# Reasonable values might be week (7), fortnight (14), or month (28)
TIMESTEP = 14  # in days
GENERATION_TIME = 5.5  # in days
START_DATE = "2019-12-01"


def date_range(stop):
    start = datetime.datetime.strptime(START_DATE, "%Y-%m-%d")
    step = datetime.timedelta(days=TIMESTEP)
    return np.array([start + step * t for t in range(stop)])

def date_map(i):
    start = datetime.datetime.strptime(START_DATE, "%Y-%m-%d")
    step = datetime.timedelta(days=TIMESTEP)
    return start + step * i // TIMESTEP

In [None]:
# Load the entire constant dataset
max_num_clades = 3000
min_num_mutations = 1
min_region_size = 50
ambiguous = False
columns_filename=f"results/columns.{max_num_clades}.pkl"
features_filename=f"results/features.{max_num_clades}.{min_num_mutations}.pt"

In [None]:
input_dataset = mutrans.load_gisaid_data(
        device="cpu",
        columns_filename=columns_filename,
        features_filename=features_filename,
        min_region_size=min_region_size
)

In [None]:
try:
    with open("results/nextclade.counts.pkl", "rb") as f:
        all_mutations = pickle.load(f)
except Exception:
    with open("results/stats.pkl", "rb") as f:
        all_mutations = pickle.load(f)["aaSubstitutions"]
print(f"Loaded {len(all_mutations)} mutations")

In [None]:
def get_NS_ratio(fit):
    mutations = input_dataset['mutations']
    position = torch.tensor([aa_mutation_to_position(m) for m in mutations])
    gene_id = {gene_name: i for i, gene_name in enumerate(GENE_TO_POSITION)}
    gene_ids = torch.tensor([gene_id[m.split(":")[0]] for m in mutations])
    
    N_to_S = fit["mean"]["coef"][gene_ids == gene_id['N']].clamp(min=0).mean() / fit["mean"]["coef"][gene_ids == gene_id['S']].clamp(min=0).mean()
        
    return N_to_S.item()

In [None]:
aa_mutation_to_position('N:P207S') # +- 100

In [None]:
aa_mutation_to_position('S:T478K') # +- 100

In [None]:
def get_NS_peak_ratio(fit):
    mutations = input_dataset['mutations']
    position = torch.tensor([aa_mutation_to_position(m) for m in mutations])
    gene_id = {gene_name: i for i, gene_name in enumerate(GENE_TO_POSITION)}
    gene_ids = torch.tensor([gene_id[m.split(":")[0]] for m in mutations])
    
    N_start = 28792
    N_end = 28992 
    
    S_start = 22894
    S_end = 23094
    
    N_value = fit["mean"]["coef"][(position > N_start) & (position < N_end)].mean() 
    S_value = fit["mean"]["coef"][(position > S_start) & (position < S_end)].mean() 
    
    N_to_S = N_value / S_value
        
    return N_to_S.item()

In [None]:
model_files = glob.glob('results/mutrans.svi.3000.1.50.coef_scale=0.05.reparam-localinit.full.10001.0.05.0.1.10.0.200.12.*..pt')

In [None]:
model_files

In [None]:
days = []
ns_ratio = []
ns_peak_ratio = []

for model_filename in model_files:
    fit = torch.load(model_filename, map_location = 'cpu')
    model_days = model_filename.split('.')[18]
    days.append(model_days)
    ns_ratio.append(get_NS_ratio(fit))
    ns_peak_ratio.append(get_NS_peak_ratio(fit))
    
days = list(int(x) for x in days)

df = pd.DataFrame({'days': days, 'ns_ratio': ns_ratio, 'ns_peak_ratio': ns_peak_ratio})

In [None]:
df['mapped_time'] = list(date_map(int(x)) for x in df['days'] )

In [None]:
df.plot.scatter(x='mapped_time',y='ns_ratio', title="N / S ratio mean", rot=25)
plt.savefig('paper/N_S_ratio.png', dpi=300)

In [None]:
df.plot.scatter(x='mapped_time',y='ns_peak_ratio', title="N / S peak ratio mean", rot = 90)
plt.savefig('paper/N_S_peak_ratio.png', dpi=300)

# Stack plot

In [None]:
def get_gene_ratios(fit):
    mutations = input_dataset['mutations']
    position = torch.tensor([aa_mutation_to_position(m) for m in mutations])
    gene_id = {gene_name: i for i, gene_name in enumerate(GENE_TO_POSITION)}
    gene_ids = torch.tensor([gene_id[m.split(":")[0]] for m in mutations])
    
    genes = sorted(gene_id)
    components = torch.stack([
        fit["mean"]["coef"][gene_ids == gene_id[gene]].clamp(min=0).mean()
        for gene in genes
    ])
    components /= components.sum(0)
        
    return genes, components

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import pyrocov.mutrans as mutrans

In [None]:
model_files

In [None]:
series = []
days = []

for model_filename in model_files:
    fit = torch.load(model_filename, map_location = 'cpu')
    model_days = model_filename.split('.')[18]
    days.append(model_days)
    genes, components = get_gene_ratios(fit)
    series.append(components)
    
series = torch.stack(series)

In [None]:
times = list(date_map(int(x)) for x in days)
idx = np.argsort(times)

In [None]:
import seaborn as sns

In [None]:
pal = sns.color_palette('tab20')

In [None]:
fig, ax = plt.subplots()
plt.stackplot(list( times[i] for i in idx), series[idx,].T, labels = genes, colors=pal)
plt.legend(loc='upper left',prop={'size': 8})
ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%b %Y"))
plt.xticks(rotation=45);
fig.tight_layout()
plt.savefig('paper/gene_ratios.png', dpi = 300)