# Count mutations from `matUtils` translated mutations

Import Python modules:

In [None]:
import Bio.SeqIO

import numpy

import pandas as pd

import yaml

Get variables from `snakemake`:

In [None]:
max_nt_mutations = snakemake.params.max_nt_mutations
max_reversions_to_ref = snakemake.params.max_reversions_to_ref
max_reversions_to_founder = snakemake.params.max_reversions_to_clade_founder
input_tsv = snakemake.input.tsv
input_nt_mut_csv = snakemake.input.nt_mut_csv
ref_fasta = snakemake.input.ref_fasta
usher_masked_sites_yaml = snakemake.input.usher_masked_sites
site_mask_csv = snakemake.input.site_mask
clade_founder_fasta = snakemake.input.clade_founder_fasta
clade = snakemake.wildcards.clade
sites_to_exclude = snakemake.params.sites_to_exclude
site_include_range = snakemake.params.site_include_range
exclude_ref_to_founder_muts = snakemake.params.exclude_ref_to_founder_muts
ref_to_founder_muts_csv = snakemake.input.ref_to_founder_muts
output_csv = snakemake.output.csv

Get reference and founder sequence:

In [None]:
ref = str(Bio.SeqIO.read(ref_fasta, "fasta").seq)
founder = str(Bio.SeqIO.read(clade_founder_fasta, "fasta").seq)

Get the sites and mutations to exclude:

In [None]:
if sites_to_exclude:
    sites_to_exclude = set(sites_to_exclude)
else:
    sites_to_exclude = set()
print(f"There are {len(sites_to_exclude)} sites to exclude")

masked_sites = set(pd.read_csv(site_mask_csv)["site"])
print(f"There are {len(masked_sites)} masked sites")
sites_to_exclude = sites_to_exclude.union(masked_sites)

if exclude_ref_to_founder_muts:
    muts_to_exclude = set(pd.read_csv(ref_to_founder_muts_csv)["mutation"])
else:
    muts_to_exclude = set()
print(f"There are {len(muts_to_exclude)} mutations to exclude")

with open(usher_masked_sites_yaml) as f:
    usher_masked_sites = yaml.safe_load(f)
for mask, mask_dict in usher_masked_sites.items():
    if clade in mask_dict["clades"]:
        sites = mask_dict["sites"]
        print(f"Applying UShER mask {mask} of {len(sites)} sites")
        sites_to_exclude = sites_to_exclude.union(sites)

Process mutations:

In [None]:
translated_mat = pd.read_csv(input_tsv, sep="\t")

# non-coding as well as coding
nt_muts = pd.read_csv(input_nt_mut_csv).rename(
    columns={"nt_mutations": "nt_mutations_coding_and_noncoding"}
)

assert len(translated_mat) <= len(nt_muts)
assert nt_muts["node_id"].nunique() >= translated_mat["node_id"].nunique()
assert set(nt_muts["node_id"]).issuperset(set(translated_mat["node_id"]))

# expected final columns
final_cols = [
    'protein',
    'aa_mutation',
    'nt_mutation',
    'codon_change',
    'synonymous',
    'noncoding',
    'nt_site',
    'reference_nt',
    'clade_founder_nt',
    'exclude',
    'count',
    'count_terminal',
    'count_non_terminal',
    'mean_log_size',
]

def get_noncoding(row):
    """Get just the noncoding mutations in a DataFrame row."""
    coding = row["nt_mutations"]
    both = row["nt_mutations_coding_and_noncoding"]
    assert set(coding).issubset(both)
    return sorted(set(both) - set(coding))

if len(translated_mat) == 0 or len(nt_muts) == 0:
    # no data, just define an empty data frame
    mutation_counts = pd.DataFrame(columns=final_cols)

