# Analyze viral and mitochondrial reads and make some summary plots

Get variables from `snakemake`:

In [None]:
counts_csv = snakemake.input.counts  # input data

collection_date_of_interest = snakemake.params.collection_date_of_interest

# only keep viruses with at least this many total reads among samples, or in the list to keep
min_total_virus_reads = snakemake.params.min_total_virus_reads

# which species to keep on subset chart
species_for_subset_chart = snakemake.params.species_for_subset_chart

# for log scales, show zeros as the minimum non-zero value divided by this factor
log_scale_axis_min_factor = snakemake.params.log_scale_axis_min_factor

# output files
virus_counts_html = snakemake.output.viral_counts_html
viral_reads_per_sample_html = snakemake.output.viral_reads_per_sample_html
viral_subset_species_corr_html = snakemake.output.viral_subset_species_corr_html
viral_all_species_corr_html = snakemake.output.viral_all_species_corr_html

Import Python modules:

In [None]:
import altair as alt

import numpy

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Quick summary of viral read counts
Read data and plot quick summary of viral counts on all dates and the collection date of interest.
The chart is interactive, so you can modify the dates of interest or cutoffs for how many total reads are needed for a virus to be shown:

In [None]:
counts = (
    pd.read_csv(counts_csv)
    .assign(
        percent_reads_from_virus=lambda x: 100 * x["virus_reads"] / x["preprocessed_reads"],
        percent_reads_from_species=lambda x: 100 * x["species_reads"] / x["preprocessed_reads"],
        percent_mito_reads_from_species=lambda x: 100 * x["species_reads"] / x.groupby(["sample", "virus_id"])["species_reads"].transform("sum"),
    )
    .rename(columns={"Collection date": "collection_date"})
)

assert "all" not in set(counts["collection_date"])
collection_dates = ["all"] + sorted(counts["collection_date"].unique())
counts = pd.concat([counts.assign(collection_date="all"), counts], ignore_index=True)

virus_counts = (
    counts
    [["collection_date", "sample", "virus_id", "virus_name", "virus_reads"]]
    .drop_duplicates()
    .assign(has_virus=lambda x: x["virus_reads"] > 0)
    .groupby(["collection_date", "virus_id", "virus_name"], as_index=False)
    .aggregate(
        total_virus_reads=pd.NamedAgg("virus_reads", "sum"),
        max_virus_reads=pd.NamedAgg("virus_reads", "max"),
        n_samples_w_virus_reads=pd.NamedAgg("has_virus", "sum"),
    )
    .sort_values(["collection_date", "total_virus_reads"], ascending=False)
)

select_collection_date = alt.selection_point(
    fields=["collection_date"],
    bind="legend",
    toggle="true",
    value=[{"collection_date": "all"}, {"collection_date": collection_date_of_interest}],
)

min_virus_reads = alt.param(
    bind=alt.binding_range(
        min=0,
        step=1,
        max=5 * min_total_virus_reads,
        name="show viruses with at least this many reads",
    ),
    value=min_total_virus_reads,
)

virus_counts_chart = (
    alt.Chart(virus_counts)
    .add_params(select_collection_date, min_virus_reads)
    .transform_filter(select_collection_date)
    .transform_joinaggregate(
        max_reads="max(total_virus_reads)", groupby=["virus_name"],
    )
    .transform_filter(alt.datum["max_reads"] >= min_virus_reads)
    .encode(
        alt.Y(
            "virus_name",
            title=None,
            sort=alt.SortField("max_reads:Q", order="descending"),
        ),
        alt.X(
            "total_virus_reads",
            title="total virus reads across samples",
        ),
        alt.Color(
            "collection_date",
            scale=alt.Scale(domain=collection_dates),
            sort=collection_dates,
        ),
        alt.YOffset("collection_date", sort=collection_dates),
        tooltip=virus_counts.columns.tolist(),
    )
    .mark_bar()
    .properties(
        width=250,
        height=alt.Step(9),
        title=alt.TitleParams(
            "Total reads mapping to each coronavirus across all samples",
            subtitle=[
                "Click on interactive legend to select dates.",
                "Use bottom slider to filter which viruses to show.",
            ],
            dy=-5,
        ),
    )
)
display(virus_counts_chart)

print(f"Saving to {virus_counts_html=}")
virus_counts_chart.save(virus_counts_html)

Get top species str annotation for each sample, these are used for interactive plot tooltips:

In [None]:
top_species_per_sample = (
    counts
    [["sample", "species_name", "percent_mito_reads_from_species"]]
    .drop_duplicates()
    .sort_values("percent_mito_reads_from_species", ascending=False)
    .query("percent_mito_reads_from_species >= 10")
    .groupby("sample")
    .head(n=5)
    .assign(annotation=lambda x: x["species_name"] + " (" + x["percent_mito_reads_from_species"].map(lambda p: f"{p:.0f}") + "%)")
    .groupby("sample", as_index=False)
    .aggregate(top_species=pd.NamedAgg("annotation", lambda s: "; ".join(s)))    
)

pd.set_option("display.max_colwidth", 500)

## Plot number of viral reads of each type for each sample

