In [1]:
!pip install "numpy<2"



In [2]:
from itertools import groupby

import numpy as np
import torch

from dpdp import dpdp
from dpwfst import dpwfst

Let's load some data

`features` is a $T \times D$ tensor. It contains the HuBERT features for a single utterance.

`codebook_centroids` is a $K \times D$ tensor. It contains the codebook centroids from a $K$-means clustering of the features.

`codebook_labels` is a numpy array of length $K$. It contains the most likely phone label for each codebook entry.

In [16]:
features = torch.load("data/1272-128104-0000-hubert-bshall-layer-7.pt")
codebook_centroids = torch.load("data/hubert-bshall-layer-7-kmeans-200-centroids.pt")
codebook_labels = np.load("data/hubert-bshall-layer-7-kmeans-200-labels.npy")

We can extract $K$-means units from the utterance and label them using the codebook labels.

In [17]:
distances = torch.cdist(features, codebook_centroids, p=2.0)
units = torch.argmin(distances, dim=1)

Before we label the units, let's remove consecutive duplicate labels.

In [18]:
units_deduped = [k.item() for k, g in groupby(units)]

Now we can map the units to phone labels.

In [19]:
units_labels = codebook_labels[units_deduped]

In [20]:
for i, (unit, label) in enumerate(zip(units_deduped, units_labels)):
    if i < 10:
        continue
    print(f"{unit:03d}: {label}")
    if i > 30:
        break


096: SIL
155: M
087: IH
121: SH
112: T
041: T
061: ER
141: ER
107: T
179: K
068: K
154: F
133: W
065: IH
159: L
189: L
140: L
125: D
009: T
061: ER
059: ER
080: R


The LibriSpeech transcription for this utterance is:

`Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.`

We can see that the labeled units sort of correspond to the phones that are expected given the transcription.

Here is a visualization of the the units along with the true phone labels.

![](dpdp-visualization.svg)

To obtain the DPDP units, we can run the following code.

In [21]:
lmbda = 8 # <- change this to see the effect of the penalty

dpdp_units = dpdp(features, codebook_centroids, lmbda=lmbda)

Just like before, we can remove consecutive duplicate labels and then map the units to phone labels.

In [22]:
dpdp_units_deduped = [k.item() for k, g in groupby(dpdp_units)]
dpdp_units_labels = codebook_labels[dpdp_units_deduped]
for i, (unit, label) in enumerate(zip(dpdp_units_deduped, dpdp_units_labels)):
    print(f"{unit:03d}: {label}")
    if i > 30:
        break

162: SIL
016: SIL
135: SIL
094: SIL
171: SIL
013: SIL
171: SIL
096: SIL
155: M
087: IH
121: SH
112: T
041: T
061: ER
107: T
068: K
065: IH
159: L
189: L
140: L
125: D
061: ER
059: ER
080: R
087: IH
147: Z
167: Z
052: AH
088: Y
001: AH
107: T
075: SIL


We can achieve the same result using the DPWFST algorithm.

In [23]:
dpwfst_units = dpwfst(features, codebook_centroids, lmbda=lmbda)

In [24]:
dpwfst_units_deduped = [k.item() for k, g in groupby(dpwfst_units)]
dpwfst_units_labels = codebook_labels[dpwfst_units_deduped]
for i, (unit, label) in enumerate(zip(dpwfst_units_deduped, dpwfst_units_labels)):
    print(f"{unit:03d}: {label}")
    if i > 30:
        break

162: SIL
016: SIL
135: SIL
094: SIL
171: SIL
013: SIL
171: SIL
096: SIL
155: M
087: IH
121: SH
112: T
041: T
061: ER
107: T
068: K
065: IH
159: L
189: L
140: L
125: D
061: ER
059: ER
080: R
087: IH
147: Z
167: Z
052: AH
088: Y
001: AH
107: T
075: SIL


In [25]:
%timeit dpdp(features, codebook_centroids, lmbda=lmbda)

33.6 ms ± 1.21 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [26]:
%timeit dpwfst(features, codebook_centroids, lmbda=lmbda)

789 ms ± 79.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The DP-WFST algorithm is much slower than the DPDP algorithm if we search the full codebook.

However, the DP-WFST algorithm is more flexible and we can limit the search to a few nearest neighbors and speed it up significantly.

In [27]:
%timeit dpwfst(features, codebook_centroids, lmbda=lmbda, num_neighbors=5)

2.75 ms ± 199 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Now let's show the labeled units when we limit the search to the 5 nearest neighbors.

In [28]:
dpwfst_units_5 = dpwfst(features, codebook_centroids, lmbda=lmbda, num_neighbors=5)
dpwfst_units_5_deduped = [k.item() for k, g in groupby(dpwfst_units_5)]
dpwfst_units_5_labels = codebook_labels[dpwfst_units_5_deduped]
for i, (unit, label) in enumerate(zip(dpwfst_units_5_deduped, dpwfst_units_5_labels)):
    print(f"{unit:03d}: {label}")
    if i > 30:
        break

162: SIL
016: SIL
135: SIL
094: SIL
171: SIL
013: SIL
171: SIL
096: SIL
155: M
087: IH
121: SH
112: T
041: T
061: ER
107: T
068: K
065: IH
159: L
189: L
140: L
125: D
061: ER
059: ER
080: R
087: IH
147: Z
167: Z
052: AH
088: Y
001: AH
107: T
075: SIL


They look almost identical, but we have reduced the execution time significantly!!!