# Pyannote-Whisper model for AMS

In [1]:
! pip install git+https://github.com/openai/whisper.git
! pip install -q git+https://github.com/pyannote/pyannote-audio
! pip install onnxruntime
! pip install git+https://github.com/thomasmol/faster-whisper.git@master
! pip install -U huggingface_hub transformers

Collecting git+https://github.com/openai/whisper.git
  Cloning https://github.com/openai/whisper.git to /tmp/pip-req-build-y9jz24qj
  Running command git clone --filter=blob:none --quiet https://github.com/openai/whisper.git /tmp/pip-req-build-y9jz24qj
  Resolved https://github.com/openai/whisper.git to commit 279133e3107392276dc509148da1f41bfb532c7e
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tiktoken (from openai-whisper==20231117)
  Downloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting triton>=2.0.0 (from openai-whisper==20231117)
  Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [2]:
import os, sys
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [19]:
path = '/content/drive/MyDrive/Test/3.mp3'# upload the sample to google driver
#path = os.getcwd() # to get current path
print(path)

/content/drive/MyDrive/Test/3.mp3


In [17]:
# Import necessary libraries
import whisper
from pyannote.audio import Pipeline
from pyannote.core import Segment, Annotation, Timeline

# Define helper functions
def get_text_with_timestamp(transcribe_res):
    timestamp_texts = []
    for item in transcribe_res['segments']:
        start = item['start']
        end = item['end']
        text = item['text']
        timestamp_texts.append((Segment(start, end), text))
    return timestamp_texts

def add_speaker_info_to_text(timestamp_texts, ann):
    spk_text = []
    for seg, text in timestamp_texts:
        spk = ann.crop(seg).argmax()
        spk_text.append((seg, spk, text))
    return spk_text

def merge_cache(text_cache):
    sentence = ''.join([item[-1] for item in text_cache])
    spk = text_cache[0][1]
    start = text_cache[0][0].start
    end = text_cache[-1][0].end
    return Segment(start, end), spk, sentence

PUNC_SENT_END = ['.', '?', '!']

def merge_sentence(spk_text):
    merged_spk_text = []
    pre_spk = None
    text_cache = []
    for seg, spk, text in spk_text:
        if spk != pre_spk and pre_spk is not None and len(text_cache) > 0:
            merged_spk_text.append(merge_cache(text_cache))
            text_cache = [(seg, spk, text)]
            pre_spk = spk
        elif text and len(text) > 0 and text[-1] in PUNC_SENT_END:
            text_cache.append((seg, spk, text))
            merged_spk_text.append(merge_cache(text_cache))
            text_cache = []
            pre_spk = spk
        else:
            text_cache.append((seg, spk, text))
            pre_spk = spk
    if len(text_cache) > 0:
        merged_spk_text.append(merge_cache(text_cache))
    return merged_spk_text

def diarize_text(transcribe_res, diarization_result):
    timestamp_texts = get_text_with_timestamp(transcribe_res)
    spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result)
    res_processed = merge_sentence(spk_text)
    return res_processed


In [22]:
import whisper
from pyannote.audio import Pipeline
from pyannote.core import Segment
import csv
import torch

import onnxruntime as ort

print(ort.get_device())

# Main processing workflow
def process_audio(audio_file, auth_token):
    # Initialize pipeline and model
    print("Initializing models...")
    try:
        pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1",
                                            use_auth_token=auth_token)
        model = whisper.load_model("tiny.en")# You can change Whisper model, tiny→medium are english-only, large and large-v2 are Multilingual
        pipeline.to(torch.device("cuda"))
    except Exception as e:
        print(f"Error during model initialization: {e}")
        return None

    # Perform transcription and speaker diarization
    print("Processing audio for transcription and diarization...")
    try:
        asr_result = model.transcribe(audio_file)
        diarization_result = pipeline(audio_file)
    except Exception as e:
        print(f"Error during processing: {e}")
        return None

    #Merge results
    print("Merging transcription and diarization results...")
    final_result = diarize_text(asr_result, diarization_result)

    if not final_result:
        print("No results obtained after merging.")
    else:
        print("Processing complete.")

    return final_result

