# Compare empirical and predicted equilbrium frequencies 


In [1]:
if "snakemake" in locals() or "snakemake" in globals():
    # get variables from `snakemake`
    other_viruses = snakemake.input.other_viruses
    sars2_predicted = snakemake.input.sars2_predicted
    sars2_empirical= snakemake.input.sars2_empirical
    plotfile = snakemake.output.plot
    
else:
    # define variables if running interactively
    other_viruses = "../results/other_virus_spectra/other_virus_spectra.json"
    sars2_predicted = "../results/equilibrium_freqs/predicted_equilibrium_freqs.csv"
    sars2_empirical= "../results/equilibrium_freqs/empirical_equilibrium_freqs.csv"
    plotfile = "../results/other_virus_spectra/spectra_comparison.html"

In [22]:
import json
import math

import altair as alt

import pandas as pd

In [7]:
with open(other_viruses) as f:
    other_viruses_d = {
        virus: {k: d[k] for k in ["equilibrium_frequencies", "empirical_frequencies"]}
        for virus, d in json.load(f).items()
    }
    
other_df = pd.concat(
    [
        pd.DataFrame(freq_d.items(), columns=["nucleotide", "freq"]).assign(
            virus=virus, freq_type=freq_type,
        )
        for virus, d in other_viruses_d.items()
        for freq_type, freq_d in d.items()
    ],
    ignore_index=True,
).assign(
    freq_type=lambda x: x["freq_type"].map(
        {"equilibrium_frequencies": "predicted", "empirical_frequencies": "empirical"}
    ),
).pivot_table(
    index=["virus", "nucleotide"], columns="freq_type", values="freq",
).reset_index()

other_df

freq_type,virus,nucleotide,empirical,predicted
0,WNV,A,0.290245,0.289728
1,WNV,C,0.265111,0.209878
2,WNV,G,0.245362,0.244591
3,WNV,T,0.199282,0.255803
4,denv1,A,0.398754,0.406151
5,denv1,C,0.21433,0.190453
6,denv1,G,0.198131,0.211187
7,denv1,T,0.188785,0.192209
8,denv2,A,0.412987,0.40867
9,denv2,C,0.220779,0.21492


In [9]:
predicted_df = (
    pd.read_csv(sars2_predicted)
    .query("virus in ['20A (B.1)', '21I (Delta)', '22B (Omicron BA.5)']")
    .melt(id_vars="virus", var_name="nucleotide", value_name="predicted")
    .assign(virus=lambda x: "SARS-CoV-2 clade " + x["virus"])
)

predicted_df

Unnamed: 0,virus,nucleotide,predicted
0,SARS-CoV-2 clade 20A (B.1),A,0.1622
1,SARS-CoV-2 clade 21I (Delta),A,0.14284
2,SARS-CoV-2 clade 22B (Omicron BA.5),A,0.23271
3,SARS-CoV-2 clade 20A (B.1),C,0.07212
4,SARS-CoV-2 clade 21I (Delta),C,0.08094
5,SARS-CoV-2 clade 22B (Omicron BA.5),C,0.07349
6,SARS-CoV-2 clade 20A (B.1),G,0.0276
7,SARS-CoV-2 clade 21I (Delta),G,0.03045
8,SARS-CoV-2 clade 22B (Omicron BA.5),G,0.04548
9,SARS-CoV-2 clade 20A (B.1),T,0.73808


In [10]:
empirical_df = (
    pd.read_csv(sars2_empirical)
    .query("virus == 'SARS-CoV-2 (MN908947)'")
    .melt(id_vars="virus", var_name="nucleotide", value_name="empirical")
    .drop(columns="virus")
)

empirical_df

Unnamed: 0,nucleotide,empirical
0,A,0.28966
1,C,0.13692
2,G,0.06492
3,T,0.5085


In [12]:
sars2_df = predicted_df.merge(empirical_df, validate="many_to_one")

sars2_df

Unnamed: 0,virus,nucleotide,predicted,empirical
0,SARS-CoV-2 clade 20A (B.1),A,0.1622,0.28966
1,SARS-CoV-2 clade 21I (Delta),A,0.14284,0.28966
2,SARS-CoV-2 clade 22B (Omicron BA.5),A,0.23271,0.28966
3,SARS-CoV-2 clade 20A (B.1),C,0.07212,0.13692
4,SARS-CoV-2 clade 21I (Delta),C,0.08094,0.13692
5,SARS-CoV-2 clade 22B (Omicron BA.5),C,0.07349,0.13692
6,SARS-CoV-2 clade 20A (B.1),G,0.0276,0.06492
7,SARS-CoV-2 clade 21I (Delta),G,0.03045,0.06492
8,SARS-CoV-2 clade 22B (Omicron BA.5),G,0.04548,0.06492
9,SARS-CoV-2 clade 20A (B.1),T,0.73808,0.5085


In [31]:
def cosine_similarity(a, b):
    return (a * b).sum() / math.sqrt((a * a).sum() * (b * b).sum())

rename_virus = {
    "denv3": "dengue virus 3",
    "denv2": "dengue virus 2",
    "rsv-b": "RSV-B",
    "rsv-a": "RSV-A",
    "denv4": "dengue virus 4",
    "denv1": "dengue virus 1",
    "flu_vic": "influenza B Victoria",
    "flu_yam": "influenza B Yamagata",
    "evA71": "enterovirus A71",
    "flu_h1n1pdm": "influenza A H1N1",
    "flu_h3n2": "influenza A H3N2",
    "WNV": "West Nile virus",
    "evD68": "enterovirus D68",
}

similarity_df = (
    pd.concat([sars2_df, other_df])
    .groupby("virus")
    .apply(lambda g: cosine_similarity(g["predicted"], g["empirical"]))
    .rename("similarity")
    .sort_values(ascending=False)
    .reset_index()
    .assign(
        virus=lambda x: x["virus"].map(
            lambda v: rename_virus[v] if v in rename_virus else v
        )
    )
)

similarity_df

Unnamed: 0,virus,similarity
0,dengue virus 3,0.999839
1,dengue virus 2,0.999824
2,RSV-B,0.999743
3,RSV-A,0.999182
4,dengue virus 4,0.998817
5,dengue virus 1,0.998579
6,influenza B Victoria,0.997791
7,influenza B Yamagata,0.996416
8,enterovirus A71,0.996077
9,influenza A H1N1,0.994748


In [63]:
similarity_chart = (
    alt.Chart(similarity_df.assign(is_sars=lambda x: x["virus"].str.contains("SARS")))
    .encode(
        x=alt.X(
            "similarity",
            scale=alt.Scale(zero=False),
            title="similarity predicted to empirical nucleotide frequencies",
        ),
        y=alt.Y(
            "virus",
            sort=similarity_df["virus"].tolist(),
            title=None,
        ),
        color=alt.Color(
            "is_sars",
            legend=None,
            scale=alt.Scale(range=["gray", "#CC79A7"]),
        ),
        shape=alt.Shape("is_sars", legend=None),
    )
    .mark_point(size=70, filled=True, opacity=1)
    .properties(height=alt.Step(16), width=270)
)

similarity_chart.save(plotfile)

similarity_chart