# Make a wrapped heatmap with multiple rows
This notebook makes wrapped heatmaps with multiple rows, designed for use as paper figures.

In [None]:
import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Get configuration parameters

We get the parameters passed by `snakemake`:

In [None]:
data_csv = snakemake.input.data_csv
chart_html = snakemake.output.chart_html

params_dict = snakemake.params.params_dict
title = params_dict["title"]
effect_col = params_dict["effect_col"]
data_query_str = params_dict["data_query_str"]
sites_per_row = params_dict["sites_per_row"]
site_label_freq = params_dict["site_label_freq"]
color_scheme = params_dict["color_scheme"]
fixed_min = params_dict["fixed_min"]
fixed_max = params_dict["fixed_max"]
alphabet = params_dict["alphabet"]
if "dark_gray_muts" in params_dict:
    dark_gray_muts = params_dict["dark_gray_muts"]
else:
    dark_gray_muts = None

## Read the input data

In [None]:
if isinstance(alphabet, str):
    alphabet = list(alphabet)

data = pd.read_csv(data_csv, dtype={"site": str})
print(f"Read {len(data)=} with {data.columns=}")

if data_query_str:
    data = data.query(data_query_str).reset_index(drop=True)
    print(f"After querying with {data_query_str=}, {len(data)=}")
    
req_cols = ["site", "sequential_site", "wildtype", "mutant", effect_col]
assert set(req_cols).issubset(data.columns), f"{data.columns=} lacks {req_cols=}"

if dark_gray_muts and (dark_gray_muts["col"] not in data.columns):
    raise ValueError(f"{dark_gray_muts['col']=} not in {data.columns=}")
elif dark_gray_muts and (dark_gray_muts["col"] not in req_cols):
    req_cols.append(dark_gray_muts["col"])

data = data[data["mutant"].isin(alphabet) & data["wildtype"].isin(alphabet)].reset_index(drop=True)
print(f"After getting just amino acids of {alphabet=}, {len(data)=}")

data = data[req_cols]
assert len(data) == len(data.groupby(["site", "wildtype", "mutant"]))

## Make heatmap

In [None]:
if isinstance(color_scheme, str):
    color_scale = alt.Scale(
        scheme=color_scheme,
        domainMid=0,
        domainMin=fixed_min,
        domainMax=fixed_max,
        clamp=True,
    )
else:  # assume it's a list of hex colors
    color_scale = alt.Scale(
        range=color_scheme,
        domainMid=0,
        domainMin=fixed_min,
        domainMax=fixed_max,
        clamp=True,
    )


In [None]:
heatmap_base = (
    alt.Chart(data)
    .encode(alt.Y("mutant", sort=alphabet, title="amino acid"))
    .properties(width=alt.Step(9), height=alt.Step(9))
)

heatmap_bg = heatmap_base.transform_impute(
    impute="_stat_dummy",
    key="mutant",
    keyvals=alphabet,
    groupby=["site"],
    value=None,
).mark_rect(color="#E0E0E0", opacity=0.8)

heatmap_wildtype = (
    heatmap_base
    .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
    .mark_text(text="x", color="black")
)

heatmap_muts = (
    heatmap_base
    .encode(
        alt.Color(
            effect_col,
            scale=color_scale,
        ),
        tooltip=["site", "mutant", "wildtype", alt.Tooltip(effect_col, format=".2f")],
    )
    .mark_rect(stroke="black", opacity=1, strokeOpacity=1)
)

if dark_gray_muts:
    heatmap_muts = heatmap_muts.transform_filter(
        alt.datum[dark_gray_muts["col"]] >= dark_gray_muts["cutoff"]
    )

    heatmap_dark_gray = (
        heatmap_base
        .transform_filter(alt.datum[dark_gray_muts["col"]] < dark_gray_muts["cutoff"])
        .transform_calculate(filtered="0")
        .mark_rect(stroke="black", opacity=1, strokeOpacity=1, color="silver")
    )

heatmap_rows = []
sequential_sites = sorted(data["sequential_site"].unique())
for i in range(0, len(sequential_sites), sites_per_row):
    row_sites = sequential_sites[i: i + sites_per_row]
    last_row = row_sites[-1] == sequential_sites[-1]
    sequential_to_site = data.set_index("sequential_site")["site"].to_dict()
    # label only some of the sites, every site_label_freq
    to_label_values = [
        sequential_to_site[i]
        for i in range(row_sites[0], row_sites[-1], site_label_freq)
        if i in sequential_to_site
    ]
    if dark_gray_muts:
        row_charts = heatmap_bg + heatmap_dark_gray + heatmap_muts + heatmap_wildtype
    else:
        row_charts = heatmap_bg + heatmap_muts + heatmap_wildtype
    heatmap_rows.append(
        row_charts
        .encode(
            alt.X(
                "site:N",
                title="site" if last_row else None,
                sort=alt.SortField("sequential_site"),
                scale=alt.Scale(nice=False, zero=False),
                axis=alt.Axis(values=to_label_values, labelAngle=0)
            ),
        )
        .transform_filter(
            (alt.datum["sequential_site"] >= min(row_sites))
            & (alt.datum["sequential_site"] <= max(row_sites))
        )
    )

heatmap = (
    alt.vconcat(*heatmap_rows, spacing=6)
    .configure_axis(tickColor="black", tickSize=4, titleFontSize=16)
    .configure_legend(
        orient="bottom",
        gradientStrokeWidth=1,
        gradientStrokeColor="black",
        titleAnchor="middle",
        titleFontSize=16,
        titleLimit=200,
    )
    .properties(title=title)
    .configure_title(anchor="middle", fontSize=18)
)

print(f"Saving {chart_html=}")
heatmap.save(chart_html)

heatmap