# NeMo Diarization Process
This notebook walks through the process of speaker diarization using the NVIDIA NeMo library.
We will be using a local audio file `../../test_pcm.wav` for the diarization process.

## Install Requirements

In [None]:
# Setup installers
commands = [
    ("apt-get install -y -q sox libsndfile1 ffmpeg", "Install sox, libsndfile1, and ffmpeg"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q wget", "Install wget"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q text-unidecode", "Install text-unidecode"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]", "Install NeMo"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q torchaudio -f https://download.pytorch.org/whl/torch_stable.html", "Install TorchAudio")
]

# Import the utils module which sets up the environment
from modules import utils
from modules import disable_warnings

# Use LogTools
log_tools = utils.LogTools()

# Execute
log_tools.command_state(commands)

## Import Required Libraries

In [None]:
import nemo.collections.asr as nemo_asr
import numpy as np
from IPython.display import Audio, display
import librosa
import os
import matplotlib.pyplot as plt
from scipy.signal import resample
import pprint
pp = pprint.PrettyPrinter(indent=4)
import torch

# Check to see what GPU resources are available
def get_best_device():
    if torch.cuda.is_available():
        print("Using CUDA")
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        print("Using MPS")
        return torch.device("mps")
    else:
        print("Using CPU")
        return torch.device("cpu")

device = get_best_device()

## Setting Up Directories
Create necessary directories for storing data and specify the audio file to be used.

In [None]:
ROOT = os.getcwd()
# Create data directory
data_dir = f'{ROOT}/data'

AUDIO_FILENAME = '../../test_pcm.wav'

audio_file_list = AUDIO_FILENAME
print("Input audio file list: \n", audio_file_list)

## Loading the Audio File
Load the audio file

In [None]:
# Load the audio file
signal, sample_rate = librosa.load(AUDIO_FILENAME, sr=None)

# If signal sample_rate isn't 16000 then resample it with scipy.signal resample
new_sample_rate = 16000
if sample_rate != new_sample_rate:
    # Resample audio
    print(f"Current sample rate: {sample_rate}")
    print(f"Resampling signal to {new_sample_rate}")
    num_samples = int(len(signal) * new_sample_rate / sample_rate)
    signal = resample(signal, num_samples)
    sample_rate = new_sample_rate

# Give audio player to listen to audio
display(Audio(signal,rate=sample_rate))


## Displaying the Waveform
Define a function to display the waveform of the audio signal.

In [None]:
def display_waveform(signal,text='Audio',overlay_color=[]):
    fig,ax = plt.subplots(1,1)
    fig.set_figwidth(20)
    fig.set_figheight(2)
    plt.scatter(np.arange(len(signal)),signal,s=1,marker='o',c='k')
    if len(overlay_color):
        plt.scatter(np.arange(len(signal)),signal,s=1,marker='o',c=overlay_color)
    fig.suptitle(text, fontsize=16)
    plt.xlabel('time (secs)', fontsize=18)
    plt.ylabel('signal strength', fontsize=14);
    plt.axis([0,len(signal),-0.5,+0.5])
    time_axis,_ = plt.xticks();
    plt.xticks(time_axis[:-1],time_axis[:-1]/sample_rate);

COLORS = "b g c m y".split()

def get_color(signal, speech_labels, sample_rate=16000):
    c = np.array(['k']*len(signal))
    for time_stamp in speech_labels:
        start, end, label = time_stamp.split()
        start, end = int(float(start) * sample_rate), int(float(end) * sample_rate)
        if label == "speech":
            code = 'red'
        else:
            code = COLORS[int(label.split('_')[-1])]
        c[start:end] = code
    return c

display_waveform(signal)

## Preparing Configuration for Diarization
Download and load the configuration file for the diarization process.

In [None]:
from omegaconf import OmegaConf
import wget

DOMAIN_TYPE = "telephonic" # Can be meeting or telephonic based on domain type of the audio file
CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"

CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"

if not os.path.exists(os.path.join(data_dir,CONFIG_FILE_NAME)):
    CONFIG = wget.download(CONFIG_URL, data_dir)
else:
    CONFIG = os.path.join(data_dir,CONFIG_FILE_NAME)

cfg = OmegaConf.load(CONFIG)
print(OmegaConf.to_yaml(cfg))

## Creating a Manifest File
Create a manifest file required for the diarization process.

In [None]:
import json
meta = {
    'audio_filepath': AUDIO_FILENAME,
    'offset': 0,
    'duration':None,
    'label': 'infer',
    'text': '-',
    'num_speakers': None,
    'rttm_filepath': None,
    'uem_filepath' : None
}
with open(os.path.join(data_dir,'input_manifest.json'),'w') as fp:
    json.dump(meta,fp)
    fp.write('\n')

cfg.diarizer.manifest_filepath = os.path.join(data_dir,'input_manifest.json')

## Setting Up the Diarization Configuration
Configure the diarization settings including the VAD model, ASR model, and other parameters.

In [None]:
pretrained_speaker_model='titanet_large'
cfg.diarizer.manifest_filepath = cfg.diarizer.manifest_filepath
cfg.diarizer.out_dir = data_dir #Directory to store intermediate files and prediction outputs
cfg.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
cfg.diarizer.clustering.parameters.oracle_num_speakers=False

# Using Neural VAD and Conformer ASR
cfg.diarizer.vad.model_path = 'vad_multilingual_marblenet'
cfg.diarizer.asr.model_path = 'stt_en_conformer_ctc_large'
cfg.diarizer.oracle_vad = False # ----> Not using oracle VAD
cfg.diarizer.asr.parameters.asr_based_vad = False

## Running the ASR Model
Run the ASR model to get word-level timestamps.

In [None]:
from nemo.collections.asr.parts.utils.decoder_timestamps_utils import ASRDecoderTimeStamps
asr_decoder_ts = ASRDecoderTimeStamps(cfg.diarizer)
asr_model = asr_decoder_ts.set_asr_model()
word_hyp, word_ts_hyp = asr_decoder_ts.run_ASR(asr_model)

print("Decoded word output dictionary: \n", word_hyp['test_pcm'])
print("Word-level timestamps dictionary: \n", word_ts_hyp['test_pcm'])

## Running the Diarization Process
Run the diarization process using the ASR model output.

In [None]:
from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR
asr_diar_offline = OfflineDiarWithASR(cfg.diarizer)
asr_diar_offline.word_ts_anchor_offset = asr_decoder_ts.word_ts_anchor_offset

diar_hyp, diar_score = asr_diar_offline.run_diarization(cfg, word_ts_hyp)
print("Diarization hypothesis output: \n", diar_hyp['test_pcm'])

## Displaying the Diarization (Speaker Attribution) Results
Read the predicted RTTM file and display the waveform with speaker labels.

In [None]:
def read_file(path_to_file):
    with open(path_to_file) as f:
        contents = f.read().splitlines()
    return contents

predicted_speaker_label_rttm_path = f"{data_dir}/pred_rttms/test_pcm.rttm"
pred_rttm = read_file(predicted_speaker_label_rttm_path)

pp.pprint(pred_rttm)

from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels
pred_labels = rttm_to_labels(predicted_speaker_label_rttm_path)

color = get_color(signal, pred_labels)
display_waveform(signal,'Audio with Speaker Labels', color)
display(Audio(signal,rate=16000))

## Displaying the Transcription Results
Display the transcription results along with speaker labels.

In [None]:
# Set the word-level timestamps and speaker labels to the diarization object
trans_info_dict = asr_diar_offline.get_transcript_with_speaker_labels(diar_hyp, word_hyp, word_ts_hyp)

# Set the word-level timestamps and speaker labels to the diarization object
transcription_path_to_file = f"{data_dir}/pred_rttms/test_pcm.txt"
transcript = read_file(transcription_path_to_file)

# Print the transcript with speaker labels, newline per speaker segment
print("\n".join(transcript))

#pp.pprint(transcript)

# Set the word-level timestamps and speaker labels to the diarization object
transcription_path_to_file = f"{data_dir}/pred_rttms/test_pcm.json"
json_contents = read_file(transcription_path_to_file)
# pp.pprint(json_contents)
#print(json_contents)

## Evaluating the Diarization Results
Evaluate the diarization results by calculating cpWER and WER.

In [None]:
from nemo.collections.asr.metrics.der import concat_perm_word_error_rate
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.diarization_utils import (
    convert_word_dict_seq_to_text,
    convert_ctm_to_text,
)

# Set the hypothesis and reference transcripts
hyp1 = "Kestrel IT Charles speaking Hey how can I help? Viruses detected! Umm all right well what exactly is the message popping up Yeah? Uh no worries ah okay um where are you currently? yeah ahh north or south? all right um just um for the sake of like isolating that machine and air gaping it rather than like trying to connect it to the network and then just jump into it am I able to get you to take that to a Irfan Sheikh over at the training hut? Umm it is just the best rather than like just connecting it to the network and seeing because like a sort of message like that it is popping up from the anti virus so like you don't really want to connect to the corporate network where you've got printers servers etc you know it would be better to all cheers thank you"
hyp2 = "Yeah Charles how are ya? Yeah <unk> here from Kestrel mate I just opened up my computer and it just says there's viruses detected so I thought I would give you a call before I go further Oh mate it says Trojan virus detected Yeah I'm at work, at Kestrel South Yeah I have meetings but ok righto no worries mate oh that's all right yeah well I'll see what I can do yeah sure yep thank you"