Plot interactive plots with both linear and log scale.
Samples are sorted by amount of viral reads.

In [None]:
n_virus_reads_df = (
    counts
    [["collection_date", "sample", "virus_name", "virus_id", "virus_reads", "percent_reads_from_virus"]]
    .drop_duplicates()
    .melt(
        id_vars=["collection_date", "sample", "virus_name", "virus_id"],
        value_vars=["virus_reads", "percent_reads_from_virus"],
        value_name="amount_of_virus",
        var_name="amount of virus metric",
    )
    .merge(top_species_per_sample, validate="many_to_one")
)

# get on linear and log scale
n_virus_reads_df = n_virus_reads_df.merge(
    n_virus_reads_df
    .query("amount_of_virus > 0")
    .groupby("amount of virus metric", as_index=False)
    .aggregate(lower_limit_for_logscale=pd.NamedAgg("amount_of_virus", lambda s: s.min() / log_scale_axis_min_factor)),
    validate="many_to_one",
)

n_virus_reads_df = pd.concat(
    [
        n_virus_reads_df.assign(amount_scale="linear"),
        n_virus_reads_df.assign(
            amount_scale="log10",
            amount_of_virus=lambda x: numpy.log10(x[["amount_of_virus", "lower_limit_for_logscale"]].max(axis=1)),
        )
    ],
    ignore_index=True,
).drop(columns="lower_limit_for_logscale")

In [None]:
# total read counts per virus
virus_tot_counts = (
    virus_counts
    .query("collection_date == 'all'")
    .sort_values("total_virus_reads", ascending=False)
    [["virus_name", "total_virus_reads"]]
)

In [None]:
amount_of_virus_metric = alt.selection_point(
    fields=["amount of virus metric"],
    value="virus_reads",
    bind=alt.binding_radio(
        options=n_virus_reads_df["amount of virus metric"].unique(),
        name="metric for amount of virus",
    ),
)

scale_metric = alt.selection_point(
    fields=["amount_scale"],
    value="linear",
    bind=alt.binding_radio(
        options=n_virus_reads_df["amount_scale"].unique(),
        name="axis scale for amount of virus",
    ),
)

virus_selection = alt.selection_point(
    fields=["virus_name"],
    bind="legend",
    toggle="true",
    value=[  # seed with viruses above cutoff
        {"virus_name": v}
        for v in virus_tot_counts.query("total_virus_reads >= @min_total_virus_reads")["virus_name"]
    ]
)

date_selection = alt.selection_point(
    fields=["collection_date"],
    bind=alt.binding_select(
        name="collection date",
        options=collection_dates,
    ),
    value=collection_date_of_interest,
)

virus_scale = alt.Scale(domain=virus_tot_counts["virus_name"].tolist())

n_virus_reads_chart = (
    alt.Chart(n_virus_reads_df)
    .transform_filter(amount_of_virus_metric)
    .transform_filter(scale_metric)
    .transform_filter(virus_selection)
    .transform_filter(date_selection)
    .transform_joinaggregate(max_amount="max(amount_of_virus)", groupby=["sample"])
    .encode(
        x=alt.X("sample", sort=alt.SortField("max_amount", order="descending")),
        y=alt.Y("amount_of_virus", title="amount of virus", scale=alt.Scale(nice=False, zero=False)),
        color=alt.Color(
            "virus_name",
            scale=virus_scale,
            legend=alt.Legend(orient="bottom", columns=5, title="virus (click to select which viruses to show)", titleLimit=1000)
        ),
        shape=alt.Shape("virus_name", scale=virus_scale),
        tooltip=["sample", "virus_name", alt.Tooltip("amount_of_virus", format=".2g"), "top_species"],
    )
    .mark_point(filled=True, size=50, opacity=0.75)
    .add_params(amount_of_virus_metric, virus_selection, scale_metric, date_selection)
    .properties(
        height=150,
        width=alt.Step(12),
        title=alt.TitleParams(
            f"Amounts of genetic material from different coronaviruses in Huanan Market environmental samples",
            subtitle="Use interactive options at bottom to select collection date, which viruses to show, and how the amount of virus is quantified. Mouse over points for details.",
            dy=-5,
        ),
    )
)

display(n_virus_reads_chart)

print(f"Saving to {viral_reads_per_sample_html=}")
n_virus_reads_chart.save(viral_reads_per_sample_html)

## Correlation of viral reads and host mitochondrial reads

Only keep viruses with a minimal amount of reads across all dates:

In [None]:
print(f"Only keep viruses with at least {min_total_virus_reads=} across all samples")
viruses_to_keep = (
    virus_counts
    .query("collection_date == 'all'")
    .query("total_virus_reads >= @min_total_virus_reads")
    ["virus_name"]
    .tolist()
)

We will make two plots, one with all species and one with just the subset of species that are most relevant on the colleciton date of interest.
Get the species to keep for the subset plot:

In [None]:
subset_desc = (
    f"species with at least {species_for_subset_chart['min_species_percent']}% of "
    f"mitochondrial reads in at least one sample with at least {species_for_subset_chart['min_virus_reads']} "
    f"viral read on {collection_date_of_interest}."
)
print(subset_desc)

