# Plot figures in the paper

In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import torch
from fastabx import Dataset
from matplotlib.lines import Line2D
from matplotlib.ticker import ScalarFormatter
from scipy.stats import linregress, pearsonr
from sklearn.manifold import TSNE
from torch import nn

from spidr.config import DinoSRConfig, MaskingConfig, OptimizerConfig, SpidRConfig
from spidr.data.masks import MaskGenerator
from spidr.models.dinosr import ema_scheduler
from spidr.models.metrics import proba_phone_code
from spidr.models.spidr import exp_ema_scheduler
from spidr.optimizer import build_optimizer


def mpl_palette(name: str, n_colors: int) -> list[tuple[float, float, float]]:
    """Adapted from seaborn.mpl_palette."""
    # fmt: off
    mpl_qual_pals = {
        "tab10": 10, "tab20": 20, "tab20b": 20, "tab20c": 20, "Set1": 9, "Set2": 8, "Set3": 12,
        "Accent": 8, "Paired": 12, "Pastel1": 9, "Pastel2": 8, "Dark2": 8,
    }
    # fmt: on
    bins = (
        np.linspace(0, 1, mpl_qual_pals[name])[:n_colors]
        if name in mpl_qual_pals
        else np.linspace(0, 1, int(n_colors) + 2)[1:-1]
    )
    return list(map(tuple, plt.get_cmap(name)(bins)[:, :3]))


assets, results, figures = Path("./assets"), Path("./results"), Path("./figures")
plt.style.use(assets / "paper.mplstyle")
figures.mkdir(exist_ok=True)

TEXTWIDTH = 6.5

## 3 - Method

Figure 2: Codebook and prediction perplexities during training for SpidR and DinoSR on LibriSpeech dev-clean, with $K = 8$ codebooks. For each layer $k$, the codebook perplexity is computed over each batch with $\bm{p} = \bm{y}^k$ and then averaged across the dataset. The prediction perplexity uses $\bm{p} = \tilde{\bm{y}}^k$.

In [None]:
def get_ppl(subset: str) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    return (
        pl.read_ndjson(results / "logs" / f"{model}-{subset}.jsonl.gz")
        .filter(pl.col("rank") == 0)
        .with_columns((pl.col(f"{metric}_{i}") * pl.col(f"num_frames_{i}")).alias(f"{metric}_{i}") for i in range(8))
        .group_by("step")
        .agg([pl.col(f"{metric}_{i}").sum() for i in range(8)] + [pl.col(f"num_frames_{i}").sum() for i in range(8)])
        .with_columns([(pl.col(f"{metric}_{i}") / pl.col(f"num_frames_{i}")).alias(f"{metric}_{i}") for i in range(8)])
        .select(["step"] + [f"{metric}_{i}" for i in range(8)])
        .sort("step")
        for model in ["dinosr", "spidr"]
        for metric in ["target_ppl", "pred_ppl"]
    )


fig, ax = plt.subplots(figsize=(TEXTWIDTH, 3), nrows=2, ncols=4, sharex=True, sharey=True)
fig.set_layout_engine("tight", pad=0, w_pad=0, h_pad=0.8, rect=(0, 0.2, 1, 1))
subset = "dev-clean"
dinosr_target, dinosr_pred, spidr_target, spidr_pred = get_ppl(subset)
for k in range(8):
    i, j = divmod(k, 4)
    ax[i, j].plot(spidr_target["step"], spidr_target[f"target_ppl_{k}"], color="C0", lw=1)
    ax[i, j].plot(dinosr_target["step"], dinosr_target[f"target_ppl_{k}"], color="C3", lw=1)
    ax[i, j].plot(spidr_pred["step"], spidr_pred[f"pred_ppl_{k}"], linestyle="--", color="C0", lw=1)
    ax[i, j].plot(dinosr_pred["step"], dinosr_pred[f"pred_ppl_{k}"], linestyle="--", color="C3", lw=1)
    ax[i, j].set_title(f"Layer {5 + k}", fontsize=10)
    if j == 0:
        ax[i, j].set_ylabel("Perplexity")
    ymin = min(
        spidr_target[1:, f"target_ppl_{k}"].min(),
        dinosr_target[1:, f"target_ppl_{k}"].min(),
        spidr_pred[1:, f"pred_ppl_{k}"].min(),
        dinosr_pred[1:, f"pred_ppl_{k}"].min(),
    )
    ymax = max(
        spidr_target[1:, f"target_ppl_{k}"].max(),
        dinosr_target[1:, f"target_ppl_{k}"].max(),
        spidr_pred[1:, f"pred_ppl_{k}"].max(),
        dinosr_pred[1:, f"pred_ppl_{k}"].max(),
    )
    ax[i, j].set_ylim(ymin * 0.95, ymax * 1.05)
    ax[i, j].set_xticks([0, 200_000, 400_000])
    ax[i, j].set_xticklabels(["0", "200k", "400k"])
    ax[i, j].xaxis.set_tick_params(labelsize=7)
    ax[i, j].yaxis.set_tick_params(labelsize=7)
    ax[i, j].grid(linestyle="dotted")
