In [20]:
import numpy as np
import pandas as pd

# import librosa
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# from datasets import load_dataset, load_metric

import IPython.display as ipd

torch.random.manual_seed(0)
print(torch.__version__)
print(torchaudio.__version__)

1.13.0.dev20220725
0.13.0.dev20220725


In [176]:
model_name = "facebook/wav2vec2-base-960h"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(model_name, device)

processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

facebook/wav2vec2-base-960h cpu


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [156]:
cv_dataset = torchaudio.datasets.COMMONVOICE(
    root="_cv_corpus/en",
    tsv="test.tsv",
)

cv_dataset[0]

(tensor([[ 0.0000e+00, -3.8697e-12, -9.5148e-12,  ...,  1.1894e-06,
           1.4817e-06,  1.8046e-06]]),
 32000,
 {'client_id': '000abb3006b78ea4c1144e55d9d158f05a9db0110160510fef2b006f2c2c8e35f7bb538b04542511834b61503cdda5b0331566a5cf59dc0d375a44afc4d10777',
  'path': 'common_voice_en_27710027.mp3',
  'sentence': 'Joe Keaton disapproved of films, and Buster also had reservations about the medium.',
  'up_votes': '3',
  'down_votes': '1',
  'age': '',
  'gender': '',
  'accents': '',
  'locale': 'en',
  'segment': ''})

In [157]:
from tqdm.notebook import tqdm
def resample(ds, new_sr):
    len_longest_sentence = 0
    data = []
    for wav, sr, metadata in tqdm(ds):
        waveform = torchaudio.functional.resample(wav, sr, new_sr)
        data.append({
            "speech": waveform,
            "sample_rate": new_sr,
            "sentence": metadata["sentence"]
        })
        if len(metadata["sentence"]) > len_longest_sentence:
            len_longest_sentence = len(metadata["sentence"])
    return data, len_longest_sentence

cv_dataset, len_longest_sentence = resample(cv_dataset, processor.feature_extractor.sampling_rate)
cv_dataset[0]
print(f"len longest sentence = {len_longest_sentence}")

  0%|          | 0/16345 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
def predict(sample):
    features = processor(
        sample["speech"][0], 
        sampling_rate=processor.feature_extractor.sampling_rate, 
        return_tensors="pt", 
        padding=True
    )

    input_values = features.input_values.to(device)

    with torch.no_grad():
        logits = model(input_values).logits 

    pred_ids = torch.argmax(logits, dim=-1)

    sample["predicted"] = processor.batch_decode(pred_ids)
    return sample

In [21]:
result = []
for x in tqdm(cv_dataset):
    result.append(predict(x))

  0%|          | 0/16345 [00:00<?, ?it/s]

In [22]:
print(result[0])

{'speech': tensor([[-7.3703e-13, -7.9181e-12, -7.6020e-12,  ...,  7.5047e-07,
          8.4476e-07,  1.6297e-06]]), 'sample_rate': 16000, 'sentence': 'Joe Keaton disapproved of films, and Buster also had reservations about the medium.', 'predicted': ['JO KEEPSAN DISAPPROVED OF THONES AND BUSTER ALSO HAD HESERVATIONS ABOUT THE MEDIUM']}


In [23]:
from numpy import average
import jiwer
from tqdm.notebook import tqdm

transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.ExpandCommonEnglishContractions(),
    jiwer.RemovePunctuation(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
]) 

accum_wer = 0
hypothesies = []
truths = []

for sample in tqdm(result):
        ground_truth = sample["sentence"]
        truths.append(ground_truth)

        hypothesis = sample["predicted"][0]
        hypothesies.append(hypothesis)
        accum_wer += jiwer.wer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation)

average_wer = accum_wer / len(result)
print(f"average_wer: {average_wer}")

total_wer = jiwer.wer(truths, hypothesies, truth_transform=transformation, hypothesis_transform=transformation)
print(f"total_wer: {total_wer}")

  0%|          | 0/16345 [00:00<?, ?it/s]

average_wer: 0.4095130565965886
total_wer: 0.4011559352830882


-----------
#### CTC Test

In [None]:
sample = cv_dataset[0]
features = processor(
    sample["speech"][0], 
    sampling_rate=processor.feature_extractor.sampling_rate, 
    return_tensors="pt", 
    padding=True
)

with torch.no_grad():
    logits = model(features.input_values).logits 

pred_ids = torch.argmax(logits, dim=-1)

sample["predicted"] = processor.batch_decode(pred_ids)

