In [None]:
import torch
import torchaudio

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
import IPython

SPEECH_FILE = "corpus/clips/common_voice_pl_20547774.mp3" 
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
# bundle = torchaudio.pipelines.HUBERT_BASE
model = bundle.get_model().to(device)

In [None]:
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)

if sample_rate != bundle.sample_rate:
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

In [None]:
with torch.inference_mode():
    features, _ = model.extract_features(waveform)

In [None]:
with torch.inference_mode():
    emission, _ = model(waveform)

In [None]:
class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor) -> str:
        """Given a sequence emission over labels, get the best path string
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          str: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        return "".join([self.labels[i] for i in indices])

In [None]:
decoder = GreedyCTCDecoder(labels=_bundle.get_labels())
transcript = decoder(emission[0])

In [None]:
print(transcript)
IPython.display.Audio(SPEECH_FILE)