Skip to content

Commit

Permalink
Tiny VAD refactoring for postprocessing (NVIDIA#4625)
Browse files Browse the repository at this point in the history
* binarization start index

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix frame len

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* rame UNIT_FRAME_LEN

Signed-off-by: fayejf <fayejf07@gmail.com>

* update overlap script and fix lgtm

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fi

Signed-off-by: fayejf <fayejf07@gmail.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
fayejf authored and Hainan Xu committed Nov 29, 2022
1 parent 8e5cbcc commit 90449fc
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 30 deletions.
1 change: 0 additions & 1 deletion examples/asr/conf/vad/vad_inference_postprocessing.yaml
Expand Up @@ -5,7 +5,6 @@ num_workers: 4
sample_rate: 16000

# functionality
gen_overlap_seq: True # whether to generate predictions with overlapping input segments and smoothing filter
gen_seg_table: True # whether to converting frame level prediction to speech/no-speech segment in start and end times format
write_to_manifest: True # whether to writing above segments to a single manifest json file.

Expand Down
6 changes: 4 additions & 2 deletions examples/asr/speech_classification/vad_infer.py
Expand Up @@ -122,9 +122,10 @@ def main(cfg):
logging.info(
f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}"
)
frame_length_in_sec = cfg.vad.parameters.shift_length_in_sec

# overlap smoothing filter
if cfg.gen_overlap_seq:
if cfg.vad.parameters.smoothing:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
logging.info("Generating predictions with overlapping input segments")
Expand All @@ -141,14 +142,15 @@ def main(cfg):
f"Finish generating predictions with overlapping input segments with smoothing_method={cfg.vad.parameters.smoothing} and overlap={cfg.vad.parameters.overlap}"
)
pred_dir = smoothing_pred_dir
frame_length_in_sec = 0.01

# postprocessing and generate speech segments
if cfg.gen_seg_table:
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
table_out_dir = generate_vad_segment_table(
vad_pred_dir=pred_dir,
postprocessing_params=cfg.vad.parameters.postprocessing,
shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
frame_length_in_sec=frame_length_in_sec,
num_workers=cfg.num_workers,
out_dir=cfg.table_out_dir,
)
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/clustering_diarizer.py
Expand Up @@ -237,6 +237,7 @@ def _run_vad(self, manifest_file):
if not self._vad_params.smoothing:
# Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
self.vad_pred_dir = self._vad_dir
frame_length_in_sec = self._vad_shift_length_in_sec
else:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
Expand All @@ -250,13 +251,14 @@ def _run_vad(self, manifest_file):
num_workers=self._cfg.num_workers,
)
self.vad_pred_dir = smoothing_pred_dir
frame_length_in_sec = 0.01

logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

table_out_dir = generate_vad_segment_table(
vad_pred_dir=self.vad_pred_dir,
postprocessing_params=self._vad_params,
shift_length_in_sec=self._vad_shift_length_in_sec,
frame_length_in_sec=frame_length_in_sec,
num_workers=self._cfg.num_workers,
)

