In [None]:
import os
from dotenv import load_dotenv
import huggingface_hub

# load env var huggingface token
load_dotenv(os.path.join("../", ".env"))
# login to the hub
huggingface_hub.login(token=os.getenv("HUGGINGFACE_TOKEN"))

In [2]:
import datasets

voxpopuli = datasets.load_dataset("facebook/voxpopuli", "en", streaming=True, trust_remote_code=True)

In [None]:
voxpopuli_head = list(voxpopuli["train"].take(5))
SAMPLING_RATE = voxpopuli_head[0]["audio"]["sampling_rate"]
print(voxpopuli_head)

In [None]:
from transformers import AutoTokenizer, AutoFeatureExtractor, SpeechEncoderDecoderModel

encoder_id = "facebook/wav2vec2-base-960h"  # acoustic model encoder
decoder_id = "facebook/bart-base"  # text decoder

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)

model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_id, decoder_id, encoder_add_adapter=True
)
model.config.encoder.feat_proj_dropout = 0.0
model.config.encoder.mask_time_prob = 0.0
model.config.decoder_start_token_id = model.decoder.config.bos_token_id
model.config.pad_token_id = model.decoder.config.pad_token_id
model.config.eos_token_id = model.decoder.config.eos_token_id
model.config.max_length = 40
model.config.encoder.layerdrop = 0.0
model.config.use_cache = False
model.config.processor_class = "Wav2Vec2Processor"


input_values = feature_extractor(
    voxpopuli_head[0]["audio"]["array"], return_tensors="pt", sampling_rate=SAMPLING_RATE
).input_values

generated_ids = model.generate(input_values, decoder_start_token_id=tokenizer.cls_token_id, )
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text)

# # load its corresponding transcription and tokenize to generate labels
# labels = tokenizer(voxpopuli_head[0]["text"], return_tensors="pt").input_ids