-
Notifications
You must be signed in to change notification settings - Fork 293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
small uses more memory (but is faster) than medium (ONNX quantized) #53
Comments
Could this be because:
|
Hey @kmn1024! Thanks for opening this super interesting issue.
We decided not to release VRAM (memory) numbers in our benchmarks, since they're very dependent on hardware, CUDA version and PyTorch version. But we record some of these numbers ourselves. In my provisional benchmark, averaging over 100 samples of the LibriSpeech dataset, I got:
One reason for higher memory could be more decoding steps in from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer) |
Hi @kmn1024 👋 I did the conversions to ONNX, so I might have an explanation for this. I believe this is due to the additional outputs nodes, corresponding to the computed attentions. The reason I exported with these outputs is so that users can generate word-level timestamps with these models (and this might not be the case for the previous medium models). If this is something you will not need, you can do the conversions yourself with Optimum:
|
Thanks all! I can confirm that converting and quantizing from scratch works. The numbers are now:
P.S. The Optimum quantization command doesn't work out of the box; had to skip conv nodes as suggested in microsoft/onnxruntime#15888. |
Thanks for the great explanation @xenova! |
Setup
CUDA 12.2
GTX 1080
Copied all ONNX quantized models and required config jsons to their required location.
Code
Results
The text was updated successfully, but these errors were encountered: