In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
"""
# If you're using Google Colab and not running locally, run this cell.

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install unidecode

# ## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]

## Install TorchAudio
!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html

## Introduction
Who Speaks When? Speaker Diarization is the task of segmenting audio recordings by speaker labels. 
A diarization system consists of Voice Activity Detection (VAD) model to get the time stamps of audio where speech is being spoken ignoring the background and Speaker Embeddings model to get speaker embeddings on segments that were previously time stamped. These speaker embeddings would then be clustered into clusters based on number of speakers present in the audio recording.

In NeMo we support both **oracle VAD** and **non-oracle VAD** diarization. 

In this tutorial, we shall first demonstrate how to perform diarization with a oracle VAD time stamps (we assume we already have speech time stamps) and pretrained speaker verification model which can be found in tutorial for [Speaker Identification and Verification in NeMo](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb).

In ORACLE-VAD-DIARIZATION we show how to perform VAD and then diarization if ground truth timestamped speech were not available (non-oracle VAD). We also have tutorials for [VAD training in NeMo](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Voice_Activity_Detection.ipynb) and [online offline microphone inference](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Online_Offline_Microphone_VAD_Demo.ipynb), where you can custom your model and training/finetuning on your own data.

For demonstration purposes we would be using simulated audio from [an4 dataset](http://www.speech.cs.cmu.edu/databases/an4/)

In [None]:
import os
import wget
ROOT = os.getcwd()
data_dir = os.path.join(ROOT,'data')
os.makedirs(data_dir, exist_ok=True)
an4_audio = os.path.join(data_dir,'an4_diarize_test.wav')
an4_rttm = os.path.join(data_dir,'an4_diarize_test.rttm')
if not os.path.exists(an4_audio):
    an4_audio_url = "https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav"
    an4_audio = wget.download(an4_audio_url, data_dir)
if not os.path.exists(an4_rttm):
    an4_rttm_url = "https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.rttm"
    an4_rttm = wget.download(an4_rttm_url, data_dir)

Let's plot and listen to the audio and visualize the RTTM speaker labels

In [None]:
import IPython
import matplotlib.pyplot as plt
import numpy as np
import librosa

sr = 16000
signal, sr = librosa.load(an4_audio,sr=sr) 

fig,ax = plt.subplots(1,1)
fig.set_figwidth(20)
fig.set_figheight(2)
plt.plot(np.arange(len(signal)),signal,'gray')
fig.suptitle('Reference merged an4 audio', fontsize=16)
plt.xlabel('time (secs)', fontsize=18)
ax.margins(x=0)
plt.ylabel('signal strength', fontsize=16);
a,_ = plt.xticks();plt.xticks(a,a/sr);

IPython.display.Audio(an4_audio)

We would use [pyannote_metrics](https://pyannote.github.io/pyannote-metrics/) for visualization and score calculation purposes. Hence all the labels in rttm formats would eventually be converted to pyannote objects, we created two helper functions rttm_to_labels (for NeMo intermediate processing) and labels_to_pyannote_object for scoring and visualization format

In [None]:
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object

Let's load ground truth RTTM labels and view the reference Annotation timestamps visually

In [None]:
# view the sample rttm file
!cat {an4_rttm}

In [None]:
labels = rttm_to_labels(an4_rttm)
reference = labels_to_pyannote_object(labels)
print(labels)
reference

Speaker Diarization scripts commonly expects following arguments:
1. manifest_filepath : Path to manifest file containing json lines of format: {'audio_filepath': /path/to/audio_file, 'offset': 0, 'duration':None, 'label': 'infer', 'text': '-', 'num_speakers': None, 'rttm_filepath': /path/to/rttm/file, 'uem_filepath'='/path/to/uem/filepath'}
2. out_dir : directory where outputs and intermediate files are stored. 
3. oracle_vad: If this is true then we extract speech activity labels from rttm files, if False then either 
4. vad.model_path or external_manifestpath containing speech activity labels has to be passed. 

Mandatory fields are audio_filepath, offset, duration, label and text. For the rest if you would like to evaluate with known number of speakers pass the value else None. If you would like to score the system with known rttms then that should be passed as well, else None. uem file is used to score only part of your audio for evaluation purposes, hence pass if you would like to evaluate on it else None.


**Note** we expect audio and corresponding RTTM have **same base name** and the name should be **unique**. 

For eg: if audio file name is **test_an4**.wav, if provided we expect corresponding rttm file name to be **test_an4**.rttm (note the matching **test_an4** base name)


Lets create manifest with the an4 audio and rttm available. If you have more than one files you may also use the script `pathsfiles_to_manifest.py` to generate manifest file from list of audio files and optionally rttm files 

In [None]:
# Create a manifest for input with below format. 
# {'audio_filepath': /path/to/audio_file, 'offset': 0, 'duration':None, 'label': 'infer', 'text': '-', 
# 'num_speakers': None, 'rttm_filepath': /path/to/rttm/file, 'uem_filepath'='/path/to/uem/filepath'}
import json
meta = {
    'audio_filepath': an4_audio, 
    'offset': 0, 
    'duration':None, 
    'label': 'infer', 
    'text': '-', 
    'num_speakers': 2, 
    'rttm_filepath': an4_rttm, 
    'uem_filepath' : None
}
with open('data/input_manifest.json','w') as fp:
    json.dump(meta,fp)
    fp.write('\n')

!cat data/input_manifest.json

output_dir = os.path.join(ROOT, 'oracle_vad')
os.makedirs(output_dir,exist_ok=True)

# ORACLE-VAD DIARIZATION

Oracle-vad diarization is to compute speaker embeddings from known speech label timestamps rather than depending on VAD output. This step can also be used to run speaker diarization with rttms generated from any external VAD, not just VAD model from NeMo.

For it, the first step is to start converting reference audio rttm(vad) time stamps to oracle manifest file. This manifest file would be sent to our speaker diarizer to extract embeddings.

This is just an argument in our config, and system automatically computes oracle manifest based on the rttms provided through input manifest file

Our config file is based on [hydra](https://hydra.cc/docs/intro/). 
With hydra config, we ask users to provide values to variables that were filled with **???**, these are mandatory fields and scripts expect them for successful runs. And notice some variables were filled with **null** are optional variables. Those could be provided if needed but are not mandatory.

In [None]:
from omegaconf import OmegaConf
MODEL_CONFIG = os.path.join(data_dir,'offline_diarization.yaml')
if not os.path.exists(MODEL_CONFIG):
    config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/offline_diarization.yaml"
    MODEL_CONFIG = wget.download(config_url,data_dir)

config = OmegaConf.load(MODEL_CONFIG)
print(OmegaConf.to_yaml(config))

Now we can perform speaker diarization based on timestamps generated from ground truth rttms rather than generating through VAD

In [None]:
pretrained_speaker_model='ecapa_tdnn'
config.diarizer.manifest_filepath = 'data/input_manifest.json'
config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs

config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5
config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75
config.diarizer.oracle_vad = True # ----> ORACLE VAD 
config.diarizer.clustering.parameters.oracle_num_speakers = True

In [None]:
from nemo.collections.asr.models import ClusteringDiarizer
oracle_model = ClusteringDiarizer(cfg=config)

In [None]:
# And lets diarize
oracle_model.diarize()

With DER 0 -> means it clustered speaker embeddings correctly. Let's view 

In [None]:
!cat {output_dir}/pred_rttms/an4_diarize_test.rttm

In [None]:
pred_labels = rttm_to_labels(output_dir+'/pred_rttms/an4_diarize_test.rttm')
hypothesis = labels_to_pyannote_object(pred_labels)
hypothesis

In [None]:
reference

# VAD DIARIZATION

In this method we compute VAD time stamps using NeMo VAD model on input manifest file and then use these time stamps of speech label to find speaker embeddings followed by clustering them into num of speakers

Before we proceed let's look at the speaker diarization config, which we would be depending up on for vad computation
and speaker embedding extraction

In [None]:
print(OmegaConf.to_yaml(config))

As can be seen most of the variables in config are self explanatory 
with VAD variables under vad section and speaker related variables under speaker embeddings section. 

To perform VAD based diarization we can ignore `oracle_vad_manifest` in `speaker_embeddings` section for now and needs to fill up the rest. We also needs to provide pretrained `model_path` of vad and speaker embeddings .nemo models

In [None]:
pretrained_vad = 'vad_marblenet'
pretrained_speaker_model = 'ecapa_tdnn'

Note in this tutorial, we use the VAD model MarbleNet-3x2 introduced and published in [ICASSP MarbleNet](https://arxiv.org/pdf/2010.13886.pdf). You might need to tune on dev set similar to your dataset if you would like to improve the performance.

And the speakerNet-M-Diarization model achieves 7.3% confusion error rate on CH109 set with oracle vad. This model is trained on voxceleb1, voxceleb2, Fisher, SwitchBoard datasets. So for more improved performance specific to your dataset, finetune speaker verification model with a devset similar to your test set.

In [None]:
output_dir = os.path.join(ROOT,'outputs')
config.diarizer.manifest_filepath = 'data/input_manifest.json'
config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs

config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5
config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75
config.diarizer.oracle_vad = False # compute VAD provided with model_path to vad config
config.diarizer.clustering.parameters.oracle_num_speakers=True

#Here we use our inhouse pretrained NeMo VAD 
config.diarizer.vad.model_path = pretrained_vad
config.diarizer.vad.window_length_in_sec = 0.15
config.diarizer.vad.shift_length_in_sec = 0.01
config.diarizer.vad.parameters.onset = 0.8 
config.diarizer.vad.parameters.offset = 0.6
config.diarizer.vad.parameters.min_duration_on = 0.1
config.diarizer.vad.parameters.min_duration_off = 0.4

Now that we passed all the variables we needed lets initialize the clustering model with above config

In [None]:
from nemo.collections.asr.models import ClusteringDiarizer
sd_model = ClusteringDiarizer(cfg=config)

And Diarize with single line of code

In [None]:
sd_model.diarize()

As can be seen, we first performed VAD, then with the timestamps created in `{output_dir}/vad_outputs` by VAD we calculated speaker embeddings (`{output_dir}/speaker_outputs/embeddings/`) which are then clustered using spectral clustering. 

To generate VAD predicted time step. We perform VAD inference to have frame level prediction &#8594; (optional: use decision smoothing) &#8594; given `threshold`,  write speech segment to RTTM-like time stamps manifest.

we use vad decision smoothing (87.5% overlap median) as described [here](https://github.com/NVIDIA/NeMo/blob/stable/nemo/collections/asr/parts/utils/vad_utils.py)

you can also tune the threshold on your dev set. Use this provided [script](https://github.com/NVIDIA/NeMo/blob/stable/scripts/voice_activity_detection/vad_tune_threshold.py)

In [None]:
# VAD predicted time stamps
# you can also use single threshold(=onset=offset) for binarization and plot here
from nemo.collections.asr.parts.utils.vad_utils import plot
plot(
    an4_audio,
    'outputs/vad_outputs/overlap_smoothing_output_median_0.875/an4_diarize_test.median', 
    an4_rttm,
    per_args = config.diarizer.vad.parameters, #threshold
    ) 

print(f"postprocessing_params: {config.diarizer.vad.parameters}")

Predicted outputs are written to `output_dir/pred_rttms` and see how we predicted along with VAD prediction

In [None]:
!cat outputs/pred_rttms/an4_diarize_test.rttm

In [None]:
pred_labels = rttm_to_labels('outputs/pred_rttms/an4_diarize_test.rttm')
hypothesis = labels_to_pyannote_object(pred_labels)
hypothesis

In [None]:
reference

# Storing and Restoring models

Now we can save the whole config and model parameters in a single .nemo and restore from it anytime.

In [None]:
oracle_model.save_to(os.path.join(output_dir,'diarize.nemo'))

Restore from saved model

In [None]:
del oracle_model
import nemo.collections.asr as nemo_asr
restored_model = nemo_asr.models.ClusteringDiarizer.restore_from(os.path.join(output_dir,'diarize.nemo'))

# ADD ON - ASR 

In [None]:
IPython.display.Audio(an4_audio)

In [None]:
quartznet = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name="QuartzNet15x5Base-En")
for fname, transcription in zip([an4_audio], quartznet.transcribe(paths2audio_files=[an4_audio])):
  print(f"Audio in {fname} was recognized as:\n{transcription}")