# Compute and plot identities from alignment

Import Python modules:

In [1]:
import itertools
import os

import altair as alt

import Bio.SeqIO

import pandas as pd

import yaml

_ = alt.data_transformers.disable_max_rows()

Get variables from `snakemake`:

In [2]:
if "snakemake" in globals() or "snakemake" in locals():
    # from snakemake
    inputfasta = snakemake.input.fasta
    chartfile = snakemake.output.chart
    csvfile = snakemake.output.csv
    alignment_ref = snakemake.params.alignment_ref
    ref_regions = snakemake.params.ref_regions
else:
    # manually define for running interactively
    inputfasta = "../results/viruses/all_viruses_aligned.fasta"
    chartfile = "../results/identities/identities.html"
    csvfile = "../results/identities/identities.csv"
    with open("../config.yaml") as f:
        config = yaml.safe_load(f)
    alignment_ref = config["alignment_ref"]
    ref_regions = config["ref_regions"]

Read alignment:

In [3]:
alignment = {s.id: str(s.seq) for s in Bio.SeqIO.parse(inputfasta, "fasta")}

Get reference to alignment numbering (1-based):

In [4]:
alignment_to_ref_numbering = {}
i_ref = 0
for i_alignment, nt in enumerate(alignment[alignment_ref], start=1):
    if nt != "-":
        i_ref += 1
    alignment_to_ref_numbering[i_alignment] = i_ref
    
ref_to_alignment_numbering = {y:x for x, y in alignment_to_ref_numbering.items()}

Compute identities for each gene:

In [16]:
records = []
for gene, (start, end) in ref_regions.items():
    for i, virus_1 in enumerate(alignment):
        seq_1 = alignment[virus_1][start - 1: end]
        for virus_2 in list(alignment)[i: ]:
            seq_2 = alignment[virus_2][start - 1: end]
            ident = n_w_gaps = n_no_gaps = 0
            for nt1, nt2 in zip(seq_1, seq_2):
                if nt1 == nt2 == "-":
                    pass
                elif (nt1 == "-") or (nt2 == "-"):
                    n_w_gaps += 1
                elif nt1 == nt2:
                    ident += 1
                    n_no_gaps += 1
                    n_w_gaps += 1
                else:
                    n_no_gaps += 1
                    n_w_gaps += 1
            records.append((gene, virus_1, virus_2, ident, n_no_gaps, n_w_gaps))
            if virus_1 != virus_2:
                records.append((gene, virus_2, virus_1, ident, n_no_gaps, n_w_gaps))
        
df = (
    pd.DataFrame(
        records,
        columns=[
            "gene",
            "virus_1",
            "virus_2",
            "n identities",
            "no",
            "yes",
        ],
    )
    .melt(
        id_vars=["gene", "virus_1", "virus_2", "n identities"],
        value_vars=["no", "yes"],
        var_name="count gaps as mismatches",
        value_name="n sites",
    )
    .query("`n sites` > 0")
    .assign(percent_identity=lambda x: 100 * x["n identities"] / x["n sites"])
)

print(f"Writing identites to {csvfile}")

os.makedirs(os.path.dirname(csvfile), exist_ok=True)
df.to_csv(csvfile, index=False, float_format="%.5g")

df

Writing identites to ../results/identities/identities.csv


Unnamed: 0,gene,virus_1,virus_2,n identities,count gaps as mismatches,n sites,percent_identity
0,ORF1a,SARS-CoV-2_Wuhan-Hu-1,SARS-CoV-2_Wuhan-Hu-1,13100,no,13100,100.000000
1,ORF1a,SARS-CoV-2_Wuhan-Hu-1,RaTG13,12579,no,13097,96.044896
2,ORF1a,RaTG13,SARS-CoV-2_Wuhan-Hu-1,12579,no,13097,96.044896
3,ORF1a,SARS-CoV-2_Wuhan-Hu-1,SARS-CoV-1,9881,no,12998,76.019388
4,ORF1a,SARS-CoV-1,SARS-CoV-2_Wuhan-Hu-1,9881,no,12998,76.019388
...,...,...,...,...,...,...,...
39361,nsp16,GD_Pangolin,RShSTT182,820,yes,894,91.722595
39362,nsp16,RShSTT200,RShSTT200,894,yes,894,100.000000
39363,nsp16,RShSTT200,GD_Pangolin,820,yes,894,91.722595
39364,nsp16,GD_Pangolin,RShSTT200,820,yes,894,91.722595


