In [1]:
from torchaudio.datasets import LIBRISPEECH
import torch
import torch.nn as nn
import numpy as np
import h5py
from itertools import groupby
import Levenshtein
from itertools import product
from tqdm import tqdm

In [2]:
#units_dataset_path = "/mnt/wsl/nvme/data/LibriSpeech/cluster-ids/wavlm-large-layer-11-kmeans-500/train-clean-100.h5"
units_dataset_path = "/mnt/wsl/nvme/data/LibriSpeech/cluster-ids/wavlm-large-layer-24-spherical-kmeans-500/train-clean-100.h5"
output_path = "/mnt/wsl/nvme/data/LibriSpeech/cluster-ids/wavlm-large-layer-24-spherical-kmeans-500//most-likely-phone-labels-train-clean-100.npy"
phonemes_dataset_path = "/mnt/wsl/nvme/data/LibriSpeech/phonemes/duplicated/librispeech-phonemes-mfa-us-arpa-train-clean-100.h5"

In [3]:
def strip_number(string: str):
    return ''.join([i for i in string if not i.isdigit()])

In [4]:
phonemes_by_key = {}
with h5py.File(phonemes_dataset_path, "r") as f:
    for key in tqdm(f.keys()):
        phonemes = f[key][()].decode('utf-8').split()
        phonemes = [strip_number(p) for p in phonemes] # 100 Hz
        phonemes_by_key[key] = phonemes

100%|██████████| 28539/28539 [00:16<00:00, 1781.00it/s]


In [5]:
units_by_key = {}

with h5py.File(units_dataset_path, "r") as f:
    for key in tqdm(f.keys()):
        units = f[key][:] # 50 Hz
        units = np.repeat(units, 2, axis=0) # 100 Hz
        units_by_key[key] = units

100%|██████████| 28539/28539 [00:03<00:00, 7837.11it/s]


In [6]:
keys = set(units_by_key.keys()) & set(phonemes_by_key.keys())
assert len(keys) == len(units_by_key) == len(phonemes_by_key)

In [7]:
for key in tqdm(keys):
    phonemes = phonemes_by_key[key]
    units = units_by_key[key]
    minlen = min(len(phonemes), len(units))
    phonemes_by_key[key] = phonemes[:minlen]
    units_by_key[key] = units[:minlen]

100%|██████████| 28539/28539 [00:00<00:00, 112255.03it/s]


In [8]:
units = np.concatenate([units_by_key[key] for key in keys], axis=0)
phonemes = np.concatenate([phonemes_by_key[key] for key in keys], axis=0)

In [9]:
unique_units = np.unique(units)
len(unique_units)

500

In [10]:
from sklearn.metrics.cluster import contingency_matrix
# import labelencoder
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
label_encoder.fit(phonemes)

In [11]:
label_encoder.classes_

array(['AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH',
       'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N',
       'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'SIL', 'T', 'TH', 'UH',
       'UW', 'V', 'W', 'Y', 'Z', 'ZH'], dtype='<U3')

In [12]:
phonemes_encoded = label_encoder.transform(phonemes)

In [13]:
contingency = contingency_matrix(phonemes_encoded, units)

In [14]:
contingency.shape

(40, 500)

In [15]:
most_likely_phone_ids_for_units = np.argmax(contingency, axis=0)

In [16]:
most_likely_phone_labels_for_units = label_encoder.inverse_transform(most_likely_phone_ids_for_units)

In [17]:
most_likely_phone_labels_for_units

array(['ER', 'SIL', 'N', 'AE', 'W', 'Z', 'AE', 'SIL', 'M', 'SIL', 'S',
       'N', 'SIL', 'T', 'AH', 'IH', 'DH', 'P', 'SIL', 'SIL', 'SIL', 'AO',
       'K', 'D', 'F', 'SIL', 'L', 'OW', 'AA', 'IY', 'ER', 'NG', 'Y',
       'SIL', 'SIL', 'AY', 'R', 'SIL', 'T', 'T', 'AA', 'SIL', 'R', 'R',
       'S', 'T', 'AH', 'D', 'AY', 'HH', 'L', 'M', 'L', 'SIL', 'V', 'AH',
       'AH', 'AH', 'N', 'K', 'SIL', 'ER', 'IY', 'EY', 'S', 'EY', 'S',
       'SIL', 'W', 'R', 'IY', 'S', 'K', 'IY', 'SIL', 'T', 'T', 'T', 'OW',
       'AH', 'AH', 'EY', 'AW', 'T', 'S', 'L', 'R', 'EH', 'SIL', 'AE',
       'IH', 'P', 'T', 'AE', 'L', 'N', 'SIL', 'S', 'SIL', 'L', 'SIL', 'K',
       'EY', 'F', 'N', 'SIL', 'JH', 'UH', 'IH', 'T', 'IY', 'SIL', 'UW',
       'T', 'R', 'IH', 'ER', 'L', 'P', 'SIL', 'Z', 'SIL', 'SH', 'AY',
       'SIL', 'Y', 'AH', 'P', 'AY', 'SIL', 'AA', 'N', 'SIL', 'S', 'M',
       'EH', 'SIL', 'S', 'AE', 'F', 'SIL', 'SIL', 'S', 'S', 'N', 'R', 'S',
       'AE', 'IY', 'AH', 'IY', 'S', 'AE', 'B', 'AH', 'AY', 'ER',

In [18]:
np.save(output_path, most_likely_phone_labels_for_units)

In [19]:
output_path

'/mnt/wsl/nvme/data/LibriSpeech/cluster-ids/wavlm-large-layer-24-spherical-kmeans-500//most-likely-phone-labels-train-clean-100.npy'