In [8]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import torch
import torchaudio
from jiwer import wer
from tqdm import tqdm

dataset = torchaudio.datasets.LIBRISPEECH("data", url="test-clean", download=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
# Step 3: Load Wav2Vec 2.0 Model and Tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# Function to preprocess audio
def preprocess_audio(waveform, sample_rate):
    # Resample to 16 kHz if needed
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    return waveform

# Step 4: Transcribe audio and evaluate
total_wer = 0
num_samples = 0
model = model.to(DEVICE)
for waveform, sample_rate, utterance, _, _, _ in tqdm(dataset):
    waveform = preprocess_audio(waveform, sample_rate)

    input_values = tokenizer(waveform.squeeze().numpy(), return_tensors="pt").input_values.to(DEVICE)

    # Use Wav2Vec 2.0 to transcribe the audio
    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(predicted_ids)[0]

    # Calculate and accumulate WER
    total_wer += wer(utterance.lower(), transcription.lower())
    num_samples += 1

    # print(f"Ground Truth: {utterance}")
    # print(f"Transcription: {transcription}")
    # print("---")

# Calculate average WER
average_wer = total_wer / num_samples
print(f"Average WER: {average_wer}")

# Note: You might want to limit the number of samples for testing purposes.


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.


cuda


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'wav2vec2.encoder.pos_conv_embed.conv.weight_g']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You sho

Average WER: 0.03825904128930878



