In [None]:
import base64
import io
from pathlib import Path
from random import shuffle

import matplotlib.pyplot as plt
import numpy as np
import tgt
import torch
from IPython.display import HTML, Audio, display
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from torchcodec.decoders import AudioDecoder

from zerosyl.model import ZeroSylBase, ZeroSylDiscrete

Helper functions

In [None]:
def cosine_sim_mat(features: torch.Tensor) -> np.array:
    features = torch.nn.functional.normalize(features, p=2, dim=1)
    similarity_matrix = features @ features.T
    similarity_matrix = similarity_matrix.cpu().numpy()
    return similarity_matrix

def plot_sim_mat(similarity_matrix: np.ndarray, melspec: torch.Tensor, textgrid: tgt.TextGrid) -> None:
    fig, axes = plt.subplots(
        nrows=2,
        ncols=2,
        figsize=(8, 8),
        gridspec_kw={"height_ratios": [1, 8], "width_ratios": [1, 8]},
        constrained_layout=True,
    )

    fontsize = 6

    # --- top left ---
    axes[0][0].axis("off")

    # --- Top right ---
    ax_tr = axes[0][1]
    ax_tr.imshow(melspec.cpu().numpy(), aspect="auto", origin="lower")

    xticks = []
    xtickslabels = []
    for interval in textgrid.get_tier_by_name("syllables"):
        x1 = interval.start_time * 50
        x2 = interval.end_time * 50
        xticks.append((x1 + x2) / 2)
        xtickslabels.append(interval.text)
        ax_tr.axvline(x1, color="white")
        ax_tr.axvline(x2, color="white")
    ax_tr.set_xticks(xticks)
    ax_tr.set_xticklabels(xtickslabels, rotation=90, fontsize=fontsize)
    ax_tr.xaxis.set_ticks_position("top")
    ax_tr.xaxis.set_label_position("top")
    ax_tr.get_yaxis().set_visible(False)
    ax_tr.set_xlim(0, textgrid.end_time * 50)

    # --- Bottom left ---
    ax_bl = axes[1][0]
    ax_bl.imshow(melspec.T.flip(1), aspect="auto", origin="upper")
    yticks = []
    ytickslabels = []
    for interval in textgrid.get_tier_by_name("syllables"):
        y1 = interval.start_time * 50
        y2 = interval.end_time * 50
        yticks.append((y1 + y2) / 2)
        ytickslabels.append(interval.text)
        ax_bl.axhline(y1, color="white")
        ax_bl.axhline(y2, color="white")
    ax_bl.set_yticks(yticks)
    ax_bl.set_yticklabels(ytickslabels, fontsize=fontsize)
    ax_bl.get_xaxis().set_visible(False)
    ax_tr.set_xlim(0, textgrid.end_time * 50)

    # --- Bottom right ---
    ab_br = axes[1][1]
    im = ab_br.imshow(similarity_matrix, aspect="equal", origin="upper", vmin=0, vmax=1)
    ab_br.axis("off")
    return fig

Load some data

In [None]:
waveforms_dir = Path("data/waveforms/LibriSpeech")
alignments_dir = Path("data/alignments/LibriSpeech")
STEM = "174-50561-0013"


if waveforms_dir.exists() and alignments_dir.exists():
    wav_path = next(waveforms_dir.rglob(f"{STEM}.flac"))
    textgrid_path = next(alignments_dir.rglob(f"{STEM}.TextGrid"))
else:
    # else revert to the sample that is stored in the repository
    wav_path = "data/sample.flac"
    textgrid_path = "data/sample.TextGrid"

textgrid = tgt.read_textgrid(textgrid_path, include_empty_intervals=False)


decoder = AudioDecoder(wav_path, sample_rate=16000, num_channels=1)
audio = decoder.get_all_samples()
wav = audio.data.cuda()

tMelSpectrogram = MelSpectrogram(16000, 1024, 400, 320, n_mels=100)
tAmplitudeToDB = AmplitudeToDB(top_db=80)

melspec = tAmplitudeToDB(tMelSpectrogram(audio.data))[0]

display(Audio(wav_path))

transcription = " ".join([interval.text for interval in textgrid.get_tier_by_name("words")])
print(f"Transcription: {transcription}")

## 1. WavLM framewise features

