# Running mHubert using transformers

Here we run mHubert using transformers abstractions and apply the quantization on top.
We use model after 2nd iteration which is slightly worse. But in that way we can utilize released faiss index.

In [None]:
# follow docs https://huggingface.co/docs/transformers/en/model_doc/hubert#transformers.HubertModel.forward.example
# to extract latent representation for the audio file after 9th layer

import os
import torch
from transformers import HubertModel
import soundfile as sf
import faiss
from huggingface_hub import hf_hub_download

os.environ["CUDA_VISIBLE_DEVICES"] = "3"


def get_model():
    model = HubertModel.from_pretrained("utter-project/mHuBERT-147-base-2nd-iter").cuda()
    # Specify the repository and the filename you want to download
    repo_id = "utter-project/mHuBERT-147"
    filename = "mhubert147_faiss.index"
    download_dir = "./"

    # Download the file to the specified directory
    index_path = download_dir + filename
    if not os.path.isfile(index_path):
        index_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=download_dir)
    index: faiss.IndexPreTransform = faiss.read_index(index_path)
    index_ivf = faiss.extract_index_ivf(index)
    return model, index, index_ivf, index_path


def extract_latent(model, audio_path):
    wav, sr = sf.read(audio_path)
    assert sr == 16000
    x = torch.tensor(wav).unsqueeze(0).float().cuda()
    # need to do mean / variance normalization
    x = (x - x.mean()) / (torch.sqrt(x.var() + 1e-7))
    # https://huggingface.co/utter-project/mHuBERT-147/discussions/6#668544a0e270025784fb469c
    latent = model(x, output_hidden_states=True).hidden_states[9]
    return latent # 1 x time x 768


def extract_labels_from_latent(index, index_ivf, latent):
    xq = latent.reshape(latent.size(0) * latent.size(1), -1).float().detach().cpu().numpy()
    opq_mt = faiss.downcast_VectorTransform(index.chain.at(0))
    #Apply pre-transform to query
    xq_t = opq_mt.apply_py(xq)
    #Get centroids C and distances DC on a pre-transformed index
    _, C = index_ivf.quantizer.search(xq_t, 1)
    return C[:, 0]  # (time)


def extract_labels(model, index, index_ivf, audio_path):
    latent = extract_latent(model, audio_path)
    labels = extract_labels_from_latent(index, index_ivf, latent)
    return labels

In [10]:
model, index, index_ivf, index_path = get_model()
this_path = "rms_arctic_a0001.wav"
orig_labels = extract_labels(model, index, index_ivf, this_path)

# Tracing mHubert

Below we implement mHubert together with quantization as a torch module, so it can be traced

In [3]:
from torch import nn
import numpy as np
from typing import Tuple


class TorchFaiss(nn.Module):
    def __init__(self, index_path):
        super().__init__()

        # Load FAISS index
        index = faiss.read_index(index_path)
        index_ivf = faiss.extract_index_ivf(index)

        # get the transformation
        vt = index.chain.at(0)
        # extracting transformation, works only for a linear one
        I = np.eye(vt.d_in, dtype=np.float32)
        transform_mat = vt.apply_py(I)
        self.transform = nn.Linear(transform_mat.shape[0], transform_mat.shape[1], bias=False)
        self.transform.weight = nn.Parameter(torch.tensor(transform_mat).T)
        self.transform.requires_grad_(False)

        # get the index search
        nlist = index_ivf.nlist  # Number of clusters
        d = index_ivf.d  # Vector dimension
        centroids_npy = np.zeros((nlist, d), dtype=np.float32)
        index_ivf.quantizer.reconstruct_n(0, nlist, centroids_npy)
        centroids_arr = torch.tensor(centroids_npy)
       
        self.centroids_norm = nn.Parameter((centroids_arr ** 2).sum(dim=1, keepdim=True).T)
        self.centroids_norm.requires_grad_(False)
        self.centroids = nn.Linear(centroids_npy.shape[1], centroids_arr.shape[0], bias=False)
        self.centroids.weight = nn.Parameter(centroids_arr)
        self.centroids.requires_grad_(False)
      
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: torch.Tensor (batch x time x dim)
        latent to convert to centroid indices

        labels: torch.Tensor(batch x time)
        """
        x = self.transform(x)
        x_norm = (x ** 2).sum(dim=2, keepdim=True)  # (batch x time x 1)
        logits = self.centroids(x)
        dists = x_norm + self.centroids_norm - 2 * logits
        _, labels = torch.min(dists, dim=2)
        return labels


class HubertExtractor(torch.nn.Module):
    """
    Wrapper around hubert extractor, that combines
    hubert model and discretization using kmeans.
    """
    def __init__(self, hf_repo_id, index_path, full_precision: bool = True):
        super().__init__()
        self._float_type = torch.float32 if full_precision else torch.float16
        self._model = HubertModel.from_pretrained(hf_repo_id)
        self._faiss = TorchFaiss(index_path)

    def _compute_stats(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes mean and std taking into account padding in the sequences

        Returns
        -------
        mean: torch.Tensor
            (batch,) corrected mean, per sequence in batch as if sequences were not padded
        std: torch.Tensor
            (batch,) corrected standard deviation as if sequences were not padded
        """
        total_len = torch.ones_like(audio_len) * audio.size(1)
        correction = (total_len / audio_len).type(self._float_type)
        pad_len = total_len - audio_len

        mean = torch.mean(audio, dim=1)
        # apply correction, replacing denominator
        mean = mean * correction

        # this variance formula corresponds to torch.var(unbiased=False) or ddof=0 in numpy
        var = torch.mean(torch.square(audio - mean.unsqueeze(1)), dim=1)
        # subtract extra values that are added in padding regions
        var = var - pad_len * (torch.square(mean) / total_len)
        # apply correction, replacing denominator
        var = var * correction

        return mean, torch.sqrt(var + 1e-7)

    def _create_mask(self, audio: torch.Tensor, audio_len: torch.Tensor) -> torch.Tensor:
        """
        Creates binary mask with "0" for padded regions.
        Used to zerofy padding after normalization

        Returns
        -------
        mask: torch.Tensor
            (batch, samples_num) - binary mask with "0" for padded regions. can be multiplied
            with audio to zerofy padded regions
        """
        max_len = audio.size(1)
        ranged = torch.arange(max_len, device=audio_len.device, dtype=audio_len.dtype)
        batch_ranged = ranged.expand(audio_len.size(0), max_len)
        mask = batch_ranged < audio_len.unsqueeze(1)
        return mask

    def _mean_var_norm(self, audio: torch.Tensor, audio_len: torch.Tensor) -> torch.Tensor:
        """
        Implements traceable version of mean_var_norm from Wav2Vec2FeatureExtractor:
        https://github.com/huggingface/transformers/blob/31d452c68b34c2567b62924ee0df40a83cbc52d5/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L81
        """
        mean, std = self._compute_stats(audio, audio_len)
        audio = (audio - mean.unsqueeze(1)) / std.unsqueeze(1)
        # need to make padding zero after normalization
        mask = self._create_mask(audio, audio_len)
        audio = audio * mask
        return audio

    def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> torch.Tensor:
        """
        Extracts labels from raw audio. Runs inference with pre-trained hubert model,
        checks which centroid is closest for each frame's continuous representation
        and returns its id.

        Parameters
        ----------
        audio: torch.Tensor
            (batch_size x samples_num) - audio samples in int16 in range (-INT_MAX; INT_MAX)
        audio_len: torch.Tensor
            (batch_size,) - input that specifies actual length of audio within batch.
            Should be int32, stored on CPU regardless of config (cuda, half)

        Returns
        -------
        labels: torch.Tensor
            (batch_size x frames_num) - pseudo annotation for audio on frame level.
        """
        # convert audio from short to float
        audio = audio.type(self._float_type) / 32768.0  # to range (-1, 1)

        # run mean/var normalization
        audio = self._mean_var_norm(audio, audio_len)

        # Attention does not really affect
        # `attention_mask` note in https://huggingface.co/docs/transformers/en/model_doc/hubert#transformers.HubertModel.forward
        #attention_mask = torch.arange(
        #    audio.size(1),
        #    device=audio.device
        #)[None, :] < audio_len[:, None]
        #attention_mask = attention_mask.long()

        # run inference with transformer
        emb = self._model(audio, output_hidden_states=True).hidden_states[9]

        # compute labels
        labels = self._faiss(emb)
        return labels