print(sample["predicted"])
print(pred_ids)
print(logits)

['JO KEEPSAN DISAPPROVED OF THONES AND BUSTER ALSO HAD HESERVATIONS ABOUT THE MEDIUM']
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0, 29,  0,  0,  0,  0,  8,  0,  0,  0,  0,  0,  0,  0,  4,  0, 26,
          0,  0,  5,  0,  0,  5,  0, 23,  0,  0, 12,  0,  0,  0,  7,  0,  9,  0,
          0,  4,  4, 14,  0,  0, 10,  0, 12,  0,  0,  0,  0,  7, 23,  0,  0,  0,
         23,  0, 13,  0,  0,  0,  8,  8,  0, 25,  0,  0,  5, 14,  0,  4,  4,  0,
          0,  8, 20,  0,  4,  4,  4,  0,  0,  0,  6, 11,  0,  0,  0,  0,  8,  0,
          0,  0,  9,  5,  0,  0, 12, 12,  4,  4,  0,  0,  0,  0,  7,  9,  0, 14,
          0,  4,  4,  4,  0, 24,  0,  0,  0, 16,  0,  0, 12, 12,  0,  6,  0,  0,
          0,  5, 13, 13,  0,  4,  4,  0,  0,  7, 15,  0,  0,  0,  0, 12,  0,  0,
          8,  0,  4,  4,  0, 11,  0,  0,  7,  0,  0, 14,  4,  4,  4, 11,  0,  0,
          5,  0,  0,  

In [111]:
import json
import re

re_chars_to_remove = re.compile(r"[^A-Z ']")

with open("vocab.json", "r") as fp:
    vocab_dict = json.load(fp)

def sentence_to_tensor(sentence, vocab, vocab_size, pad_len):
    sentence = sentence.upper()
    sentence = re_chars_to_remove.sub('', sentence).replace(' ', '|')
    t = torch.zeros([pad_len], dtype=torch.int)
    for i,x in enumerate(sentence):
        t[i] = vocab[x]
    return t, len(sentence)

In [84]:
pad_i = 0
target_len = 300
target_tensor_from_predicted, real_len = sentence_to_tensor(sample["predicted"][0], vocab_dict, len(vocab_dict), target_len)
print(target_tensor_from_predicted)
print(target_tensor_from_predicted.unsqueeze(0).shape)
print(logits.shape)

