# Explainer: ZeroSyl's training free boundary detection

This notebook lays out the steps in our method in an accessible way. This logic is also implemented in classes zerosyl/zerosyl.py.

Install ZeroSyl and download sample data

In [None]:
!pip install zerosyl
!wget https://storage.googleapis.com/zerospeech-checkpoints/5895-34629-0010.flac
!wget https://storage.googleapis.com/zerospeech-checkpoints/5895-34629-0010.TextGrid

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
from IPython.display import Audio, display
from scipy.cluster.hierarchy import cut_tree, linkage
from scipy.signal import find_peaks

from zerosyl import WavLM

Load WavLM Large

In [None]:
wavlm = WavLM.from_remote()

Load a waveform

In [None]:
wav, sr = torchaudio.load("5895-34629-0010.flac")

display(Audio(wav, rate=16000))

Preprocess the waveform

In [None]:
# loudness normalization
wav = torch.nn.functional.layer_norm(wav, wav.shape)
# zero-pad such that the output features will be perfectly aligned with 20ms intervals
wav = torch.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))

Extract boundary features from layer 13

In [None]:
with torch.inference_mode():
    boundary_features, _ = wavlm.extract_features(wav, output_layer=13)

boundary_features = boundary_features.squeeze(0).cpu().numpy()

print(boundary_features.shape)

Perform boundary detection

In [None]:
# compute the L2 norm signal
norms = np.linalg.norm(boundary_features, axis=-1)

# normalize the norm signal
norms = (norms - norms.mean()) / norms.std()

# smooth the L2 norm signal
kernel = np.ones(3) / 3
pad_len = 3 // 2
norms_padded = np.pad(norms, (pad_len, pad_len), mode="edge")
norms_smooth = np.convolve(norms_padded, kernel, mode="valid")

# performan prominence based peak detection
peaks, _ = find_peaks(norms_smooth, prominence=0.45)

# use peaks to detemine boundaries
boundaries = [0] + peaks.tolist() + [len(boundary_features)]

Visualize the boundaries

In [None]:
# compute mel spectrogram
tMel = torchaudio.transforms.MelSpectrogram(
    n_fft=1024,
    win_length=400,
    hop_length=320,
)
tDB = torchaudio.transforms.AmplitudeToDB(top_db=80)
melspec = tDB(tMel(wav.squeeze()))

In [None]:
xmin, xmax = 0, melspec.size(1)

plt.figure(figsize=(12, 6))

plt.subplot(2, 1, 1)
plt.imshow(melspec, aspect="auto", origin="lower")
plt.axis("off")
for b in boundaries:
    plt.axvline(b, c="w")
plt.xlim(xmin, xmax)

plt.subplot(2, 1, 2)
for b in boundaries:
    plt.axvline(b, c="gray", alpha=0.2)

plt.plot(norms, label="norms")
plt.plot(norms_smooth, label="smoothed norms")
plt.xlim(xmin, xmax)
plt.legend()

Listen to the segments (with short silences in between)

In [None]:
listen_samples = []
listen_samples.append(np.zeros(8000))
for start_frame, end_frame in zip(boundaries[:-1], boundaries[1:]):
    listen_samples.append(wav[0, start_frame * 320 : end_frame * 320])
    listen_samples.append(np.zeros(8000))
listen_samples = np.concat(listen_samples, axis=0)
display(Audio(listen_samples, rate=16000))

Extract semantic features from layer 22

In [None]:
with torch.inference_mode():
    semantic_features, _ = wavlm.extract_features(wav, output_layer=22)

semantic_features = semantic_features.squeeze(0)

print(semantic_features).shape

Meanpool semantic features within the predicted boundaries

In [None]:
starts = torch.tensor(boundaries[:-1], device=wav.device)
ends = torch.tensor(boundaries[1:], device=wav.device)
embeddings = [
    semantic_features[start:end].mean(dim=0) for start, end in zip(starts, ends)
]
embeddings = torch.stack(embeddings)

In [None]:
embeddings.shape

K-means discretization

In [None]:
# Load the K-Means centroids
centroids = torch.hub.load_state_dict_from_url(
    "https://storage.googleapis.com/zerospeech-checkpoints/zerosyl-v040-centroids-k-10000.pt"
)
# Find the ID of the nearest centroid
ids = torch.cdist(embeddings.cpu(), centroids).argmin(1)

In [None]:
plt.figure(figsize=(12, 3))
plt.subplot(1, 1, 1)
plt.imshow(melspec, aspect="auto", origin="lower")
plt.axis("off")
for b1, b2, id in zip(boundaries[:-1], boundaries[1:], ids):
    plt.axvline(b1, c="w")
    plt.axvline(b2, c="w")
    plt.text(
        (b1 + b2) / 2, 64, str(id.item()), rotation=90, c="w", ha="center", va="center"
    )

Silences could be fragmented (such as at the start of utterance 5895-34629-0010).

After clustering, multiple centroids correspond to silences.
We find that language modeling performance improves when these entries are collapsed to a single vocabulary item.
We do this in an unsupervised manner by performing hierarchical clustering on the centroids.
From informal inspection, we know that the two main branches in agglomerative hierarchical clustering correspond to silences and non-silences, respectively.
We pick the smaller branch, which represents silences, and map these items to one vocabulary item.
This reduces the vocabulary size from 10\,000 to 9\,116.

In [None]:
# agglomerative clustering
linkage_matrix = linkage(
    centroids.numpy(), method="ward", metric="euclidean", optimal_ordering=False
)
# cut dendrogram into 2 main branches
silences = cut_tree(linkage_matrix, 2)[:, 0]
# the smaller branch should be silences
if silences.sum() > (1 - silences).sum():
    silences = 1 - silences
# to torch tensor
silences = torch.from_numpy(silences).bool()

Now we can identify segments that are silences.

In [None]:
print(silences[ids])

Create a mapping that will merge all the silence (=True) entries while placing all the non-silences at the start of the codebook and placing the single silence at the end of the codebook

In [None]:
order = torch.argsort(silences)  # [0,0,....,0,0,1,1,...,1,1]
SIL = torch.argmax(silences[order].long()).item()  # position of first 1
mapping = torch.empty_like(order)
mapping[order] = torch.arange(len(order))
mapping[mapping > SIL] = SIL

print(f"The new (single) silence token has the ID: {SIL}")

The new sequence of ids is:

In [None]:
ids_remapped = mapping[ids]

ids_remapped

Now we can collapse consecutive duplicates

In [None]:
not_repeated = torch.ones_like(ids_remapped, dtype=torch.bool)
not_repeated[1:] = ~torch.logical_and(
    ids_remapped[1:] == ids_remapped[:-1], ids_remapped[1:] == SIL
)
is_end = torch.ones_like(ids_remapped, dtype=torch.bool)
is_end[:-1] = ~torch.logical_and(
    ids_remapped[1:] == ids_remapped[:-1], ids_remapped[1:] == SIL
)
starts_merged = starts[not_repeated]
ends_merged = ends[is_end]
ids_merged = ids_remapped[not_repeated]

In [None]:
plt.figure(figsize=(12, 3))
plt.subplot(1, 1, 1)
plt.imshow(melspec, aspect="auto", origin="lower")
plt.axis("off")
for start, end, id in zip(starts_merged, ends_merged, ids_merged):
    plt.axvline(start, c="w")
    plt.axvline(end, c="w")
    plt.text(
        (start + end) / 2,
        64,
        str(id.item()),
        rotation=90,
        c="w",
        ha="center",
        va="center",
    )