# Compare the different MLR fits

In [None]:
# Get variables from `snakemake`

growth_advantage_csvs = snakemake.input.growth_advantages

protsets_mlrfits = snakemake.params.protsets_mlrfits

scatter_chart_html = snakemake.output.scatter_chart
corr_chart_html = snakemake.output.corr_chart

In [None]:
import itertools
import math

import altair as alt

import pandas as pd

Read the growth advantages:

In [None]:
assert len(growth_advantage_csvs) == len(protsets_mlrfits)

ga_df = []
for csvfile, (protset, mlrfit) in zip(growth_advantage_csvs, protsets_mlrfits):
    assert protset in csvfile
    assert mlrfit in csvfile
    ga_df.append(pd.read_csv(csvfile).assign(protset=protset, mlrfit=mlrfit))

ga_df = pd.concat(ga_df, ignore_index=True)

ga_df

Pairwise correlations in growth advantages:

In [None]:
scatter_charts = []
corr_records = []

for (protset1, mlrfit1), (protset2, mlrfit2) in itertools.combinations(
    protsets_mlrfits, 2
):
    scatter_df = (
        pd.concat(
            [
                ga_df.query("(protset == @protset1) & (mlrfit == @mlrfit1)").assign(fit=1),
                ga_df.query("(protset == @protset2) & (mlrfit == @mlrfit2)").assign(fit=2),
            ],
            ignore_index=True,
        )
        .melt(
            id_vars=["strain", "fit"],
            value_vars=[
                "growth_advantage_median",
                "growth_advantage_hpd_min",
                "growth_advantage_hpd_max",
            ],
        )
        .assign(variable=lambda x: x["variable"] + x["fit"].astype(str))
        .pivot_table(
            index="strain",
            columns="variable",
            values="value",
        )
        .dropna(axis=0)
        .reset_index()
    )

    r = (
        scatter_df[["growth_advantage_median1", "growth_advantage_median2"]]
        .corr(method="pearson").values.ravel()[1]
    )
    n = len(scatter_df)

    corr_records.append((protset1, mlrfit1, protset2, mlrfit2, r, n))

    scatter_base = alt.Chart(scatter_df).encode(tooltip=["strain"]).properties(
        width=170, height=170
    )
    scatter_points = scatter_base.mark_circle(color="black", size=50).encode(
        alt.X(
            "growth_advantage_median1",
            scale=alt.Scale(zero=False, nice=False, padding=5),
            title=[protset1, mlrfit1],
        ),
        alt.Y(
            "growth_advantage_median2",
            scale=alt.Scale(zero=False, nice=False, padding=5),
            title=[protset2, mlrfit2],
        ),
    )
    scatter_xerror = scatter_base.mark_errorbar(color="black").encode(
        alt.X("growth_advantage_hpd_min1", title=[protset1, mlrfit1]),
        alt.X2("growth_advantage_hpd_max1"),
        alt.Y("growth_advantage_median2", title=[protset2, mlrfit2]),
    )
    scatter_yerror = scatter_base.mark_errorbar(color="black").encode(
        alt.Y("growth_advantage_hpd_min2", title=[protset2, mlrfit2]),
        alt.Y2("growth_advantage_hpd_max2"),
        alt.X("growth_advantage_median1", title=[protset1, mlrfit1]),
    )
    scatter_r = alt.Chart().mark_text(text=f"R = {r:.2f} (N = {n})", align="left", color="blue").encode(
        x=alt.value(7), y=alt.value(10)
    )
    scatter_chart = scatter_points + scatter_xerror + scatter_yerror + scatter_r
    
    scatter_charts.append(scatter_chart)

ncols = 6
nrows = int(math.ceil(len(scatter_charts) / ncols))
scatter_chart = alt.vconcat(
    *[
        alt.hconcat(*scatter_charts[irow * ncols: (irow + 1) * ncols])
        for irow in range(nrows)
    ],
).configure_axis(grid=False).properties(
    title=alt.TitleParams(
        "correlation scatter plots for growth advantage estimates",
        anchor="middle",
    ),
)

print(f"Saving to {scatter_chart_html}")
scatter_chart.save(scatter_chart_html)

scatter_chart

Make a correlation heatmap:

In [None]:
corr_df = pd.DataFrame(
    corr_records,
    columns=["protset_1", "mlrfit_1", "protset_2", "mlrfit_2", "R", "N"],
).assign(
    name_1=lambda x: x["protset_1"] + " " + x["mlrfit_1"],
    name_2=lambda x: x["protset_2"] + " " + x["mlrfit_2"],
)

protset_selection = alt.param(
    bind=alt.binding_select(
        options=["all"] + ga_df["protset"].unique().tolist(),
        name="sequence set (protset)",
    ),
    value="all",
)

mlrfit_selection = alt.param(
    bind=alt.binding_select(
        options=["all"] + ga_df["mlrfit"].unique().tolist(),
        name="MLR model (mlrfit)",
    ),
    value="all",
)

corr_chart = (
    alt.Chart(corr_df)
    .add_params(protset_selection, mlrfit_selection)
    .transform_filter(
        (
            (alt.datum["protset_1"] == protset_selection)
            & (alt.datum["protset_2"] == protset_selection)
        )
        | (protset_selection == "all")
    )
    .transform_filter(
        (
            (alt.datum["mlrfit_1"] == mlrfit_selection)
            & (alt.datum["mlrfit_2"] == mlrfit_selection)
        )
        | (mlrfit_selection == "all")
    )
    .encode(
        alt.X("name_1", sort=corr_df["name_1"], title=None, axis=alt.Axis(labelLimit=500)),
        alt.Y("name_2", sort=corr_df["name_2"], title=None, axis=alt.Axis(labelLimit=500)),
        alt.Fill("R", scale=alt.Scale(scheme="redblue", domain=[-1, 1])),
        tooltip=[
            alt.Tooltip("R", format=".2f", title="correlation (R)"),
            alt.Tooltip("N", title="n strains (N)"),
            "protset_1",
            "mlrfit_1",
            "protset_2",
            "mlrfit_2",
        ],
    )
    .mark_rect(stroke="black")
    .properties(
        title="correlations among growth advantage estimates",
    )
)

print(f"Saving to {corr_chart_html}")
corr_chart.save(corr_chart_html)

corr_chart