# Imports

In [1]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import load_dataset
from jiwer import wer

import torch
import glob
import os
import librosa
import torchaudio
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


# Load model

In [2]:
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['project_q.bias', 'project_hid.bias', 'project_hid.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors', 'project_q.weight', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

# English

## Load data

In [4]:
def load_transcripts(config="clean", split="test", base_path=""):
    transcripts={}
    
    for filepath in glob.glob(f"{base_path}/*/*/*.txt", recursive=True):
        with open(filepath, "r") as f:
            for line in f.readlines()[:10]:
                tokens = line.split("\n")[0].split(" ")
                transcripts[tokens[0]] = " ".join(tokens[1:])

    return transcripts


def load_audio(config="clean", split="test", path=""):
    speech, sr = torchaudio.load(path)
    speech = speech.squeeze()
    resampler = torchaudio.transforms.Resample(sr, 16000)
    speech = resampler(speech)
    
    return speech
    
    
def load_dataset(config="clean", split="test"):
    BASE_PATH = f"../data/en/LibriSpeech/{split.lower()}-{config.lower()}"
    
    transcripts = load_transcripts(config, split, BASE_PATH)
    audio = {}
    
    for key in list(transcripts.keys())[:10]:
        audio_path = f"{BASE_PATH}/{'/'.join(key.split('-')[:2])}/{key}.flac"
        audio[key] = load_audio(config, split, audio_path)
        
    return transcripts, audio

transcripts, audio = load_dataset()

## Evaluate base performance (W.E.R)

In [5]:
%%time
wer_scores = []

for key in list(transcripts.keys())[:10]:    
    input_values = processor(audio[key], return_tensors="pt", sampling_rate=16000)["input_values"]
    logits = model(input_values)["logits"]
    
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    wer_scores.append(wer([transcription], [transcripts[key]]))
    
print(f"wer: {np.average(wer_scores)}")

wer: 11.633333333333335
CPU times: user 15.7 s, sys: 2.64 s, total: 18.4 s
Wall time: 3.35 s


# Polish

## Load data

In [6]:
def load_polish_transcripts(split="test", base_path=""):
    transcripts={}
    
    with open(f"{base_path}/transcripts.txt", "r") as f:
        for line in f.readlines()[:10]:
            tokens = line.split("\n")[0].split("\t")
            transcripts[tokens[0]] = " ".join(tokens[1:])

    return transcripts


def load_polish_audio(split="test", path=""):
    speech, sr = torchaudio.load(path)
    speech = speech.squeeze()
    resampler = torchaudio.transforms.Resample(sr, 16000)
    speech = resampler(speech)
    
    return speech
    
    
def load_polish_dataset(split="test"):
    BASE_PATH = f"../data/polish/mls_polish_opus/{split.lower()}"
    
    transcripts = load_polish_transcripts(split, BASE_PATH)
    audio = {}
    
    for key in list(transcripts.keys())[:10]:
        audio_path = f"{BASE_PATH}/audio/{'/'.join(key.split('_')[:2])}/{key}.opus"
        audio[key] = load_polish_audio(split, audio_path)
        
    return transcripts, audio

transcripts, audio = load_polish_dataset()

## Evaluate base performance (W.E.R)

In [7]:
%%time
wer_scores = []

for key in list(transcripts.keys())[:10]:    
    input_values = processor(audio[key], return_tensors="pt", sampling_rate=16000)["input_values"]
    logits = model(input_values)["logits"]
    
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    wer_scores.append(wer([transcription], [transcripts[key]]))
    
print(f"wer: {np.average(wer_scores)}")

wer: 29.0
CPU times: user 1min 5s, sys: 10.4 s, total: 1min 15s
Wall time: 14.8 s