In [5]:
# tracing in half precision works, but model mysteriously fails on longer inputs
extractor = HubertExtractor("utter-project/mHuBERT-147-base-2nd-iter", index_path, full_precision=False).cuda().half()
wav, sr = sf.read(
    "rms_arctic_a0001.wav",
    dtype="int16"
)
assert sr == 16000
x = torch.tensor(wav).unsqueeze(0).cuda()
x_len = torch.tensor([x.shape[1]]).cuda()
trace_labels = extractor(x, x_len).detach().cpu().numpy()[0]

In [None]:
import matplotlib.pylab as plt
print(orig_labels.shape)
print(trace_labels.shape)

plt.plot(orig_labels)
plt.plot(trace_labels, alpha=0.5)
plt.legend(["original", "traced"])
plt.grid()

In [None]:
# trace the extractor and save jit
import os

out_path = "./mhubert147_fp16_cuda.jit"
for param in extractor.parameters():
    param.requires_grad = False
with torch.no_grad():
        # trace the model and save result
        extractor_traced = torch.jit.trace(extractor, [x, x_len])
        extractor_traced.save(out_path)

# Running batched inference

Check how much padding affects extracted hubert labels.
There is discrepancy expected, because attention mask is not really utilized.

In [None]:
# now check that batching works
dtype = "int16"
wav1, _ = sf.read("slt_arctic_a0001.wav", dtype=dtype)
# rms is shorter and will be padded
wav2, _ = sf.read("rms_arctic_a0001.wav", dtype=dtype)
# Convert to tensors
wav1 = torch.tensor(wav1)
wav2 = torch.tensor(wav2)
x_len = torch.tensor([wav1.shape[0], wav2.shape[0]]).cuda()
# Find max length and pad
max_len = max(wav1.shape[0], wav2.shape[0])
print(f"padding {wav1.shape[0]} and {wav2.shape[0]} to {max_len}")
wav1 = torch.nn.functional.pad(wav1, (0, max_len - wav1.shape[0]))
wav2 = torch.nn.functional.pad(wav2, (0, max_len - wav2.shape[0]))
# Stack and move to CUDA
x = torch.stack([wav1, wav2]).to("cuda")

In [None]:
import matplotlib.pylab as plt

traced_extractor = torch.jit.load(out_path).to(torch.device("cuda"))
batched_labels = traced_extractor(x, x_len)[1].detach().cpu().numpy()

print(orig_labels.shape)
print(batched_labels.shape)

plt.plot(orig_labels)
plt.plot(batched_labels, alpha=0.5)
plt.legend(["original", "batch"])
plt.grid()

# Upload traced model

Upload the traced model to the HF repo

In [None]:
import os
from huggingface_hub import upload_file

upload_file(
    path_or_fileobj=out_path,
    path_in_repo=os.path.basename(out_path),
    repo_id="balacoon/mhubert-147",
    repo_type="model"
)