In [1]:
import pickle

import pandas as pd


with open("_model_bootstrap.pickle", "rb") as f:
    model = pickle.load(f)
    
df = pd.concat(
    [
        model.mut_escape_df,
        model.mut_escape_df[["epitope", "site", "wildtype"]].drop_duplicates().assign(
            mutant=lambda x: x["wildtype"],
            escape_mean=0,
            escape_median=0,
            escape_std=0,
            n_models=model.mut_escape_df["n_models"].max(),
            times_seen=model.mut_escape_df["times_seen"].max(),
            frac_models=model.mut_escape_df["frac_models"].max(),
        ),
    ],
)

In [158]:
import altair as alt

import natsort


def lineplot_and_heatmap(
    *,
    data_df,
    stat_col,
    category_col,
    alphabet=None,
    sites=None,
    addtl_tooltip_stats=None,
    addtl_slider_stats=None,
    init_floor_at_zero=True,
    init_site_escape_statistic="sum",
    cell_size=12,
    lineplot_width=5,
    lineplot_height=100,
    site_zoom_bar_width=500,
    site_zoom_bar_color_col=None,
    plot_title=None,
    show_single_category_label=False,
):
    """Lineplots and heatmaps of per-site and per-mutation values.
    
    Parameters
    ----------
    data_df : pandas.DataFrame
        Data to plot. Must have columns "site", "wildtype", "mutant", `stat_col`, and
        `category_col`. The wildtype values (wildtype = mutant) should be included,
        but are not used for the slider filtering or included in site summary lineplot.
    stat_col : str
        Column in `data_df` with statistic to plot.
    category_col : str
        Column in `data_df` with category to facet plots over.
    alphabet : array-like or None
        Alphabet letters in order. If `None`, use natsorted "mutant" col of `data_df`.
    sites : array-like or None
        Sites in order. If `None`, use natsorted "site" col of `data_df`.
    addtl_tooltip_stats : None or array-like
        Additional mutation-level stats to show in the heatmap tooltips. Values in
        `addtl_slider_stats` automatically included.
    addtl_slider_stats : None or dict
        Additional stats for which to have a slider, value is initial setting. Ignores
        wildtype and drops it when all mutants have been dropped at site.
    init_floor_at_zero : bool
        Initial value for option to put floor of zero on value is `stat_col`.
    init_site_escape_statistic : {'sum', 'mean', 'max', 'min'}
        Initial value for site escape statistic in lineplot.
    cell_size : float
        Size of cells in heatmap
    lineplot_width : float
        Width per site in lineplot.
    lineplot_height : float
        Height of line plot.
    site_zoom_bar_width : float
        Width of site zoom bar.
    site_zoom_bar_color_col : float
        Column in `data_df` with which to color zoom bar. Must be the same for all
        entries for a site.
    plot_title : str or None
        Overall plot title.
    show_single_category_label : bool
        Show the category label if just one category.
    """
    req_cols = ["site", "wildtype", "mutant", stat_col, category_col]
    if addtl_tooltip_stats is None:
        addtl_tooltip_stats = []
    req_cols += [c for c in addtl_tooltip_stats if c not in req_cols]
    if addtl_slider_stats is None:
        addtl_slider_stats = []
    req_cols += [c for c in addtl_slider_stats if c not in req_cols]
    if site_zoom_bar_color_col:
        req_cols.append(site_zoom_bar_color_col)
    req_cols = list(dict.fromkeys(req_cols))  # https://stackoverflow.com/a/17016257
    if not set(req_cols).issubset(data_df.columns):
        raise ValueError(f"Missing required columns\n{data_df.columns=}\n{req_cols=}")
    if any(c.startswith("_stat") for c in req_cols):  # used for calculated stats
        raise ValueError(f"No columns can start with '_stat' in {data_df.columns=}")
    data_df = data_df[req_cols].reset_index(drop=True)
    
    if data_df.isnull().any().any():
        raise ValueError(f"`data_df` cannot have NA values:\n{data_df.isnull().any()}")
    
    if alphabet is None:
        alphabet = natsort.natsorted(data_df["mutant"].unique())
    else:
        data_df = data_df.query("mutant in @alphabet")
        
    if sites is None:
        sites = natsort.natsorted(data_df["site"].unique(), alg=natsort.ns.SIGNED)
    else:
        data_df = data_df.query("site in @sites")
        if not set(sites).issubset(data_df["site"]):
            raise ValueError("`sites` has sites not in `data_df`")
        
    # Cannot sort by sites completely: # https://github.com/altair-viz/altair/issues/2663
    # But can sort up to ~1000 elements, which is enough that regular sorting will do the
    # rest (assuming <10,000 sites). So make `sort_sites` list that sorts the first 1000
    # elements and then check that is enough.
    n_sort_sites = 1002  # this many does not raise error
    sort_sites = sites[: n_sort_sites]
    if list(sites) != [*sort_sites, *sorted(sites[n_sort_sites:])]:
        raise ValueError(f"Cannot sort {len(sites)=} sites")
        
    # get tooltips for heatmap
    heatmap_tooltips = [
        alt.Tooltip(c, type="quantitative", format=".3g")
        if data_df[c].dtype == float else alt.Tooltip(c, type="nominal")
        for c in req_cols
    ]
            
    # make floor at zero selection, setting floor to either 0 or min in data (no floor)
    min_stat = data_df[stat_col].min()
    max_stat = data_df[stat_col].max()
    floor_at_zero = alt.selection_point(
        name="floor_at_zero",
        bind=alt.binding_radio(
            options=[0, min_stat],
            labels=["yes", "no"],
            name=f"floor {stat_col} at zero",
        ),
        fields=["floor"],
        value=[{"floor": 0 if init_floor_at_zero else min_stat}],
    )
    
    # create sliders for max stat at site and any additional sliders
    sliders = {
        "_stat_site_max": alt.selection_point(
            fields=["cutoff"],
            value=[{"cutoff": df[stat_col].min()}],
            bind=alt.binding_range(
                name=f"minimum max of {stat_col} at site",
                min=data_df[stat_col].min(),
                max=data_df[stat_col].max(),
            ),
        )
    }
    if addtl_slider_stats:
        for slider_stat, init_slider_stat in addtl_slider_stats.items():
            sliders[slider_stat] = alt.selection_point(
                fields=["cutoff"],
                value=[{"cutoff": init_slider_stat}],
                bind=alt.binding_range(
                    min=data_df[slider_stat].min(),
                    max=data_df[slider_stat].max(),
                    name=f"minimum {slider_stat}",
                ),
            )
            
    # whether to show line on line plot
    line_selection = alt.selection_point(
        bind=alt.binding_radio(
            options=[True, False], labels=["yes", "no"], name="show line on site plot",
        ),
        fields=["_stat_show_line"],
        value=[{"_stat_show_line": True}],
    )
            
    # create site zoom bar
    site_brush = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke="black", strokeWidth=2),
    )
    if site_zoom_bar_color_col:
        site_zoom_bar_df = data_df[["site", site_zoom_bar_color_col]].drop_duplicates()
        if any(site_zoom_bar_df.groupby("site").size() > 1):
            raise ValueError(f"multiple {site_zoom_bar_color_col=} values for sites")
    else:
        site_zoom_bar_df = data_df[["site"]].drop_duplicates()
    site_zoom_bar = (
        alt.Chart(site_zoom_bar_df)
        .mark_rect()
        .encode(
            x=alt.X("site:O", sort=sort_sites),
            color=(
                alt.Color(
                    site_zoom_bar_color_col,
                    legend=alt.Legend(orient="left", title=site_zoom_bar_color_col),
                    # order colors according which ones in first sites
                    sort=(
                        site_zoom_bar_df
                        .set_index("site")
                        .loc[sort_sites]
                        [site_zoom_bar_color_col]
                        .unique()
                    ),
                )
                if site_zoom_bar_color_col
                else alt.value("gray")
            ),
            tooltip=site_zoom_bar_df.columns.tolist(),
        )
        .mark_rect()
        .add_parameter(site_brush)
        .properties(width=site_zoom_bar_width, height=cell_size, title="site zoom bar")
    )
       
    # to make data in Chart smaller, access properties that are same across all sites
    # or categories via a transform_lookup. Make data frames with columns to do that.
    lookup_dfs = {}
    for lookup_col in ["site", category_col]:
        cols_to_lookup = [
            c for c in data_df.columns
            if all(data_df.groupby(lookup_col)[c].nunique(dropna=False) == 1)
            if c not in ["site", category_col]
        ]
        if cols_to_lookup:
            lookup_dfs[lookup_col] = data_df[[lookup_col, *cols_to_lookup]].drop_duplicates()
            assert len(lookup_dfs[lookup_col]) == data_df[lookup_col].nunique(), f"{lookup_col=}\n{lookup_dfs[lookup_col]=}\n{len(lookup_dfs[lookup_col])=}\n{data_df[lookup_col].nunique()=}"
            data_df = data_df.drop(columns=cols_to_lookup)
            
    # make the base chart that holds the data and common elements
    base_chart = alt.Chart(data_df).encode(x=alt.X("site:O", sort=sort_sites))
    for lookup_col, lookup_df in lookup_dfs.items():
        base_chart = base_chart.transform_lookup(
            lookup=lookup_col,
            from_=alt.LookupData(
                data=lookup_df,
                key=lookup_col,
                fields=[c for c in lookup_df.columns if c != lookup_col],
            ),
        )
        
    # Transforms on base chart. The "_stat" columns is floor transformed stat_col.
    base_chart = (
        base_chart
        .transform_calculate(
            _stat=alt.expr.max(alt.datum[stat_col], floor_at_zero["floor"]),
        )
        .transform_joinaggregate(
            _stat_site_max="max(_stat)",
            groupby=["site"],
        )
    )
    # Filter data using slider stat
    for slider_stat, slider in sliders.items():
        base_chart = base_chart.transform_filter(alt.datum[slider_stat] >= slider["cutoff"])
    # Remove any sites that are only wildtype and filter with site zoom brush
    base_chart = (
        base_chart
        .transform_calculate(_stat_not_wildtype=alt.datum.wildtype != alt.datum.mutant)
        .transform_joinaggregate(
            _stat_site_has_non_wildtype="max(_stat_not_wildtype)",
            groupby=["site"],
        )
        .transform_filter(alt.datum["_stat_site_has_non_wildtype"])
        .transform_filter(site_brush)
    )
    
    # make the site chart
    site_summaries = ["sum", "mean", "max", "min"]
    site_summary_selection = alt.selection_point(
        bind=alt.binding_radio(options=site_summaries, name="site escape statistic"),
        fields=["site escape statistic"],
        value=[{"site escape statistic": init_site_escape_statistic}],
    )
    site_prop_cols = lookup_dfs["site"].columns if "site" in lookup_dfs else ["site"]
    lineplot_base = (
        base_chart
        .transform_filter(alt.datum.wildtype != alt.datum.mutant)
        .transform_aggregate(
            **{summary: f"{summary}(_stat)" for summary in site_summaries},
            groupby=[*site_prop_cols, category_col],
        )
        .transform_fold(site_summaries, ["site escape statistic", "site escape"])
        .transform_filter(site_summary_selection)
        .encode(
            y=alt.Y("site escape:Q", scale=alt.Scale(zero=True)),
            color=alt.Color(
                category_col,
                scale=alt.Scale(domain=data_df[category_col].unique()),
                legend=alt.Legend(orient="left", title=category_col),
            ),
            tooltip=[
                "site",
                category_col,
                alt.Tooltip("site escape:Q", format=".3g"),
                *[f"{c}:N" for c in site_prop_cols if c != "site"],
            ],
        )
    )
    site_lineplot = (
        (
            (
                lineplot_base
                .mark_line(size=1)
                .transform_calculate(_stat_show_line="true")
                .transform_filter(line_selection)
            ) + lineplot_base.mark_circle(opacity=1)
        )
        .add_parameter(site_summary_selection, line_selection)
        .properties(width=alt.Step(lineplot_width), height=lineplot_height)
    )
    
    # make base chart for heatmaps
    heatmap_base = (
        base_chart
        .encode(
            y=alt.Y(
                "mutant",
                sort=alphabet,
                scale=alt.Scale(domain=alphabet),
                title=None,
            ),
        )
    )
    
    # wildtype text marks for heatmap
    heatmap_wildtype = (
        heatmap_base
        .transform_filter(alt.datum.wildtype == alt.datum.mutant)
        .mark_text(text="x", color="black")
    )
    
    # background fill for missing values in heatmap, imputing dummy stat
    # to get all cells
    heatmap_bg = (
        heatmap_base
        .transform_impute(
            impute="_stat_dummy",
            key="mutant",
            keyvals=alphabet,
            groupby=["site"],
            value=None,
        )
        .mark_rect(color="gray", opacity=0.25)
    )
    
    # Make heatmaps for each category and vertically concatenate. We do this in loop
    # rather than faceting to enable compound chart w wildtype marks and category
    # specific coloring.
    heatmaps = alt.vconcat(
        *[
            heatmap_bg
            + heatmap_base
            .transform_filter(alt.datum[category_col] == category)
            .encode(
                color=alt.Color(
                    "_stat:Q",
                    legend=alt.Legend(
                        orient="left",
                        title=stat_col,
                        titleOrient="left",
                        gradientLength=100,
                        gradientStrokeColor="black",                            
                        gradientStrokeWidth=0.5,
                    ),
                    scale=alt.Scale(domainMax=max_stat, zero=True, nice=False),
                ),
                stroke=alt.value("black"),
                tooltip=heatmap_tooltips,
            )
            .mark_rect()
            .properties(
                width=alt.Step(cell_size),
                height=alt.Step(cell_size),
                title=alt.TitleParams(
                    (
                        category
                        if show_single_category_label or data_df[category_col].nunique() > 1
                        else ""
                    ),
                    color="black",
                    anchor="start",
                ),
            )
            + heatmap_wildtype
            for category in data_df[category_col].unique()
        ],
        spacing=0,
    ).resolve_scale(x="shared")
    
    chart = (
        alt.vconcat(site_zoom_bar, site_lineplot, heatmaps)
        .add_parameter(floor_at_zero, site_brush, *sliders.values())
        .configure_axis(labelOverlap="parity", grid=False)
    )
    
    if plot_title:
        chart = chart.properties(
            title=alt.TitleParams(
                plot_title, anchor="start", align="left", fontSize=16,
            ),
        )
    
    return chart

https://joelostblom.github.io/altair-docs/user_guide/data.html

In [159]:
chart = lineplot_and_heatmap(
    data_df=df.assign(domain=lambda x: x["site"].map(lambda r: "440 loop" if 440 <= r <= 450 else "other")),
    stat_col="escape_median",
    category_col="epitope",
    site_zoom_bar_color_col="domain",
    addtl_tooltip_stats=["escape_mean", "escape_std"],
    addtl_slider_stats={"n_models": 3, "times_seen": 10},
    init_site_escape_statistic="mean",
)

#chart.save("_concat_heatmap_w_wildtype.html")

chart

In [147]:
pd.Series([pd.NA, 1]).unique()

array([<NA>, 1], dtype=object)