Expand Down
46 changes: 24 additions & 22 deletions nemo/collections/asr/parts/utils/vad_utils.py
Expand Up @@ -459,12 +459,12 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te
offset (float): offset threshold for detecting the end of a speech.
pad_onset (float): adding durations before each speech segment
pad_offset (float): adding durations after each speech segment;
shift_length_in_sec (float): amount of shift of window for generating the frame.
frame_length_in_sec (float): length of frame.
Returns:
speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
"""
shift_length_in_sec = per_args.get('shift_length_in_sec', 0.01)
frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01)

onset = per_args.get('onset', 0.5)
offset = per_args.get('offset', 0.5)
Expand All @@ -477,30 +477,30 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te

speech_segments = torch.empty(0)

for i in range(1, len(sequence)):
for i in range(0, len(sequence)):
# Current frame is speech
if speech:
# Switch from speech to non-speech
if sequence[i] < offset:
if i * shift_length_in_sec + pad_offset > max(0, start - pad_onset):
if i * frame_length_in_sec + pad_offset > max(0, start - pad_onset):
new_seg = torch.tensor(
[max(0, start - pad_onset), i * shift_length_in_sec + pad_offset]
[max(0, start - pad_onset), i * frame_length_in_sec + pad_offset]
).unsqueeze(0)
speech_segments = torch.cat((speech_segments, new_seg), 0)

start = i * shift_length_in_sec
start = i * frame_length_in_sec
speech = False

# Current frame is non-speech
else:
# Switch from non-speech to speech
if sequence[i] > onset:
start = i * shift_length_in_sec
start = i * frame_length_in_sec
speech = True

# if it's speech at the end, add final segment
if speech:
new_seg = torch.tensor([max(0, start - pad_onset), i * shift_length_in_sec + pad_offset]).unsqueeze(0)
new_seg = torch.tensor([max(0, start - pad_onset), i * frame_length_in_sec + pad_offset]).unsqueeze(0)
speech_segments = torch.cat((speech_segments, new_seg), 0)

# Merge the overlapped speech segments due to padding
Expand Down Expand Up @@ -627,8 +627,8 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict
See description in generate_overlap_vad_seq.
Use this for single instance pipeline.
"""
UNIT_FRAME_LEN = 0.01

shift_length_in_sec = per_args['shift_length_in_sec']
speech_segments = binarization(sequence, per_args)
speech_segments = filtering(speech_segments, per_args)

Expand All @@ -637,7 +637,7 @@ def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict

speech_segments, _ = torch.sort(speech_segments, 0)

dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + shift_length_in_sec
dur = speech_segments[:, 1:2] - speech_segments[:, 0:1] + UNIT_FRAME_LEN
speech_segments = torch.column_stack((speech_segments, dur))

return speech_segments
Expand Down Expand Up @@ -667,7 +667,7 @@ def generate_vad_segment_table_per_file(pred_filepath: str, per_args: dict) -> s


def generate_vad_segment_table(
vad_pred_dir: str, postprocessing_params: dict, shift_length_in_sec: float, num_workers: int, out_dir: str = None,
vad_pred_dir: str, postprocessing_params: dict, frame_length_in_sec: float, num_workers: int, out_dir: str = None,
) -> str:
"""
Convert frame level prediction to speech segment in start and end times format.
Expand All @@ -677,7 +677,7 @@ def generate_vad_segment_table(
Args:
vad_pred_dir (str): directory of prediction files to be processed.
postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering.
shift_length_in_sec (float): amount of shift of window for generating the frame.
frame_length_in_sec (float): frame length.
out_dir (str): output dir of generated table/csv file.
num_workers(float): number of process for multiprocessing
Returns:
Expand All @@ -700,7 +700,7 @@ def generate_vad_segment_table(
os.mkdir(table_out_dir)

per_args = {
"shift_length_in_sec": shift_length_in_sec,
"frame_length_in_sec": frame_length_in_sec,
"out_dir": table_out_dir,
}
per_args = {**per_args, **postprocessing_params}
Expand Down Expand Up @@ -778,7 +778,7 @@ def vad_tune_threshold_on_dev(
result_file: str = "res",
vad_pred_method: str = "frame",
focus_metric: str = "DetER",
shift_length_in_sec: float = 0.01,
frame_length_in_sec: float = 0.01,
num_workers: int = 20,
) -> Tuple[dict, dict]:
"""
Expand All @@ -788,6 +788,8 @@ def vad_tune_threshold_on_dev(
vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median".
groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them.
focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS"
frame_length_in_sec (float): frame length.
num_workers (int): number of workers.
Returns:
best_threshold (float): threshold that gives lowest DetER.
"""
Expand All @@ -810,7 +812,7 @@ def vad_tune_threshold_on_dev(
# Generate speech segments by performing binarization on the VAD prediction according to param.
# Filter speech segments according to param and write the result to rttm-like table.
vad_table_dir = generate_vad_segment_table(
vad_pred, param, shift_length_in_sec=shift_length_in_sec, num_workers=num_workers
vad_pred, param, frame_length_in_sec=frame_length_in_sec, num_workers=num_workers
)
# add reference and hypothesis to metrics
for filename in paired_filenames:
Expand Down Expand Up @@ -938,14 +940,14 @@ def plot(
per_args(dict): a dict that stores the thresholds for postprocessing.
"""
plt.figure(figsize=[20, 2])
FRAME_LEN = 0.01
UNIT_FRAME_LEN = 0.01

audio, sample_rate = librosa.load(path=path2audio_file, sr=16000, mono=True, offset=offset, duration=duration)
dur = librosa.get_duration(y=audio, sr=sample_rate)

time = np.arange(offset, offset + dur, FRAME_LEN)
time = np.arange(offset, offset + dur, UNIT_FRAME_LEN)
frame, _ = load_tensor_from_file(path2_vad_pred)
frame_snippet = frame[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)]
frame_snippet = frame[int(offset / UNIT_FRAME_LEN) : int((offset + dur) / UNIT_FRAME_LEN)]