fig.supxlabel("Training steps", y=0.18, fontsize=10)
leg = ax[1, 0].legend(
    [Line2D([], [], color="C0", lw=2), Line2D([], [], color="C3", lw=2)],
    ["SpidR", "DinoSR"],
    bbox_to_anchor=(0.8, -0.49),
    loc="upper center",
    fontsize=10,
    title="Model",
    ncols=2,
)
leg.set_in_layout(False)
leg2 = ax[1, 2].legend(
    [Line2D([], [], color="black", lw=2), Line2D([], [], color="black", linestyle="--", lw=2)],
    ["From codebook", "From prediction"],
    bbox_to_anchor=(0.9, -0.49),
    loc="upper center",
    fontsize=10,
    title="Perplexity",
    ncols=2,
)
leg2.set_in_layout(False)
plt.savefig(figures / f"perplexity-{subset}.pdf")
plt.show()

## 4.4 - Evaluation of downstream spoken language modeling

Figure 3: Data scaling results for a 125M parameters OPT model trained on Libri-Light, with different discrete units encoders. Zero-shot accuracy in %, chance level 50%. The speech encoders have V = 256 units. The log-likelihoods are normalized by the number of tokens, except for WUGGY with text.

In [None]:
metrics = {"swuggy_all": "WUGGY all", "sblimp": "BLIMP", "tsc": "tSC"}
styles = {
    "spidr_codebook": {"label": "SpidR (Codebooks)", "linestyle": "-"},
    "spidr_kmeans": {"label": "SpidR (K-means)", "linestyle": "--"},
    "hubert_kmeans": {"label": "HuBERT (K-means)", "linestyle": "-."},
    "text": {"label": "Text (BPE)", "color": "#484848", "linestyle": ":"},
}
layers = {"spidr_codebook": 6, "spidr_kmeans": 6, "hubert_kmeans": 11}

df = pl.read_csv(results / "spoken-lm.csv")
fig, ax = plt.subplots(nrows=1, ncols=len(metrics), figsize=(TEXTWIDTH, 2), sharex=True)
fig.set_layout_engine("tight", pad=0.1, w_pad=1, rect=(0, 0.22, 1, 1))
for k, metric in enumerate(metrics):
    for model, style in styles.items():
        select_layer = (pl.col("layer") == layers[model]) if model != "text" else pl.col("layer").is_null()
        subdf = df.filter(pl.col("metric") == metric, pl.col("model") == model, select_layer).sort("hours")
        ax[k].plot(subdf["hours"], subdf["score"], lw=1.5, marker="o", markersize=4, **style)
    ax[k].set_xscale("log")
    ax[k].xaxis.set_major_formatter(ScalarFormatter())
    ax[k].minorticks_off()
    ax[k].set_xticks([600, 6000, 60_000])
    ax[k].grid(linestyle="dotted")
    ax[k].xaxis.set_tick_params(labelsize=7)
    ax[k].yaxis.set_tick_params(labelsize=7)
    ax[k].set_title(metrics[metric], fontsize=10)
ax[0].set_ylabel(r"Accuracy (\%)", fontsize=10)
fig.supxlabel("Training hours", y=0.2, fontsize=10)
leg = ax[1].legend(bbox_to_anchor=(0.47, -0.4), loc="upper center", ncols=len(styles), fontsize=9)
leg.set_in_layout(False)
plt.savefig(figures / "scaling.pdf")
plt.show()

## 4.5 - Codebase and pretraining time

Figure 4: Approximate pretraining time for various hardware configurations with constant total batch size.

