-
-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
167 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,160 @@ | ||
""" | ||
import asyncio | ||
import os | ||
import time | ||
from functools import wraps | ||
from typing import Union | ||
|
||
""" | ||
import torch | ||
from termcolor import colored | ||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | ||
|
||
|
||
def async_retry(max_retries=3, exceptions=(Exception,), delay=1): | ||
""" | ||
A decorator for adding retry logic to async functions. | ||
:param max_retries: Maximum number of retries before giving up. | ||
:param exceptions: A tuple of exceptions to catch and retry on. | ||
:param delay: Delay between retries. | ||
""" | ||
|
||
def decorator(func): | ||
@wraps(func) | ||
async def wrapper(*args, **kwargs): | ||
retries = max_retries | ||
while retries: | ||
try: | ||
return await func(*args, **kwargs) | ||
except exceptions as e: | ||
retries -= 1 | ||
if retries <= 0: | ||
raise | ||
print(f"Retry after exception: {e}, Attempts remaining: {retries}") | ||
await asyncio.sleep(delay) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
class DistilWhisperModel: | ||
""" | ||
This class encapsulates the Distil-Whisper model for English speech recognition. | ||
It allows for both synchronous and asynchronous transcription of short and long-form audio. | ||
Args: | ||
model_id: The model ID to use. Defaults to "distil-whisper/distil-large-v2". | ||
Attributes: | ||
device: The device to use for inference. | ||
torch_dtype: The torch data type to use for inference. | ||
model_id: The model ID to use. | ||
model: The model instance. | ||
processor: The processor instance. | ||
Usage: | ||
model_wrapper = DistilWhisperModel() | ||
transcription = model_wrapper('path/to/audio.mp3') | ||
# For async usage | ||
transcription = asyncio.run(model_wrapper.async_transcribe('path/to/audio.mp3')) | ||
""" | ||
|
||
def __init__(self, model_id="distil-whisper/distil-large-v2"): | ||
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||
self.model_id = model_id | ||
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||
model_id, | ||
torch_dtype=self.torch_dtype, | ||
low_cpu_mem_usage=True, | ||
use_safetensors=True, | ||
).to(self.device) | ||
self.processor = AutoProcessor.from_pretrained(model_id) | ||
|
||
def __call__(self, inputs: Union[str, dict]): | ||
return self.transcribe(inputs) | ||
|
||
def transcribe(self, inputs: Union[str, dict]): | ||
""" | ||
Synchronously transcribe the given audio input using the Distil-Whisper model. | ||
:param inputs: A string representing the file path or a dict with audio data. | ||
:return: The transcribed text. | ||
""" | ||
pipe = pipeline( | ||
"automatic-speech-recognition", | ||
model=self.model, | ||
tokenizer=self.processor.tokenizer, | ||
feature_extractor=self.processor.feature_extractor, | ||
max_new_tokens=128, | ||
torch_dtype=self.torch_dtype, | ||
device=self.device, | ||
) | ||
|
||
return pipe(inputs)["text"] | ||
|
||
@async_retry() | ||
async def async_transcribe(self, inputs: Union[str, dict]): | ||
""" | ||
Asynchronously transcribe the given audio input using the Distil-Whisper model. | ||
:param inputs: A string representing the file path or a dict with audio data. | ||
:return: The transcribed text. | ||
""" | ||
loop = asyncio.get_event_loop() | ||
return await loop.run_in_executor(None, self.transcribe, inputs) | ||
|
||
def real_time_transcribe(self, audio_file_path, chunk_duration=5): | ||
""" | ||
Simulates real-time transcription of an audio file, processing and printing results | ||
in chunks with colored output for readability. | ||
:param audio_file_path: Path to the audio file to be transcribed. | ||
:param chunk_duration: Duration in seconds of each audio chunk to be processed. | ||
""" | ||
if not os.path.isfile(audio_file_path): | ||
print(colored("The audio file was not found.", "red")) | ||
return | ||
|
||
# Assuming `chunk_duration` is in seconds and `processor` can handle chunk-wise processing | ||
try: | ||
with torch.no_grad(): | ||
# Load the whole audio file, but process and transcribe it in chunks | ||
audio_input = self.processor.audio_file_to_array(audio_file_path) | ||
sample_rate = audio_input.sampling_rate | ||
total_duration = len(audio_input.array) / sample_rate | ||
chunks = [ | ||
audio_input.array[i : i + sample_rate * chunk_duration] | ||
for i in range( | ||
0, len(audio_input.array), sample_rate * chunk_duration | ||
) | ||
] | ||
|
||
print(colored("Starting real-time transcription...", "green")) | ||
|
||
for i, chunk in enumerate(chunks): | ||
# Process the current chunk | ||
processed_inputs = self.processor( | ||
chunk, | ||
sampling_rate=sample_rate, | ||
return_tensors="pt", | ||
padding=True, | ||
) | ||
processed_inputs = processed_inputs.input_values.to(self.device) | ||
|
||
# Generate transcription for the chunk | ||
logits = self.model.generate(processed_inputs) | ||
transcription = self.processor.batch_decode( | ||
logits, skip_special_tokens=True | ||
)[0] | ||
|
||
# Print the chunk's transcription | ||
print( | ||
colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") | ||
+ transcription | ||
) | ||
|
||
# Wait for the chunk's duration to simulate real-time processing | ||
time.sleep(chunk_duration) | ||
|
||
except Exception as e: | ||
print(colored(f"An error occurred during transcription: {e}", "red")) |