# Imports

In [14]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from jiwer import wer

import torch
import glob
import os
import librosa
import torchaudio
import numpy as np

# Load model

In [21]:
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# English

## Load data

In [3]:
def load_transcripts(config="clean", split="test", base_path=""):
    transcripts={}
    
    for filepath in glob.glob(f"{base_path}/*/*/*.txt", recursive=True):
        with open(filepath, "r") as f:
            for line in f.readlines():
                tokens = line.split("\n")[0].split(" ")
                transcripts[tokens[0]] = " ".join(tokens[1:])

    return transcripts


def load_audio(config="clean", split="test", path=""):
    speech, sr = torchaudio.load(path)
    speech = speech.squeeze()
    resampler = torchaudio.transforms.Resample(sr, 16000)
    speech = resampler(speech)
    
    return speech
    
    
def load_dataset(config="clean", split="test"):
    BASE_PATH = f"../data/en/LibriSpeech/{split.lower()}-{config.lower()}"
    
    transcripts = load_transcripts(config, split, BASE_PATH)
    audio = {}
    
    for key, value in transcripts.items():
        audio_path = f"{BASE_PATH}/{'/'.join(key.split('-')[:2])}/{key}.flac"
        audio[key] = load_audio(config, split, audio_path)
        
    return transcripts, audio

transcripts, audio = load_dataset()

## Evaluate base performance (W.E.R)

In [20]:
%%time
wer_scores = []

for key in list(transcripts.keys())[:10]:    
    input_values = processor(audio[key], return_tensors="pt", sampling_rate=16000)["input_values"]
    logits = model(input_values)["logits"]
    
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    wer_scores.append(wer([transcription], [transcripts[key]]))
    
print(f"wer: {np.average(wer_scores)}")

wer: 0.021323529411764706
CPU times: user 15.6 s, sys: 2.51 s, total: 18.1 s
Wall time: 3.21 s


# Polish

## Load data

In [33]:
def load_polish_transcripts(split="test", base_path=""):
    transcripts={}
    
    with open(f"{base_path}/transcripts.txt", "r") as f:
        for line in f.readlines():
            tokens = line.split("\n")[0].split("\t")
            transcripts[tokens[0]] = " ".join(tokens[1:])

    return transcripts


def load_polish_audio(split="test", path=""):
    speech, sr = torchaudio.load(path)
    speech = speech.squeeze()
    resampler = torchaudio.transforms.Resample(sr, 16000)
    speech = resampler(speech)
    
    return speech
    
    
def load_polish_dataset(split="test"):
    BASE_PATH = f"../data/polish/mls_polish_opus/{split.lower()}"
    
    transcripts = load_polish_transcripts(split, BASE_PATH)
    audio = {}
    
    for key, value in transcripts.items():
        audio_path = f"{BASE_PATH}/audio/{'/'.join(key.split('_')[:2])}/{key}.opus"
        audio[key] = load_polish_audio(split, audio_path)
        
    return transcripts, audio

transcripts, audio = load_polish_dataset()

## Evaluate base performance (W.E.R)

In [34]:
%%time
wer_scores = []

for key in list(transcripts.keys())[:10]:    
    input_values = processor(audio[key], return_tensors="pt", sampling_rate=16000)["input_values"]
    logits = model(input_values)["logits"]
    
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    wer_scores.append(wer([transcription], [transcripts[key]]))
    
print(f"wer: {np.average(wer_scores)}")

wer: 1.307249343009336
CPU times: user 1min 8s, sys: 11.1 s, total: 1min 19s
Wall time: 17.1 s
