# Counts of variants

This notebook analyzes the counts of the different variants and outputs variant-count files.

Import Python modules:

In [None]:
import os

import Bio.SeqIO

import altair as alt

import dms_variants.codonvarianttable

import pandas as pd

import yaml

Get configuration information:

In [None]:
# If you are running notebook interactively rather than in pipeline that handles
# working directories, you may have to first `os.chdir` to appropriate directory.

os.chdir("../test_example")

with open("config.yaml") as f:
    config = yaml.safe_load(f)

Read information on the barcode runs:

In [None]:
barcode_runs = pd.read_csv(config["processed_barcode_runs"])

assert len(barcode_runs) == barcode_runs["library_sample"].nunique()

## Barcode sequencing alignment stats
Read the "fates" of barcode reads (e.g., did they align?):

In [None]:
fates = (
    pd.concat([
        pd.read_csv(os.path.join(config["barcode_fates_dir"], f"{library_sample}.csv"))
        for library_sample in barcode_runs["library_sample"]
    ])
    .merge(barcode_runs, on=["library", "sample"], validate="many_to_one")
    .drop(columns=["fastq_R1", "notes"])
    .assign(
        valid=lambda x: x["fate"] == "valid barcode",
        not_valid=lambda x: ~x["valid"],
    )
)

Plot the barcode fates in an interactive plot:

In [None]:
list(reversed(["library", "date", "sample_type", "virus_batch", "antibody", "exclude_after_counts"]))

In [None]:
selections = [
    alt.selection_point(
        fields=[col],
        bind=alt.binding_select(
            options=[None] + fates[col].dropna().unique().tolist(),
            labels=["all"] + [str(x) for x in fates[col].dropna().unique()],
            name=col,
        )
    )
    for col in [
        'exclude_after_counts',
        'antibody',
        'virus_batch',
        'sample_type',
        'date',
        'library',
    ]
]

sample_types = barcode_runs["sample_type"].unique().tolist()

fate_chart = (
    alt.Chart(fates)
    .encode(
        x=alt.X("count", axis=alt.Axis(format=".2g"), scale=alt.Scale(nice=False)),
        y=alt.Y("library_sample", title=None),
        color=alt.Color(
            "fate",
            scale=alt.Scale(domain=list(reversed(sorted(fates["fate"].unique())))),
        ),
        order="not_valid",
        tooltip=[
            alt.Tooltip(c, format=".3g") if c == "count" else c
            for c in fates.columns
            if c not in ["valid", "not_valid", "library_sample", "sample"]
        ],
    )
    .mark_bar()
    .properties(width=300, height=alt.Step(13))
    .configure_axis(labelLimit=500)
)

for selection in selections:
    fate_chart = fate_chart.add_parameter(selection).transform_filter(selection)

display(fate_chart)

## Barcode counts
Read the counts of all valid and invalid barcodes:

In [None]:
counts = (
    pd.concat([
        pd.read_csv(os.path.join(subdir, f"{library_sample}.csv")).assign(valid=valid)
        for library_sample in barcode_runs["library_sample"]
        for (subdir, valid) in [
            (config["barcode_counts_dir"], True),
            (config["barcode_counts_invalid_dir"], False),
        ]
    ])
)

Get the average number of counts per barcode, separating valid and invalid ones:

In [None]:
avg_counts = (
    counts
    .groupby(["library", "sample", "valid"], as_index=False)
    .aggregate(
        mean_counts=pd.NamedAgg("count", "mean"),
        n_barcodes=pd.NamedAgg("barcode", "count"),
    )
    .merge(
        barcode_runs,
        on=["library", "sample"],
        validate="many_to_one",
    )
    .assign(valid=lambda x: x["valid"].astype(str))  # to make interactive chart work
    .drop(columns=["fastq_R1", "notes"])
)

Plot the average counts per barcode for both valid and invalid ones.
The plot below is interactive: you can use the drop downs at bottom to select specific subsets, mouseover points for details, and click the legend to show only valid or invalid counts:

In [None]:
validities = avg_counts["valid"].unique()
valid_selection = alt.selection_multi(
    fields=["valid"], bind="legend",
)

