<a href="https://colab.research.google.com/github/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installing Dependencies

In [1]:
import os
import wget
from omegaconf import OmegaConf
import json
import shutil
from faster_whisper import WhisperModel
import whisperx
import torch
import librosa
import soundfile
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
from deepmultilingualpunctuation import PunctuationModel
import re
import logging

[NeMo W 2023-05-23 10:02:36 optimizers:54] Apex was not found. Using the lamb or fused_adam optimizer will error out.
[NeMo W 2023-05-23 10:02:37 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.


# Helper Functions

In [3]:
punct_model_langs = [
    "en",
    "fr",
    "de",
    "es",
    "it",
    "nl",
    "pt",
    "bg",
    "pl",
    "cs",
    "sk",
    "sl",
]
wav2vec2_langs = [
    "en",
    "fr",
    "de",
    "es",
    "it",
    "nl",
    "pt",
    "ja",
    "zh",
    "uk",
    "pt",
    "ar",
    "ru",
    "pl",
    "hu",
    "fi",
    "fa",
    "el",
    "tr",
]


def create_config(output_dir):
    DOMAIN_TYPE = "telephonic"  # Can be meeting or telephonic based on domain type of the audio file
    CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
    CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
    MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)
    if not os.path.exists(MODEL_CONFIG):
        MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)

    config = OmegaConf.load(MODEL_CONFIG)

    data_dir = os.path.join(output_dir, "data")
    os.makedirs(data_dir, exist_ok=True)

    config.batch_size = 16
    meta = {
        "audio_filepath": os.path.join(output_dir, "mono_file.wav"),
        "offset": 0,
        "duration": None,
        "label": "infer",
        "text": "-",
        "rttm_filepath": None,
        "uem_filepath": None,
    }
    with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
        json.dump(meta, fp)
        fp.write("\n")

    pretrained_vad = "vad_multilingual_marblenet"
    pretrained_speaker_model = "titanet_large"

    config.num_workers = 1  # Workaround for multiprocessing hanging with ipython issue

    config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
    config.diarizer.out_dir = (
        output_dir  # Directory to store intermediate files and prediction outputs
    )

    config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
    config.diarizer.oracle_vad = (
        False  # compute VAD provided with model_path to vad config
    )
    config.diarizer.clustering.parameters.oracle_num_speakers = True
    config.diarizer.clustering.parameters.num_speakers = 2

    # Here, we use our in-house pretrained NeMo VAD model
    config.diarizer.vad.model_path = pretrained_vad
    config.diarizer.vad.parameters.onset = 0.8
    config.diarizer.vad.parameters.offset = 0.6
    config.diarizer.vad.parameters.pad_offset = -0.05
    config.diarizer.msdd_model.model_path = (
        "diar_msdd_telephonic"  # Telephonic speaker diarization model
    )

    return config


def get_word_ts_anchor(s, e, option="start"):
    if option == "end":
        return e
    elif option == "mid":
        return (s + e) / 2
    return s


