# Compute and plot identities from alignment

Import Python modules:

In [None]:
import itertools
import os

import altair as alt

import Bio.SeqIO

import pandas as pd

import yaml

_ = alt.data_transformers.disable_max_rows()

Get variables from `snakemake`:

In [None]:
if "snakemake" in globals() or "snakemake" in locals():
    # from snakemake
    inputfasta = snakemake.input.fasta
    chartfile = snakemake.output.chart
    csvfile = snakemake.output.csv
    alignment_ref = snakemake.params.alignment_ref
    ref_regions = snakemake.params.ref_regions
    viruses = snakemake.params.viruses
else:
    # manually define for running interactively
    inputfasta = "../results/viruses/all_viruses_aligned.fasta"
    chartfile = "../results/identities/identities.html"
    csvfile = "../results/identities/identities.csv"
    with open("../config.yaml") as f:
        config = yaml.safe_load(f)
    alignment_ref = config["alignment_ref"]
    ref_regions = config["ref_regions"]
    viruses = config["viruses"]

Get list of all viruses and their accessions or URLs:

In [None]:
accessions = {
    virus: acc
    for db, db_dict in viruses.items()
    for virus, acc in db_dict.items()
}

accessions

Read alignment:

In [None]:
alignment = {s.id: str(s.seq) for s in Bio.SeqIO.parse(inputfasta, "fasta")}

Get reference to alignment numbering (1-based):

In [None]:
alignment_to_ref_numbering = {}
i_ref = 0
for i_alignment, nt in enumerate(alignment[alignment_ref], start=1):
    if nt != "-":
        i_ref += 1
    alignment_to_ref_numbering[i_alignment] = i_ref
    
ref_to_alignment_numbering = {y:x for x, y in alignment_to_ref_numbering.items()}

Compute identities for each gene:

In [None]:
records = []
for gene, (start, end) in ref_regions.items():
    for i, virus_1 in enumerate(alignment):
        seq_1 = alignment[virus_1][start - 1: end]
        for virus_2 in list(alignment)[i: ]:
            seq_2 = alignment[virus_2][start - 1: end]
            ident = n_w_gaps = n_no_gaps = 0
            for nt1, nt2 in zip(seq_1, seq_2):
                if nt1 == nt2 == "-":
                    pass
                elif (nt1 == "-") or (nt2 == "-"):
                    n_w_gaps += 1
                elif nt1 == nt2:
                    ident += 1
                    n_no_gaps += 1
                    n_w_gaps += 1
                else:
                    n_no_gaps += 1
                    n_w_gaps += 1
            records.append((gene, virus_1, virus_2, ident, n_no_gaps, n_w_gaps))
            if virus_1 != virus_2:
                records.append((gene, virus_2, virus_1, ident, n_no_gaps, n_w_gaps))
        
df = (
    pd.DataFrame(
        records,
        columns=[
            "gene",
            "virus_1",
            "virus_2",
            "n identities",
            "no",
            "yes",
        ],
    )
    .melt(
        id_vars=["gene", "virus_1", "virus_2", "n identities"],
        value_vars=["no", "yes"],
        var_name="count gaps as mismatches",
        value_name="n sites",
    )
    .query("`n sites` > 0")
    .assign(
        percent_identity=lambda x: 100 * x["n identities"] / x["n sites"],
        virus_1_accession=lambda x: x["virus_1"].map(accessions),
        virus_2_accession=lambda x: x["virus_2"].map(accessions),
    )
)

print(f"Writing identites to {csvfile}")

os.makedirs(os.path.dirname(csvfile), exist_ok=True)
df.to_csv(csvfile, index=False, float_format="%.5g")

df

Make plot:

In [None]:
gene_selection = alt.selection_single(
    fields=["gene"],
    bind=alt.binding_select(name="gene", options=list(ref_regions)),
    init={"gene": "spike"},
)

gaps_selection = alt.selection_single(
    fields=["count gaps as mismatches"],
    bind=alt.binding_radio(
        name="count gaps as mismatches",
        options=df["count gaps as mismatches"].unique().tolist(),
    ),
    init={"count gaps as mismatches": "no"},
)

w_respect_selection = alt.selection_single(
    fields=["focal"],
    name="w_respect_selection",
    bind=alt.binding_select(
        name="highlight identities with respect to",
        options=sorted(set(df["virus_1"]).union(set(df["virus_2"]))),
    ),
    init={"focal": alignment_ref},
)

base_heatmap = (
    alt.Chart(df)
    .transform_filter(gene_selection)
    .transform_filter(gaps_selection)
    .transform_calculate(
        is_focal_1="datum.virus_1 == w_respect_selection.focal",
        is_focal_2="datum.virus_2 == w_respect_selection.focal",
        ident_to_focal_1=alt.expr.if_(
            alt.datum["is_focal_2"],
            alt.datum["percent_identity"],
            0,
        ),
        ident_to_focal_2=alt.expr.if_(
            alt.datum["is_focal_1"],
            alt.datum["percent_identity"],
            0,
        ),
        is_focal=alt.datum["is_focal_1"] | alt.datum["is_focal_2"],
    )
    .transform_joinaggregate(
        ident_to_focal_1="max(ident_to_focal_1)",
        groupby=["virus_1"],
    )
    .transform_joinaggregate(
        ident_to_focal_2="max(ident_to_focal_2)",
        groupby=["virus_2"],
    )
    .encode(
        x=alt.X("virus_1", title=None, sort=alt.SortField("ident_to_focal_1", order="descending")),
        y=alt.Y("virus_2", title=None, sort=alt.SortField("ident_to_focal_2", order="descending")),
        tooltip=[
            alt.Tooltip(c, format=".2f") if df[c].dtype == float else c for c in df.columns
        ],
        strokeWidth=alt.condition(alt.datum["is_focal"], alt.value(2), alt.value(0.5)),
        strokeOpacity=alt.condition(alt.datum["is_focal"], alt.value(1), alt.value(0.4)),
    )
    .properties(width=alt.Step(20), height=alt.Step(20))
)

heatmap = (
    base_heatmap
    .encode(fill=alt.Fill("percent_identity", title=None))
    .mark_rect(stroke="black")
)

text = (
    base_heatmap
    .encode(
        text=alt.Text("percent_identity", format=".3g"),
    )
    .mark_text(color="gray", size=8)
)

chart = (
    (heatmap + text)
    .add_selection(gene_selection)
    .add_selection(gaps_selection)
    .add_selection(w_respect_selection)
)

chart.save(chartfile)

chart