In [1]:
import os
import time
from jiwer import wer
from tqdm import tqdm

import torch
import torchaudio

import fairseq_mod

from utils import Wav2VecCtc, W2lViterbiDecoder, postprocess_features, post_process_sentence

In this notebook, we will show how to quantize a wav2vec 2.0 model. We will use the dev-clean dataset from LibriSpeech which can be downloaded [here](https://www.openslr.org/12). We will quantize the wav2vec 2.0 large model (wav2vec_big_960h) and it can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt). Without further ado let's get started!


### Step 1: Specify paths to wav2vec 2.0 model and dataset. Create the letter dictionary.

In [2]:
model_path = "/home/models/wav2vec2/wav2vec_big_960h.pt"
data_path = "/home/datasets"
target_dict = fairseq_mod.data.Dictionary.load('ltr_dict.txt')

### Step 2: Initialize wav2vec 2.0 model

In [13]:
w2v = torch.load(model_path)
model = Wav2VecCtc.build_model(w2v["args"], target_dict)
model.load_state_dict(w2v["model"], strict=True)
model.eval()

Wav2VecCtc(
  (w2v_encoder): Wav2VecEncoder(
    (w2v_model): Wav2Vec2Model(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
            (3): GELU()
          )
          (1): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (2): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (3): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU()
          )
          (4): Sequen

### Step 3: Define a helper method which calculates the model size

In [11]:
def get_model_size(model):
    torch.save(model.state_dict(), 'temp_saved_model.pt')
    model_size_in_mb = os.path.getsize('temp_saved_model.pt') >> 20
    os.remove('temp_saved_model.pt')
    return model_size_in_mb

### Step 4: Get the model size of the original wav2vec 2.0 model

In [14]:
print("original model size is {:.2f} MB".format(get_model_size(model)))

original model size is 1203.00 MB


### Step 5: Quantize the original wav2vec 2.0 model

In [4]:
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
quantized_model.prepare_for_inference_after_quantization()

In [6]:
print("quantized model size is {:.2f} MB".format(get_model_size(quantized_model)))

quantized model size is 338.00 MB


You can see that the quantized model has a much smaller size!

### Step 6: Create decoder

In [7]:
decoder = W2lViterbiDecoder(target_dict)

### Step 7: Create data loader

In [8]:
dev_clean_librispeech_data = torchaudio.datasets.LIBRISPEECH(data_path, url='dev-clean', download=False)
data_loader = torch.utils.data.DataLoader(dev_clean_librispeech_data, batch_size=1, shuffle=False)

### Step 8: Define a helper method which converts one audio sample into text

In [9]:
def process_data_sample(data_sample, model, decoder, target_dict):
    encoder_input = dict()
    feature = postprocess_features(data_sample[0][0][0], data_sample[1]).unsqueeze(0)
    padding_mask = torch.BoolTensor(feature.size(1)).fill_(False).unsqueeze(0)
    
    encoder_input["source"] = feature
    encoder_input["padding_mask"] = padding_mask
    encoder_input["features_only"] = True
    encoder_input["mask"] = False
    
    encoder_out = model(**encoder_input)
    emissions = model.get_normalized_probs(encoder_out, log_probs=True)
    emissions = emissions.transpose(0, 1).float().cpu().contiguous()
    
    decoder_out = decoder.decode(emissions)
    hyp_pieces = target_dict.string(decoder_out[0][0]["tokens"].int().cpu())
    prediction = post_process_sentence(hyp_pieces, 'letter')
    
    return prediction

### Step 9: Calculate the WER of the quantized model on the entire dataset

In [10]:
predictions, ground_truths, start_time = [], [], time.time()
for i, data_sample in enumerate(tqdm(data_loader)):
    prediction = process_data_sample(data_sample, quantized_model, decoder, target_dict)
    predictions.append(prediction)
    ground_truths.append(data_sample[2][0])
inference_time = time.time() - start_time
wer_score = wer(ground_truths, predictions)
print("WER is {:.2f}%. Inference took {} seconds".format(wer_score*100, int(inference_time)))

100%|██████████| 2703/2703 [16:03<00:00,  3.26it/s]


WER is 2.75%. Inference took 963 seconds


Note that the quantized model is in CPU, so inference is slow. We used a 48 CPU core machine for this experiment.