len_pred = len(frame_snippet)
ax1 = plt.subplot()
Expand All @@ -969,14 +971,14 @@ def plot(
) # take whole frame here for calculating onset and offset
speech_segments = generate_vad_segment_table_per_tensor(frame, per_args_float)
pred = gen_pred_from_speech_segments(speech_segments, frame)
pred_snippet = pred[int(offset / FRAME_LEN) : int((offset + dur) / FRAME_LEN)]
pred_snippet = pred[int(offset / UNIT_FRAME_LEN) : int((offset + dur) / UNIT_FRAME_LEN)]

if path2ground_truth_label:
label = extract_labels(path2ground_truth_label, time)
ax2.plot(np.arange(len_pred) * FRAME_LEN, label, 'r', label='label')
ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, label, 'r', label='label')

ax2.plot(np.arange(len_pred) * FRAME_LEN, pred_snippet, 'b', label='pred')
ax2.plot(np.arange(len_pred) * FRAME_LEN, frame_snippet, 'g--', label='speech prob')
ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, pred_snippet, 'b', label='pred')
ax2.plot(np.arange(len_pred) * UNIT_FRAME_LEN, frame_snippet, 'g--', label='speech prob')
ax2.tick_params(axis='y', labelcolor='r')
ax2.legend(loc='lower right', shadow=True)
ax2.set_ylabel('Preds and Probas')
Expand Down
6 changes: 3 additions & 3 deletions scripts/voice_activity_detection/vad_overlap_posterior.py
Expand Up @@ -83,19 +83,19 @@
start = time.time()
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")

frame_length_in_sec = args.shift_length_in_sec
if args.gen_overlap_seq:
logging.info("Use overlap prediction. Change if you want to use basic frame level prediction")
vad_pred_dir = overlap_out_dir
shift_length_in_sec = 0.01
frame_length_in_sec = 0.01
else:
logging.info("Use basic frame level prediction")
vad_pred_dir = args.frame_folder
shift_length_in_sec = args.shift_length_in_sec

table_out_dir = generate_vad_segment_table(
vad_pred_dir=vad_pred_dir,
postprocessing_params=postprocessing_params,
shift_length_in_sec=args.shift_length_in_sec,
frame_length_in_sec=frame_length_in_sec,
num_workers=args.num_workers,
out_dir=args.table_out_dir,
)
Expand Down
11 changes: 10 additions & 1 deletion scripts/voice_activity_detection/vad_tune_threshold.py
Expand Up @@ -81,6 +81,9 @@
type=str,
default='DetER',
)
parser.add_argument(
"--frame_length_in_sec", help="frame_length_in_sec ", type=float, default=0.01,
)
args = parser.parse_args()

params = {}
Expand Down Expand Up @@ -125,7 +128,13 @@
)

best_threhsold, optimal_scores = vad_tune_threshold_on_dev(
params, args.vad_pred, args.groundtruth_RTTM, args.result_file, args.vad_pred_method, args.focus_metric
params,
args.vad_pred,
args.groundtruth_RTTM,
args.result_file,
args.vad_pred_method,
args.focus_metric,
args.frame_length_in_sec,
)
logging.info(
f"Best combination of thresholds for binarization selected from input ranges is {best_threhsold}, and the optimal score is {optimal_scores}"
Expand Down

0 comments on commit 90449fc

Please sign in to comment.