Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try VAD with auditok #78

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@
include_package_data=True,
extras_require={
'dev': ['matplotlib', 'jsonschema', 'transformers'],
'vad': ['onnxruntime', 'torchaudio'],
'vad': ['auditok'],
},
)
67 changes: 42 additions & 25 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
__version__ = "1.13.3"
__version__ = "1.13.5"

# Set some environment variables
import os
Expand All @@ -15,6 +15,13 @@
import torch
import torch.nn.functional as F

from importlib.util import find_spec
if find_spec("intel_extension_for_pytorch") is not None:
try:
import intel_extension_for_pytorch
except ImportError:
pass

# For alignment
import numpy as np
import dtw
Expand Down Expand Up @@ -1761,15 +1768,14 @@ def split_tokens_on_spaces(tokens: torch.Tensor, tokenizer, remove_punctuation_f

return words, word_tokens, word_tokens_indices

silero_vad_model = None
def get_vad_segments(audio,
output_sample=False,
min_speech_duration=0.1,
min_silence_duration=0.1,
dilatation=0.5,
):
"""
Get speech segments from audio using Silero VAD
Get speech segments from audio using Auditok
parameters:
audio: torch.Tensor
audio data *in 16kHz*
Expand All @@ -1782,28 +1788,27 @@ def get_vad_segments(audio,
dilatation: float
how much (in sec) to enlarge each speech segment detected by the VAD
"""
global silero_vad_model, silero_get_speech_ts

if silero_vad_model is None:
import onnxruntime
onnxruntime.set_default_logger_severity(3) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model."
repo_or_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
source = "local"
if not os.path.exists(repo_or_dir):
repo_or_dir = "snakers4/silero-vad"
source = "github"
silero_vad_model, utils = torch.hub.load(repo_or_dir=repo_or_dir, model="silero_vad", onnx=True, source=source)
silero_get_speech_ts = utils[0]
import auditok

# Cheap normalization of the volume
audio = audio / max(0.1, audio.abs().max())

segments = silero_get_speech_ts(audio, silero_vad_model,
min_speech_duration_ms = round(min_speech_duration * 1000),
min_silence_duration_ms = round(min_silence_duration * 1000),
return_seconds = False,
data = (audio.numpy() * 32767).astype(np.int16).tobytes()

segments = auditok.split(
data,
sampling_rate=SAMPLE_RATE, # sampling frequency in Hz
channels=1, # number of channels
sample_width=2, # number of bytes per sample
min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds
max_dur=len(audio)/SAMPLE_RATE, # maximum duration of an event
max_silence=min_silence_duration, # maximum duration of tolerated continuous silence within an event
energy_threshold=50,
drop_trailing_silence=True,
)

segments = [{"start": s._meta.start * SAMPLE_RATE, "end": s._meta.end * SAMPLE_RATE} for s in segments]

if dilatation > 0:
dilatation = round(dilatation * SAMPLE_RATE)
new_segments = []
Expand Down Expand Up @@ -1837,7 +1842,7 @@ def remove_non_speech(audio,
plot=False,
):
"""
Remove non-speech segments from audio (using Silero VAD),
Remove non-speech segments from audio (using Auditok VAD),
glue the speech segments together and return the result along with
a function to convert timestamps from the new audio to the original audio
"""
Expand All @@ -1858,9 +1863,12 @@ def remove_non_speech(audio,
if plot:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(audio)
for s,e in segments:
plt.axvspan(s, e, color='red', alpha=0.1)
max_num_samples = 10000
step = (audio.shape[-1] // max_num_samples) + 1
times = [i*step/SAMPLE_RATE for i in range((audio.shape[-1]-1) // step + 1)]
plt.plot(times, audio[::step])
for s, e in segments:
plt.axvspan(s/SAMPLE_RATE, e/SAMPLE_RATE, color='red', alpha=0.1)
if isinstance(plot, str):
plt.savefig(f"{plot}.VAD.jpg", bbox_inches='tight', pad_inches=0)
else:
Expand Down Expand Up @@ -2042,9 +2050,18 @@ def write_csv(transcript, file, sep = ",", text_first=True, format_timestamps=No
# CUDA initialization may fail on old GPU card
def force_cudnn_initialization(device=None, s=32):
if device is None:
device = torch.device('cuda')
device = get_default_device()
torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=device), torch.zeros(s, s, s, s, device=device))

def get_default_device():
if torch.cuda.is_available():
device = "cuda"
elif find_spec('torch.xpu') is not None and torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
return device

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
Expand Down Expand Up @@ -2238,7 +2255,7 @@ def get_do_write(output_format):
parser.add_argument('audio', help="audio file(s) to transcribe", nargs='+')
parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small")
parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str)
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str)
valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"]
def str2output_formats(string):
Expand Down