# Self-Supervised Syllable Discovery Based on Speaker-Disentangled HuBERT

In [None]:
# Only required for Google Colab
# Runtime -> Change runtime type -> select "T4 GPU" and save
!git clone https://github.com/ryota-komatsu/speaker_disentangled_hubert
%cd speaker_disentangled_hubert

Install dependencies

In [None]:
!git clone https://github.com/cheoljun95/sdhubert.git src/sdhubert

Download LibriSpeech

In [None]:
!wget -t 0 -c -P data/LibriSpeech https://www.openslr.org/resources/12/test-clean.tar.gz
!tar zxvf data/LibriSpeech/test-clean.tar.gz -C data

In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import torchaudio

from src.s5hubert import S5HubertForSyllableDiscovery

In [None]:
syllable_alignment_path = "src/sdhubert/files/librispeech_syllable_test.json"
wav_name = "test-clean/61/70968/61-70968-0021.flac"
wav_path = "data/LibriSpeech/test-clean/61/70968/61-70968-0021.flac"

Load model

In [None]:
model = S5HubertForSyllableDiscovery.from_pretrained("ryota-komatsu/s5-hubert").cuda()

Load audio

In [None]:
waveform, sr = torchaudio.load(wav_path)
waveform = torchaudio.functional.resample(waveform, sr, 16000)

Inference

In [None]:
batch_outputs = model(waveform.cuda())
batch_outputs

In [None]:
frame_boundary = batch_outputs[0]["durations"].cumsum(0).cpu().numpy()
frame_similarity = (batch_outputs[0]["dense"] @ batch_outputs[0]["dense"].T).cpu().numpy()

Load ground truth syllable alignment

In [None]:
refs = {}
with open(syllable_alignment_path) as f:
    syllables = json.load(f)
    for item in syllables.values():
        boundary = []
        labels = []
        ticks = []
        for syllable in item["syllables"]:
            start = round(float(syllable["start"]) / 0.02)
            end = round(float(syllable["end"]) / 0.02)

            boundary.append([start, end])
            labels.append(syllable["label"])
            ticks.append((start + end) / 2)
        refs.update(
            {
                item["file_name"]: {
                    "boundary": np.unique(boundary),
                    "labels": labels,
                    "ticks": ticks,
                }
            }
        )

ref_boundary = refs[wav_name]["boundary"]
labels = refs[wav_name]["labels"]
ticks = refs[wav_name]["ticks"]

Plot results

In [None]:
plt.figure()
plt.imshow(frame_similarity)
plt.vlines(ref_boundary, 0, frame_similarity.shape[1] - 1, colors="red", label="ground truth")
plt.vlines(
    frame_boundary, 0, frame_similarity.shape[1] - 1, colors="white", linestyles="dotted", label="prediction"
)
plt.xticks(ticks=ticks, labels=labels, rotation=-60, color="red")
plt.yticks([], [])
plt.legend()