In [None]:
%pip install pyannote.metrics
%pip install pyannote.core

In [None]:
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.core import Annotation, Timeline, Segment
import os

FOLDERPATH_TO_TARGET_RTTM = 'outputs/pyannote_community_1/AMI/'
FOLDERPATH_TO_REFERENCE = '/home/digitalhub/Desktop/data/diarization_benchmarks/ami_dataset/rttm/'

target_rttm_files = [f for f in os.listdir(FOLDERPATH_TO_TARGET_RTTM) if f.endswith(".rttm")]

def read_rttm_file_into_annotation(rttm_path):
    '''
    Read rttm file into pyannote annotation class
    
    :param rttm_path: path to rttm file
    :return: speaker segments in pyannote annotation class
    '''

    speaker_segments = Annotation()
    
    with open(rttm_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if parts[0] == "SPEAKER": 
                speaker = parts[7]
                start_time = float(parts[3])
                duration = float(parts[4])
                end_time = start_time + duration

                speaker_segments[Segment(round(start_time, 2), round(end_time,2))] = speaker

    return speaker_segments

def compute_der(hyp_rttm_path, ref_rttm_path):
    """
    Compute Diarization Error Rate (DER) allowing overlap.
    
    :param ref_rttm_path: Ground truth diarization rttm path (str)
    :param hyp_rttm_path: Predicted diarization rttm path (str)
    :return: DER score (float)
    """
    metric = DiarizationErrorRate(skip_overlap=False)
    hypothesis = read_rttm_file_into_annotation(hyp_rttm_path)
    reference = read_rttm_file_into_annotation(ref_rttm_path)

    der_score = metric(reference, hypothesis)
    return der_score

scores=[]
count = 0

for rttm_file in target_rttm_files:

    rttm_filepath = FOLDERPATH_TO_TARGET_RTTM + rttm_file
    reference_rttm_path = FOLDERPATH_TO_REFERENCE + rttm_file

    DER_score = compute_der(rttm_filepath, reference_rttm_path)

    scores.append(DER_score)

    count+=1
    if count % 10 == 0:
        print(f"We are at {count}!")

print(f"Average DER score : {sum(scores)/len(scores)}")



We are at 10!
We are at 20!
We are at 30!
We are at 40!
We are at 50!
We are at 60!
We are at 70!
We are at 80!
We are at 90!
We are at 100!
We are at 110!
We are at 120!
We are at 130!
We are at 140!
We are at 150!
We are at 160!
We are at 170!
Average DER score : 0.1648839377808588
