In [1]:
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
from datasets import load_dataset
import soundfile as sf
import torch

# load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# define function to read in sound file
def map_to_array(batch):
    speech, _ = sf.read(batch["file"])
    batch["speech"] = speech
    return batch

# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.map(map_to_array)

# tokenize
input_values = tokenizer(ds["speech"][:2], return_tensors="pt", padding="longest").input_values  # Batch size 1

# retrieve logits
logits = model(input_values).logits

# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading:   0%|          | 0.00/5.02k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr/clean (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to C:\Users\kibong\.cache\huggingface\datasets\librispeech_asr\clean\2.1.0\468ec03677f46a8714ac6b5b64dba02d246a228d92cbbad7f3dc190fa039eab1...


Downloading:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset librispeech_asr downloaded and prepared to C:\Users\kibong\.cache\huggingface\datasets\librispeech_asr\clean\2.1.0\468ec03677f46a8714ac6b5b64dba02d246a228d92cbbad7f3dc190fa039eab1. Subsequent calls will reuse this data.


  0%|          | 0/73 [00:00<?, ?ex/s]

In [2]:
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import soundfile as sf
import torch
from jiwer import wer


librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")

def map_to_array(batch):
    speech, _ = sf.read(batch["file"])
    batch["speech"] = speech
    return batch

librispeech_eval = librispeech_eval.map(map_to_array)

def map_to_pred(batch):
    input_values = tokenizer(batch["speech"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])

print("WER:", wer(result["text"], result["transcription"]))

Downloading:   0%|          | 0.00/2.25k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.76k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr/clean (download: 28.05 GiB, generated: 54.01 MiB, post-processed: Unknown size, total: 28.11 GiB) to C:\Users\kibong\.cache\huggingface\datasets\librispeech_asr\clean\2.1.0\7020b7c48a960d82be5eae749f537115c6a45c75d5207011fd55948bd95d4de0...


Downloading:   0%|          | 0.00/338M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/347M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/6.39G [00:00<?, ?B/s]

ConnectionError: HTTPConnectionPool(host='www.openslr.org', port=80): Read timed out.

In [3]:
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")


Downloading and preparing dataset librispeech_asr/clean (download: 28.05 GiB, generated: 54.01 MiB, post-processed: Unknown size, total: 28.11 GiB) to C:\Users\kibong\.cache\huggingface\datasets\librispeech_asr\clean\2.1.0\7020b7c48a960d82be5eae749f537115c6a45c75d5207011fd55948bd95d4de0...


Downloading:   0%|          | 0.00/6.39G [00:00<?, ?B/s]

ConnectionError: HTTPConnectionPool(host='www.openslr.org', port=80): Read timed out.