# Make a wrapped heatmap with multiple rows
This notebook makes wrapped heatmaps with multiple rows, designed for use as paper figures.

In [None]:
import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Get configuration parameters

The next cell is tagged as `parameters` for `papermill` parameterization:

In [None]:
# cell with expected parameters, all set to None but will be parameterized by papermill
data_csv = None
title = None
effect_col = None
data_query_str = None
sites_per_row = None
site_label_freq = None
color_scheme = None
fixed_min = None
fixed_max = None
alphabet = None
chart_html = None

## Read the input data

In [None]:
if isinstance(alphabet, str):
    alphabet = list(alphabet)

data = pd.read_csv(data_csv, dtype={"site": str})
print(f"Read {len(data)=} with {data.columns=}")

if data_query_str:
    data = data.query(data_query_str).reset_index(drop=True)
    print(f"After querying with {data_query_str=}, {len(data)=}")
    
req_cols = ["site", "sequential_site", "wildtype", "mutant", effect_col]
assert set(req_cols).issubset(data.columns), f"{data.columns=} lacks {req_cols=}"

data = data[data["mutant"].isin(alphabet) & data["wildtype"].isin(alphabet)].reset_index(drop=True)
print(f"After getting just amino acids of {alphabet=}, {len(data)=}")

data = data[req_cols]
assert len(data) == len(data.groupby(["site", "wildtype", "mutant"]))

In [None]:
data = data.assign(row=lambda x: x["site"].str.contains("E1").astype(int))

heatmap_base = (
    alt.Chart(data)
    .encode(alt.Y("mutant", sort=alphabet, title="amino acid"))
    .properties(width=alt.Step(9), height=alt.Step(9))
)

heatmap_bg = heatmap_base.transform_impute(
    impute="_stat_dummy",
    key="mutant",
    keyvals=alphabet,
    groupby=["site"],
    value=None,
).mark_rect(color="#E0E0E0", opacity=0.8)

heatmap_wildtype = (
    heatmap_base
    .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
    .mark_text(text="x", color="black")
)

heatmap_muts = (
    heatmap_base
    .encode(
        alt.Color(
            effect_col,
            scale=alt.Scale(
                scheme=color_scheme,
                domainMid=0,
                domainMin=fixed_min,
                domainMax=fixed_max,
                clamp=True,
            ),
        ),
        tooltip=["site", "mutant", "wildtype", alt.Tooltip(effect_col, format=".2f")],
    )
    .mark_rect(stroke="black", opacity=1, strokeOpacity=1)
)

heatmap_rows = []
sequential_sites = sorted(data["sequential_site"].unique())
for i in range(0, len(sequential_sites), sites_per_row):
    row_sites = sequential_sites[i: i + sites_per_row]
    last_row = row_sites[-1] == sequential_sites[-1]
    sequential_to_site = data.set_index("sequential_site")["site"].to_dict()
    # some complicated stuff is needed to label only some of the sites if site not int
    to_label_values = [
        sequential_to_site[i]
        for i in range(row_sites[0], row_sites[-1], site_label_freq)
    ]
    heatmap_rows.append(
        (heatmap_bg + heatmap_muts + heatmap_wildtype)
        .encode(
            alt.X(
                "site:N",
                title="site" if last_row else None,
                sort=alt.SortField("sequential_site"),
                scale=alt.Scale(nice=False, zero=False),
                axis=alt.Axis(values=to_label_values, labelAngle=0)
            ),
        )
        .transform_filter(
            (alt.datum["sequential_site"] >= min(row_sites))
            & (alt.datum["sequential_site"] <= max(row_sites))
        )
    )

heatmap = (
    alt.vconcat(*heatmap_rows, spacing=6)
    .configure_axis(tickColor="black", tickSize=4, titleFontSize=16)
    .configure_legend(
        orient="bottom",
        gradientStrokeWidth=1,
        gradientStrokeColor="black",
        titleAnchor="middle",
        titleFontSize=16,
        titleLimit=200,
    )
    .properties(title=title)
    .configure_title(anchor="middle", fontSize=18)
)

print(f"Saving {chart_html=}")
heatmap.save(chart_html)

heatmap