# Mapping of sequential to reference site numbers

In [None]:
# get variables from `snakemake`
ref_fasta = snakemake.input.ref
variant_fasta = snakemake.input.variant
mutations_csv = snakemake.output.mutations
numbering_csv = snakemake.output.numbering

In [None]:
import io
import os
import re
import string
import subprocess
import tempfile

import Bio.SeqIO

import pandas as pd

import yaml

In [None]:
ref = str(Bio.SeqIO.read(ref_fasta, "fasta").seq)

seq = Bio.SeqIO.read(variant_fasta, "fasta").seq
if len(seq) % 3 != 0:
    raise ValueError(f"{len(seq)=} not multiple of 3")
seq = str(seq.translate())
if "*" == seq[-1]:
    seq = seq[: -1]
if "*" in seq:
    raise ValueError(f"premature stop codons in {seq=}")

In [None]:
with tempfile.NamedTemporaryFile("w") as f:
    f.write(f">sequence\n{seq}\n>reference\n{ref}\n")
    f.flush()
    res = subprocess.run(["mafft", f.name], capture_output=True)
    
alignment = {
    s.id: str(s.seq)
    for s in Bio.SeqIO.parse(io.StringIO(res.stdout.decode("utf-8")), "fasta")
}

In [None]:
records = []
deletions = []
site = ref_site = ref_letter = 0
for aa, ref_aa in zip(alignment["sequence"], alignment["reference"]):
    if aa == "-":
        assert ref_aa != "-"
        ref_site += 1
        ref_letter = 0
        deletions.append(f"{ref_aa}{ref_site}-")
    elif ref_aa == "-":
        assert aa != "-"
        site += 1
        ref_letter += 1
        assert ref_letter < len(string.ascii_lowercase)
        letter = string.ascii_lowercase[ref_letter - 1]
        records.append((site, f"{ref_site}{letter}", aa, ref_aa))
    else:
        site += 1
        ref_site += 1
        ref_letter = 0
        records.append((site, ref_site, aa, ref_aa))
        
df = pd.DataFrame(
    records, columns=["sequential_site", "reference_site", "aa", "reference_aa"],
)

mutations = deletions + (
    df
    .query("aa != reference_aa")
    .assign(
        mutation=lambda x: (
            x["reference_aa"]
            + x["reference_site"].astype(str)
            + x["aa"]
        )
    )
    ["mutation"]
    .tolist()
)

mutations_df = (
    pd.DataFrame({"mutation": mutations})
    .assign(
        mutation_type=lambda x: x["mutation"].map(
            lambda m: (
                "deletion" if m.endswith("-")
                else "insertion" if m.startswith("-")
                else "substitution"
            )
        )
    )
    .reset_index()
    .sort_values(["mutation_type", "index"])
    .drop(columns="index")
)

print("Here are the number of mutations:")
display(
    mutations_df.groupby("mutation_type").aggregate(n=pd.NamedAgg("mutation", "count"))
)

mutations_df.to_csv(mutations_csv, index=False)

df.to_csv(numbering_csv, index=False)