avg_counts_chart = (
    alt.Chart(avg_counts)
    .encode(
        x=alt.X("mean_counts", title="average counts per barcode"),
        y=alt.Y("library_sample", title=None),
        color=alt.Color("valid", title="valid barcode", scale=alt.Scale(domain=validities)),
        shape=alt.Shape("valid", scale=alt.Scale(domain=validities)),
        opacity=alt.condition(valid_selection, alt.value(0.8), alt.value(0)),
        tooltip=[alt.Tooltip(c, format=".3g") if c in {"mean_counts", "n_barcodes"} else c
                 for c in avg_counts.columns if c != "library_sample"],
    )
    .mark_point(filled=True, size=50)
    .properties(width=200, height=alt.Step(15))
    .configure_axis(labelLimit=500)
    .add_parameter(*selections, valid_selection)
)
for selection in selections:
    avg_counts_chart = avg_counts_chart.transform_filter(selection)

avg_counts_chart

## Invalid barcodes and possible library-to-library contamination
Look at the top invalid barcodes.

First get all invalid barcodes along with where their counts rank among **all** (valid and invalid) barcodes for that library / sample:

In [None]:
ranked_invalid = (
    counts
    .assign(overall_rank=lambda x: (
        x.groupby(["library", "sample"])
        ["count"]
        .transform("rank", ascending=False, method="first")
        .astype(int)
        )
    )
    .query("not valid")
    .merge(barcode_runs)
    .sort_values(["library", "overall_rank"])
)

How many invalid barcodes are in top 10 most abundant?
Plot this, and then list top barcode for any sample with an invalid in the top 10:

In [None]:
# get top invalid barcode and number invalid in top 10
invalid_topn = (
    ranked_invalid
    .groupby(["library_sample", "library", "sample"])
    .aggregate(
        invalid_barcodes_in_top_10=pd.NamedAgg("overall_rank", lambda s: len(s[s <= 10])),
        top_invalid_rank=pd.NamedAgg("overall_rank", "first"),
        top_invalid_barcode=pd.NamedAgg("barcode", "first"),
    )
    .reset_index()
    .merge(barcode_runs.drop(columns=["fastq_R1", "notes"]), how="left")
)

# plot number invalid barcodes in top 10
topn_chart = (
    alt.Chart(invalid_topn)
    .encode(
        x=alt.X(
            "invalid_barcodes_in_top_10",
            title="invalid barcodes in top 10",
            scale=alt.Scale(domain=(0, 10)),
        ),
        y=alt.Y("library_sample", title=None),
        tooltip=[alt.Tooltip(c, format=".3g") if c in {"mean_counts", "n_barcodes"} else c
                 for c in invalid_topn.columns if c != "library_sample"],
    )
    .mark_bar()
    .properties(width=200, height=alt.Step(15))
    .configure_axis(labelLimit=500)
    .add_parameter(*selections, valid_selection)
)
for selection in selections:
    topn_chart = topn_chart.transform_filter(selection)
display(topn_chart)

# top barcode for any samples with an invalid in top 10
print("\nTop barcode for samples with an invalid barcode in top 10:")
display(
    invalid_topn
    .query("invalid_barcodes_in_top_10 > 0")
    .reset_index(drop=True)
    [["library", "sample", "invalid_barcodes_in_top_10", "top_invalid_rank", "top_invalid_barcode"]]
)

Get which libraries each barcode maps to:

In [None]:
barcodes_by_library = (
    pd.read_csv(config["codon_variants"])
    .groupby(["barcode", "target"], as_index=False)
    .aggregate(
        libraries_w_barcode=pd.NamedAgg("library", lambda s: ", ".join(s.unique())),
        n_libraries_w_barcode=pd.NamedAgg("library", "nunique"),
    )
)

display(
    barcodes_by_library
    .groupby(["target", "libraries_w_barcode"])
    .aggregate(n_barcodes=pd.NamedAgg("barcode", "count"))
)

Now look at the overall barcode counts for each sample and see how many map to the expected library or to some other library.
Having many barcodes that map to a different library can be an indication of contamination unless there is a lot of expected overlap between the two libraries (which would be indicated in table above):