def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
    s, e, sp = spk_ts[0]
    wrd_pos, turn_idx = 0, 0
    wrd_spk_mapping = []
    for wrd_dict in wrd_ts:
        ws, we, wrd = (
            int(wrd_dict["start"] * 1000),
            int(wrd_dict["end"] * 1000),
            wrd_dict["text"],
        )
        wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
        while wrd_pos > float(e):
            turn_idx += 1
            turn_idx = min(turn_idx, len(spk_ts) - 1)
            s, e, sp = spk_ts[turn_idx]
            if turn_idx == len(spk_ts) - 1:
                e = get_word_ts_anchor(ws, we, option="end")
        wrd_spk_mapping.append(
            {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
        )
    return wrd_spk_mapping


sentence_ending_punctuations = ".?!"


def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
    is_word_sentence_end = (
        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
    )
    left_idx = word_idx
    while (
        left_idx > 0
        and word_idx - left_idx < max_words
        and speaker_list[left_idx - 1] == speaker_list[left_idx]
        and not is_word_sentence_end(left_idx - 1)
    ):
        left_idx -= 1

    return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1


def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
    is_word_sentence_end = (
        lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
    )
    right_idx = word_idx
    while (
        right_idx < len(word_list)
        and right_idx - word_idx < max_words
        and not is_word_sentence_end(right_idx)
    ):
        right_idx += 1

    return (
        right_idx
        if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
        else -1
    )


def get_realigned_ws_mapping_with_punctuation(
    word_speaker_mapping, max_words_in_sentence=50
):
    is_word_sentence_end = (
        lambda x: x >= 0
        and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
    )
    wsp_len = len(word_speaker_mapping)

    words_list, speaker_list = [], []
    for k, line_dict in enumerate(word_speaker_mapping):
        word, speaker = line_dict["word"], line_dict["speaker"]
        words_list.append(word)
        speaker_list.append(speaker)

    k = 0
    while k < len(word_speaker_mapping):
        line_dict = word_speaker_mapping[k]
        if (
            k < wsp_len - 1
            and speaker_list[k] != speaker_list[k + 1]
            and not is_word_sentence_end(k)
        ):
            left_idx = get_first_word_idx_of_sentence(
                k, words_list, speaker_list, max_words_in_sentence
            )
            right_idx = (
                get_last_word_idx_of_sentence(
                    k, words_list, max_words_in_sentence - k + left_idx - 1
                )
                if left_idx > -1
                else -1
            )
            if min(left_idx, right_idx) == -1:
                k += 1
                continue

            spk_labels = speaker_list[left_idx : right_idx + 1]
            mod_speaker = max(set(spk_labels), key=spk_labels.count)
            if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
                k += 1
                continue

            speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
                right_idx - left_idx + 1
            )
            k = right_idx

        k += 1

    k, realigned_list = 0, []
    while k < len(word_speaker_mapping):
        line_dict = word_speaker_mapping[k].copy()
        line_dict["speaker"] = speaker_list[k]
        realigned_list.append(line_dict)
        k += 1

    return realigned_list


def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
    s, e, spk = spk_ts[0]
    prev_spk = spk

    snts = []
    snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}

    for wrd_dict in word_speaker_mapping:
        wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
        s, e = wrd_dict["start_time"], wrd_dict["end_time"]
        if spk != prev_spk:
            snts.append(snt)
            snt = {
                "speaker": f"Speaker {spk}",
                "start_time": s,
                "end_time": e,
                "text": "",
            }
        else:
            snt["end_time"] = e
        snt["text"] += wrd + " "
        prev_spk = spk

    snts.append(snt)
    return snts


def get_speaker_aware_transcript(sentences_speaker_mapping, f):
    for sentence_dict in sentences_speaker_mapping:
        sp = sentence_dict["speaker"]
        text = sentence_dict["text"]
        f.write(f"\n\n{sp}: {text}")