# Combine hypotheses into a list
hypotheses = [hyp1, hyp2]

# Example trans_info_dict for context:
trans_info_dict = {
    "test_pcm": {
        "words": [
            {'start': 2.36, 'end': 4.15, 'word': 'Kestrel IT Charles speaking', 'speaker': 'speaker_0'},
            {'start': 4.79, 'end': 5.75, 'word': 'Yeah Charles how are ya?', 'speaker': 'speaker_1'},
            {'start': 7.11, 'end': 8.03, 'word': 'Hey how can I help?', 'speaker': 'speaker_0'},
            {'start': 8.59, 'end': 15.31, 'word': 'Yeah <unk> here from Kestrel mate I just opened up my computer and it just says there\'s viruses detected so I thought I would give you a call before I go further', 'speaker': 'speaker_1'},
            {'start': 17.11, 'end': 21.91, 'word': 'Viruses detected! Umm all right well what exactly is the message popping up', 'speaker': 'speaker_0'},
            {'start': 30.19, 'end': 33.95, 'word': 'Uh no worries ah okay um where are you currently?', 'speaker': 'speaker_0'},
            {'start': 34.07, 'end': 36.47, 'word': 'I\'m at work, at Kestrel', 'speaker': 'speaker_1'},
            {'start': 37.95, 'end': 39.15, 'word': 'yeah ahh north or south?', 'speaker': 'speaker_0'},
            {'start': 40.27, 'end': 40.51, 'word': 'South', 'speaker': 'speaker_1'},
            {'start': 42.75, 'end': 56.79, 'word': 'all right um just um for the sake of like isolating that machine and air gaping it rather than like trying to connect it to the network and then just jump into it am I able to get you to take that to a Irfan Sheikh over at the training hut?', 'speaker': 'speaker_0'},
            {'start': 59.07, 'end': 59.91, 'word': 'Yeah I have meetings but ok righto no worries mate', 'speaker': 'speaker_1'},
            {'start': 62.63, 'end': 77.87, 'word': 'Umm it is just the best rather than like just connecting it to the network and seeing because like a sort of message like that it is popping up from the anti virus so like you don\'t really want to connect to the corporate network where you\'ve got printers servers etc you know it would be better to', 'speaker': 'speaker_0'},
            {'start': 78.07, 'end': 80.87, 'word': 'oh that\'s all right yeah well I\'ll see what I can do', 'speaker': 'speaker_1'},
            {'start': 83.10, 'end': 83.03, 'word': 'alright cheers thank you', 'speaker': 'speaker_0'},
            {'start': 83.42, 'end': 85.87, 'word': 'yeah sure yep thank you', 'speaker': 'speaker_1'},
        ]
    }
}

# List to hold word sequences
word_seq_lists = []
uniq_id = "test_pcm"

# Add the word sequences to the list
word_seq_lists.append(trans_info_dict[uniq_id]["words"])

# Print the first session in word_seq_lists
print("word_seq_lists:\n", word_seq_lists[0])