tensor([29,  8,  4, 26,  5,  5, 23, 12,  7,  9,  4, 14, 10, 12,  7, 23, 23, 13,
         8, 25,  5, 14,  4,  8, 20,  4,  6, 11,  8,  9,  5, 12,  4,  7,  9, 14,
         4, 24, 16, 12,  6,  5, 13,  4,  7, 15, 12,  8,  4, 11,  7, 14,  4, 11,
         5, 12,  5, 13, 25,  7,  6, 10,  8,  9, 12,  4,  7, 24,  8, 16,  6,  4,
         6, 11,  5,  4, 17,  5, 14, 10, 16, 17,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 

In [85]:
target_lengths = torch.tensor([real_len])
input_lengths = torch.tensor([logits.shape[1]])
print(target_lengths.shape)
print(input_lengths.shape)

torch.Size([1])
torch.Size([1])


In [86]:
ctcloss = torch.nn.CTCLoss(blank=pad_i)
loss = ctcloss(logits.squeeze().unsqueeze(1), target_tensor_from_predicted.unsqueeze(0), input_lengths, target_lengths)
print(loss)

tensor(-48.4535)


In [106]:
for i in range(3):
    target_tensor_from_predicted[i] = torch.randint(28, (1,))[0]

In [107]:
loss = ctcloss(logits.squeeze().unsqueeze(1), target_tensor_from_predicted.unsqueeze(0), input_lengths, target_lengths)
print(loss)

tensor(-48.2106)


------------

In [170]:
def transform_and_resample(dataset, new_sample_rate, vocab, batch_size=1):
    batch_wavs = []
    batch_inputs = []
    batch_sentences = []
    batch_targets = []
    target_lengths = []
    TEST_BATCH_CAP = 4
    count = 0
    for wav, sr, metadata in dataset:
        if(len(batch_wavs) == batch_size):
            yield {
                "wavs": torch.stack(batch_wavs),
                "inputs": torch.stack(batch_inputs),
                "sentences": batch_sentences,
                "targets": torch.stack(batch_targets),
                "target_lengths": torch.tensor(target_lengths),
            }
            count += 1
            batch_wavs = []
            batch_inputs = []
            batch_sentences = []
            batch_targets = []
            target_lengths = []
        if count == TEST_BATCH_CAP: break
        
        waveform = torchaudio.functional.resample(wav, sr, new_sample_rate)
        batch_wavs.append(waveform)
        features = processor(
            waveform[0], 
            sampling_rate=processor.feature_extractor.sampling_rate, 
            return_tensors="pt",
            padding=True
        )
        batch_inputs.append(features.input_values)
        batch_sentences.append(metadata["sentence"])
        target, length = sentence_to_tensor(metadata["sentence"], vocab, len(vocab), pad_len=300)
        batch_targets.append(target)
        target_lengths.append(length)

dataloaders = {}
dataset_sizes = {}
def reset_dataloaders():
    for phase in ['train', 'test']:
        dataset = torchaudio.datasets.COMMONVOICE(
            root="_cv_corpus/en",
            tsv=f"{phase}.tsv",
        )
        dataset_sizes[phase] = len(dataset)
        dataloaders[phase] = transform_and_resample(dataset, processor.feature_extractor.sampling_rate, vocab=vocab_dict, batch_size=1)


reset_dataloaders()
data = next(dataloaders['test'])
print(data["wavs"].shape)
print(data["inputs"].shape)
print(data["targets"].shape)
print(data["wavs"])
# reset_dataloaders()

torch.Size([1, 1, 100800])
torch.Size([1, 1, 100800])
torch.Size([1, 300])
tensor([[[-7.3703e-13, -7.9181e-12, -7.6020e-12,  ...,  7.5047e-07,
           8.4476e-07,  1.6297e-06]]])


In [177]:
import time
import copy

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_wer = 100.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        reset_dataloaders()

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_wer = 0.0

            # Iterate over data.
            for d in dataloaders[phase]:
                inputs = d["inputs"].to(device)
                targets = d["targets"].to(device)
                target_lengths = d["target_lengths"].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    logits = model(inputs[0]).logits
                    # print(f'logits = {logits.shape}')
                    # print(f'logits = {logits}')
                    pred_ids = torch.argmax(logits, dim=-1)

                    pred_sentence = processor.batch_decode(pred_ids)
                    print(f'sentence = {d["sentences"][0]}')
                    print(f'pred_sentence = {pred_sentence}')

                    wer = torchaudio.functional.edit_distance(d["sentences"][0], pred_sentence)
                    print(f'wer = {wer}')

                    input_lengths = torch.tensor([logits.shape[1]]).to(device)
                    # print(f'logits = {logits.shape}')
                    # print(f'pred_ids = {pred_ids.shape}')
                    # print(f'targets = {targets.shape}')
                    # print(f'target_lengths = {target_lengths.shape}')
                    loss = criterion(logits.squeeze().unsqueeze(1), targets, input_lengths, target_lengths)
                    print(f'loss = {loss}')

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_wer += wer
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_wer = running_wer / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} WER: {epoch_wer:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_wer < best_wer:
                best_wer = epoch_wer
                best_model_wts = copy.deepcopy(model.state_dict())


    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_wer:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model



In [178]:
# reset_dataloaders()
ctc_loss = torch.nn.CTCLoss()
optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.00001)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_ft = train_model(model, ctc_loss, optimizer_ft, exp_lr_scheduler, num_epochs=25)

Epoch 0/24
----------
sentence = The Kerry cable stations are recognised as World Heritage Communications Sites.
pred_sentence = ['HE CARY CABLE SATION ARE RECOGNIED A EL HAR IS IATIONE']
wer = 79
loss = -38.75562286376953
sentence = So too was the sign at Bromley, Kent.
pred_sentence = ['E<unk> C </s>E </s> C </s> <unk> <unk> </s> </s> E C  <unk></s> E   <unk>  <unk>   E </s> <unk> C <unk> </s><unk>  </s> E  <unk>']
wer = 37
loss = -45.74892807006836
sentence = Briles had good reason to record Hey Hank.
pred_sentence = ['']
wer = 42
loss = nan
sentence = Tornjaks have a clear, self-confident, serious and calm disposition.
pred_sentence = ['']
wer = 68
loss = nan
train Loss: nan WER: 0.0002
sentence = Joe Keaton disapproved of films, and Buster also had reservations about the medium.
pred_sentence = ['']
wer = 83
loss = nan
sentence = She'll be all right.
pred_sentence = ['']
wer = 20
loss = nan
sentence = six
pred_sentence = ['']
wer = 3
loss = nan
sentence = All's well that ends well

KeyboardInterrupt: 