# Correlate results from different mutation-annotated trees

Get variables from `snakemake`:

In [None]:
aa_fitness_csvs = snakemake.input.aa_fitnesses
mats = snakemake.params.mats
min_expected_count = snakemake.params.min_expected_count
fitness_corr_chart_html = snakemake.output.fitness_corrs_chart

Import Python modules:

In [None]:
import altair as alt

import itertools

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

Read data:

In [None]:
assert len(mats) == len(aa_fitness_csvs)

aa_fitness = (
    pd.concat(
        [
            pd.read_csv(aa_fitness_csv).assign(dataset=mat)
            for aa_fitness_csv, mat in zip(aa_fitness_csvs, mats)
        ],
        ignore_index=True,
    )
    .query("gene != 'ORF1ab'")
    .drop(columns=["aa_differs_among_clade_founders", "subset_of_ORF1ab"])
)

aa_fitness

Make the chart:

In [None]:
expected_count_selection = alt.selection_single(
    bind=alt.binding_range(
        min=1,
        max=min(5 * min_expected_count, aa_fitness["expected_count"].quantile(0.9)),
        step=1,
        name="minimum expected count",
    ),
    fields=["cutoff"],
    init={"cutoff": min_expected_count},
)

chart_base = (
    alt.Chart(aa_fitness)
    .transform_filter(alt.datum["expected_count"] >= expected_count_selection["cutoff"])
    .transform_pivot(
        groupby=["gene", "aa_site", "aa"],
        pivot="dataset",
        value="fitness",
        op="mean",
    )
)

corr_charts = []
for dataset1, dataset2 in itertools.combinations(aa_fitness["dataset"].unique(), 2):
    corr_chart = (
        chart_base
        .encode(
            x=alt.X(dataset1, type="quantitative"),
            y=alt.Y(dataset2, type="quantitative"),
            tooltip=[
                "gene",
                "aa_site",
                "aa",
                alt.Tooltip(dataset1, type="quantitative", format=".3g"),
                alt.Tooltip(dataset2, type="quantitative", format=".3g"),
            ],
        )
        .mark_circle(opacity=0.15)
        .properties(width=180, height=180)
    )
    
    # regression line and correlation coefficient: https://stackoverflow.com/a/60239699
    line = (
        corr_chart
        .transform_regression(dataset1, dataset2)
        .mark_line(color="orange", clip=True)
    )
    params_r = (
        corr_chart
        .transform_regression(dataset1, dataset2, params=True)
        .transform_calculate(
            r=alt.expr.sqrt(alt.datum["rSquared"]),
            label='"r = " + format(datum.r, ".3f")',
        )
        .mark_text(align="left", color="orange", fontWeight="bold")
        .encode(
            x=alt.value(5),
            y=alt.value(8),
            text=alt.Text("label:N"),
        )
    )
    
    corr_charts.append(corr_chart + line + params_r)

ncol = 3

chart = (
    alt.vconcat(
        *[
            alt.hconcat(*corr_charts[i : i + ncol])
            for i in range(0, len(corr_charts), ncol)
        ]
    )
    .add_selection(expected_count_selection)
    .configure_axis(grid=False)
)

chart.save(fitness_corr_chart_html)

chart