# Example test_pcm_ctm data for context
test_pcm_ctm = [
    "test_pcm 1 3.04 0.55 Kestrel NA lex speaker_0",
    "test_pcm 1 3.59 0.02 IT NA lex speaker_0",
    "test_pcm 1 3.61 0.21 Charles NA lex speaker_0",
    "test_pcm 1 3.82 0.37 speaking NA lex speaker_0",
    "test_pcm 1 5.18 0.14 Yeah NA lex speaker_1",
    "test_pcm 1 5.32 0.24 Charles NA lex speaker_1",
    "test_pcm 1 5.56 0.15 how NA lex speaker_1",
    "test_pcm 1 5.71 0.02 are NA lex speaker_1",
    "test_pcm 1 5.73 0.12 ya NA lex speaker_1",
    "test_pcm 1 7.34 0.28 Hey NA lex speaker_0",
    "test_pcm 1 7.63 0.11 how NA lex speaker_0",
    "test_pcm 1 7.74 0.14 can NA lex speaker_0",
    "test_pcm 1 7.92 0.20 help NA lex speaker_0",
    "test_pcm 1 8.99 0.18 Yeah NA lex speaker_1",
    "test_pcm 1 9.33 0.24 <unk> NA lex speaker_1",
    "test_pcm 1 9.17 0.16 mate NA lex speaker_1",
    "test_pcm 1 9.75 0.19 from NA lex speaker_1",
    "test_pcm 1 9.94 0.26 Kestrel NA lex speaker_1",
    "test_pcm 1 10.46 0.09 I NA lex speaker_1",
    "test_pcm 1 9.57 0.18 here NA lex speaker_1",
    "test_pcm 1 10.55 0.18 just NA lex speaker_1",
    "test_pcm 1 10.75 0.28 opened NA lex speaker_1",
    "test_pcm 1 11.76 0.21 just NA lex speaker_1",
    "test_pcm 1 11.97 0.29 says NA lex speaker_1",
    "test_pcm 1 12.26 0.24 um NA lex speaker_1",
    "test_pcm 1 12.53 0.20 there's NA lex speaker_1",
    "test_pcm 1 12.74 0.58 viruses NA lex speaker_1",
    "test_pcm 1 13.32 0.45 detected NA lex speaker_1",
    "test_pcm 1 14.06 0.17 give NA lex speaker_1",
    "test_pcm 1 14.23 0.08 you NA lex speaker_1",
    "test_pcm 1 14.31 0.06 a NA lex speaker_1",
    "test_pcm 1 14.37 0.17 call NA lex speaker_1",
    "test_pcm 1 14.54 0.20 before NA lex speaker_1",
    "test_pcm 1 15.15 0.01 I NA lex speaker_1",
    "test_pcm 1 15.16 0.16 go NA lex speaker_1",
    "test_pcm 1 11.73 0.02 it NA lex speaker_1",
    "test_pcm 1 15.21 0.09 uh NA lex speaker_0",
    "test_pcm 1 17.16 0.65 Viruses NA lex speaker_0",
    "test_pcm 1 17.81 0.49 detected NA lex speaker_0",
    "test_pcm 1 18.33 0.56 Umm NA lex speaker_0",
    "test_pcm 1 18.33 0.55 um NA lex speaker_0",
    "test_pcm 1 19.16 0.17 all NA lex speaker_0",
    "test_pcm 1 19.33 0.25 right NA lex speaker_0",
    "test_pcm 1 19.96 0.22 uh NA lex speaker_0",
    "test_pcm 1 20.18 0.15 well NA lex speaker_0",
    "test_pcm 1 20.33 0.16 what NA lex speaker_0",
    "test_pcm 1 20.49 0.48 exactly NA lex speaker_0",
    "test_pcm 1 20.97 0.11 is NA lex speaker_0",
    "test_pcm 1 21.08 0.08 the NA lex speaker_0",
    "test_pcm 1 21.16 0.32 message NA lex speaker_0",
    "test_pcm 1 21.48 0.25 popping NA lex speaker_0",
    "test_pcm 1 21.73 0.29 up NA lex speaker_0",
    "test_pcm 1 22.89 0.17 Oh NA lex speaker_1",
    "test_pcm 1 23.44 0.01 uh NA lex speaker_1",
    "test_pcm 1 23.46 0.12 it NA lex speaker_1",
    "test_pcm 1 23.58 0.35 says NA lex speaker_1",
    "test_pcm 1 23.99 0.48 Trojan NA lex speaker_1",
    "test_pcm 1 24.48 0.44 virus NA lex speaker_1",
    "test_pcm 1 24.95 0.52 detected NA lex speaker_1",
    "test_pcm 1 27.11 0.37 Yeah NA lex speaker_0",
    "test_pcm 1 28.42 0.68 Yep NA lex speaker_1",
    "test_pcm 1 30.26 0.45 Uh NA lex speaker_0",
    "test_pcm 1 30.77 0.54 no NA lex speaker_0",
    "test_pcm 1 31.32 0.53 worries NA lex speaker_0",
    "test_pcm 1 31.88 0.68 ah NA lex speaker_0",
    "test_pcm 1 32.71 0.34 okay NA lex speaker_0",
    "test_pcm 1 33.50 0.29 um NA lex speaker_0",
    "test_pcm 1 33.79 0.14 where NA lex speaker_0",
    "test_pcm 1 33.93 0.06 are NA lex speaker_0",
    "test_pcm 1 33.99 0.15 you NA lex speaker_0",
    "test_pcm 1 34.14 0.51 currently NA lex speaker_0",
    "test_pcm 1 35.68 0.13 I'm NA lex speaker_1",
    "test_pcm 1 35.81 0.10 at NA lex speaker_1",
    "test_pcm 1 35.92 0.01 Kestrel NA lex speaker_1",
    "test_pcm 1 36.18 0.01 uh NA lex speaker_0",
    "test_pcm 1 36.55 0.09 uh NA lex speaker_0",
    "test_pcm 1 38.07 0.41 yeah NA lex speaker_0",
    "test_pcm 1 38.48 0.02 ahh NA lex speaker_0",
    "test_pcm 1 38.51 0.12 uh NA lex speaker_0",
    "test_pcm 1 38.63 0.26 north NA lex speaker_0",
    "test_pcm 1 38.89 0.03 or NA lex speaker_0",
    "test_pcm 1 38.95 0.38 south NA lex speaker_0",
    "test_pcm 1 40.29 0.12 South NA lex speaker_1",
    "test_pcm 1 40.41 0.39 all NA lex speaker_0",
    "test_pcm 1 42.95 0.27 right NA lex speaker_0",
    "test_pcm 1 43.26 0.53 um NA lex speaker_0",
    "test_pcm 1 45.53 0.66 just NA lex speaker_0",
    "test_pcm 1 46.38 0.42 um NA lex speaker_0",
    "test_pcm 1 46.81 0.12 for NA lex speaker_0",
    "test_pcm 1 46.93 0.09 the NA lex speaker_0",
    "test_pcm 1 47.02 0.19 sake NA lex speaker_0",
    "test_pcm 1 47.21 0.13 of NA lex speaker_0",
    "test_pcm 1 47.34 0.14 like NA lex speaker_0",
    "test_pcm 1 47.50 0.57 isolating NA lex speaker_0",
    "test_pcm 1 48.07 0.12 that NA lex speaker_0",
    "test_pcm 1 48.19 0.29 machine NA lex speaker_0",
    "test_pcm 1 48.48 0.15 and NA lex speaker_0",
    "test_pcm 1 48.64 0.17 air NA lex speaker_0",
    "test_pcm 1 48.81 0.33 gaping NA lex speaker_0",
    "test_pcm 1 49.14 0.13 it NA lex speaker_0",
    "test_pcm 1 49.27 0.30 rather NA lex speaker_0",
    "test_pcm 1 49.57 0.11 than NA lex speaker_0",
    "test_pcm 1 49.68 0.15 like NA lex speaker_0",
    "test_pcm 1 49.83 0.20 trying NA lex speaker_0",
    "test_pcm 1 50.03 0.07 to NA lex speaker_0",
    "test_pcm 1 50.10 0.32 connect NA lex speaker_0",
    "test_pcm 1 50.42 0.09 it NA lex speaker_0",
    "test_pcm 1 50.51 0.07 to NA lex speaker_0",
    "test_pcm 1 50.58 0.10 the NA lex speaker_0",
    "test_pcm 1 50.68 0.32 network NA lex speaker_0",
    "test_pcm 1 51.00 0.11 and NA lex speaker_0",
    "test_pcm 1 51.11 0.20 then NA lex speaker_0",
    "test_pcm 1 51.32 0.20 just NA lex speaker_0",
    "test_pcm 1 51.52 0.22 jump NA lex speaker_0",
    "test_pcm 1 51.74 0.23 into NA lex speaker_0",
    "test_pcm 1 51.97 0.13 it NA lex speaker_0",
    "test_pcm 1 52.55 0.19 am NA lex speaker_0",
    "test_pcm 1 52.74 0.09 i NA lex speaker_0",
    "test_pcm 1 52.83 0.14 able NA lex speaker_0",
    "test_pcm 1 52.97 0.06 to NA lex speaker_0",
    "test_pcm 1 53.05 0.15 get NA lex speaker_0",
    "test_pcm 1 53.20 0.13 you NA lex speaker_0",
    "test_pcm 1 53.33 0.12 to NA lex speaker_0",
    "test_pcm 1 53.46 0.23 take NA lex speaker_0",
    "test_pcm 1 53.69 0.24 that NA lex speaker_0",
    "test_pcm 1 53.94 0.40 to NA lex speaker_0",
    "test_pcm 1 54.36 0.39 uh NA lex speaker_0",
    "test_pcm 1 54.79 0.15 uh NA lex speaker_0",
    "test_pcm 1 54.79 0.15 a NA lex speaker_0",
    "test_pcm 1 54.94 0.37 Irfan NA lex speaker_0",
    "test_pcm 1 55.54 0.01 uh NA lex speaker_0",
    "test_pcm 1 55.56 0.34 over NA lex speaker_0",
    "test_pcm 1 55.90 0.13 at NA lex speaker_0",
    "test_pcm 1 56.04 0.29 the NA lex speaker_0",
    "test_pcm 1 56.35 0.32 training NA lex speaker_0",
    "test_pcm 1 56.67 0.19 hut NA lex speaker_0",
    "test_pcm 1 59.26 0.30 Yeah NA lex speaker_1",
    "test_pcm 1 59.56 0.01 I NA lex speaker_1",
    "test_pcm 1 59.58 0.18 have NA lex speaker_1",
    "test_pcm 1 59.76 0.40 meetings NA lex speaker_1",
    "test_pcm 1 60.69 0.29 ok NA lex speaker_1",
    "test_pcm 1 60.98 0.28 righto NA lex speaker_1",
    "test_pcm 1 61.32 0.20 no NA lex speaker_1",
    "test_pcm 1 61.81 0.04 worries NA lex speaker_1",
    "test_pcm 1 61.85 0.03 mate NA lex speaker_1",
    "test_pcm 1 62.45 0.25 Umm NA lex speaker_0",
    "test_pcm 1 62.45 0.25 um NA lex speaker_0",
    "test_pcm 1 62.75 0.30 it NA lex speaker_0",
    "test_pcm 1 63.05 0.08 is NA lex speaker_0",
    "test_pcm 1 63.41 0.19 just NA lex speaker_0",
    "test_pcm 1 63.60 0.13 the NA lex speaker_0",
    "test_pcm 1 63.73 0.34 best NA lex speaker_0",
    "test_pcm 1 64.07 0.21 rather NA lex speaker_0",
    "test_pcm 1 64.28 0.11 than NA lex speaker_0",
    "test_pcm 1 64.39 0.24 like NA lex speaker_0",
    "test_pcm 1 64.63 0.19 just NA lex speaker_0",
    "test_pcm 1 64.83 0.39 connecting NA lex speaker_0",
    "test_pcm 1 65.22 0.11 it NA lex speaker_0",
    "test_pcm 1 65.33 0.07 to NA lex speaker_0",
    "test_pcm 1 65.40 0.10 the NA lex speaker_0",
    "test_pcm 1 65.50 0.34 network NA lex speaker_0",
    "test_pcm 1 65.84 0.12 and NA lex speaker_0",
    "test_pcm 1 65.96 0.39 seeing NA lex speaker_0",
    "test_pcm 1 66.42 0.44 um NA lex speaker_0",
    "test_pcm 1 66.91 0.66 because NA lex speaker_0",
    "test_pcm 1 67.64 0.23 like NA lex speaker_0",
    "test_pcm 1 68.23 0.23 a NA lex speaker_0",
    "test_pcm 1 68.01 0.45 uh NA lex speaker_0",
    "test_pcm 1 68.66 0.20 sort NA lex speaker_0",
    "test_pcm 1 68.86 0.08 of NA lex speaker_0",
    "test_pcm 1 68.94 0.28 message NA lex speaker_0",
    "test_pcm 1 69.22 0.16 like NA lex speaker_0",
    "test_pcm 1 69.38 0.16 that NA lex speaker_0",
    "test_pcm 1 69.54 0.02 it NA lex speaker_0",
    "test_pcm 1 69.56 0.22 is NA lex speaker_0",
    "test_pcm 1 69.78 0.24 popping NA lex speaker_0",
    "test_pcm 1 70.02 0.15 up NA lex speaker_0",
    "test_pcm 1 70.17 0.11 from NA lex speaker_0",
    "test_pcm 1 70.28 0.06 the NA lex speaker_0",
    "test_pcm 1 70.34 0.25 anti NA lex speaker_0",
    "test_pcm 1 70.59 0.53 virus NA lex speaker_0",
    "test_pcm 1 71.93 0.33 um NA lex speaker_0",
    "test_pcm 1 72.27 0.25 so NA lex speaker_0",
    "test_pcm 1 72.74 0.23 like NA lex speaker_0",
    "test_pcm 1 72.98 0.36 you NA lex speaker_0",
    "test_pcm 1 73.34 0.14 don't NA lex speaker_0",
    "test_pcm 1 73.49 0.22 really NA lex speaker_0",
    "test_pcm 1 73.71 0.15 want NA lex speaker_0",
    "test_pcm 1 74.30 0.11 the NA lex speaker_0",
    "test_pcm 1 74.41 0.29 corporate NA lex speaker_0",
    "test_pcm 1 74.70 0.33 network NA lex speaker_0",
    "test_pcm 1 75.03 0.20 where NA lex speaker_0",
    "test_pcm 1 75.40 0.51 printers NA lex speaker_0",
    "test_pcm 1 75.92 0.38 servers NA lex speaker_0",
    "test_pcm 1 76.30 0.53 etc NA lex speaker_0",
    "test_pcm 1 77.03 0.10 you NA lex speaker_0",
    "test_pcm 1 77.13 0.18 know NA lex speaker_0",
    "test_pcm 1 77.32 0.21 it NA lex speaker_0",
    "test_pcm 1 77.53 0.03 would NA lex speaker_0",
    "test_pcm 1 77.56 0.08 be NA lex speaker_0",
    "test_pcm 1 77.64 0.27 better NA lex speaker_0",
    "test_pcm 1 77.99 0.14 to NA lex speaker_0",
    "test_pcm 1 78.18 0.01 oh NA lex speaker_1",
    "test_pcm 1 78.54 0.14 uh NA lex speaker_1",
    "test_pcm 1 79.08 0.24 that's NA lex speaker_1",
    "test_pcm 1 79.32 0.15 all NA lex speaker_1",
    "test_pcm 1 79.81 0.03 right NA lex speaker_1",
    "test_pcm 1 79.85 0.23 yeah NA lex speaker_1",
    "test_pcm 1 80.08 0.30 well NA lex speaker_1",
    "test_pcm 1 80.40 0.21 i'll NA lex speaker_1",
    "test_pcm 1 80.62 0.11 see NA lex speaker_1",
    "test_pcm 1 80.73 0.07 what NA lex speaker_1",
    "test_pcm 1 80.80 0.06 i NA lex speaker_1",
    "test_pcm 1 80.86 0.13 can NA lex speaker_1",
    "test_pcm 1 80.99 0.08 do NA lex speaker_1",
    "test_pcm 1 82.10 0.06 all NA lex speaker_0",
    "test_pcm 1 82.33 0.29 cheers NA lex speaker_0",
    "test_pcm 1 82.62 0.24 thank NA lex speaker_0",
    "test_pcm 1 82.86 0.17 you NA lex speaker_0",
    "test_pcm 1 83.42 0.18 yeah NA lex speaker_1",
    "test_pcm 1 83.60 0.20 sure NA lex speaker_1",
    "test_pcm 1 83.87 0.19 yep NA lex speaker_1",
    "test_pcm 1 84.84 0.04 thank NA lex speaker_1",
    "test_pcm 1 85.56 0.21 you NA lex speaker_1",
]

# Function to write CTM data to a file
def write_ctm(path, the_list):
    with open(path, "w") as outF:
        for line in the_list:
            outF.write(line + "\n")

# Assuming data_dir is defined
data_dir = "data"
write_ctm(f"{data_dir}/test_pcm.ctm", test_pcm_ctm)

# Converting word sequences and CTM file to text
word_seq_list = trans_info_dict["test_pcm"]["words"]
ctm_file_path = f"{data_dir}/test_pcm.ctm"

spk_hypothesis, mix_hypothesis = convert_word_dict_seq_to_text(word_seq_list)
spk_reference, mix_reference = convert_ctm_to_text(ctm_file_path)

print(f"spk_hypothesis: {spk_hypothesis}")
print(f"mix_hypothesis: {mix_hypothesis}\n")
print(f"spk_reference: {spk_reference}")
print(f"mix_reference: {mix_reference}")

# Calculating concatenated permutation WER and simple WER
cpWER, concat_hyp, concat_ref = concat_perm_word_error_rate(
    [spk_hypothesis], [spk_reference]
)
WER = word_error_rate([mix_hypothesis], [mix_reference])

print(f"cpWER: {cpWER[0]}")
print(f"WER: {WER}")

# Check the concatenated hypothesis and reference transcript
print(f"concat_hyp: {concat_hyp[0]}")
print(f"concat_ref: {concat_ref[0]}")