In [None]:
styles = {
    False: {"fill": True, "facecolor": "gray", "edgecolor": "black"},
    True: {"fill": False, "hatch": "///", "edgecolor": "black"},
}
width = 0.3

df = pl.read_csv(results / "speed.csv")
fig, ax = plt.subplots(figsize=(TEXTWIDTH * 0.35, 2.15), nrows=2, sharey=True, sharex=True)
fig.set_layout_engine("tight", rect=(0.1, 0.06, 0.85, 1), pad=0.05, h_pad=0.5)
x = np.arange(len(df["n"].unique()))
for i, gpu in enumerate(["a100", "h100"]):
    for multiplier, with_compile in enumerate([False, True]):
        subdf = df.filter((pl.col("gpu") == gpu) & (pl.col("compile") == with_compile)).sort("n", descending=False)
        offset = width * multiplier
        rects = ax[i].bar(
            x + offset,
            subdf["hours"],
            width,
            label={True: "W/ compile", False: "W/o compile"}[with_compile],
            **styles[with_compile],
        )
        ax[i].bar_label(rects, padding=3, fontsize=9)
    ax[i].set_yticks([])
    ax[i].set_xticks(x + width / 2, sorted(df["n"].unique()), fontsize=9)
    ax[i].set_title(f"{gpu.upper()}", y=0.7 if i == 1 else None, fontsize=10)
handles, labels = ax[0].get_legend_handles_labels()
leg = fig.legend(
    handles,
    labels,
    loc="center right",
    bbox_to_anchor=(1, 0.45),
    fontsize=6,
)
leg.set_in_layout(False)
text = fig.supylabel("Pretraining time (hours)", fontsize=9)
text.set_in_layout(False)
text = fig.supxlabel("Number of GPUs", fontsize=9)
text.set_in_layout(False)
plt.savefig(figures / "speed.pdf")
plt.show()

## A.1 - SpidR pretraining

Figure 5: Learning rate schedule and EMA decay schedule of the teacher for DinoSR and SpidR.

In [None]:
model = nn.Linear(1, 1)
cfg = OptimizerConfig()
opt, _, scheduler = build_optimizer(model, cfg)
d, s, ema_s, ema_d, lrs = DinoSRConfig(), SpidRConfig(), [], [], []
for i in range(400_000):  # Naive loop
    opt.step()
    lrs.append(scheduler.get_last_lr()[0])
    scheduler.step()
    ema_s.append(exp_ema_scheduler(i, s.ema_start_decay, s.ema_timescale, s.ema_threshold))
    ema_d.append(ema_scheduler(i, d.ema_start_decay, d.ema_final_decay, d.ema_final_step, d.freeze_step))

fig, ax = plt.subplots(figsize=(3, 1.5))
fig.set_layout_engine("tight", pad=0.1)
ax.plot(lrs, color="k", lw=2)
ax.set_xlabel("Training steps")
ax.set_ylabel("Learning rate")
ax.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0))
ax.set_yticks([0, 1e-4, 2e-4, 3e-4, 4e-4, 5e-4])
ax.set_xticks([0, 200_000, 400_000])
ax.grid(linestyle="dotted")
ax.set_xticklabels(["0", "200k", "400k"])
plt.savefig(figures / "lr-scheduler.pdf")
plt.show()

fig, ax = plt.subplots(figsize=(3, 1.5))
fig.set_layout_engine("tight", pad=0.1)
ax.plot(ema_s, color="C0", lw=2, zorder=3, label="SpidR")
ax.plot(ema_d, color="C3", lw=2, zorder=2, label="DinoSR", linestyle="--")
ax.set_xlabel("Training steps")
ax.set_ylabel(r"$\beta_t$")
ax.set_xticks([0, 200_000, 400_000])
ax.set_xticklabels(["0", "200k", "400k"])
ax.legend(loc="lower right")
ax.grid(linestyle="dotted")
plt.savefig(figures / "ema-scheduler.pdf")
plt.show()

## A.2 - Masking procedure

Figure 6: Mask and unmasked frames:

In [None]:
def conv_length(length: torch.Tensor) -> torch.Tensor:
    conv_layer_config = [(10, 5)] + [(3, 2)] * 4 + [(2, 2)] * 2
    for kernel_size, stride in conv_layer_config:
        length = torch.div(length - kernel_size, stride, rounding_mode="floor") + 1
        length = torch.max(torch.zeros_like(length), length)
    return length


