# Make plots

Import Python modules:

In [None]:
import itertools
import os

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

Get data and variable from `snakemake`:

In [None]:
sars2_aligned_by_run = pd.read_csv(snakemake.input.sars2_aligned_by_run)

sars2_aligned_by_metagenomic_sample = pd.read_csv(snakemake.input.sars2_aligned_by_metagenomic_sample)

mito_composition_by_metagenomic_run = pd.read_csv(snakemake.input.mito_composition_by_metagenomic_run)

mito_composition_by_metagenomic_sample = pd.read_csv(snakemake.input.mito_composition_by_metagenomic_sample)

crits_christoph_read_counts = pd.read_csv(snakemake.input.crits_christoph_read_counts)

ngdc_to_crits_christoph = pd.read_csv(snakemake.input.ngdc_to_crits_christoph)

plotsdir = snakemake.output.plotsdir

## Compare mitochondrial DNA composition to Crits-Christoph et al
First get Crits-Christoph read counts in tidy format, assigning NGDC run accessions and summing counts for runs:

In [None]:
crits_christoph_read_counts_tidy = (
    crits_christoph_read_counts
    .drop(columns=["Location", "Sample_name"])
    .merge(
        ngdc_to_crits_christoph
        .assign(Filename=lambda x: x["fastq Crits-Christoph"].map(lambda s: s.split(".")[0]))
        .rename(columns={"Run accession NGDC": "Run accession"})
        [["Filename", "Run accession"]]
        .drop_duplicates(),
        validate="one_to_one",
    )
    .merge(
        sars2_aligned_by_run[["Run accession", "sample"]].drop_duplicates(),
        validate="many_to_one",
        how="left",
    )
    .melt(
        id_vars=["sample", "Run accession", "Filename"],
        var_name="species",
        value_name="aligned reads Crits-Christoph",
    )
    .groupby(["sample", "Run accession", "species"], as_index=False)
    .aggregate({"aligned reads Crits-Christoph": "sum"})
)

crits_christoph_read_counts_tidy

In [None]:
 mito_composition_by_metagenomic_run.columns

Now merge current read counts with those from Crits-Christoph for all species:

In [None]:
crits_christoph_species = crits_christoph_read_counts_tidy["species"].unique().tolist()

mito_composition_to_compare = (
    mito_composition_by_metagenomic_run
    [["Run accession", "species", "common_name", "aligned_reads", "covered_bases"]]
    .rename(
        columns={
            "aligned_reads": "aligned reads current study",
            "covered_bases": "covered bases current study",
        }
    )
)

assert set(crits_christoph_species).issubset(mito_composition_to_compare["species"])

crits_christoph_vs_current = (
    crits_christoph_read_counts_tidy
    .merge(
        mito_composition_to_compare,
        validate="one_to_one",
        on=["Run accession", "species"],
    )
)

Now get the correlations by species and plot them:

In [None]:
crits_christoph_vs_current_species_corr = (
    crits_christoph_vs_current
    .groupby(["species", "common_name"])
    [["aligned reads Crits-Christoph", "aligned reads current study"]]
    .corr(method="pearson")
    .reset_index()
    .query("level_2 == 'aligned reads Crits-Christoph'")
    .rename(columns={"aligned reads current study": "correlation"})
    .drop(columns=["level_2", "aligned reads Crits-Christoph"])
    .query("correlation.notnull()")
    .merge(
        crits_christoph_vs_current
        .groupby("species", as_index=False)
        .aggregate(
            aligned_reads_Crits_Christoph=pd.NamedAgg("aligned reads Crits-Christoph", "sum"),
            aligned_reads_current_study=pd.NamedAgg("aligned reads current study", "sum"),
        )
    )
    .sort_values("correlation")
)

crits_christoph_vs_current_species_corr

In [None]:
crits_christoph_vs_current_species_corr_chart = (
    alt.Chart(crits_christoph_vs_current_species_corr)
    .encode(
        x="correlation",
        y=alt.Y(
            "common_name",
            sort=alt.SortField("correlation", order="descending"),
            title=None,
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g")
            if crits_christoph_vs_current_species_corr[c].dtype == float
            else c
            for c in crits_christoph_vs_current_species_corr.columns.tolist()
        ],
    )
    .mark_circle(size=50, opacity=1)
    .properties(
        height=alt.Step(13),
        width=185,
    )
) 
    
crits_christoph_vs_current_species_corr_chart

Now get the correlations by run:

In [None]:
crits_christoph_vs_current_run_corr = (
    crits_christoph_vs_current
    .groupby(["Run accession", "sample"])
    [["aligned reads Crits-Christoph", "aligned reads current study"]]
    .corr(method="pearson")
    .reset_index()
    .query("level_2 == 'aligned reads Crits-Christoph'")
    .rename(columns={"aligned reads current study": "correlation"})
    .drop(columns=["level_2", "aligned reads Crits-Christoph"])
    .query("correlation.notnull()")
    .merge(
        crits_christoph_vs_current
        .groupby("Run accession", as_index=False)
        .aggregate(
            aligned_reads_Crits_Christoph=pd.NamedAgg("aligned reads Crits-Christoph", "sum"),
            aligned_reads_current_study=pd.NamedAgg("aligned reads current study", "sum"),
        )
    )
)

crits_christoph_vs_current_run_corr.sort_values("correlation")

Plot the correlations by run:

In [None]:
aligned_reads_current_slider = alt.param(
    bind=alt.binding_range(
        name="minimum aligned reads in current study",
        min=0,
        max=1000,
    ),
    value=50,
)

aligned_reads_crits_christoph_slider = alt.param(
    bind=alt.binding_range(
        name="minimum aligned reads in Crits-Christoph et al",
        min=0,
        max=1000,
    ),
    value=50,
)

crits_christoph_vs_current_run_corr_chart = (
    alt.Chart(crits_christoph_vs_current_run_corr)
    .encode(
        x=alt.X(
            "Run accession",
            sort=alt.SortField("correlation", order="descending"),
        ),
        y=alt.Y("correlation", title="Pearson correlation"),
        tooltip=[
            alt.Tooltip(c, format=".4g")
            if crits_christoph_vs_current_run_corr[c].dtype == float
            else c
            for c in crits_christoph_vs_current_run_corr.columns.tolist()
        ],
    )
    .mark_circle(size=50, opacity=1)
    .properties(
        width=alt.Step(11),
        height=150,
        title=(
            "Correlation in read counts aligned to mitochondrial genomes of different "
            "species in Crits-Christoph et al and current study"
        )
    )
    .add_params(aligned_reads_current_slider, aligned_reads_crits_christoph_slider)
    .transform_filter(
        alt.datum["aligned_reads_current_study"] >= aligned_reads_current_slider
    )
    .transform_filter(
        alt.datum["aligned_reads_Crits_Christoph"] >= aligned_reads_crits_christoph_slider
    )
)

crits_christoph_vs_current_run_corr_chart

## Plot mitochondrial composition for samples
First, for each sample we get just the species that are sufficiently abundant in either the full species set or the Crits-Christoph species set, and call other species "other":

In [None]:
assert set(crits_christoph_species).issubset(mito_composition_by_metagenomic_sample["species"])

group_as_other_cutoff = 5  # group as "other" if < this %
max_species_per_sample = 10  # group as other lower percents so we never get more than this many

# require at least this many reads aligned to mitochondrial genomes
min_mito_reads = 100

composition_df = (
    mito_composition_by_metagenomic_sample
    .assign(
        cc_species=lambda x: x["species"].isin(crits_christoph_species),
        total_mito=lambda x: x.groupby("sample")["aligned_reads"].transform("sum"),
        aligned_reads_cc=lambda x: x["aligned_reads"].where(x["cc_species"], 0),
        total_mito_cc=lambda x: x.groupby("sample")["aligned_reads_cc"].transform("sum"),
        percent=lambda x: 100 * x["aligned_reads"] / x["total_mito"],
        percent_cc=lambda x: 100 * x["aligned_reads_cc"] / x["total_mito_cc"],
        rank=lambda x: x.groupby("sample")["aligned_reads"].transform("rank", method="first", ascending=False),
        other=lambda x: (
            (x[["percent", "percent_cc"]].max(axis=1) < group_as_other_cutoff)
            | (x["rank"] >= max_species_per_sample)
        ),
        species=lambda x: x["species"].where(~x["other"], "other"),
        common_name=lambda x: x["common_name"].where(~x["other"], "other"),
    )
    .query("total_mito > @min_mito_reads")
    .groupby(["sample", "species", "common_name"], as_index=False)
    [["aligned_reads", "aligned_reads_cc", "percent", "percent_cc"]]
    .aggregate("sum")
)

composition_df = (
    composition_df
    .rename(columns={"aligned_reads": "all", "aligned_reads_cc": "crits_christoph"})
    .melt(
        id_vars=["sample", "species", "common_name"],
        value_vars=["all", "crits_christoph"],
        value_name="aligned_reads",
        var_name="species_set",
    )
    .merge(
        composition_df
        .rename(columns={"percent": "all", "percent_cc": "crits_christoph"})
        .melt(
            id_vars=["sample", "species", "common_name"],
            value_vars=["all", "crits_christoph"],
            value_name="percent",
            var_name="species_set",
        )
    )
    .assign(
        species_set=lambda x: x["species_set"].map(
            {
                "all": "chordates (current study)",
                "crits_christoph": "mammals (Crits-Christoph et al)",
            }
        ),
    )
)

composition_df

Now make pie charts:

In [None]:
sample_selection = alt.selection_point(
    fields=["sample"],
    bind=alt.binding_select(
        options=composition_df["sample"].unique(),
        name="sample",
    ),
    value="Q61",
)

composition_chart = (
    alt.Chart(composition_df)
    .encode(
        theta="percent",
        color=alt.Color(
            "common_name",
            title="species (common name)",
            legend=alt.Legend(
                orient="right", columns=1, symbolSize=125, symbolType="square", labelLimit=0,
            ),
        ),
        column=alt.Column(
            "species_set",
            title=None,
            sort="descending",
            header=alt.Header(labelFontSize=11, labelFontWeight="bold"),
        ),
        tooltip=[
            alt.Tooltip(c, format=".3g") if composition_df[c].dtype == float else c
            for c in composition_df.columns
        ],
    )
    .mark_arc()
    .add_params(sample_selection)
    .transform_filter(sample_selection)
    .properties(height=175, width=175)
)

composition_chart