# Voice Activity Detection (VAD) using NeMo

This notebook demonstrates how to perform Voice Activity Detection (VAD) using the `nemo` library. We will:
1. Load and preprocess an audio file.
2. Apply the VAD algorithm to detect speech segments.
3. Visualize and output the detected speech segments.


## Step 1: Install Requirements

In [None]:
# Install required packages
!pip install -q torch torchvision numpy matplotlib soundfile nemo_toolkit[asr]

## Step 2: Load Libraries and Discover GPU Resources

In [None]:
# Import necessary libraries
import nemo.collections.asr as nemo_asr
import torch
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt

# Check to see what GPU resources are available
def get_best_device():
    if torch.cuda.is_available():
        print("Using CUDA")
        return "cuda"
    elif torch.backends.mps.is_available():
        print("Using MPS")
        return "mps"
    else:
        print("Using CPU")
        return "cpu"
device = get_best_device()


## Step 3: Load the Audio File

We start by loading an audio file using `soundfile`. The audio needs to be in a format supported by `nemo`.

In [None]:
# Load the audio file
audio_filepath = "../../test_pcm.wav"
audio, sample_rate = sf.read(audio_filepath)

# Ensure the audio is a 1D array
if audio.ndim > 1:
    audio = np.mean(audio, axis=1)

# Resample audio if necessary
if sample_rate != 16000:
    import resampy
    audio = resampy.resample(audio, sample_rate, 16000)
    sample_rate = 16000

# Convert audio to torch tensor
audio_tensor = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)

# Plot the audio waveform
plt.figure(figsize=(15, 5))
plt.plot(np.linspace(0, len(audio) / sample_rate, num=len(audio)), audio)
plt.title('Audio Waveform')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.show()


## Step 4: Apply NeMo VAD

Next, we initialize the `nemo` VAD model and apply it to the audio file to detect speech segments.

In [None]:
# Initialize the NeMo VAD model
vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(model_name="vad_multilingual_marblenet").to(device)

# Move the audio tensor to the same device as the model
audio_tensor = audio_tensor.to(device)

# Perform VAD
logits = vad_model(input_signal=audio_tensor, input_signal_length=torch.tensor([audio_tensor.shape[1]], dtype=torch.int64).to(device))
probs = torch.softmax(logits, dim=-1).squeeze().cpu().detach().numpy()

# Process VAD output
vad_threshold = 0.5  # Threshold for detecting speech
segments = []
current_segment = None

for i, prob in enumerate(probs):                                                                                                                                                                                                                                                      
    if prob > vad_threshold:  # Class 1 corresponds to speech                                                                                                                                                                                                                         
        if current_segment is None:                                                                                                                                                                                                                                                   
            current_segment = [i * 0.02, (i + 1) * 0.02]                                                                                                                                                                                                                              
    else:                                                                                                                                                                                                                                                                             
        if current_segment is not None:                                                                                                                                                                                                                                               
            current_segment[1] = (i + 1) * 0.02                                                                                                                                                                                                                                       
            speech_segments.append(current_segment)                                                                                                                                                                                                                                   
            current_segment = None

if current_segment is not None:
    segments.append(current_segment)

# Print the VAD segments
print("Detected speech segments (in seconds):")
for start, end in segments:
    print(f"Start: {start:.2f}, End: {end:.2f}")

# Print VAD statistics: Number of speech segments, total duration of speech
# segments, and speech ratio
num_speech_segments = len(segments)
total_duration = sum([end - start for start, end in segments])
speech_ratio = total_duration / (len(audio) / sample_rate)
total_audio_length = len(audio) / sample_rate
print(f"\nNumber of speech segments: {num_speech_segments}")
print(f"Total length of audio: {total_audio_length:.2f} seconds")
print(f"Total duration of speech segments: {total_duration:.2f} seconds")
print(f"Speech ratio: {speech_ratio:.2f}")
print(f"Segments: \n{segments}")


## Step 5: Visualize the Detected Speech Segments

We visualize the detected speech segments on the audio waveform to better understand where speech occurs.

In [None]:
# Plot the audio waveform with detected speech segments
plt.figure(figsize=(15, 5))
plt.plot(np.linspace(0, len(audio) / sample_rate, num=len(audio)), audio, label='Audio')
for start, end in segments:
    plt.axvspan(start, end, color='red', alpha=0.5, label='Speech Segment' if start == segments[0][0] else "")
plt.title('Audio Waveform with Detected Speech Segments')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.legend()
plt.show()


## Cleanup Models & Pipelines on GPU

In [None]:
# Cleanup models and pipelines from GPU memory
# If device is cuda then cleanup cuda resources, if mps, cleanup mps resources
if device == "cuda":
    torch.cuda.empty_cache()
elif device == "mps":
    torch.backends.mps.release_process_group()
    torch.backends.mps.destroy_process_group()
    torch.backends.mps.shutdown()
    torch.cuda.empty_cache()

## Conclusion

In this notebook, we demonstrated how to use the `nemo` library to detect speech segments in an audio file. We loaded and preprocessed the audio, applied the VAD algorithm, and visualized the detected speech segments. Optionally, we saved the detected speech segments as separate audio files for further analysis.