In [None]:
from transformers import SpeechEncoderDecoderModel, AutoProcessor
import torch
from datasets import load_dataset

MODEL_ID = "matejhornik/wav2vec2-base_bart-base_voxpopuli-en"
DATASET_ID = "facebook/voxpopuli"
DATASET_CONFIG = "en"
DATASET_SPLIT = "test"  # "validation"

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = SpeechEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)

print(
    f"Using device: {device}\nStreaming one sample from '{DATASET_ID}'"
    "(config: '{DATASET_CONFIG}', split: '{DATASET_SPLIT}')..."
)
streamed_dataset = load_dataset(
    DATASET_ID,
    DATASET_CONFIG,
    split=DATASET_SPLIT,
    streaming=True,
)
sample = next(iter(streamed_dataset))

audio_input = sample["audio"]["array"]
input_sampling_rate = sample["audio"]["sampling_rate"]

inputs = processor(
    audio_input, sampling_rate=input_sampling_rate, return_tensors="pt", padding=True
)
input_features = inputs.input_values.to(device)

with torch.no_grad():
    predicted_ids = model.generate(input_features, max_length=128)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

print(f"\nOriginal: {sample['normalized_text']}")
print(f"Transcribed: {transcription}")