We start by initializing ZeroSylBase from the official WavLM Large checkpoint.

In [None]:
model = ZeroSylBase.from_pretrained_checkpoint("checkpoints/WavLM-Large.pt").cuda()

ZeroSylBase extends WavLM, so we can then extract framewise features like we would normally.

In [None]:
with torch.inference_mode():
    features, _ = model.extract_features(wav, output_layer=None)
    features = features.squeeze(0).cpu()
print(features.shape)

When we visualize the cosine similary of WavLM features we can see repeating sound patterns at repeating words

In [None]:
features_sim_mat = cosine_sim_mat(features)
plot_sim_mat(features_sim_mat, melspec, textgrid);

## 2. Meanpooled WavLM features within the detected boundaries

In `demo-detect-boundaries.ipynb` we showed how we can extract boundaries by doing prominence based segmentation on layer 13 of WavLM.

With the ZeroSylBase class we can access these boundaries like this:

In [None]:
boundaries = model.boundaries(wav)
print(boundaries)


plt.figure(figsize=(10, 4), constrained_layout=True)
plt.subplot(2, 1, 1)
plt.imshow(melspec, aspect="auto", origin="lower")
xticks = []
xtickslabels = []
for interval in textgrid.get_tier_by_name("syllables"):
    x1 = interval.start_time * 50
    x2 = interval.end_time * 50
    xticks.append((x1 + x2) / 2)
    xtickslabels.append(interval.text)
    plt.axvline(x1, color="white")
    plt.axvline(x2, color="white")
plt.xticks(xticks, xtickslabels, rotation=90)
plt.gca().get_yaxis().set_visible(False)
plt.gca().xaxis.set_ticks_position("top")
plt.gca().xaxis.set_label_position("top")
plt.title("Syllables from forced alignments")
plt.xlim(0, textgrid.end_time * 50)
plt.subplot(2, 1, 2)
plt.imshow(melspec, aspect="auto", origin="lower")
for t in boundaries:
    plt.axvline(t * 50, color="white")
plt.axis("off")
plt.title("ZeroSyl-Base boundaries")
plt.xlim(0, textgrid.end_time * 50)
plt.show()

We can also meanpool the output embeddings within the segments. This gives us a single continuous embedding for each syllable-like segment.

In [None]:
embeddings, starts, ends = model.segment(wav)

print(embeddings.shape)
print(starts)
print(ends)

When we visualize the similarity of the meanpooled embeddings, we can easily identify repeating syllables:
 - `AE` and `P-AH-L` in "apple"
 - `L-EY` in "lady"
 - `D-IY` in "shady and "lady"


However we also see high similarity between certain syllables that sound similar, but are distinct:
 - `N-AW` and `B-AW`
 - `SH-EY` and `L-EY`

In [None]:
embeddings_50Hz = model.framewise_meanpooled_embeddings(wav) # <- convenient way to access the framewise (duplicated) embeddings
embeddings_sim_mat = cosine_sim_mat(embeddings_50Hz)
fig_unboosted = plot_sim_mat(embeddings_sim_mat, melspec, textgrid)

We then clustered these embeddings such that we can assign pseudo-labels (discrete tokens) to each syllable-like segments.

To obtain the discrete tokens, we load ZeroSylDiscrete that includes a K-means codebook.

The K-means model was trained with cosine distance on the normalized embeddings.

In [None]:
discrete_model = ZeroSylDiscrete.from_pretrained_checkpoint(
    "checkpoints/WavLM-Large.pt", "checkpoints/km10000-centroids-v020.pt"
).cuda()
tokens, starts, ends = discrete_model.tokenize(wav)

When we look at the cluster IDs for each of the four segments corresponding to `L-EY`,
we see `[3051, 3051, 3051, 3051]`. Great! ðŸ¥³

But when we look at other repeating syllables like D-IY, we see `[1905, 3056, 1905, 1905, 7992]`. Less great. ðŸ¥²

In [None]:
plt.figure(figsize=(12, 6), constrained_layout=True)
plt.subplot(2, 1, 1)
plt.imshow(melspec, aspect="auto", origin="lower")
xticks = []
xtickslabels = []
for interval in textgrid.get_tier_by_name("syllables"):
    x1 = interval.start_time * 50
    x2 = interval.end_time * 50
    xticks.append((x1 + x2) / 2)
    xtickslabels.append(interval.text)
    plt.axvline(x1, color="white")
    plt.axvline(x2, color="white")