## Modifying Speaker Labels for Evaluation
Artificially flip a speaker label and recalculate cpWER and WER to check if the evaluation reflects the change.

In [None]:
import copy
word_seq_lists_modified = copy.deepcopy(word_seq_lists)
# Let's artificially flip a speaker label and check whether cpWER reflects it
word_seq_lists_modified[0][-1]['speaker'] = 'speaker_0'
print(word_seq_lists_modified[0])

spk_hypothesis_modified, mix_hypothesis_modified = convert_word_dict_seq_to_text(word_seq_lists_modified[0])

# Check that "seventy" in spk_hypothesis has been moved to speaker_0
print(f"spk_hypothesis_modified: {spk_hypothesis_modified}")
print(f"mix_hypothesis_modified: {mix_hypothesis_modified}\n")

print(f"spk_reference: {spk_reference}")
print(f"mix_reference: {mix_reference}")

# Recalculate cpWER and WER
cpWER_modified, concat_hyp, concat_ref = concat_perm_word_error_rate([spk_hypothesis_modified], [spk_reference])
WER_modified = word_error_rate([mix_hypothesis_modified], [mix_reference])

print(f"cpWER: {cpWER_modified[0]}")
print(f"WER: {WER_modified}")

## Creating a New Manifest File
Create a new manifest file for input with the reference CTM file.