Make plot:

In [66]:
gene_selection = alt.selection_single(
    fields=["gene"],
    bind=alt.binding_select(name="gene", options=list(ref_regions)),
    init={"gene": "spike"},
)

gaps_selection = alt.selection_single(
    fields=["count gaps as mismatches"],
    bind=alt.binding_radio(
        name="count gaps as mismatches",
        options=df["count gaps as mismatches"].unique().tolist(),
    ),
    init={"count gaps as mismatches": "no"},
)

w_respect_selection = alt.selection_single(
    fields=["focal"],
    name="w_respect_selection",
    bind=alt.binding_select(
        name="highlight identities with respect to",
        options=sorted(set(df["virus_1"]).union(set(df["virus_2"]))),
    ),
    init={"focal": alignment_ref},
)

base_heatmap = (
    alt.Chart(df)
    .add_selection(gene_selection)
    .add_selection(gaps_selection)
    .add_selection(w_respect_selection)
    .transform_filter(gene_selection)
    .transform_filter(gaps_selection)
    .transform_calculate(
        is_focal_1="datum.virus_1 == w_respect_selection.focal",
        is_focal_2="datum.virus_2 == w_respect_selection.focal",
        ident_to_focal_1=alt.expr.if_(
            alt.datum["is_focal_2"],
            alt.datum["percent_identity"],
            0,
        ),
        ident_to_focal_2=alt.expr.if_(
            alt.datum["is_focal_1"],
            alt.datum["percent_identity"],
            0,
        ),
        is_focal=alt.datum["is_focal_1"] | alt.datum["is_focal_2"],
    )
    .transform_joinaggregate(
        ident_to_focal_1="max(ident_to_focal_1)",
        groupby=["virus_1"],
    )
    .transform_joinaggregate(
        ident_to_focal_2="max(ident_to_focal_2)",
        groupby=["virus_2"],
    )
    .encode(
        x=alt.X("virus_1", title=None, sort=alt.SortField("ident_to_focal_1", order="descending")),
        y=alt.Y("virus_2", title=None, sort=alt.SortField("ident_to_focal_2", order="descending")),
        fill=alt.Fill("percent_identity", title=None),
        tooltip=[
            alt.Tooltip(c, format=".2f") if df[c].dtype == float else c for c in df.columns
        ],
        strokeWidth=alt.condition(alt.datum["is_focal"], alt.value(2), alt.value(0.5)),
        strokeOpacity=alt.condition(alt.datum["is_focal"], alt.value(1), alt.value(0.4)),
    )
    .mark_rect(stroke="black")
    .properties(width=alt.Step(16), height=alt.Step(16))
)

base_heatmap

In [33]:
alt

Help on package altair.expr in altair:

NAME
    altair.expr - Tools for creating transform & filter expressions with a python syntax

PACKAGE CONTENTS
    consts
    core
    funcs
    tests (package)

DATA
    E = E
        E: the transcendental number e (alias to Math.E)
    
    LN10 = LN10
        LN10: the natural log of 10 (alias to Math.LN10)
    
    LN2 = LN2
        LN2: the natural log of 2 (alias to Math.LN2)
    
    LOG10E = LOG10E
        LOG10E: the base 10 logarithm e (alias to Math.LOG10E)
    
    LOG2E = LOG2E
        LOG2E: the base 2 logarithm of e (alias to Math.LOG2E)
    
    NaN = NaN
        NaN: not a number (same as JavaScript literal NaN)
    
    PI = PI
        PI: the transcendental number pi (alias to Math.PI)
    
    SQRT1_2 = SQRT1_2
        SQRT1_2: the square root of 0.5 (alias to Math.SQRT1_2)
    
    SQRT2 = SQRT2
        SQRT2: the square root of 2 (alias to Math.SQRT1_2)
    
    abs = <function expr.abs(*args)>
        abs(*args)
        Re