def format_timestamp(
    milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
    assert milliseconds >= 0, "non-negative timestamp expected"

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return (
        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
    )


def write_srt(transcript, file):
    """
    Write a transcript to a file in SRT format.

    """
    for i, segment in enumerate(transcript, start=1):
        # write srt lines
        print(
            f"{i}\n"
            f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
            f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
            f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )


def cleanup(path: str):
    """path could either be relative or absolute."""
    # check if file or directory exists
    if os.path.isfile(path) or os.path.islink(path):
        # remove file
        os.remove(path)
    elif os.path.isdir(path):
        # remove directory and all its content
        shutil.rmtree(path)
    else:
        raise ValueError("Path {} is not a file or dir.".format(path))

# Options

In [4]:
# Name of the audio file
audio_path = 'HI/HelloInternet-11.mp3'

# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram
enable_stemming = True

# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large')
whisper_model_name = 'medium.en'

# Processing

## Separating music from speech using Demucs

---

By isolating the vocals from the rest of the audio, it becomes easier to identify and track individual speakers based on the spectral and temporal characteristics of their speech signals. Source separation is just one of many techniques that can be used as a preprocessing step to help improve the accuracy and reliability of the overall diarization process.

In [5]:
if enable_stemming:
    # Isolate vocals from the rest of the audio

    return_code = os.system(
        f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{audio_path}" -o "temp_outputs"'
    )

    if return_code != 0:
        logging.warning(
            "Source splitting failed, using original audio file."
        )
        vocal_target = audio_path
    else:
        vocal_target = os.path.join(
            "temp_outputs", "htdemucs", os.path.splitext(audio_path)[0], "vocals.wav"
        )
else:
    vocal_target = audio_path

Selected model is a bag of 1 models. You will see that many progress bars per track.
Separated tracks will be stored in /home/omer/Workspace/whisper-diarization/temp_outputs/htdemucs
Separating track HI/HelloInternet-11.mp3


100%|████████████████████████████████████████████████████████████████████| 7353.45/7353.45 [03:19<00:00, 36.80seconds/s]
Killed


In [5]:
vocal_target = audio_path

## Transcriping audio using Whisper and realligning timestamps using Wav2Vec2
---
This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.

The first model is called OpenAI Whisper, which is a speech recognition model that can transcribe speech with high accuracy. The code loads the whisper model and uses it to transcribe the vocal_target file.

The output of the transcription process is a set of text segments with corresponding timestamps indicating when each segment was spoken.


In [6]:
# Run on GPU with FP16
whisper_model = WhisperModel(whisper_model_name, device="cuda", compute_type="float16")

# or run on GPU with INT8
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_size, device="cpu", compute_type="int8")

segments, info = whisper_model.transcribe(
    vocal_target, beam_size=1, word_timestamps=True
)
whisper_results = []
for segment in segments:
    whisper_results.append(segment._asdict())
# clear gpu vram
del whisper_model
torch.cuda.empty_cache()

In [7]:
whisper_results[0]