else:
    # first filter nodes by non-coding mutation criterion
    filtered_nodes = (
        translated_mat
        # this next merge will drop nodes with only noncoding mutations
        .merge(nt_muts, on="node_id", validate="one_to_one")
        .query("not nt_mutations.str.contains(',')")
        .assign(
            nt_mutations=lambda x: x["nt_mutations"].str.split(";"),
            nt_mutations_coding_and_noncoding=lambda x: x["nt_mutations_coding_and_noncoding"].str.split(";"),
            codon_changes=lambda x: x["codon_changes"].str.split(";"),
            aa_mutations=lambda x: x["aa_mutations"].str.split(";"),
            noncoding_nt_mutations=lambda x: x.apply(get_noncoding, axis=1),
            n_nt_mutations=lambda x: x["nt_mutations"].map(lambda ms: len(set(ms))),
            n_reversions_to_ref=lambda x: x["nt_mutations"].map(
                lambda ms: sum(m[-1] == ref[int(m[1:-1]) - 1] for m in set(ms))
            ),
            n_reversions_to_founder=lambda x: x["nt_mutations"].map(
                lambda ms: sum(m[-1] == founder[int(m[1:-1]) - 1] for m in set(ms))
            ),
            is_terminal=lambda x: x["leaves_sharing_mutations"] == 1,
            log_size=lambda x: numpy.log(x["leaves_sharing_mutations"].clip(lower=1)),
        )
        .drop(columns="nt_mutations_coding_and_noncoding")
        .query("n_reversions_to_ref <= @max_reversions_to_ref")
        .query("n_reversions_to_founder <= @max_reversions_to_founder")
        .query("n_nt_mutations <= @max_nt_mutations")
    )

    # get information on coding mutations
    coding_mutations = (
        filtered_nodes
        .drop(columns="noncoding_nt_mutations")
        .explode(["aa_mutations", "nt_mutations", "codon_changes"])
        .assign(
            protein=lambda x: x["aa_mutations"].str.split(":").str[0],
            aa_mutation=lambda x: x["aa_mutations"].str.split(":").str[1],
            synonymous=lambda x: x["aa_mutation"].map(lambda m: m[0] == m[-1]),
        )
        .rename(columns={"nt_mutations": "nt_mutation", "codon_changes": "codon_change"})
        .groupby(["node_id", "nt_mutation", "is_terminal", "log_size"], as_index=False)
        .aggregate(
            protein=pd.NamedAgg("protein", lambda s: ";".join(s)),
            aa_mutation=pd.NamedAgg("aa_mutation", lambda s: ";".join(s)),
            codon_change=pd.NamedAgg("codon_change", lambda s: ";".join(s)),
            synonymous=pd.NamedAgg("synonymous", "all"),
        )
        .assign(noncoding=False)
    )

    # get information on non-coding mutations
    noncoding_mutations = (
        filtered_nodes
        .drop(columns=["aa_mutations", "nt_mutations", "codon_changes"])
        .explode("noncoding_nt_mutations")
        .rename(columns={"noncoding_nt_mutations": "nt_mutation"})
        .query("nt_mutation.notnull()")
        .assign(
            aa_mutation="noncoding",
            codon_change="noncoding",
            protein="noncoding",
            synonymous=False,
            noncoding=True,
        )
        [coding_mutations.columns]
    )

    # combine noncoding and coding mutations
    mutations = pd.concat([coding_mutations, noncoding_mutations],ignore_index=True)

    mutation_counts_tidy = (
        pd.concat(
            [
                mutations.assign(terminal_nodes_only=False),
                mutations.query("is_terminal").assign(terminal_nodes_only=True),
            ],
        )
        .drop(columns="is_terminal")
        .groupby(
            [
                "terminal_nodes_only",
                "protein",
                "aa_mutation",
                "nt_mutation",
                "codon_change",
                "synonymous",
                "noncoding",
            ],
            as_index=False,
            dropna=False,
        )
        .aggregate(
            count=pd.NamedAgg("node_id", "count"),
            mean_log_size=pd.NamedAgg("log_size", "mean"),
        )
        .sort_values("count", ascending=False)
        .assign(
            nt_site=lambda x: x["nt_mutation"].str[1:-1].astype(int),
            reference_nt=lambda x: x["nt_site"].map(lambda r: ref[r - 1]),
            clade_founder_nt=lambda x: x["nt_site"].map(lambda r: founder[r - 1]),
            exclude=lambda x: (
                x["nt_site"].isin(sites_to_exclude)
                | x["nt_mutation"].isin(muts_to_exclude)
                | (x["nt_site"] < site_include_range["start"])
                | (x["nt_site"] > site_include_range["end"])
            ),
        )
    )

    mutation_counts = (
        mutation_counts_tidy
        .assign(terminal_nodes_only=lambda x: x["terminal_nodes_only"].map({False: "count", True: "count_terminal"}))
        .pivot_table(
            index=[
                c for c in mutation_counts_tidy.columns
                if c not in {"count", "terminal_nodes_only", "mean_log_size"}
            ],
            values="count",
            columns="terminal_nodes_only",
            fill_value=0,
        )
        .assign(
            # first add empty columns if empty data frame
            count=lambda x: x["count"] if "count" in x.columns else 0,
            count_terminal=lambda x: x["count_terminal"] if "count_terminal" in x.columns else 0,
            # compute non-terminal counts
            count_non_terminal=lambda x: x["count"] - x["count_terminal"],
        )
        .sort_values("count", ascending=False)
        .reset_index()
    )

    # add log node size
    mutation_counts = mutation_counts.merge(
        (
            mutation_counts_tidy
            .query("not terminal_nodes_only")
            [["protein", "aa_mutation", "nt_mutation", "codon_change", "mean_log_size"]]
        ),
        on=["protein", "aa_mutation", "nt_mutation", "codon_change"],
        how="outer",
        validate="one_to_one",
    )
    assert mutation_counts["mean_log_size"].notnull().all()

assert mutation_counts.columns.tolist() == final_cols
    
mutation_counts.to_csv(output_csv, index=False)

mutation_counts