# Usage example
audio_file = path # Update with your actual sample path
auth_token = "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"  # Set up your token in Hugging Face
output_csv = "/content/drive/MyDrive/output.csv"  # Specify the CSV file path

# Process the audio and get the results
result = process_audio(audio_file, auth_token)

# Check if the result is valid before attempting to print and save
if result:
    print("Saving the final result to CSV...")
    try:
        with open(output_csv, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Start Time", "End Time", "Speaker", "Text"])
            for seg, spk, sent in result:
                line = f'{seg.start:.2f} {seg.end:.2f} {spk} {sent}'
                print(line)
                writer.writerow([f'{seg.start:.2f}', f'{seg.end:.2f}', spk, sent])
        print(f"Results successfully saved to {output_csv}")
    except Exception as e:
        print(f"Error saving CSV file: {e}")
else:
    print("No result to print or save.")


GPU
Initializing models...
Processing audio for transcription and diarization...
Merging transcription and diarization results...
Processing complete.
Saving the final result to CSV...
0.00 11.44 SPEAKER_00  Hello, my name is Rajesh and I'm here to help you with all of your issues.
11.44 18.72 SPEAKER_00  Hey Rajesh, we've been using your product for a few weeks now and today it crashed and we cannot launch it anymore.
18.72 22.12 SPEAKER_00  Okay sir, thank you for letting me know.
22.12 24.12 SPEAKER_00  And what am I supposed to do now?
24.12 27.80 SPEAKER_00  What is the priority of the issue?
27.80 29.24 SPEAKER_00  What do you mean the priority?
29.24 31.12 SPEAKER_00  It's urgent.
31.12 42.32 SPEAKER_00  Priority from 1 to 10 where one means the severity is very low and I can close the ticket now and 10 means the severity is very high and I cannot close the ticket yet.
42.32 44.36 SPEAKER_00  It's 20.
44.36 46.08 SPEAKER_00  20 meaning?
46.08 49.88 SPEAKER_00  Meaning it's peopl

# Model metrics
We use WER and Rouge to measure the accuracy ASR, Speaker segmentation accuray to measure diarization.

In [32]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

! pip install rouge_score
! pip install jiwer



In [37]:
import csv
import string
from rouge_score import rouge_scorer
from jiwer import wer  # For WER calculation

# Function to clean text content by removing punctuation and converting to lowercase. This is to delete the format influence.
def clean_text(text):
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator).lower()

# Function to read a CSV file and return lists of cleaned text and speaker information
def read_csv(file_path):
    texts = []
    speakers = []
    with open(file_path, mode='r') as file:
        reader = csv.reader(file)
        next(reader)  # Skip the header row
        for row in reader:
            text = clean_text(row[3])  # Clean the speaker's content (text)
            speaker = row[2]            # Keep speaker label as is
            speakers.append(speaker)   # Append speaker information
            texts.append(text)         # Append cleaned text content
    return texts, speakers

# Function to concatenate all text segments for WER and ROUGE calculation
def concatenate_texts(texts):
    return ' '.join(texts)

# Function to calculate WER for the entire transcription
def calculate_wer_concatenated(machine_texts, human_texts):
    concatenated_machine = concatenate_texts(machine_texts)
    concatenated_human = concatenate_texts(human_texts)

    # Calculate WER for the concatenated texts
    total_wer = wer(concatenated_human, concatenated_machine)
    print(f"WER for concatenated transcription: {total_wer:.4f}")
    return total_wer

# Function to calculate ROUGE scores for the entire transcription
def calculate_rouge_concatenated(machine_texts, human_texts):
    concatenated_machine = concatenate_texts(machine_texts)
    concatenated_human = concatenate_texts(human_texts)

    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    # Calculate ROUGE scores for concatenated texts
    scores = scorer.score(concatenated_human, concatenated_machine)

    print("ROUGE scores for concatenated transcription:")
    print(f"ROUGE-1: {scores['rouge1'].fmeasure:.4f}")
    print(f"ROUGE-2: {scores['rouge2'].fmeasure:.4f}")
    print(f"ROUGE-L: {scores['rougeL'].fmeasure:.4f}")
    return scores

