In [9]:
import os
import json
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import Dataset, Audio
from tqdm.auto import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2"
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

asr_pipeline = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

# dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
# sample = dataset[0]["audio"]

# result = asr_pipeline(sample,return_timestamps=True)
# print(result["text"])


Device set to use cuda:0


In [18]:
# Path to your folder containing WAV files
audio_folder = "data/asr_testing/clean"

# Build a list of file paths for the WAV files
file_paths = [
    os.path.join(audio_folder, fname)
    for fname in os.listdir(audio_folder)
    if fname.lower().endswith(".wav")
]


# Create a dataset from the list of file paths.
# Here, the column "file" will initially hold the file path (a string).
dataset = Dataset.from_dict({"file": file_paths})

# Cast the "file" column to the Audio feature. This converts the file path to a dict with keys:
# "array", "sampling_rate", and "path".
dataset = dataset.cast_column("file", Audio())

# Add a filename column to each sample (extracted from the file's path)
def add_filename(example):
    example["filename"] = os.path.basename(example["file"]["path"])
    return example

dataset = dataset.map(add_filename)

# Define a batched transcription function.
def transcribe(batch):
    # Instead of passing file paths, pass the actual audio arrays
    audio_arrays = [x["array"] for x in batch["file"]]
    # Transcribe the batch using the pipeline.
    results = asr_pipeline(audio_arrays)
    texts = [res["text"] for res in results]
    return {"transcription": texts}



Map: 100%|██████████| 824/824 [00:14<00:00, 55.28 examples/s]  


In [None]:
# Apply the transcription function using batched mapping.
result_dataset = dataset.map(transcribe, batched=True, batch_size=6)

# Build a dictionary mapping filenames to transcriptions.
transcriptions = {
    entry["filename"]: entry["transcription"] for entry in result_dataset
}



Map: 100%|██████████| 824/824 [05:20<00:00,  2.57 examples/s]

Transcriptions have been saved to asr_output.json





In [21]:
# sort by filename
transcriptions = dict(sorted(transcriptions.items()))

# Save the dictionary to a JSON file.
output_file = "asr_output.json"
with open(output_file, "w") as f:
    json.dump(transcriptions, f, indent=4)

print(f"Transcriptions have been saved to {output_file}")

Transcriptions have been saved to asr_output.json