In [None]:
meta = {
    'audio_filepath': AUDIO_FILENAME,
    'offset': 0,
    'duration':None,
    'label': 'infer',
    'text': '-',
    'num_speakers': 2,
    'rttm_filepath': None,
    'ctm_filepath': f"{data_dir}/test_pcm.ctm",
    'uem_filepath' : None
}

with open(os.path.join(data_dir,'input_manifest.json'),'w') as fp:
    json.dump(meta,fp)
    fp.write('\n')

cfg.diarizer.manifest_filepath = os.path.join(data_dir,'input_manifest.json')

# We need to call `make_file_lists` again to update manifest file to `asr_diar_offline` instance
asr_diar_offline.make_file_lists()

trans_info_dict = asr_diar_offline.get_transcript_with_speaker_labels(diar_hyp, word_hyp, word_ts_hyp)
session_result_dict = OfflineDiarWithASR.evaluate(hyp_trans_info_dict=trans_info_dict,
                                                    audio_file_list=asr_diar_offline.audio_file_list,
                                                    ref_ctm_file_list=asr_diar_offline.ctm_file_list)
session_result_dict['test_pcm']

print("cpWER:", session_result_dict['test_pcm']['cpWER'])
print("WER:", session_result_dict['test_pcm']['WER'])

## Downloading and Extracting ARPA Language Model
Download and extract the ARPA language model for better alignment during diarization.