def get_mask(sample_size: int, dim: int = 200) -> torch.Tensor:
    torch.manual_seed(1)
    gen = MaskGenerator(MaskingConfig())
    max_len = conv_length(torch.tensor(sample_size)).item()
    x = torch.zeros((1, max_len), dtype=torch.bool)
    return (~gen(x)[0]).float().unsqueeze(-1).expand(dim, -1, 3)


mask = get_mask(216000)
color = torch.tensor([225, 225, 225], dtype=torch.float32) / 255.0
x = torch.ones_like(mask) * color.view(1, 1, 3)

plt.figure(figsize=(10, 2))
plt.imshow(x)
plt.axis("off")
plt.savefig(figures / "x.pdf", bbox_inches="tight", pad_inches=0)
plt.show()
plt.figure(figsize=(10, 2))
plt.imshow(x * mask)
plt.axis("off")
plt.savefig(figures / "x-tilde.pdf", bbox_inches="tight", pad_inches=0)
plt.show()

## B.1 - Discriminability of continuous embeddings

Figure 7: ABX and MAP (in %, chancel level 50% for ABX) by layer for SpidR, DinoSR and HuBERT:

In [None]:
model_mapping = {"hubert": "HuBERT", "spidr": "SpidR", "dinosr": "DinoSR", "dinosr_ours": "DinoSR (ours)"}

abx = (
    pl.read_csv(results / "abx-continuous.csv")
    .filter(pl.col("model").is_in(model_mapping))
    .with_columns(pl.lit("abx").alias("metric"), pl.col("model").replace_strict(model_mapping))
    .group_by(["model", "layer"])
    .agg(pl.col("score").mean())
)
speech_map = (
    pl.read_csv(results / "map-continuous.csv")
    .filter(pl.col("model").is_in(model_mapping))
    .with_columns(pl.lit("map").alias("metric"), pl.col("model").replace_strict(model_mapping))
    .group_by(["model", "layer"])
    .agg(pl.col("score").mean())
)

styles = {
    "SpidR": {"color": "C0"},
    "HuBERT": {"color": "C2", "linestyle": "-."},
    "DinoSR": {"color": "C3", "linestyle": "--"},
    "DinoSR (ours)": {"color": "C3", "linestyle": "--", "alpha": 0.3, "markeredgecolor": "none"},
}
common = {"marker": "o", "markersize": 4, "linewidth": 1}
models = ["SpidR", "DinoSR", "DinoSR (ours)", "HuBERT"]

plt.style.use("mpoli")
fig, ax = plt.subplots(figsize=(6.5, 2), ncols=2, sharex=True, sharey=False)
fig.set_layout_engine("tight", pad=1, w_pad=1, rect=(0, 0.1, 1, 1))
for i, model in enumerate(models):
    abx_df = abx.filter(pl.col("model") == model).sort("layer")
    map_df = speech_map.filter(pl.col("model") == model).sort("layer")
    ax[0].plot(abx_df["layer"], abx_df["score"], label=model, zorder=len(models) + 1 - i, **common, **styles[model])
    ax[1].plot(map_df["layer"], map_df["score"], label=model, zorder=len(models) + 1 - i, **common, **styles[model])
ax[0].grid(linestyle="dotted", alpha=0.4)
ax[1].grid(linestyle="dotted", alpha=0.4)
ax[0].set_xticks(list(range(1, 13)))
ax[0].set_yticks([4, 6, 8, 10, 12, 14])
ax[0].set_ylim(4 * 0.9, 14 * 1.1)
ax[0].set_xlabel("Layer")
ax[1].set_xlabel("Layer")
ax[0].set_ylabel(r"ABX error rate (\%)")
ax[1].set_ylabel(r"MAP (\%)")
leg = ax[0].legend(bbox_to_anchor=(1.1, -0.38), loc="upper center", ncols=4, fontsize=9)
leg.set_in_layout(False)
plt.savefig(figures / "continuous-by-layer.pdf")
plt.show()

## B.2 - Embeddings visualization

<div class="alert alert-warning">
To reproduce the t-SNE figures from the paper, first extract the features from SpidR layer 6 with:

```python extract_features.py /path/to/manifests/dev-clean.csv ./features```
where "dev-clean.csv" is your manifest file to LibriSpeech dev-clean.

Then execute the following cells after changing the variable `librispeech` to the actual path of your LibriSpeech copy.
</div>