subset_species = sorted(
    counts
    .query("collection_date == @collection_date_of_interest")
    .query("virus_name in @viruses_to_keep")
    .query("virus_reads >= @species_for_subset_chart['min_virus_reads']")
    .query("percent_mito_reads_from_species >= @species_for_subset_chart['min_species_percent']")
    ["species_name"]
    .unique()
    .tolist()
)

print(f"For subset chart, keeping {len(subset_species)=}\n{subset_species=}")

Now make the correlation plot:

In [None]:
corr_df = (
    counts
    [["collection_date", "sample", "virus_name", "species_name", "virus_reads", "species_reads"]]
    .query("virus_name in @viruses_to_keep")
)

corr_df = (
    pd.concat(
        [
            corr_df.assign(amount_scale="linear"),
            corr_df.assign(
                amount_scale="log10",
                virus_reads=lambda x: numpy.log10(x["virus_reads"].clip(lower=1 / log_scale_axis_min_factor)),
                species_reads=lambda x: numpy.log10(x["species_reads"].clip(lower=1 / log_scale_axis_min_factor)),                                  
            ),
        ],
    )
    .merge(top_species_per_sample, validate="many_to_one")
)

Only keep viruses with sufficient reads on all dates:

In [None]:
axis_metric = alt.selection_point(
    fields=["amount_scale"],
    value="linear",
    bind=alt.binding_radio(
        options=corr_df["amount_scale"].unique(),
        name="axis scales for number of reads",
    ),
)

for (df, desc, outfile) in [
    (
        corr_df.query("species_name in @subset_species").query("collection_date == @collection_date_of_interest"),
        subset_desc,
        viral_subset_species_corr_html,
    ),
    (corr_df, "all species.", viral_all_species_corr_html),
]:

    corr_base = (
        alt.Chart(df)
        .add_params(axis_metric)
        .transform_filter(axis_metric)
    )

    if df["collection_date"].nunique() > 1:
        corr_base = corr_base.add_params(date_selection).transform_filter(date_selection)
    
    corr_scatter = (
        corr_base
        .encode(
            alt.Shape("virus_name", scale=virus_scale, legend=None),
            alt.Color(
                "virus_name",
                scale=virus_scale,
                legend=None,
            ),
            tooltip=["sample", "virus_name", alt.Tooltip("virus_reads", format=".2g"), alt.Tooltip("species_reads", format=".2g"), "top_species"],
        )
        .mark_point(filled=True, size=50, opacity=0.65)
    )   
    
    params_r = (
        corr_base
        .transform_regression("species_reads", "virus_reads", params=True)
        .transform_calculate(
            r=alt.expr.if_(
                alt.datum["coef"][1] >= 0,
                alt.expr.sqrt(alt.datum["rSquared"]),
                -alt.expr.sqrt(alt.datum["rSquared"]),
            ),
            label='"r = " + format(datum.r, ".2f")',
        )
        .mark_text(align="left", color="black", fontWeight=500, fontSize=11, opacity=0.7)
        .encode(
            x=alt.value(3),
            y=alt.value(8),
            text=alt.Text("label:N"),
        )
    )
    
    corr_charts = []
    for i, species in enumerate(sorted(df["species_name"].unique())): 
        corr_charts.append(
            (
                corr_scatter
                    .encode(
                        alt.X("species_reads", axis=alt.Axis(title=None, labelOverlap=True, tickCount=2, format=".1g"), scale=alt.Scale(nice=False, padding=2)),
                        alt.Y(
                            "virus_reads",
                            axis=alt.Axis(title=None) if i == 0 else None,
                            scale=alt.Scale(nice=False, padding=2)
                        ),
                    )
                + params_r
            )
            .transform_filter(alt.datum["species_name"] == species)
            .properties(width=79, height=88)
            .facet(
                row=alt.Row(
                    "virus_name",
                    header=(
                        alt.Header(labelPadding=3, title=None)
                        if i == 0
                        else alt.Header(title=None, labels=False)
                    ),
                ),
                column=alt.Column(
                    "species_name",
                    header=alt.Header(
                        title=None,
                        titleOrient="bottom",
                        titlePadding=5,
                        labelPadding=3,
                        labelOrient="bottom",
                    ),
                ),
                spacing=1,
            )
        )
    
    corr_chart = (
        alt.hconcat(*corr_charts, spacing=1)
        .configure_axis(grid=False)
        .properties(
            title=alt.TitleParams(
                f"Reads mapping to different coronaviruses versus mitochondria of different species for all Huanan Market environmental samples",
                subtitle=[
                    f"Showing viruses with at least {min_total_virus_reads} across all samples and {desc}",
                    "Numbers in upper-left of each panel show Pearson correlation r.",
                    "Use interactive options at bottom of plot to show read counts on log10 or linear scale, etc; mouseover points for details.",
                ],
                dy=-10,
                anchor="middle",
            ),
        )
    )
    
    display(corr_chart)
    
    print(f"Saving to {outfile}")
    corr_chart.save(outfile)