plt.xticks(xticks, xtickslabels, rotation=90)
plt.gca().get_yaxis().set_visible(False)
plt.gca().xaxis.set_ticks_position("top")
plt.gca().xaxis.set_label_position("top")
plt.title("Syllables from forced alignments")
plt.xlim(0, textgrid.end_time * 50)
plt.subplot(2, 1, 2)
plt.imshow(melspec, aspect="auto", origin="lower")
xticks = []
xtickslabels = []
for token, start, end in zip(tokens, starts, ends):
    x1 = start.item()
    x2 = end.item()
    xticks.append((x1 + x2) / 2)
    xtickslabels.append(str(token.item()))
    plt.axvline(x1, color="white")
    plt.axvline(x2, color="white")
plt.xticks(xticks, xtickslabels, rotation=90)
plt.gca().get_yaxis().set_visible(False)
plt.title("ZeroSyl-Discrete boundaries and tokens")
plt.xlim(0, textgrid.end_time * 50)
plt.show()


Can we do better?

## 3. Boosted framewise features

We now turn to a method inspired by SyllableLM and Sylber. In their work, the continue training the SSL model to predict the continuous meanpooled representations. SyllableLM called this model Syl**Boost**.

So we train out own boosting model but continuing the pretraining of WavLM.
Our training objective is slightly different though: Instead of predicting the continuous representations, we predict the discrete cluster ID.

In [None]:
boosted_model = ZeroSylBase.from_pretrained_checkpoint("checkpoints/zerosyl-boost-v020-step-5000.pt").cuda()

Now when extracting features, the features in the final layer are much more piecewise-constant within the boundaries.

Unlike in Section 2 above, these are not yet meanpooled.

In [None]:
with torch.inference_mode():
    boosted_features, _ = boosted_model.extract_features(wav, output_layer=None)
    boosted_features = boosted_features.squeeze(0).cpu()
boosted_features_sim_mat = cosine_sim_mat(boosted_features)
plot_sim_mat(boosted_features_sim_mat, melspec, textgrid)

## 4. Boosted features meanpooled within the detected boundaries

Now we can go ahead and meanpool the boosted features within the boundaries.

In [None]:
boosted_embeddings, starts, ends = boosted_model.segment(wav)
boosted_embeddings_50Hz = boosted_model.framewise_meanpooled_embeddings(wav) # <- convenient way to access the framewise (duplicated) embeddings
boosted_embeddings_sim_mat = cosine_sim_mat(boosted_embeddings_50Hz)
fig_boosted = plot_sim_mat(boosted_embeddings_sim_mat, melspec, textgrid)


This looks very similar to the figure in Section 3. So we will need to compare side-by-side.

## Compare unboosted (left) and boosted (right) embeddings

In [None]:


def fig_to_base64(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    return base64.b64encode(buf.read()).decode('utf-8')

# Convert both figs to base64
b64_fig1 = fig_to_base64(fig_unboosted)
b64_fig2 = fig_to_base64(fig_boosted)

# Display side-by-side
html = f"""
<div style="display: flex; justify-content: center; gap: 10px;">
  <img src="data:image/png;base64,{b64_fig1}" style="max-width: 45%; height: auto;"/>
  <img src="data:image/png;base64,{b64_fig2}" style="max-width: 45%; height: auto;"/>
</div>
"""
display(HTML(html))


The differences are subtle, but boosting seems better:

1. The boosted embeddings (on the right) are generally less similar when they are not the same syllable.
    - (there is much more dark purple in the figure on the right)
2. On the left plot, the syllable `SH-EY` is similar to all the `D-IY` syllables that follow (turquoise shade)
    - on the right plot, these similarities are lower (towards a darker blue)
    - So, SYLLABLES THAT ARE NOT SIMILAR APPEAR LESS SIMILAR
3. On the left plot the repeated syllables `AE` and `P-AH-L` in "apple" is a green color (quite high similarity)
    - on the right these are as yellow (similar) as we can get.
    - So, SYLLABLES THAT ARE SIMILAR APPEAR MORE SIMILAR

There are several more such examples in the plot.

But we do not yet know if boosting really helps. We will need to cluster and look at the purity and normalized mutual information metrics.