<a href="https://colab.research.google.com/github/harveenchadha/bol/blob/main/demos/hf/hindi/hf_quantization_him_4200.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchaudio transformers

Collecting torchaudio
  Downloading torchaudio-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 4.1 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.9.1-py3-none-any.whl (2.6 MB)
[K     |████████████████████████████████| 2.6 MB 20.9 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 57.4 MB/s 
Collecting huggingface-hub==0.0.12
  Downloading huggingface_hub-0.0.12-py3-none-any.whl (37 kB)
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 45.6 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 48.9 MB/s 
Installing collected packages: tokenizers, sacremoses, pyyaml, 

In [2]:
import torch
from torch import Tensor
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from transformers import Wav2Vec2ForCTC

# Wav2vec2 model emits sequences of probability (logits) distributions over the characters
# The following class adds steps to decode the transcript (best path)
class SpeechRecognizer(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "ँ": 5, "ं": 6, "ः": 7, "अ": 8, "आ": 9, "इ": 10, "ई": 11, "उ": 12, "ऊ": 13, "ऋ": 14, "ए": 15, "ऐ": 16, "ऑ": 17, "ओ": 18, "औ": 19, "क": 20, "ख": 21, "ग": 22, "घ": 23, "ङ": 24, "च": 25, "छ": 26, "ज": 27, "झ": 28, "ञ": 29, "ट": 30, "ठ": 31, "ड": 32, "ढ": 33, "ण": 34, "त": 35, "थ": 36, "द": 37, "ध": 38, "न": 39, "प": 40, "फ": 41, "ब": 42, "भ": 43, "म": 44, "य": 45, "र": 46, "ल": 47, "व": 48, "श": 49, "ष": 50, "स": 51, "ह": 52, "़": 53, "ा": 54, "ि": 55, "ी": 56, "ु": 57, "ू": 58, "ृ": 59, "ॅ": 60, "े": 61, "ै": 62, "ॉ": 63, "ो": 64, "ौ": 65, "्": 66}

        self.labels = list(vocab.keys())

    def forward(self, waveforms: Tensor) -> str:
        """Given a single channel speech data, return transcription.
        Args:
            waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`.
        Returns:
            str: The resulting transcript
        """
        logits, _ = self.model(waveforms)  # [batch, num_seq, num_label]
        best_path = torch.argmax(logits[0], dim=-1)  # [num_seq,]
        prev = ''
        hypothesis = ''
        for i in best_path:
            char = self.labels[i]
            if char == prev:
                continue
            if char == '<s>':
                prev = ''
                continue
            hypothesis += char
            prev = char
        return hypothesis.replace('|', ' ')


# Load Wav2Vec2 pretrained model from Hugging Face Hub
model = Wav2Vec2ForCTC.from_pretrained("Harveenchadha/vakyansh-wav2vec2-hindi-him-4200")
# Convert the model to torchaudio format, which supports TorchScript.
model = import_huggingface_model(model)
# Remove weight normalization which is not supported by quantization.
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
model = model.eval()
# Attach decoder
model = SpeechRecognizer(model)

# Apply quantization / script / optimize for motbile
quantized_model = torch.quantization.quantize_dynamic(
    model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_model = torch.jit.script(quantized_model)
optimized_model = optimize_for_mobile(scripted_model)

# Sanity check
# waveform , _ = torchaudio.load('scent_of_a_woman_future.wav')
# print('Result:', optimized_model(waveform))

optimized_model.save("wav2vec2.pt")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1658.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=377774679.0, style=ProgressStyle(descri…




Removing weight_norm from ConvolutionalPositionalEmbedding


In [3]:
from google.colab import files
files.download('wav2vec2.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>