<a href="https://colab.research.google.com/github/hsbakshi/ml-notebooks/blob/main/whisper_diarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
from pyannote.audio import Audio


In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
%pip install -qq cog
%pip install -qq faster_whisper


In [None]:
from cog import BasePredictor, Input, Path, BaseModel
import os
import time
import wave
import torch
from faster_whisper import WhisperModel
import datetime
import contextlib
import numpy as np
from pyannote.audio import Audio
from pyannote.core import Segment
from sklearn.cluster import AgglomerativeClustering
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from typing import Any
from sklearn.metrics import silhouette_score


In [None]:

class ModelOutput(BaseModel):
    segments: Any

class Predictor(BasePredictor):
    def setup(self):
        model_name = "medium"
        self.model = WhisperModel(model_name, device="cuda", compute_type="float16")
        self.embedding_model = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb",
                                                          device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    def predict(self, audio: Path = Input(description="An audio file", default=None),
                group_segments: bool = Input(description="Group segments of the same speaker shorter apart than 2 seconds", default=True),
                num_speakers: int = Input(description="Number of speakers", ge=0, le=25, default=0),
                prompt: str = Input(description="Prompt, to be used as context", default="Some people speaking."),
                offset_seconds: int = Input(description="Offset in seconds for chunking inputs", default=0, ge=0)) -> ModelOutput:
        filepath = audio
        segments = self.speech_to_text(filepath, num_speakers, prompt, offset_seconds, group_segments)
        return ModelOutput(segments=segments, offset_seconds=offset_seconds)

    def convert_time(self, secs, offset_seconds=0):
        return datetime.timedelta(seconds=(round(secs) + offset_seconds))

    def speech_to_text(self, filepath, num_speakers, prompt="People talking.", offset_seconds=0, group_segments=True):
        time_start = time.time()
        try:
            audio_file_wav = self.convert_audio_to_wav(filepath)
            duration = self.get_audio_duration(audio_file_wav)
            segments = self.transcribe_audio(audio_file_wav, prompt)
            segments = self.convert_segments(segments)
            embeddings = self.create_embeddings(segments, audio_file_wav, duration)
            speaker_count = self.find_speaker_count(embeddings, num_speakers)
            output = self.assign_speaker_labels(segments, embeddings, speaker_count, offset_seconds, group_segments)
            time_end = time.time()
            time_diff = time_end - time_start
            system_info = f"Processing time: {time_diff:.5} seconds"
            print(system_info)
            os.remove(audio_file_wav)
            return output
        except Exception as e:
            os.remove(audio_file_wav)
            raise RuntimeError("Error running inference with local model", e)

    def convert_audio_to_wav(self, filepath):
        file_ending = os.path.splitext(f'{filepath}')[-1]
        print(f'File ending: "{file_ending}"')
        if file_ending != '.wav':
            audio_file_wav = str(filepath).replace(file_ending, ".wav")
            print("Starting conversion to wav")
            os.system(f'ffmpeg -i "{filepath}" -ar 16000 -ac 1 -c:a pcm_s16le "{audio_file_wav}"')
        else:
            audio_file_wav = filepath
        return audio_file_wav

    def get_audio_duration(self, audio_file_wav):
        with contextlib.closing(wave.open(audio_file_wav, 'r')) as f:
            frames = f.getnframes()
            rate = f.getframerate()
            duration = frames / float(rate)
        print(f"Conversion to wav ready, duration of audio file: {duration}")
        return duration

    def transcribe_audio(self, audio_file_wav, prompt):
        print("Starting whisper")
        options = dict(beam_size=5, best_of=5)
        transcribe_options = dict(task="transcribe",
                                  word_timestamps=True,
                                  vad_filter=True,
                                  **options)
        print(prompt)
        segments, _ = self.model.transcribe(audio_file_wav, **transcribe_options, initial_prompt=prompt)
        print("Done with whisper")
        result = list(segments)
        print(f"Sample segment::{result[0]}")
        return result

    def convert_segments(self, segments):
        return [
            {
                'start': int(s.start),
                'end': int(s.end),
                'text': s.text,
                'words': s.words
            }
            for s in segments]

    def create_embeddings(self, segments, audio_file_wav, duration):
        print("Starting embedding")
        embeddings = np.zeros(shape=(len(segments), 192))
        audio = Audio()
        for i, segment in enumerate(segments):
            waveform, sample_rate = audio.crop(audio_file_wav, Segment(segment["start"], min(duration, segment["end"])))
            embeddings[i] = self.embedding_model(waveform[None])
        embeddings = np.nan_to_num(embeddings)
        print(f'Embedding shape: {embeddings.shape}')
        return embeddings
    
    def find_speaker_count(self, embeddings, speaker_count_override):
        if speaker_count_override == 0:
            # Find the best number of speakers
            score_num_speakers = {}
            for num_speakers in range(2, 10):
                clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
                score = silhouette_score(embeddings, clustering.labels_, metric='euclidean')
                score_num_speakers[num_speakers] = score
            best_num_speaker = max(score_num_speakers, key=lambda x:score_num_speakers[x])
            print(f"The best number of speakers: {best_num_speaker} with {score_num_speakers[best_num_speaker]} score")
        else:
            best_num_speaker = speaker_count_override
        return best_num_speaker

    def assign_speaker_labels(self, segments, embeddings, num_speakers, offset_seconds, group_segments):
        clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
        labels = clustering.labels_
        for i in range(len(segments)):
            segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
        output = []
        current_group = {
            'start': str(round(segments[0]["start"] + offset_seconds)),
            'end': str(round(segments[0]["end"] + offset_seconds)),
            'speaker': segments[0]["speaker"],
            'text': segments[0]["text"],
            'words': segments[0]["words"]
        }
        for i in range(1, len(segments)):
            time_gap = segments[i]["start"] - segments[i - 1]["end"]
            if segments[i]["speaker"] == segments[i - 1]["speaker"] and time_gap <= 2 and group_segments:
                current_group["end"] = str(round(segments[i]["end"] + offset_seconds))
                current_group["text"] += " " + segments[i]["text"]
            else:
                output.append(current_group)
                current_group = {
                    'start': str(round(segments[i]["start"] + offset_seconds)),
                    'end': str(round(segments[i]["end"] + offset_seconds)),
                    'speaker': segments[i]["speaker"],
                    'text': segments[i]["text"],
                    'words': segments[i]["words"]
                }
        output.append(current_group)
        print("Embedding complete")
        return output
        

# New Section

# New Section

Setting up the Predictor class

In [None]:
predictor = Predictor()
predictor.setup()

Run the test case

In [None]:
import google.colab
own_file, _ = google.colab.files.upload().popitem()
OWN_FILE = {'audio': own_file}



In [None]:
predictor.predict(OWN_FILE['audio'], group_segments=False, num_speakers=0, prompt="People talking", offset_seconds=0)