In [None]:
import gzip
import shutil
def gunzip(file_path,output_path):
    with gzip.open(file_path,"rb") as f_in, open(output_path,"wb") as f_out:
        shutil.copyfileobj(f_in, f_out)
        f_in.close()
        f_out.close()

# ARPA source URL
ARPA_URL = 'https://kaldi-asr.org/models/5/4gram_big.arpa.gz'

# Path to the ARPA model
arpa_model_path = os.path.join(data_dir, '4gram_big.arpa')

# If 4gram_big.arpa already exists, don't download it again
if not os.path.exists(arpa_model_path):
    # Download the ARPA file
    arpa_gz_path = os.path.join(data_dir, '4gram_big.arpa.gz')
    wget.download(ARPA_URL, arpa_gz_path)
    gunzip(arpa_gz_path, arpa_model_path)
    os.remove(arpa_gz_path)

# Set the path to the ARPA model in the config
cfg.diarizer.asr.ctc_decoder_parameters.pretrained_language_model = arpa_model_path

## Reloading Modules and Re-running ASR with ARPA Model
Reload necessary modules and re-run ASR with the ARPA language model for better alignment.

In [None]:
import importlib
import nemo.collections.asr.parts.utils.decoder_timestamps_utils as decoder_timestamps_utils
importlib.reload(decoder_timestamps_utils) # This module should be reloaded after you install pyctcdecode.

asr_decoder_ts = ASRDecoderTimeStamps(cfg.diarizer)
asr_model = asr_decoder_ts.set_asr_model()
word_hyp, word_ts_hyp = asr_decoder_ts.run_ASR(asr_model)

print("Decoded word output dictionary: \n", word_hyp['test_pcm'])
print("Word-level timestamps dictionary: \n", word_ts_hyp['test_pcm'])

## Setting Realigning Language Model Parameters
Set parameters for realigning language model and create a new instance with these settings.

In [None]:
arpa_model_path = os.path.join(data_dir, '4gram_big.arpa')
cfg.diarizer.asr.realigning_lm_parameters.arpa_language_model = arpa_model_path
cfg.diarizer.asr.realigning_lm_parameters.logprob_diff_threshold = 1.2

import nemo.collections.asr.parts.utils.diarization_utils as diarization_utils
importlib.reload(diarization_utils) # This module should be reloaded after you install arpa.

# Create a new instance with realigning language model
asr_diar_offline = OfflineDiarWithASR(cfg.diarizer)
asr_diar_offline.word_ts_anchor_offset = asr_decoder_ts.word_ts_anchor_offset

asr_diar_offline.get_transcript_with_speaker_labels(diar_hyp, word_hyp, word_ts_hyp)

transcription_path_to_file = f"{data_dir}/pred_rttms/test_pcm.txt"
transcript = read_file(transcription_path_to_file)
pp.pprint(transcript)

# Free up Resources
*Remove any local files and free up GPU resources.*

Press the large red button below to get started! 🚀

In [None]:
import torch
# Free up GPU memory
torch.cuda.empty_cache()
print("GPU memory freed")