In [None]:
# discrete unit model
!wget https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km200/km.bin
# tts model
!wget https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/tts_checkpoint_best.pt
# waveglow
!wget https://dl.fbaipublicfiles.com/textless_nlp/gslm/waveglow_256channels_new.pt
# download dummpy speech
!wget https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav

In [None]:
%pip install transformers asrp

In [None]:
from itertools import groupby

import joblib
import torch
import torchaudio
import numpy
from transformers import (
    Wav2Vec2FeatureExtractor, 
    HubertModel,
)


class HubertCode(object):
    """
    HuBERT unit extraction tool.
    Directly use it as a function to get units.
    """

    def __init__(self, 
        hubert_model, 
        km_path, 
        km_layer, 
        return_diff=False, 
        sampling_rate=16000,
    ):
        """
        Initialize an object for HuBERT unit extraction.
        """
        self.processor = (
            Wav2Vec2FeatureExtractor
                .from_pretrained(hubert_model))
        self.model = (
            HubertModel
                .from_pretrained(hubert_model))
        self.model.eval()
        self.sampling_rate = sampling_rate
        self.km_model = joblib.load(km_path)
        self.km_layer = km_layer
        self.return_diff = return_diff
        self.C_np = (
            self.km_model.cluster_centers_.transpose())
        self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)

        self.C = torch.from_numpy(self.C_np)
        self.Cnorm = torch.from_numpy(self.Cnorm_np)
        if torch.cuda.is_available():
            self.C = self.C.cuda()
            self.Cnorm = self.Cnorm.cuda()
            self.model = self.model.cuda()

    def __call__(self, filepath, merge=True):
        """
        Unit extraction.
        * merge: to collapse repeated units
        """
        with torch.no_grad():
            speech, sr = torchaudio.load(filepath)
            if sr != self.sampling_rate:
                resampler = torchaudio.transforms.Resample(
                    orig_freq=sr, 
                    new_freq=self.sampling_rate)
                speech = resampler.forward(speech.squeeze(0)).numpy()
            else:
                speech = speech.squeeze(0).numpy()
            input_values = self.processor(
                speech, 
                return_tensors="pt", 
                sampling_rate=self.sampling_rate
            ).input_values
            if torch.cuda.is_available():
                input_values = input_values.cuda()
            hidden_states = self.model(
                input_values, 
                output_hidden_states=True
            ).hidden_states
            x = hidden_states[self.km_layer].squeeze()
            dist = torch.sqrt(
                x.pow(2).sum(1, keepdim=True)
                - 2 * torch.matmul(x, self.C)
                + self.Cnorm
            )
            # top K == 6
            min_dist = torch.topk(dist.detach(), 6, dim=-1,largest=False)
            pred_ind_array = min_dist.indices.cpu().numpy()
            pred_values_array = min_dist.values.cpu().numpy()
            greedy_output = min_dist.indices.T.cpu().numpy()[0]
            print("greedy length", len(greedy_output))
            greedy_output = [k for k,_ in groupby(greedy_output)]
            print("greedy merged length", len(greedy_output))

            sequences = [[[], 1.0]]
            for i_row, v_row in zip(pred_ind_array, pred_values_array):
                all_candidates = list()
                exceed = False
                for seq in sequences:
                    tokens, score = seq
                    for k, v in zip(i_row, v_row):
                        norm_len_rate = (
                            len([k for k, _ in groupby(tokens + [k])])
                            / len(greedy_output))
                        norm_dist_rate = (v/numpy.sum(v_row))
                        candidate = [
                            tokens + [k], 
                            score + norm_len_rate * norm_dist_rate]
                        all_candidates.append(candidate)
                ordered = sorted(all_candidates, 
                                 key=lambda tup: tup[1], 
                                 reverse=False)
                sequences = ordered[:200]
            
            # top beamsearch result
            unitcode = [k for k,_ in groupby(sequences[0][0])]
            if self.return_diff:
                return (
                    unitcode, 
                    x.cpu() - torch.index_select(
                        torch.tensor(self.C_np.transpose())
                            .cpu(), 
                        0, 
                        min_dist.indices.cpu()))
            else:
                return unitcode

In [None]:
hc = HubertCode(
    "facebook/hubert-base-ls960", 
    './km.bin', 
    km_layer=6)

In [None]:
code = hc('LJ037-0171.wav')

In [None]:
len(code)

In [None]:
import asrp
cs = asrp.Code2Speech(
    tts_checkpoint='./tts_checkpoint_best.pt', 
    waveglow_checkpint='waveglow_256channels_new.pt', 
    end_tok=201, 
    code_begin_pad=1)

# play on notebook
import IPython.display as ipd
ipd.Audio(data=cs(code), autoplay=False, rate=cs.sample_rate)