In [None]:
counts_by_library = (
    counts
    .merge(barcodes_by_library, on="barcode", validate="many_to_one")
    .groupby(
        ["library", "sample", "libraries_w_barcode", "target", "n_libraries_w_barcode"],
        as_index=False,
    )
    .aggregate(n_counts=pd.NamedAgg("count", "sum"))
    .assign(
        frac_counts=lambda x: x["n_counts"] / x.groupby(["library", "sample"])["n_counts"].transform("sum"),
    )
    .merge(barcode_runs)
    .assign(category=lambda x: x["libraries_w_barcode"].where(x["target"] == "gene", x["target"]))
    .drop(columns=["fastq_R1", "notes", "antibody_concentration", "target", "libraries_w_barcode"])
)

Plot which libraries overall barcode counts map to for each sample:

In [None]:
ordered_cats = (
    counts_by_library.sort_values(["n_libraries_w_barcode", "category"])
    ["category"].unique().tolist()
)

category_selection = alt.selection_point(fields=["category"], bind="legend")

counts_by_library_chart = (
    alt.Chart(
        counts_by_library
        .assign(order=lambda x: x["category"].map(lambda s: ordered_cats.index(s)))
    )
    .encode(
        x=alt.X("frac_counts", scale=alt.Scale(domain=[0, 1])),
        y=alt.Y("library_sample", title=None),
        color=alt.Color("category", scale=alt.Scale(domain=ordered_cats)),
        order="order",
        tooltip=[
            alt.Tooltip(c, format=".2g") if c in {"n_counts", "frac_counts"} else c
            for c in counts_by_library.columns
            if c not in {"library_sample"}
        ],
    )
    .mark_bar()
    .properties(width=250, height=alt.Step(15))
    .configure_axis(labelLimit=500)
    .add_parameter(*selections, category_selection)
    .transform_filter(category_selection)
)
for selection in selections:
    counts_by_library_chart = counts_by_library_chart.transform_filter(selection)
    
counts_by_library_chart

## Get `CodonVariantTable` for valid variant counts
Get the variant counts for samples with sufficient valid barcode counts, and raise an error if there are any samples without sufficient counts that don't havea `exclude_after_counts` specified:

In [None]:
print(f"Requiring {config['min_avg_counts']=} average counts per variant\n")

valid_counts = counts.query("valid").drop(columns="valid")

avg_counts = (
    valid_counts
    .groupby(["library", "sample"], as_index=False)
    .aggregate(avg_counts=pd.NamedAgg("count", "mean"))
    .assign(adequate_counts=lambda x: x["avg_counts"] >= config["min_avg_counts"])
    .merge(barcode_runs, how="left", validate="one_to_one")
    .drop(columns=["fastq_R1", "notes"])
)

avg_counts_chart = (
    alt.Chart(avg_counts)
    .encode(
        x=alt.X("avg_counts"),
        y=alt.Y("library_sample", title=None),
        color="exclude_after_counts",
        tooltip=[
            alt.Tooltip(c, format=".2g") if c in {"avg_counts"} else c
            for c in avg_counts.columns
            if c not in {"library_sample"}
        ],
    )
    .mark_bar()
    .properties(width=250, height=alt.Step(15))
    .configure_axis(labelLimit=500)
    .add_parameter(*selections)
)
for selection in selections:
    avg_counts_chart = avg_counts_chart.transform_filter(selection)
display(avg_counts_chart)

insufficient_counts = (
    avg_counts
    .query("(not adequate_counts) & (exclude_after_counts != 'yes')")
)
if len(insufficient_counts):
    raise ValueError(
        "Samples w/o `exclude_after_counts` specified have insufficient counts:\n"
        + str(insufficient_counts[["library", "sample", "avg_counts"]])
    )

Now create a [CodonVariantTable](https://jbloomlab.github.io/dms_variants/dms_variants.codonvarianttable.html) for the samples not specified for exclusion:

In [None]:
geneseq = str(Bio.SeqIO.read(config["gene_sequence_codon"], "fasta").seq)

variants = dms_variants.codonvarianttable.CodonVariantTable(
    barcode_variant_file=config["codon_variants"],
    geneseq=geneseq,
    allowgaps=True,
    substitutions_are_codon=True,
    primary_target="gene",
    substitutions_col="codon_substitutions",
)

variants.add_sample_counts_df(
    valid_counts
    .merge(
        barcode_runs[["library", "sample", "exclude_after_counts"]],
        validate="many_to_one",
        how="left",
        on=["library", "sample"],
    )
    .query("exclude_after_counts == 'no'")
)

## Mutations in each sample