Figure 8: t-SNE visualization of phone embeddings from SpidR layer 6 on LibriSpeech dev-clean. Embeddings are colored by phone class and by speaker gender.

In [None]:
def build_data(dataset: Dataset, *, n: int, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    partitions = {
        key: value.sample(seed=seed, n=n) if len(value) >= n else None
        for key, value in dataset.labels.with_row_index().partition_by("#phone", "speaker", as_dict=True).items()
    }
    phones, speakers, embeddings = [], [], []
    for (phone, speaker), samples in partitions.items():
        if samples is None:
            continue
        phones += [phone] * len(samples)
        speakers += [speaker] * len(samples)
        embeddings += [dataset.accessor[i].mean(dim=0).numpy() for i in samples["index"]]
    return np.array(embeddings), np.array(phones), np.array(speakers)


def read_gender(speaker_file: str, subset: str) -> dict[str, str]:
    data = {}
    with Path(speaker_file).open() as f:
        lines = f.readlines()[12:]
    for line in lines:
        x, y, z, *_ = line.split(" | ")
        if z.strip() == subset:
            data[x.strip()] = y.strip()
    return data


features, librispeech = Path("./features"), Path("../../LibriSpeech")  # To adapt to your setup

sonority_to_phones = json.loads((assets / "sonority_to_arpabet.json").read_text())
phone_to_sonority = {phone: sonority for sonority, phones in sonority_to_phones.items() for phone in phones}
arpabet_to_tipa = json.loads((assets / "arpabet_to_tipa.json").read_text())
palette = mpl_palette("magma", len(sonority_to_phones))

dataset = Dataset.from_item(assets / "phoneme-dev-clean.item", features, 50)
embeddings, phones, speakers = build_data(dataset, n=10, seed=0)
tsne = TSNE(n_components=2, random_state=0, n_jobs=-1).fit_transform(embeddings)
df = pl.DataFrame({"x": tsne[:, 0], "y": tsne[:, 1], "phone": phones, "speaker": speakers}).with_columns(
    pl.col("phone").replace_strict(phone_to_sonority).alias("sonority"),
    pl.col("speaker").replace_strict(read_gender(librispeech / "SPEAKERS.txt", "dev-clean")).alias("gender"),
)
common = {"s": 5, "alpha": 0.4, "rasterized": True, "edgecolors": "white", "linewidth": 0.2}

fig, ax = plt.subplots(figsize=(TEXTWIDTH, TEXTWIDTH / 2), ncols=2, sharex=True, sharey=True)
fig.set_layout_engine("tight", pad=0, rect=(0, 0.18, 1, 1))
for i, sonority in enumerate(sonority_to_phones):
    subdf = df.filter(pl.col("sonority") == sonority)
    ax[0].scatter(subdf["x"], subdf["y"], label=sonority.capitalize(), color=palette[i], **common)
ax[0].axis("off")
ax[0].set_aspect("equal")
legend = ax[0].legend(
    title="Phone class",
    fontsize=9,
    ncols=3,
    bbox_to_anchor=(0.5, 0.05),
    loc="upper center",
    markerscale=2,
)
for lh in legend.legend_handles:
    lh.set_alpha(1)
legend.set_in_layout(False)
for gender, name in zip(["F", "M"], ["Female", "Male"], strict=False):
    subdf = df.filter(pl.col("gender") == gender)
    ax[1].scatter(subdf["x"], subdf["y"], label=name, **common)
ax[1].axis("off")
ax[1].set_aspect("equal")
legend = ax[1].legend(
    title="Gender",
    fontsize=9,
    ncols=2,
    bbox_to_anchor=(0.5, 0.05),
    loc="upper center",
    markerscale=2,
)
for lh in legend.legend_handles:
    lh.set_alpha(1)
legend.set_in_layout(False)
plt.savefig(figures / "tsne.pdf", dpi=200)
plt.show()

Figure 9: t-SNE visualization of phone embeddings from SpidR layer 6 on LibriSpeech dev-clean, colored by individual phones within each phone class. Embeddings from other classes are shown in gray.

In [None]:
fig, ax = plt.subplots(figsize=(TEXTWIDTH, TEXTWIDTH / 2), nrows=2, ncols=3)
fig.set_layout_engine("tight", pad=0.05, h_pad=0, w_pad=4, rect=(0, 0, 1, 1))
for k, sonority in enumerate(sonority_to_phones):
    i, j = divmod(k, 3)
    mult = 1 if sonority == "vowel" else 2
    palette = mpl_palette("tab20", len(sonority_to_phones[sonority]) * mult)
    colors = {phone: palette[i * mult] for i, phone in enumerate(sonority_to_phones[sonority])} | {
        phone: "gray" for phone in phone_to_sonority if phone not in sonority_to_phones[sonority]
    }
    for phone in phone_to_sonority:
        subdf = df.filter(pl.col("phone") == phone)
        ax[i, j].scatter(
            subdf["x"],
            subdf["y"],
            s=1,
            label=f"[{arpabet_to_tipa[phone]}]" if phone in sonority_to_phones[sonority] else None,
            color=colors[phone],
            alpha=0.8 if phone in sonority_to_phones[sonority] else 0.02,
            edgecolors="none",
            rasterized=True,
            zorder=3 if phone in sonority_to_phones[sonority] else 1,
        )
    legend = ax[i, j].legend(
        loc="upper right",
        bbox_to_anchor=(0.06, 1),
        ncols=2 if sonority == "vowel" else 1,
        columnspacing=0.1,
        markerscale=2.5,
        fontsize=5,
        handletextpad=0.05,
    )
    for lh in legend.legend_handles:
        lh.set_alpha(1)
    legend.set_in_layout(False)
    ax[i, j].set_title(sonority.capitalize(), fontsize=9)
    ax[i, j].set_axis_off()
    ax[i, j].set_aspect("equal")
plt.savefig(figures / "tsne-sonority.pdf", dpi=200)
plt.show()

Figure 10: P(phone | code) visualization for SpidR layer 6 using either codebook predictions (left) or K-means quantization (right), on LibriSpeech dev-clean and dev-other.

In [None]:
for method in ["codebooks", "kmeans"]:
    df = pl.read_ndjson(assets / "phoneme-dev-clean-and-other.jsonl.gz").join(
        pl.read_ndjson(results / f"units-{method}-spidr-l6-dev-clean-and-other.jsonl.gz"), on="name"
    )
    data = {name: (phones, codes) for name, phones, codes in df.iter_rows()}
    proba, phone_order, _ = proba_phone_code(data, num_units=256, num_phones=40, only_active=True)

    fig, ax = plt.subplots(figsize=(3, 3))
    fig.set_layout_engine("tight", pad=0)
    ax.imshow(proba, cmap="Blues", aspect="auto", interpolation="none")
    ax.spines[["top", "right", "left", "bottom"]].set_visible(False)
    ax.set_yticks(range(len(phone_order)))
    ax.set_yticklabels(phone_order, fontsize=5)
    ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
    plt.savefig(figures / f"proba-phone-code-{method}.pdf", bbox_inches="tight")
    plt.show()

## B.3 - Layer-wise analysis

Fgiure 11: ABX (in %, chance level 50%) and PNMI by layer on discrete units from SpidR using codebook predictions or K-means, and from HuBERT using K-means, with V = 256 units. ABX scores averaged across subsets and speaker conditions, and PNMI computed on LibriSpeech dev-clean and dev-other.

In [None]:
styles = {
    "SpidR (Codebooks)": {"color": "C0"},
    "SpidR (K-means)": {"color": "C1", "linestyle": "--"},
    "HuBERT (K-means)": {"color": "C2", "linestyle": "-."},
}
common = {"marker": "o", "markersize": 4, "linewidth": 1}
models = {"spidr_codebooks": "SpidR (Codebooks)", "spidr": "SpidR (K-means)", "hubert": "HuBERT (K-means)"}

pnmi = pl.read_csv(results / "units-quality.csv")
abx = pl.read_csv(results / "abx-discrete.csv")

fig, ax = plt.subplots(figsize=(TEXTWIDTH, 2), ncols=2, sharex=True, sharey=False)
fig.set_layout_engine("tight", h_pad=1, w_pad=1, rect=(0, 0.09, 1, 1))
for i, (name, model) in enumerate(models.items()):
    abx_df = abx.filter(pl.col("model") == name).sort("layer").group_by("layer").agg(pl.col("score").mean())
    pnmi_df = pnmi.filter(pl.col("model") == name, pl.col("metric") == "pnmi").sort("layer")
    ax[0].plot(abx_df["layer"], abx_df["score"], label=model, zorder=len(models) + 1 - i, **common, **styles[model])
    ax[1].plot(pnmi_df["layer"], pnmi_df["score"], label=model, zorder=len(models) + 1 - i, **common, **styles[model])
ax[0].grid(linestyle="dotted", alpha=0.4)
ax[1].grid(linestyle="dotted", alpha=0.4)
ax[0].set_xticks(list(range(5, 13)))
ax[0].set_yticks([4, 6, 8, 10, 12, 14])
ax[0].set_ylim(4 * 0.9, 14 * 1.1)
ax[0].set_xlabel("Layer")
ax[1].set_xlabel("Layer")
ax[0].set_ylabel(r"ABX error rate (\%)")
ax[1].set_ylabel(r"PNMI")
leg = ax[1].legend(bbox_to_anchor=(-0.1, -0.38), loc="upper center", ncols=4, fontsize=9)
leg.set_in_layout(False)
plt.savefig(figures / "discrete-by-layer.pdf")
plt.show()

Figure 12: Zero-shot spoken language modeling from each layer of HuBERT and SpidR (in %, chance level 50%), with units from codebook predictions or from K-means quantization, with V = 256 units.

In [None]:
model_mapping = {
    "spidr_codebooks": ("SpidR (Codebooks)", {"color": "C0"}),
    "spidr_kmeans": ("SpidR (K-means)", {"color": "C1", "linestyle": "--"}),
    "hubert_kmeans": ("HuBERT (K-means)", {"color": "C2", "linestyle": "-."}),
}
metric_mapping = {
    "swuggy_all": "sWUGGY all",
    "swuggy_invocab": "sWUGGY in-vocab",
    "sblimp": "sBLIMP",
    "tsc": "tSC",
    "ssc": "sSC",
}
hours = 6000

slm = pl.read_csv(results / "spoken-lm.csv").filter(pl.col("model").is_in(model_mapping), pl.col("hours") == hours)
fig, ax = plt.subplots(figsize=(TEXTWIDTH, 2), ncols=3, sharex=True, sharey=False)
fig.set_layout_engine("tight", pad=0.5, w_pad=1, rect=(0, 0.2, 1, 1))
for j, metric in enumerate(["swuggy_all", "sblimp", "tsc"]):
    for i, (model, (name, style)) in enumerate(model_mapping.items()):
        metric_df = slm.filter(pl.col("model") == model, pl.col("metric") == metric).sort("layer")
        ax[j].plot(
            metric_df["layer"],
            metric_df["score"],
            label=name,
            zorder=len(styles) + 1 - i,
            marker="o",
            markersize=4,
            linewidth=1,
            **style,
        )
    ax[j].set_xticks(range(5, 13))
    ax[j].set_xticklabels(range(5, 13), fontsize=7)
    ax[j].tick_params(axis="y", labelsize=7)
    ax[j].grid(linestyle="dotted", alpha=0.4)
    ax[j].set_xlabel("Layer", fontsize=9)
    ax[j].set_ylabel(rf"{metric_mapping[metric]} (\%)", fontsize=9)
leg = ax[1].legend(bbox_to_anchor=(0.5, -0.38), loc="upper center", ncols=4, fontsize=9)
leg.set_in_layout(False)
plt.savefig(figures / "slm-by-layer.pdf")
plt.show()

Figure 13: Spoken language modeling against discriminability of the continuous representations. Dots are labeled by intermediate layer index. ABX for SpidR (Codebooks) is computed over codebook predictions

Figure 14: Spoken language modeling against phonetic evaluation of the discrete units, with V = 256 units. Dots are labeled by intermediate layer index.

In [None]:
metrics_mapping = {
    "swuggy_all": "sWUGGY all",
    "swuggy_invocab": "sWUGGY in-vocab",
    "sblimp": "sBLIMP",
    "tsc": "tSC",
    "ssc": "sSC",
    "codebook_perplexity": "Codebook perplexity",
    "phone_purity": "Phn. purity",
    "cluster_purity": "Clus. purity",
    "active_codewords": "Active units",
    "pnmi": "PNMI",
    "abx": "ABX",
    "map": "MAP",
    "abx_discrete": "ABX discrete",
}
model_mapping = {
    "spidr_codebooks": "SpidR (Codebooks)",
    "spidr": "SpidR (K-means)",
    "spidr_kmeans": "SpidR (K-means)",
    "hubert": "HuBERT (K-means)",
    "hubert_kmeans": "HuBERT (K-means)",
}
hours = 6000


def average_scores(df: pl.DataFrame) -> pl.DataFrame:
    return (
        df.group_by("model", "layer", "metric")
        .agg(pl.col("score").mean())
        .sort("model", "layer")
        .select("model", "layer", "metric", "score")
    )


slm = pl.read_csv(results / "spoken-lm.csv").filter(pl.col("hours") == hours).select(pl.exclude("hours"))
abx = pl.read_csv(results / "abx-continuous.csv").with_columns(pl.lit("abx").alias("metric"))
speech_map = pl.read_csv(results / "map-continuous.csv").with_columns(pl.lit("map").alias("metric"))
discrete = pl.read_csv(results / "units-quality.csv")
abx_discrete = pl.read_csv(results / "abx-discrete.csv").with_columns(pl.lit("abx_discrete").alias("metric"))
df = (
    pl.concat([slm, average_scores(abx), average_scores(speech_map), discrete, average_scores(abx_discrete)])
    .filter(pl.col("model").is_in(model_mapping))
    .with_columns(pl.col("model").replace_strict(model_mapping), pl.col("metric").replace_strict(metrics_mapping))
    .pivot(on="metric", values="score", index=["model", "layer"])
    .drop_nulls()
    .unpivot(
        on=["sBLIMP", "sWUGGY all", "sWUGGY in-vocab", "tSC", "sSC"],
        variable_name="metric",
        value_name="slm",
        index=["model", "layer", "ABX", "MAP", "ABX discrete", "PNMI"],
    )
)

styles = ["o", "s", "D"]
text_style = {"size": 3, "ha": "center", "va": "center", "zorder": 4, "math_fontfamily": "dejavuserif"}
for metrics in [["ABX", "MAP"], ["ABX discrete", "PNMI"]]:
    fig, ax = plt.subplots(figsize=(6.5, 4), nrows=2, ncols=3, sharey="col", sharex="row")
    fig.set_layout_engine("tight", pad=0.1, w_pad=0, h_pad=2, rect=(0, 0.1, 1, 1))
    for i, metric_y in enumerate(["sWUGGY all", "sBLIMP", "tSC"]):
        selection = df.filter(pl.col("metric") == metric_y)
        for j, metric_x in enumerate(metrics):
            coeff = pearsonr(selection[metric_x], selection["slm"])[0]
            line = linregress(selection[metric_x], selection["slm"])
            x = np.array([selection[metric_x].min(), selection[metric_x].max()])
            ax[j, i].plot(
                x,
                line.intercept + line.slope * x,
                color="gray",
                linewidth=0.5,
                linestyle="dotted",
                zorder=0,
            )
            ax[j, i].text(
                0.7,
                0.6 if metric_x.startswith("ABX") else 0.2,
                rf"$\mathbf{{r = {coeff:.2f}}}$",
                transform=ax[j, i].transAxes,
                fontsize=6,
                ha="center",
                va="center",
            )
            for k, model in enumerate(["SpidR (Codebooks)", "SpidR (K-means)", "HuBERT (K-means)"]):
                subdf = selection.filter(pl.col("model") == model)
                scatter_abx = ax[j, i].scatter(
                    subdf[metric_x],
                    subdf["slm"],
                    label=model,
                    s=25,
                    zorder=3,
                    alpha=0.7,
                    edgecolor=f"C{k}",
                    facecolor="none",
                    linewidth=1,
                    marker=styles[k],
                )
                for idx, layer in enumerate(subdf["layer"].cast(str).to_list()):
                    ax[j, i].annotate(
                        r"$\mathbf{" + layer + "}$",
                        (subdf[metric_x][idx], subdf["slm"][idx]),
                        color=f"C{k}",
                        **text_style,
                    )
                ax[j, i].set_xlabel(metric_x + (r" (\%)" if metric_x != "PNMI" else ""), fontsize=9)
                ax[j, i].set_ylabel(metric_y + r" (\%)", fontsize=9)
                ax[j, i].grid(linestyle="dotted", alpha=0.4)
                ax[j, i].tick_params(axis="both", labelsize=7)
    leg = ax[1, 0].legend(bbox_to_anchor=(1.7, -0.32), loc="upper center", ncols=3)
    leg.set_in_layout(False)
    plt.savefig(figures / f"slm-{'-'.join(m.replace(' ', '-').lower() for m in metrics)}.pdf")
    plt.show()