# Setup

## Install dependencies

In [None]:
!pip install pympi-ling gradio

## Model definition

In [9]:
import torch
import torchaudio

class mHuBERTFinetuneModel(torch.nn.Module):

    def __init__(self, vocab):
        super().__init__()
        self.vocab = vocab
        self.pad_token_index = self.vocab.index("<pad>")
        self.transformer = torchaudio.models.hubert_base()
        self.lm_head = torch.nn.Linear(768, len(vocab), bias=False)

    @classmethod
    def from_finetuned(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
        vocab = checkpoint["metadata"]["vocab"]

        model = mHuBERTFinetuneModel(vocab)
        model.load_state_dict(checkpoint["weights"])

        return model

    def forward(self, audio_padded, audio_lengths=None):
        hidden_feats, hidden_lengths = self.transformer(audio_padded, audio_lengths)
        logits = self.lm_head(hidden_feats)
        logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
        return self.decode(logprobs)

    def decode(self, logprobs):
        indices = torch.argmax(logprobs, dim=-1)

        predictions = []

        for p in list(indices):
            unique_indices = torch.unique_consecutive(p, dim=-1)
            prediction = "".join([ self.vocab[i] for i in unique_indices if i != self.pad_token_index ])
            predictions.append(prediction)

        return predictions

## Download ASR model

This could be using `gdown` to download a link-shared model on Google Drive, `rclone` to download from another shared drive, etc.

In [10]:
# Use random init model to mock checkpoint downloaded from elsewhere
!mkdir -p /content

import string

mock_vocab = ['<pad>'] + list(string.ascii_lowercase)
mock_model = mHuBERTFinetuneModel(vocab=mock_vocab)
torch.save(
    { "metadata" : { "vocab" : mock_vocab }, "weights" : mock_model.state_dict() },
    "/content/mock_checkpoint.pt"
)

## Load VAD and ASR Models

In [11]:
model = mHuBERTFinetuneModel.from_finetuned("/content/mock_checkpoint.pt")

In [12]:
vad_model, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, trust_repo=True)
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = vad_utils

Using cache found in /root/.cache/torch/hub/snakers4_silero-vad_master


In [13]:
import sys

from pympi import Elan
from pathlib import Path
from tqdm import tqdm

def print_silero_progress(silero_pc):
    sys.stdout.write('\r')
    sys.stdout.write("[%-100s] %d%%" % ('='*round(silero_pc), silero_pc))
    sys.stdout.flush()

    if silero_pc == 100.0:
        print("\n", end="")

def transcribe(audio_filepath, one_tier_per_channel):

  audio_filepath = Path(audio_filepath)

  waveform, sr = torchaudio.load(audio_filepath)

  # If needed, resample audio to 16 kHz for Silero-VAD and HuBERT
  if sr != 16_000:
    waveform = torchaudio.functional.resample(waveform, sr, 16_000)

  # If not mono and not set to one tier per channel then convert to mono
  if waveform.size(0) > 1 and not one_tier_per_channel:
    waveform = waveform.mean(axis=0, keepdim=True)

  eaf_data = Elan.Eaf()
  eaf_data.add_linked_file(audio_filepath.name)
  # Remove 'default' tier from newly created eaf object
  eaf_data.remove_tier('default')

  for channel in range(waveform.size(0)):

    eaf_data.add_tier(f"Channel {channel}")

    print(f"Detecting time regions with speech on channel {channel}")

    channel_waveform = waveform[channel, :].unsqueeze(0)

    speech_timestamps = get_speech_timestamps(channel_waveform, vad_model, threshold=0.75, progress_tracking_callback=print_silero_progress)

    print(f"Transcribing speech in detected regions on channel {channel}")

    for segment_bounds in tqdm(speech_timestamps):
      segment_samples=channel_waveform[:, segment_bounds['start']:segment_bounds['end']].cuda()

      with torch.inference_mode():
        audio_normed = torch.nn.functional.layer_norm(segment_samples, segment_samples.shape)
        text = model(audio_normed)[0].strip()

      start_ms, end_ms = [segment_bounds['start']/16, segment_bounds['end']/16]
      eaf_data.add_annotation(f"Channel {channel}", start=round(start_ms), end=round(end_ms), value=text)

  eaf_file = audio_filepath.with_suffix(".eaf")
  eaf_data.to_file(eaf_file)
  
  return str(eaf_file)

## Create Gradio interface

In [None]:
import gradio as gr

asr = gr.Interface(
    fn=transcribe,
    inputs=[ 
        gr.Audio(type="filepath"),
        gr.Checkbox(label="Yes", info="Create one tier per audio channel?", value=True)
    ],
    outputs=gr.File(label="ELAN eaf file"),
    title="Transcribe Audio to ELAN eaf file",
)

# Run app

In [None]:
asr.launch(debug=True)