# Function to save the concatenated scores to a CSV file
def save_concatenated_scores_to_csv(file_path, wer_score, rouge_scores, segmentation_accuracy):
    with open(file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Metric', 'Score'])
        writer.writerow(['WER', f"{wer_score:.4f}"])
        writer.writerow(['ROUGE-1', f"{rouge_scores['rouge1'].fmeasure:.4f}"])
        writer.writerow(['ROUGE-2', f"{rouge_scores['rouge2'].fmeasure:.4f}"])
        writer.writerow(['ROUGE-L', f"{rouge_scores['rougeL'].fmeasure:.4f}"])
        writer.writerow(['Speaker Segmentation Accuracy', f"{segmentation_accuracy:.4f}"])

# Example usage
machine_texts, machine_speakers = read_csv('/content/drive/MyDrive/output.csv') #change to the compared machine_transcription.
human_texts, human_speakers = read_csv('/content/drive/MyDrive/3.csv')#change to the compared ground truth.

#print(f"Length of machine_speakers: {len(machine_speakers)}")  # Print the length of machine_speakers
#print(f"Length of human_speakers: {len(human_speakers)}")  # Print the length of human_speakers


# Calculate WER for concatenated text
wer_score = calculate_wer_concatenated(machine_texts, human_texts)

# Calculate ROUGE scores for concatenated text
rouge_scores = calculate_rouge_concatenated(machine_texts, human_texts)

# Compare speaker segmentation
segmentation_accuracy = compare_speakers(machine_speakers, human_speakers)

# Save concatenated scores to a CSV file
save_concatenated_scores_to_csv('/content/drive/MyDrive/concatenated_score_output.csv', wer_score, rouge_scores, segmentation_accuracy)

WER for concatenated transcription: 0.0272
ROUGE scores for concatenated transcription:
ROUGE-1: 0.9797
ROUGE-2: 0.9678
ROUGE-L: 0.9797
Speaker Segmentation Accuracy: 0.1875 (15/80)


# Optional: Agent Identification
We can identify the agent by using fuzzy greeting match. But this is not very accurate and now we do not consider the mutiple persons conversation. This is still an infant model.

In [None]:
# Import necessary libraries
import whisper
from pyannote.audio import Pipeline
from pyannote.core import Segment, Annotation, Timeline

# Define helper functions
def get_text_with_timestamp(transcribe_res):
    timestamp_texts = []
    for item in transcribe_res['segments']:
        start = item['start']
        end = item['end']
        text = item['text']
        timestamp_texts.append((Segment(start, end), text))
    return timestamp_texts

def add_speaker_info_to_text(timestamp_texts, ann):
    spk_text = []
    for seg, text in timestamp_texts:
        spk = ann.crop(seg).argmax()
        spk_text.append((seg, spk, text))
    return spk_text

def merge_cache(text_cache):
    sentence = ''.join([item[-1] for item in text_cache])
    spk = text_cache[0][1]
    start = text_cache[0][0].start
    end = text_cache[-1][0].end
    return Segment(start, end), spk, sentence

PUNC_SENT_END = ['.', '?', '!']

def merge_sentence(spk_text):
    merged_spk_text = []
    pre_spk = None
    text_cache = []
    for seg, spk, text in spk_text:
        if spk != pre_spk and pre_spk is not None and len(text_cache) > 0:
            merged_spk_text.append(merge_cache(text_cache))
            text_cache = [(seg, spk, text)]
            pre_spk = spk
        elif text and len(text) > 0 and text[-1] in PUNC_SENT_END:
            text_cache.append((seg, spk, text))
            merged_spk_text.append(merge_cache(text_cache))
            text_cache = []
            pre_spk = spk
        else:
            text_cache.append((seg, spk, text))
            pre_spk = spk
    if len(text_cache) > 0:
        merged_spk_text.append(merge_cache(text_cache))
    return merged_spk_text

def diarize_text(transcribe_res, diarization_result):
    timestamp_texts = get_text_with_timestamp(transcribe_res)
    spk_text = add_speaker_info_to_text(timestamp_texts, diarization_result)
    res_processed = merge_sentence(spk_text)
    return res_processed

# Function to identify agent based on specific phrases (with fuzzy matching)
def identify_agent(merged_spk_text):
    agent_phrases = ["Thank you for calling", "How may I help you"]# Can expand the fuzzy matching scope
    agent_speaker = None

    for seg, spk, text in merged_spk_text:
        for phrase in agent_phrases:
             if phrase.lower() in text.lower():  # Fuzzy matching for the phrase
                agent_speaker = spk
                break
        if agent_speaker:
            break

    return agent_speaker

# If there is no match in agent, we stil leave the label as speaker xx.
# Function to map speaker labels
def map_speaker_labels(merged_spk_text, agent_speaker=None):
    speaker_labels = {}
    if agent_speaker is not None:
        speaker_labels[agent_speaker] = "Agent"
        # Assuming binary speakers; adjust if more speakers are possible
        other_speakers = set(spk for _, spk, _ in merged_spk_text if spk != agent_speaker)
        for spk in other_speakers:
            speaker_labels[spk] = "Customer"
    else:
        # Assign generic labels like Speaker 1, Speaker 2, etc.
        unique_speakers = sorted(set(spk for _, spk, _ in merged_spk_text))
        for idx, spk in enumerate(unique_speakers, start=1):
            speaker_labels[spk] = f"Speaker {idx}"

    # Apply labels to the merged_spk_text
    labeled_text = []
    for seg, spk, text in merged_spk_text:
        label = speaker_labels.get(spk, f"Speaker {spk}")
        labeled_text.append((seg, label, text))

    return labeled_text

In [None]:
import whisper
from pyannote.audio import Pipeline
from pyannote.core import Segment
import csv
import torch

import onnxruntime as ort

print(ort.get_device())

# Takes 3 minutes to run a audio file
# Main processing workflow
def process_audio(audio_file, auth_token):
    # Initialize pipeline and model
    print("Initializing models...")
    try:
        pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1",
                                            use_auth_token=auth_token)
        model = whisper.load_model("tiny.en")# change this version
        pipeline.to(torch.device("cuda"))
        #distil_small_en = hf_hub_download(repo_id="distil-whisper/distil-small.en", filename="original-model.bin")
        #model = load_model(distil_small_en)
    except Exception as e:
        print(f"Error during model initialization: {e}")
        return None

    # Perform transcription and speaker diarization
    print("Processing audio for transcription and diarization...")
    try:
        asr_result = model.transcribe(audio_file)
        diarization_result = pipeline(audio_file)
    except Exception as e:
        print(f"Error during processing: {e}")
        return None

    # Merge results
    #print("Merging transcription and diarization results...")
    #final_result = diarize_text(asr_result, diarization_result)

     # Merge results
    print("Merging transcription and diarization results...")
    merged_spk_text = diarize_text(asr_result, diarization_result)


    # Identify the agent
    print("Identifying the agent...")
    agent_speaker = identify_agent(merged_spk_text)

    # Assume two speaker conversation. Assign speaker roles as "Agent" or "Customer".
    final_result = []
    for seg, spk, text in merged_spk_text:
        if spk == agent_speaker:
            speaker_role = "Agent"
        else:
            speaker_role = "Customer"
        final_result.append((seg, speaker_role, text))


    if not final_result:
        print("No results obtained after merging.")
    else:
        print("Processing complete.")

    return final_result

# Usage example
audio_file = path # Update with your actual file path
auth_token = "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"  # Set up your token in Hugging Face
output_csv = "/content/drive/MyDrive/output/output.csv"  # Specify the CSV file path

# Process the audio and get the results
result = process_audio(audio_file, auth_token)

# Check if the result is valid before attempting to print and save
if result:
    print("Saving the final result to CSV...")
    try:
        with open(output_csv, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Start Time", "End Time", "Speaker", "Text"])
            for seg, spk, sent in result:
                line = f'{seg.start:.2f} {seg.end:.2f} {spk} {sent}'
                print(line)
                writer.writerow([f'{seg.start:.2f}', f'{seg.end:.2f}', spk, sent])
        print(f"Results successfully saved to {output_csv}")
    except Exception as e:
        print(f"Error saving CSV file: {e}")
else:
    print("No result to print or save.")