In [9]:
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


matplotlib.rcParams.update({"legend.frameon": False, "axes.spines.right": False, "axes.spines.top": False})
matplotlib.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": "CMU Serif"})

## Varying correlation factor

In [None]:
df = pd.read_json("../data/correlation/results.jsonl", lines=True)
df

In [None]:
# Left: original correlated data, Right: data after random permutation

palette = sns.color_palette("flare", as_cmap=True)
g = sns.relplot(
    data=df,
    x="corr_factor",
    y="recall_mean",
    hue="k_per_bucket",
    col="shuffle",
    style="interleaved",
    kind="line",
    markers=True,
    palette=palette,
)
g.figure.set_size_inches(9, 3)
g.figure.subplots_adjust(wspace=0.3) # breathing space for columns
g.set_axis_labels("Correlation factor", "Recall")
g.set_titles(col_template="")

g.legend.set_bbox_to_anchor((1.05, 0.5))
for t in g.legend.get_texts():
    text = t.get_text()
    if text in ["k_per_bucket", "interleaved"]:
        t.set_visible(False)
    if text.isnumeric():
        t.set_text(f"$k_b = {text}$")
    if text == "False":
        t.set_text("Contiguous")
    if text == "True":
        t.set_text("Interleaved")

plt.savefig("figures/appendix-correlation-sim.pdf", bbox_inches="tight")

## SparQ results (interleaved vs contiguous)

In [12]:
df1 = pd.read_json("../data/sparq_v1.jsonl", lines=True)

In [13]:
df1 = df1[(df1["task_name"] == "repetition") & (df1["topk_k_per_bucket"].isin([1, 2, 4, 8]))]

In [14]:
df2 = pd.read_json("../data/sparq_v2.jsonl", lines=True)
df2 = df2[(df2["task_name"] == "squad") & (df2["topk_k_per_bucket"].isin([1, 2, 4, 8]))]

In [None]:
df = pd.concat([df1, df2])
df["topk_interleaved"] = df["topk_interleaved"].astype(bool)
df

In [None]:
# Left: Repetition, Right: SQuAD
# Only use k_mult == 1

palette = sns.color_palette("flare", as_cmap=True)
g = sns.relplot(
    data=df[df["topk_k_mult"] == 1],
    x="topk_k_per_bucket",
    y="score",
    style="topk_interleaved",
    col="task_name",
    kind="line",
    markers=True,
    color=palette(0.99),
    facet_kws={"sharey": False},
)
g.figure.set_size_inches(9, 3)
g.figure.subplots_adjust(wspace=0.3) # breathing space for columns
g.set_axis_labels(x_var="$k_b$")
for i, ax in enumerate(g.axes.flat):
    ylabel = ["Repetition match", "SQuAD accuracy"]
    ax.set_xscale("log", base=2)
    ax.set_xticks([1, 2, 4, 8])
    ax.set_xticklabels([1, 2, 4, 8])
    ax.set_title("")
    ax.set_ylabel(ylabel[i])

g.legend.set_bbox_to_anchor((1.05, 0.5))
g.legend.set_title("")
for t in g.legend.get_texts():
    text = t.get_text()
    if text == "False":
        t.set_text("Contiguous")
    if text == "True":
        t.set_text("Interleaved")

plt.savefig("figures/appendix-correlation-sparq.pdf", bbox_inches="tight")