In [4]:
from pathlib import Path

import numpy as np
import tgt
import torch
from sklearn.metrics.cluster import contingency_matrix
from sklearn.preprocessing import LabelEncoder
from torchcodec.decoders import AudioDecoder
from tqdm import tqdm

from zerosyl.model import ZeroSylDiscrete

In [5]:
waveform_dir = Path("data/waveforms/LibriSpeech")
alignment_dir = Path("data/alignments/LibriSpeech")

waveform_paths = {p.stem: p for p in waveform_dir.glob("dev*/**/*.flac")}
alignment_paths = {p.stem: p for p in alignment_dir.glob("dev*/**/*.TextGrid")}

assert len(waveform_paths) > 0
assert len(alignment_paths) > 0
common_stems = waveform_paths.keys() & alignment_paths.keys()
assert len(common_stems) > 0
waveform_paths = [waveform_paths[stem] for stem in common_stems]
alignment_paths = [alignment_paths[stem] for stem in common_stems]

In [None]:
# to ensure you are working with all utterances from dev-clean and dev-other sets
assert len(common_stems) == 5567

In [6]:
model = ZeroSylDiscrete.from_pretrained_checkpoint("checkpoints/WavLM-Large.pt", "checkpoints/km10000-centroids-v020.pt").cuda()

  WeightNorm.apply(module, name, dim)


In [7]:
all_syllables = []
all_tokens = []
for waveform_path, alignment_path in zip(tqdm(waveform_paths, desc="loading data"), alignment_paths):
    # extract ground truth syllables from the alignment file
    tg = tgt.read_textgrid(alignment_path, include_empty_intervals=True)
    tier = tg.get_tier_by_name("syllables")
    timesteps = np.arange(0.5/100, tg.end_time, 1/100) # 100 Hz
    syllables = [tier.get_annotations_by_time(t)[0].text for t in timesteps]
    syllables = [("SIL" if s == "" else s) for s in syllables]

    # extract tokens from the speech
    decoder = AudioDecoder(waveform_path, sample_rate=16000, num_channels=1)
    audio = decoder.get_all_samples()
    tokens, starts, ends = model.tokenize(audio.data.cuda())
    tokens_duped = torch.repeat_interleave(tokens, ends-starts, dim=0) # 50Hz
    tokens_duped = torch.repeat_interleave(tokens_duped, 2) # 100Hz
    tokens_duped = tokens_duped.cpu().numpy()

    assert abs(len(syllables) - len(tokens_duped)) <= 3
    minlen = min(len(syllables), len(tokens_duped))
    syllables = syllables[:minlen]
    tokens_duped = tokens_duped[:minlen]

    all_syllables.append(syllables)
    all_tokens.append(tokens_duped)

all_syllables = np.concatenate(all_syllables, axis=0)
all_tokens = np.concatenate(all_tokens, axis=0)

loading data: 100%|██████████| 5567/5567 [03:49<00:00, 24.25it/s]


In [8]:
syllable_label_encoder = LabelEncoder()
all_syllable_ids = syllable_label_encoder.fit_transform(all_syllables)

token_label_encoder = LabelEncoder()
all_token_ids = token_label_encoder.fit_transform(all_tokens)

n_syllables = len(syllable_label_encoder.classes_)
n_tokens = len(token_label_encoder.classes_)

In [9]:
joint_counts = contingency_matrix(all_syllable_ids, all_token_ids)

In [10]:
# joint probabilities
p_yz = joint_counts / joint_counts.sum() + 1e-10 # (I,J)
I, J = p_yz.shape

# marginal probabilities
p_y = np.sum(p_yz, axis=1) # (I,)
p_z = np.sum(p_yz, axis=0) # (J,)

# most likely target label
z_star = np.argmax(p_yz, axis=1) # (I,)
# most likely syllable label
y_star = np.argmax(p_yz, axis=0) # (J,)

# conditional probabilities
p_y_given_z = p_yz / p_z[None,:] # (I,J)
p_z_given_y = p_yz / p_y[:, None] # (I,J)

# syllable purity
sp = np.sum(p_y_given_z[y_star,np.arange(J)] * p_z)

# cluster purity
cp = np.sum(p_z_given_y[np.arange(I),z_star] * p_y)

# syllable-normalized mutual information
snmi = - (p_yz * np.log(p_yz /  p_y[:, None] / p_z[None,:])).sum() / (p_y * np.log(p_y)).sum()

In [11]:
print(f"Syllable purity:                        {sp.item():.4f}")
print(f"Cluster purity:                         {cp.item():.4f}")
print(f"Syllable-normalized mutial information: {snmi.item():.4f}")

Syllable purity:                        0.6693
Cluster purity:                         0.1859
Syllable-normalized mutial information: 0.7896