{'id': 1,
 'seek': 2204,
 'start': 0.0,
 'end': 2.94,
 'text': ' I want to keep turning away because the microphone.',
 'tokens': [50363,
  314,
  765,
  284,
  1394,
  6225,
  1497,
  780,
  262,
  21822,
  13,
  50545],
 'temperature': 0.0,
 'avg_logprob': -0.38712222562279813,
 'compression_ratio': 1.46524064171123,
 'no_speech_prob': 0.0946044921875,
 'words': [Word(start=0.0, end=0.2, word=' I', probability=0.74169921875),
  Word(start=0.2, end=0.28, word=' want', probability=0.27734375),
  Word(start=0.28, end=0.42, word=' to', probability=0.99169921875),
  Word(start=0.42, end=0.64, word=' keep', probability=0.99755859375),
  Word(start=0.64, end=1.16, word=' turning', probability=0.99853515625),
  Word(start=1.16, end=1.76, word=' away', probability=0.998046875),
  Word(start=1.76, end=2.28, word=' because', probability=0.52099609375),
  Word(start=2.28, end=2.46, word=' the', probability=0.66357421875),
  Word(start=2.46, end=2.94, word=' microphone.', probability=0.9985351562

## Aligning the transcription with the original audio using Wav2Vec2
---
The second model used is called wav2vec2, which is a large-scale neural network that is designed to learn representations of speech that are useful for a variety of speech processing tasks, including speech recognition and alignment.

The code loads the wav2vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.

By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification.

If there's no Wav2Vec2 model available for your language, word timestamps generated by whisper will be used instead.

In [8]:
if info.language in wav2vec2_langs:
    device = "cuda"
    alignment_model, metadata = whisperx.load_align_model(
        language_code=info.language, device=device
    )
    result_aligned = whisperx.align(
        whisper_results, alignment_model, metadata, vocal_target, device
    )
    word_timestamps = result_aligned["word_segments"]
    # clear gpu vram
    del alignment_model
    torch.cuda.empty_cache()
else:
    word_timestamps = []
    for segment in whisper_results:
        for word in segment["words"]:
            word_timestamps.append({"text": word[2], "start": word[0], "end": word[1]})

Failed to align segment (" Yes."): backtrack failed, resorting to original...
Failed to align segment (" I don't know."): backtrack failed, resorting to original...
Failed to align segment (" Bought a couple of those."): backtrack failed, resorting to original...
Failed to align segment (" I don't know."): backtrack failed, resorting to original...
Failed to align segment: duration smaller than 0.02s time precision
Failed to align segment (" Yeah."): backtrack failed, resorting to original...


## Convert audio to mono for NeMo combatibility

In [6]:
signal, sample_rate = librosa.load(vocal_target, sr=None)
ROOT = os.getcwd()
temp_path = os.path.join(ROOT, "temp_outputs")
os.makedirs(temp_path, exist_ok=True)
soundfile.write(os.path.join(temp_path, "mono_file.wav"), signal, sample_rate, "PCM_24")

## Speaker Diarization using NeMo MSDD Model
---
This code uses a model called Nvidia NeMo MSDD (Multi-scale Diarization Decoder) to perform speaker diarization on an audio signal. Speaker diarization is the process of separating an audio signal into different segments based on who is speaking at any given time.

In [8]:
# Initialize NeMo MSDD diarization model
msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to("cuda")
msdd_model.diarize()

del msdd_model
torch.cuda.empty_cache()

[NeMo I 2023-05-23 10:08:05 msdd_models:1092] Loading pretrained diar_msdd_telephonic model from NGC
[NeMo I 2023-05-23 10:08:05 cloud:58] Found existing object /home/omer/.cache/torch/NeMo/NeMo_1.17.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2023-05-23 10:08:05 cloud:64] Re-using file from: /home/omer/.cache/torch/NeMo/NeMo_1.17.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo
[NeMo I 2023-05-23 10:08:05 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2023-05-23 10:08:05 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: true
    
[NeMo W 2023-05-23 10:08:05 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: false
    
[NeMo W 2023-05-23 10:08:05 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple

[NeMo I 2023-05-23 10:08:05 features:287] PADDING: 16
[NeMo I 2023-05-23 10:08:05 features:287] PADDING: 16
[NeMo I 2023-05-23 10:08:06 save_restore_connector:247] Model EncDecDiarLabelModel was successfully restored from /home/omer/.cache/torch/NeMo/NeMo_1.17.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2023-05-23 10:08:06 features:287] PADDING: 16
[NeMo I 2023-05-23 10:08:06 clustering_diarizer:127] Loading pretrained vad_multilingual_marblenet model from NGC
[NeMo I 2023-05-23 10:08:06 cloud:58] Found existing object /home/omer/.cache/torch/NeMo/NeMo_1.17.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
[NeMo I 2023-05-23 10:08:06 cloud:64] Re-using file from: /home/omer/.cache/torch/NeMo/NeMo_1.17.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo
[NeMo I 2023-05-23 10:08:06 common:913] Instantiating model from pre-trained checkpoint


[NeMo W 2023-05-23 10:08:06 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/ami_train_0.63.json,/manifests/freesound_background_train.json,/manifests/freesound_laughter_train.json,/manifests/fisher_2004_background.json,/manifests/fisher_2004_speech_sampled.json,/manifests/google_train_manifest.json,/manifests/icsi_all_0.63.json,/manifests/musan_freesound_train.json,/manifests/musan_music_train.json,/manifests/musan_soundbible_train.json,/manifests/mandarin_train_sample.json,/manifests/german_train_sample.json,/manifests/spanish_train_sample.json,/manifests/french_train_sample.json,/manifests/russian_train_sample.json
    sample_rate: 16000
    labels:
    - background
    - speech
    batch_size: 256
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: sca

[NeMo I 2023-05-23 10:08:06 features:287] PADDING: 16
[NeMo I 2023-05-23 10:08:07 save_restore_connector:247] Model EncDecClassificationModel was successfully restored from /home/omer/.cache/torch/NeMo/NeMo_1.17.0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
[NeMo I 2023-05-23 10:08:07 msdd_models:864] Multiscale Weights: [1, 1, 1, 1, 1]
[NeMo I 2023-05-23 10:08:07 msdd_models:865] Clustering Parameters: {
        "oracle_num_speakers": 2,
        "max_num_speakers": 8,
        "enhanced_count_thres": 80,
        "max_rp_threshold": 0.25,
        "sparse_search_volume": 30,
        "maj_vote_spk_count": false
    }


[NeMo W 2023-05-23 10:08:07 clustering_diarizer:411] Deleting previous clustering diarizer outputs.


[NeMo I 2023-05-23 10:08:07 speaker_utils:93] Number of files to diarize: 1
[NeMo I 2023-05-23 10:08:07 clustering_diarizer:309] Split long audio file to avoid CUDA memory issue


splitting manifest: 100%|██████████| 1/1 [00:04<00:00,  4.23s/it]

[NeMo I 2023-05-23 10:08:11 vad_utils:101] The prepared manifest file exists. Overwriting!
[NeMo I 2023-05-23 10:08:11 classification_models:263] Perform streaming frame-level VAD
[NeMo I 2023-05-23 10:08:11 collections:298] Filtered duration for loading collection is 0.000000.
[NeMo I 2023-05-23 10:08:11 collections:301] Dataset loaded with 147 items, total duration of  2.05 hours.
[NeMo I 2023-05-23 10:08:11 collections:303] # 147 files loaded accounting to # 1 labels



vad: 100%|██████████| 147/147 [00:50<00:00,  2.94it/s]

[NeMo I 2023-05-23 10:09:01 clustering_diarizer:250] Generating predictions with overlapping input segments



                                                               

[NeMo I 2023-05-23 10:09:52 clustering_diarizer:262] Converting frame level prediction to speech/no-speech segment in start and end times format.


creating speech segments: 100%|██████████| 1/1 [00:04<00:00,  4.44s/it]


[NeMo I 2023-05-23 10:09:57 clustering_diarizer:287] Subsegmentation for embedding extraction: scale0, /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/subsegments_scale0.json
[NeMo I 2023-05-23 10:09:57 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-05-23 10:09:57 collections:298] Filtered duration for loading collection is 0.000000.
[NeMo I 2023-05-23 10:09:57 collections:301] Dataset loaded with 7151 items, total duration of  2.62 hours.
[NeMo I 2023-05-23 10:09:57 collections:303] # 7151 files loaded accounting to # 1 labels


[1/5] extract embeddings: 100%|██████████| 447/447 [00:13<00:00, 32.01it/s]


[NeMo I 2023-05-23 10:10:11 clustering_diarizer:389] Saved embedding files to /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/embeddings
[NeMo I 2023-05-23 10:10:11 clustering_diarizer:287] Subsegmentation for embedding extraction: scale1, /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/subsegments_scale1.json
[NeMo I 2023-05-23 10:10:11 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-05-23 10:10:12 collections:298] Filtered duration for loading collection is 0.000000.
[NeMo I 2023-05-23 10:10:12 collections:301] Dataset loaded with 8666 items, total duration of  2.72 hours.
[NeMo I 2023-05-23 10:10:12 collections:303] # 8666 files loaded accounting to # 1 labels


[2/5] extract embeddings: 100%|██████████| 542/542 [00:14<00:00, 36.42it/s]


[NeMo I 2023-05-23 10:10:27 clustering_diarizer:389] Saved embedding files to /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/embeddings
[NeMo I 2023-05-23 10:10:27 clustering_diarizer:287] Subsegmentation for embedding extraction: scale2, /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/subsegments_scale2.json
[NeMo I 2023-05-23 10:10:27 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-05-23 10:10:27 collections:298] Filtered duration for loading collection is 0.000000.
[NeMo I 2023-05-23 10:10:27 collections:301] Dataset loaded with 10885 items, total duration of  2.81 hours.
[NeMo I 2023-05-23 10:10:27 collections:303] # 10885 files loaded accounting to # 1 labels


[3/5] extract embeddings: 100%|██████████| 681/681 [00:18<00:00, 37.79it/s]


[NeMo I 2023-05-23 10:10:46 clustering_diarizer:389] Saved embedding files to /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/embeddings
[NeMo I 2023-05-23 10:10:46 clustering_diarizer:287] Subsegmentation for embedding extraction: scale3, /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/subsegments_scale3.json
[NeMo I 2023-05-23 10:10:46 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-05-23 10:10:46 collections:298] Filtered duration for loading collection is 0.000000.
[NeMo I 2023-05-23 10:10:46 collections:301] Dataset loaded with 14675 items, total duration of  2.91 hours.
[NeMo I 2023-05-23 10:10:46 collections:303] # 14675 files loaded accounting to # 1 labels


[4/5] extract embeddings: 100%|██████████| 918/918 [00:23<00:00, 38.66it/s]


[NeMo I 2023-05-23 10:11:11 clustering_diarizer:389] Saved embedding files to /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/embeddings
[NeMo I 2023-05-23 10:11:11 clustering_diarizer:287] Subsegmentation for embedding extraction: scale4, /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/subsegments_scale4.json
[NeMo I 2023-05-23 10:11:11 clustering_diarizer:343] Extracting embeddings for Diarization
[NeMo I 2023-05-23 10:11:11 collections:298] Filtered duration for loading collection is 0.000000.
[NeMo I 2023-05-23 10:11:11 collections:301] Dataset loaded with 22492 items, total duration of  3.03 hours.
[NeMo I 2023-05-23 10:11:11 collections:303] # 22492 files loaded accounting to # 1 labels


[5/5] extract embeddings: 100%|██████████| 1406/1406 [00:36<00:00, 38.37it/s]


[NeMo I 2023-05-23 10:11:52 clustering_diarizer:389] Saved embedding files to /home/omer/Workspace/whisper-diarization/temp_outputs/speaker_outputs/embeddings


clustering:   0%|          | 0/1 [00:00<?, ?it/s]


ValueError: Provided option as oracle num of speakers but num_speakers in manifest is null

## Mapping Spekers to Sentences According to Timestamps

In [None]:
# Reading timestamps <> Speaker Labels mapping

speaker_ts = []
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
    lines = f.readlines()
    for line in lines:
        line_list = line.split(" ")
        s = int(float(line_list[5]) * 1000)
        e = s + int(float(line_list[8]) * 1000)
        speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])

wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")

## Realligning Speech segments using Punctuation
---

This code provides a method for disambiguating speaker labels in cases where a sentence is split between two different speakers. It uses punctuation markings to determine the dominant speaker for each sentence in the transcription.

```
Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are
Speaker B: going to suck, right? So it's actually it hits you on both sides.
```

For example, if a sentence is split between two speakers, the code takes the mode of speaker labels for each word in the sentence, and uses that speaker label for the whole sentence. This can help to improve the accuracy of speaker diarization, especially in cases where the Whisper model may not take fine utterances like "hmm" and "yeah" into account, but the Diarization Model (Nemo) may include them, leading to inconsistent results.

The code also handles cases where one speaker is giving a monologue while other speakers are making occasional comments in the background. It ignores the comments and assigns the entire monologue to the speaker who is speaking the majority of the time. This provides a robust and reliable method for realigning speech segments to their respective speakers based on punctuation in the transcription.

In [None]:
if info.language in punct_model_langs:
    # restoring punctuation in the transcript to help realign the sentences
    punct_model = PunctuationModel(model="kredor/punctuate-all")

    words_list = list(map(lambda x: x["word"], wsm))

    labled_words = punct_model.predict(words_list)

    ending_puncts = ".?!"
    model_puncts = ".,;:!?"

    # We don't want to punctuate U.S.A. with a period. Right?
    is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)

    for word_dict, labeled_tuple in zip(wsm, labled_words):
        word = word_dict["word"]
        if (
            word
            and labeled_tuple[1] in ending_puncts
            and (word[-1] not in model_puncts or is_acronym(word))
        ):
            word += labeled_tuple[1]
            if word.endswith(".."):
                word = word.rstrip(".")
            word_dict["word"] = word

    

    wsm = get_realigned_ws_mapping_with_punctuation(wsm)
else:
    print(
        f'Punctuation restoration is not available for {whisper_results["language"]} language.'
    )

ssm = get_sentences_speaker_mapping(wsm, speaker_ts)

## Cleanup and Exporing the results

In [None]:
with open(f"{audio_path[:-4]}.txt", "w", encoding="utf-8-sig") as f:
    get_speaker_aware_transcript(ssm, f)

with open(f"{audio_path[:-4]}.srt", "w", encoding="utf-8-sig") as srt:
    write_srt(ssm, srt)

cleanup(temp_path)

In [None]:
import os
import json
import soundfile as sf
from pydub import AudioSegment


In [None]:

os.makedirs('segments_dir', exist_ok=True)

# Read the diarization output
with open('temp_outputs/pred_rttms/mono_file.rttm', 'r') as f:
    lines = f.readlines()

# Load the audio file
audio = AudioSegment.from_wav('temp_outputs/mono_file.wav')

# Sample rate of the audio file
sample_rate = sf.info('temp_outputs/mono_file.wav').samplerate

# List to store the manifest data
manifest_data = []

for line in lines:
    fields = line.split()
    speaker = fields[7]
    start_time = float(fields[3]) * 1000  # convert to milliseconds
    end_time = start_time + (float(fields[4]) * 1000)  # convert to milliseconds

    # Slice the audio segment
    segment = audio[start_time:end_time]

    # Save the audio segment
    segment_path = f"segments_dir/{speaker}_{start_time}_{end_time}.wav"
    segment.export(segment_path, format='wav')

    duration = (end_time - start_time) / 1000  # convert back to seconds

    # discard samples less than 1 second
    if duration < 1.0:
        continue
    
    # Append data to the manifest
    manifest_data.append({
        'audio_filepath': segment_path,
        'duration': duration,
        'label': "Grey" if speaker == "speaker_0" else "Brady",
        'text': 'N/A'
    })

# Save the manifest file
with open('manifest.json', 'w') as f:
    json.dump(manifest_data, f, indent=4)


In [None]:
import json
import random

# Set the seed so that the split is reproducible
random.seed(123)

# Read the original manifest file
with open('manifest.json', 'r') as f:
    manifest = json.loads(f.read())

# Shuffle the manifest
random.shuffle(manifest)

# Split into train, dev and test
total = len(manifest)
train_cutoff = int(total * 0.8)  # 80% for training
dev_cutoff = int(total * 0.9)  # 10% for dev, 10% for test

train_manifest = manifest[:train_cutoff]
dev_manifest = manifest[train_cutoff:dev_cutoff]
test_manifest = manifest[dev_cutoff:]

# Write out the split manifests
with open('train_manifest.json', 'w') as f:
    for item in train_manifest:
        f.write(json.dumps(item) + '\n')

with open('dev_manifest.json', 'w') as f:
    for item in dev_manifest:
        f.write(json.dumps(item) + '\n')

with open('test_manifest.json', 'w') as f:
    for item in test_manifest:
        f.write(json.dumps(item) + '\n')
