In [None]:
import altair as alt
import numpy as np
import pandas as pd

In [None]:
df = pd.read_csv(
    snakemake.input.distances,
    sep="\t"
)

In [None]:
df.head()

In [None]:
df["log2_titer"] = np.round(df["log2_titer_mean"], 1)

In [None]:
font_size = 14

In [None]:
base = alt.Chart(df).encode(
    x=alt.X(
        "clade_reference:N",
        title="reference strain clade"
    ),
    y=alt.Y(
        "clade_test:N",
        title="test strain clade"
    )
).properties(
    width=700,
    height=700
)

heatmap = base.mark_rect().encode(
    color=alt.Color(
        "log2_titer:Q",
        scale=alt.Scale(
            scheme="blueorange",
            domain=[-2.0, 7.0],
            domainMid=0.0,
        ),
        legend=alt.Legend(
            direction="vertical",
            title="log2 titer",
        )
    )
)

text = base.mark_text(baseline="middle").encode(
    text="log2_titer:Q",
    color=alt.condition(
        (alt.datum.log2_titer < 3.5) & (alt.datum.log2_titer > -1),
        alt.value("black"),
        alt.value("white")
    )
)

chart = heatmap + text
chart = chart.configure(
    text=alt.MarkConfig(fontSize=font_size)
).configure_axis(
    labelFontSize=font_size,
    titleFontSize=font_size,
).configure_legend(    
    labelFontSize=font_size,
    titleFontSize=font_size,
)
chart

In [None]:
chart.save(snakemake.output.plot)

In [None]:
len(set(df["clade_reference"].values) | set(df["clade_test"].values))

In [None]:
df.shape

In [None]:
df.query("clade_reference == clade_test").shape

In [None]:
np.round(
    df.query("clade_reference == clade_test")["log2_titer"].aggregate(["mean", "std"]),
    2
)

In [None]:
np.round(
    df.query("clade_reference == clade_test")["log2_titer"].aggregate(["mean", "std"]),
    2
).sum()

In [None]:
np.round(
    df.query("clade_reference != clade_test")["log2_titer"].aggregate(["mean", "std"]),
    2
)

In [None]:
df.query("clade_reference != clade_test")["log2_titer"].describe()