# Evaluate Overlapped Speech Detection (OSD) and Voice Activity Detection (VAD)

evaluate OSD/VAD predictions obtained via wav2vec2_audioFrameClassification.py or wav2vec2_audioFrameClassification_multitask.py

evaluation is done via pyannote.metrics (https://pyannote.github.io/)

evaluation on AMI corpus supports pyannote.db.odessa.ami (https://github.com/pyannote/pyannote-db-odessa-ami) - 
use task="AMI_OSD_dev" / "AMI_VAD_dev" (ODESSA/AMI development set) or task="AMI_OSD_test" / "AMI_VAD_test" (ODESSA/AMI test set)

other datasets require reference annotations in text format (one txt file per audio):

- the format we use is a sequence of space-delimited integers, 100 values per second (one for every 10ms of audio); first value is the total number of labeled audio frames, then 1 = overlap/speech; 0 = non-overlap/non-speech (e.g. "61500 0 0 0 0 1 1 1 1 0 0 0 \[...\]")

(Note: many of the function names and variables refer to "overlaps", but the same code also works for VAD (or any other similar detection task), as long as the references are in the same format)

### How to use this notebook:
- run all cells until you reach "Evaluation starts here"
- pick one of the options for eval (custom references or pyannote.db)
- change the paths and settings there and run the cells

### Changelog:
2023-02-17
- fixed miss/FA values when evaluating using pyannote.db.odessa.ami - previously, the miss/FA values were accidentally swapped (watch out for this when loading previous saved results!)


In [None]:
from pyannote.core import Annotation, Timeline, Segment
from pyannote.metrics.detection import (
    DetectionCostFunction,
    DetectionErrorRate,
    DetectionPrecision,
    DetectionRecall,
    DetectionAccuracy,
    DetectionPrecisionRecallFMeasure
)

import os
import re
import numpy as np

from matplotlib import pyplot as plt

from tabulate import tabulate

In [None]:
# global settings

MIN_OL_LEN = 0 # min. duration of detected overlaps/speech intervals - anything shorter is ignored
MIN_NON_OL_LEN = 0 # min. duration of gaps between two overlaps/speech intervals - if shorter, they are merged

In [None]:
# loads OSD or VAD references from a text file
#  (the format we use is a sequence of space-delimited integers, 100 values per second (one for every 10ms of audio);
#  first value is the total number of labeled audio frames, then 1 = overlap/speech; 0 = non-overlap/non-speech
#   (e.g. "61500 0 0 0 0 1 1 1 1 0 0 0 [...]")

def get_overlap_reference(ref_file,frameRate=100,offset=0):

    with open(ref_file,'r') as file:
        ref = file.readline().strip().split(" ")
    
    if len(ref) == 0:
        print("Failed to read anything from file %s"%(ref_file))
    
    ref = list(map(float, ref))


    num_values = int(ref[0])

    uem = Segment(0,num_values / frameRate + offset)
    
    reference = get_annotation_from_labels(ref[1:],frameRate,offset,apply_min_lens=False)
 
    return reference,uem

In [None]:
# converts filenames of audio segments into unique identifiers of the original audio files

def get_audio_id(filename):
    # obtains the ID of an audio file by repeatedly removing specific substrings from the end
    #  (and only the end, to avoid false matches)
    #  (e.g. "IS1009c.Mix-Headset_t700-710.wav" -> "IS1009c")
    basename = os.path.basename(filename)
    
    suffixesToRemove = ['.wav','.mp4','.dereverb','.denoise','.OMEETING',
                        '.OHALLWAY','.OOFFICE','.meeting','.booth','.office',
                        '_OMEETING','_OHALLWAY','_OOFFICE','_meeting','_booth','_office',
                        '.Mix-Headset'] 
    
    regexToRemove = ['(\.|_)\d+kHz',
                    '_t[0-9]+(\.[0-9]+)?\-[0-9]+(\.[0-9]+)?']
    
    found = True
    while found:
        found = False
        for suffix in suffixesToRemove:
            if basename.endswith(suffix):
                basename = basename[:-(len(suffix))]
                found = True
        for rgx in regexToRemove:
            match = re.search(rgx+"$",basename) # +"$" ensures we only find a match at the *end* 
            if match is not None:
                basename = basename[:match.start()]
                found = True
    
    return basename

In [None]:
# get OSD or VAD hypotheses for each threshold, in pyannote's preferred format

def get_annotation_from_labels(labels,frameRate,offset=0,threshold=0.5,annotation=None,apply_min_lens=False):
    num_values = len(labels)
    
    if annotation is None:
        annotation = Annotation()
        
    overlaps_list = []
    
    nOverlaps = 0
    
    isOverlap = False
    for i in range(num_values):
        if not isOverlap and labels[i] >= threshold:
            startTime = i / frameRate + offset
            isOverlap = True
        if isOverlap and labels[i] < threshold:
            endTime = i / frameRate + offset
            isOverlap = False
            
            if apply_min_lens and nOverlaps > 0 and startTime - overlaps_list[-1][1] < MIN_NON_OL_LEN: # if the interval between overlaps is too short
                overlaps_list[-1][1] = endTime # merge the overlaps
            else:
                overlaps_list.append([startTime,endTime]) # else add a new overlap
            
    if isOverlap: # if the file ends with an overlap, add it too
        endTime = num_values / frameRate + offset
        if apply_min_lens and nOverlaps > 0 and startTime - overlaps_list[-1][1] < MIN_NON_OL_LEN: # if the interval between overlaps is too short
            overlaps_list[-1][1] = endTime # merge the overlaps
        else:
            overlaps_list.append([startTime,endTime]) # else add a new overlap
        
    for startTime,endTime in overlaps_list:
        if (not apply_min_lens) or endTime - startTime >= MIN_OL_LEN:
            annotation[Segment(startTime, endTime)] = 'overlap'
        
    return annotation

def get_overlap_hypotheses_w2v2(test_file,frameRate=50,audio_id_list=None,offset=0.02,
                                threshold=0.5,duration=None,shift=None):
    # reads predicted labels from w2v2 output, stitches them back together 
    # and turns them into hypotheses in the format required by pyannote
    
    predictions_all = []
    
    with open(test_file,'r') as file:
        for line in file:
            line = line.strip()
            if len(line) == 0:
                continue
            path,values_str = line.split(",")
            audio_filename = os.path.basename(path)
            
            audio_id = get_audio_id(audio_filename)
            if (audio_id_list is not None) and not (audio_id in audio_id_list):
                continue # skip files that are not on the list
            
            match = re.search('_t[0-9]+(\.[0-9]+)?\-[0-9]+(\.[0-9]+)?',audio_filename)
            if match is not None:
                str = match.group()
                startTime,endTime = str[2:].split('-')
                startTime = float(startTime)
                endTime = float(endTime)
            else:
                startTime = None
                endTime = None

            idx1 = values_str.find('[')
            idx2 = values_str.find(']')
            
            values = values_str[idx1+1:idx2].split(" ")
            values = list(map(float, values))
            
            predictions_all.append({
                "path":path,
                "values": values,
                "id": audio_id,
                "startTime": startTime,
                "endTime": endTime
            })
            
    labels_all = {}
    for prediction in predictions_all:
        audio_id = prediction["id"]
        if audio_id in labels_all:
            startTime = prediction["startTime"]
            endTime = prediction["endTime"]
            if startTime is not None and endTime is not None:
                
                startIdx = round(startTime * frameRate)
                endIdx = round(endTime * frameRate)
          
                if duration is not None and shift is not None:
                    overlap = duration - shift # this "overlap" refers to overlap between audio segments
                    halfoverlap_frames = round(shift * frameRate / 2)
                    startIdx = startIdx + halfoverlap_frames
                    
                    if startIdx >= endIdx: # this can happen in the very last segment
                        #print("start time %f, end time %f"%(startTime,endTime))
                        continue # just skip this file completely
                    
                    newLabels = prediction["values"][halfoverlap_frames:]
                    
                    if endIdx - startIdx != len(newLabels):
                        # w2v labels are one value shorter than time*frameRate => off by one will probably happen every time...
                        #  ... and the last segment of each conversation will obviously have a much bigger difference
                        endIdx = startIdx + len(newLabels)
                    labels_all[audio_id][startIdx:endIdx] = newLabels
                else:
                    if len(prediction["values"]) != endIdx - startIdx:
                        endIdx = startIdx + len(prediction["values"])
                    labels_all[audio_id][startIdx:endIdx] = prediction["values"]
            else:
                # Note: "+ [0]" adds the one missing label (w2v outputs only 999 labels for 20s audio)
                labels_all[audio_id] = labels_all[audio_id] + prediction["values"] + [0]
        else:
            labels_all[audio_id] = prediction["values"]
            
            
    hypotheses = {}
    for audio_id in labels_all:
        hypotheses[audio_id] = get_annotation_from_labels(labels_all[audio_id],
                frameRate=frameRate,offset=offset,threshold=threshold,apply_min_lens=True)
    
    return hypotheses
            
    
def get_overlap_hypotheses_CNN(test_dir,frameRate,offset,threshold=0.5,audio_id_list=None):
    # reads predicted labels from our old CNN-based overlap detection (Kunesova et al., 2019)
    # and turns them into hypotheses in the format required by pyannote
    
    overlaps = []
    
    for filename in os.listdir(test_dir):
        if filename.endswith('.bin'):
            path = os.path.join(test_dir,filename)
            with open(path,'rb') as file:
                values = np.fromfile(file, dtype=np.float32)
                
                match = re.search('_t[0-9]+(\.[0-9]+)?\-[0-9]+(\.[0-9]+)?',filename)
                if match is not None:
                    str = match.group()
                    startTime,endTime = str[2:].split('-')
                    startTime = float(startTime)
                    endTime = float(endTime)
                else:
                    startTime = 0
                    endTime = None
                
                audio_id = os.path.basename(filename)
                audio_id = get_audio_id(audio_id[0:-4])
                
                if (audio_id_list is not None) and not (audio_id in audio_id_list):
                    continue
                    
                #print(audio_id)

                overlaps.append({
                    "path":path,
                    "values": values,
                    "id": audio_id,
                    "startTime": startTime
                })

    hypotheses = {}   
      
    for overlap in overlaps:
        audio_id = overlap["id"]
        if audio_id in hypotheses:
            hypotheses[audio_id] = get_annotation_from_labels(overlap["values"],
                frameRate,offset + overlap["startTime"],threshold,
                annotation=hypotheses[audio_id],apply_min_lens=True)
        else:
            hypotheses[audio_id] = get_annotation_from_labels(overlap["values"],
                frameRate,offset + overlap["startTime"],threshold,apply_min_lens=True)
    return hypotheses

In [None]:
# calculate precision, recall, etc for all files and thresholds from a specific set of hypotheses
#  (version with custom references in text format)
def evaluate_overlaps(hypotheses,thresholds,main_thresh,ref_suffix="_overlaps.txt"):
    
    nThresh = len(thresholds)
    
    detectionErrorRate = DetectionErrorRate()
    detectionAccuracy = DetectionAccuracy()
    detectionPrecision = DetectionPrecision()
    detectionRecall = DetectionRecall()
    detectionFMeasure = DetectionPrecisionRecallFMeasure()

    results = [["threshold","error_rate","accuracy","precision","recall","fscore","TP-FP delta"]]
    results_pf = [["audio_id","error_rate","accuracy","precision","recall","fscore","TP-FP delta"]]
    reference_all = None

    miss_total = [0] * nThresh
    FA_total = [0] * nThresh
    miss_pctg_total = [0] * nThresh
    FA_pctg_total = [0] * nThresh
    overlap_total = [0] * nThresh
    DetErr_avg = [0] * nThresh

    TN_total = [0] * nThresh
    TP_total = [0] * nThresh
    FN_total = [0] * nThresh
    FP_total = [0] * nThresh
    accuracy_avg = [0] * nThresh
    TP_FP_delta = [0] * nThresh
    TP_FP_delta_total = [0] * nThresh
    TP_FP_delta_avg = [0] * nThresh

    retrieved_total = [0] * nThresh
    relevant_total = [0] * nThresh
    RR_total = [0] * nThresh
    precision_avg = [0] * nThresh
    recall_avg = [0] * nThresh

    fscore_avg = [0] * nThresh
    
    DetErr_total = [0] * nThresh
    accuracy_total = [0] * nThresh
    precision_total = [0] * nThresh
    recall_total = [0] * nThresh
    fscore_total = [0] * nThresh
    
    frames_total = [0] * nThresh
    
    nfiles = 0

    for audio_id in hypotheses[0].keys():
        print(audio_id)
        
        ref_file = os.path.join(REF_DIR, audio_id + ref_suffix)
        if not os.path.exists(ref_file):
            print('File ''%s'' not found. Skipping.'%ref_file)
            continue
        
        nfiles = nfiles + 1
        
        reference,uem = get_overlap_reference(ref_file,frameRate=100,offset=0)
        
        for iThresh in range(nThresh): 

            # Detection Error Rate

            hypothesis = hypotheses[iThresh][audio_id]

            error_rate_details = detectionErrorRate(reference, hypothesis, detailed=True, uem=uem)
            error_rate = error_rate_details["detection error rate"]

            miss_total[iThresh] = miss_total[iThresh] + error_rate_details["miss"]
            FA_total[iThresh] = FA_total[iThresh] + error_rate_details["false alarm"]
            overlap_total[iThresh] = overlap_total[iThresh] + error_rate_details["total"]
            DetErr_avg[iThresh] = DetErr_avg[iThresh] + error_rate

            # Accuracy

            accuracy_details = detectionAccuracy(reference, hypothesis, detailed=True, uem=uem)
            accuracy = accuracy_details["detection accuracy"]

            TP = accuracy_details["true positive"]
            FP = accuracy_details["false positive"]
            TN = accuracy_details["true negative"]
            FN = accuracy_details["false negative"]
            TP_FP_delta = TP - FP
            
            TN_total[iThresh] = TN_total[iThresh] + TN
            TP_total[iThresh] = TP_total[iThresh] + TP
            FN_total[iThresh] = FN_total[iThresh] + FN
            FP_total[iThresh] = FP_total[iThresh] + FP
            accuracy_avg[iThresh] = accuracy_avg[iThresh] + accuracy
            TP_FP_delta_avg[iThresh] = (TP - FP) / (TP + FP + TN + FN)
            
            TP_FP_delta_total[iThresh] += TP_FP_delta
            

            # Precision

            precision_details = detectionPrecision(reference, hypothesis, detailed=True, uem=uem)
            precision = precision_details["detection precision"]

            retrieved_total[iThresh] = retrieved_total[iThresh] + precision_details["retrieved"]
            RR_total[iThresh] = RR_total[iThresh] + precision_details["relevant retrieved"]
            precision_avg[iThresh] = precision_avg[iThresh] + precision

            # Recall

            recall_details = detectionRecall(reference, hypothesis, detailed=True, uem=uem)
            recall = recall_details["detection recall"]

            relevant_total[iThresh] = relevant_total[iThresh] + recall_details["relevant"]
            recall_avg[iThresh] = recall_avg[iThresh] + recall

            # F-score

            fscore_details = detectionFMeasure(reference, hypothesis, detailed=True, uem=uem)
            fscore = fscore_details["F[precision|recall]"]

            fscore_avg[iThresh] = fscore_avg[iThresh] + fscore

            #print(error_rate_details)
            #print(accuracy_details)
            #print(precision_details)
            #print(recall_details)
            #print(fscore_details)
            #print(costFcn_details)
            
            if thresholds[iThresh] == main_thresh:
                results_pf.append([audio_id,error_rate,accuracy,precision,recall,fscore,TP_FP_delta])



    #nfiles = len(hypotheses[0])
    
    for iThresh in range(nThresh): 

        DetErr_avg[iThresh] /= nfiles
        accuracy_avg[iThresh] /= nfiles
        precision_avg[iThresh] /= nfiles
        recall_avg[iThresh] /= nfiles
        fscore_avg[iThresh] /= nfiles
        TP_FP_delta_avg[iThresh] /= nfiles
        
        frames_total[iThresh] = TP_total[iThresh] + TN_total[iThresh] + FP_total[iThresh] + FN_total[iThresh]
        
        
        TP_FP_delta_total[iThresh] /= frames_total[iThresh]

        if overlap_total[iThresh] == 0:
            if (miss_total[iThresh] + FA_total[iThresh]) == 0:
                DetErr_total[iThresh] = 0
            else:
                DetErr_total[iThresh] = 9999999
        else:
            DetErr_total[iThresh] = (miss_total[iThresh] + FA_total[iThresh]) / overlap_total[iThresh]
 
        accuracy_total[iThresh] = (TP_total[iThresh] + TN_total[iThresh]) / (TP_total[iThresh] + TN_total[iThresh] + FP_total[iThresh] + FN_total[iThresh])
    
        miss_pctg_total[iThresh] = FN_total[iThresh] / overlap_total[iThresh]
        FA_pctg_total[iThresh] = FN_total[iThresh] / overlap_total[iThresh]
    
    
        if retrieved_total[iThresh] == 0:
            precision_total[iThresh] = 1
        else:
            precision_total[iThresh] = RR_total[iThresh] / retrieved_total[iThresh]
            
        if relevant_total[iThresh] == 0:
            recall_total[iThresh] = 1
        else:
            recall_total[iThresh] = RR_total[iThresh] / relevant_total[iThresh]
            
        if (precision_total[iThresh] + recall_total[iThresh]) == 0:
            fscore_total[iThresh] = 0
        else:
            fscore_total[iThresh] = 2 * (precision_total[iThresh] * recall_total[iThresh]) / (precision_total[iThresh] + recall_total[iThresh])

        if thresholds[iThresh] == main_thresh:
            results_pf.append(["AVG",DetErr_avg[iThresh],accuracy_avg[iThresh],precision_avg[iThresh],recall_avg[iThresh],fscore_avg[iThresh],TP_FP_delta_avg[iThresh]])
            results_pf.append(["TOTAL",DetErr_total[iThresh],accuracy_total[iThresh],precision_total[iThresh],recall_total[iThresh],fscore_total[iThresh],TP_FP_delta_total[iThresh]])

        results.append([thresholds[iThresh],DetErr_total[iThresh],accuracy_total[iThresh],precision_total[iThresh],recall_total[iThresh],fscore_total[iThresh],TP_FP_delta_total[iThresh]])

    
    stats = {}
    stats["thresholds"] = thresholds
    #stats["error_rate"] = DetErr_avg
    #stats["accuracy"] = accuracy_avg
    #stats["precision"] = precision_avg
    #stats["recall"] = recall_avg
    #stats["fscore"] = fscore_avg
    
    stats["error_rate"] = DetErr_total
    stats["accuracy"] = accuracy_total
    stats["precision"] = precision_total
    stats["recall"] = recall_total
    stats["fscore"] = fscore_total
    stats["TP-FP delta"] = TP_FP_delta_total
    stats["miss"] = miss_pctg_total
    stats["FA"] = FA_pctg_total
    
    return results, stats, results_pf
    

In [None]:
# calculate VAD or OSD precision, recall, etc for all files and thresholds from a specific set of hypotheses
#  (version using pyannote.db.odessa.ami, AMI corpus only)
# FIXED 2023-02-17: previously, the miss/FA values were accidentally swapped
from pyannote.database import get_protocol

def evaluate_AMI(hypotheses,thresholds,main_thresh,task):
    
    print('Evaluation with pyannote refs')
    
    protocol = get_protocol('AMI.SpeakerDiarization.MixHeadset')
    
    detectionErrorRate = DetectionErrorRate()
    detectionAccuracy = DetectionAccuracy()
    detectionPrecision = DetectionPrecision()
    detectionRecall = DetectionRecall()
    detectionFMeasure = DetectionPrecisionRecallFMeasure()
    
    # references for OSD have to be converted, otherwise the function evaluates VAD
    task_is_overlap = (task in ["AMI_OSD_dev","AMI_OSD_test"])
    
    results = [["threshold","error_rate","accuracy","precision","recall","fscore","miss", "FA"]]
    
    DetErr_total = []
    accuracy_total = []
    precision_total = []
    recall_total = []
    fscore_total = []
    miss_total = []
    FA_total = []
    
    nThresh = len(thresholds)
    for iThresh in range(nThresh): 
        if thresholds[iThresh] == main_thresh: 
            display = True
            print("Results for threshold %g:"%main_thresh)
        else:
            display = False

        if task[-4:] == "_dev":
            AMI_set = protocol.development()
        else:
            AMI_set = protocol.test()
            
        # iterate over each file of the test/dev set
        for test_file in AMI_set:
            
            audio_id = get_audio_id(test_file["uri"])
            #print(audio_id)
            
            hypothesis = hypotheses[iThresh][audio_id]
            
            # evaluate hypothesis
            if task_is_overlap:
                # references for overlap
                reference = to_overlap(test_file)
            else:
                # references for VAD
                reference = test_file['annotation']
            
            uem = test_file['annotated']
            detectionErrorRate(reference, hypothesis, uem=uem)
            detectionAccuracy(reference, hypothesis, uem=uem)
            detectionPrecision(reference, hypothesis, uem=uem)
            detectionRecall(reference, hypothesis, uem=uem)
            detectionFMeasure(reference, hypothesis, uem=uem)
            
        err = detectionErrorRate.report(display=display)
        acc = detectionAccuracy.report(display=display)
        prec = detectionPrecision.report(display=display)
        rec = detectionRecall.report(display=display)
        F = detectionFMeasure.report(display=display)
        
        err = err.values.tolist()
        DetErr_total.append(err[-1][0])  
        FA_total.append(err[-1][3]) # FIXED 2023-02-17: previously, the miss/FA values were accidentally swapped
        miss_total.append(err[-1][5]) 
        
        acc = acc.values.tolist()
        accuracy_total.append(acc[-1][0])
        
        prec = prec.values.tolist()
        precision_total.append(prec[-1][0])
        
        rec = rec.values.tolist()
        recall_total.append(rec[-1][0])
        
        F = F.values.tolist()
        fscore_total.append(F[-1][0])
        
        
        
        results.append(
            [thresholds[iThresh],DetErr_total[iThresh],accuracy_total[iThresh],
             precision_total[iThresh],recall_total[iThresh],fscore_total[iThresh],
             miss_total[iThresh],FA_total[iThresh]]
        )
        
        detectionErrorRate.reset()
        detectionAccuracy.reset()
        detectionPrecision.reset()
        detectionRecall.reset()
        detectionFMeasure.reset()
        
        
    stats = {}
    stats["thresholds"] = thresholds
    
    stats["error_rate"] = DetErr_total
    stats["accuracy"] = accuracy_total
    stats["precision"] = precision_total
    stats["recall"] = recall_total
    stats["fscore"] = fscore_total
        
    return results,stats,None


# Note: to_overlap() is taken from pyannote/metrics/cli.py by H. Bredin
#   (https://github.com/pyannote/pyannote-metrics/blob/develop/pyannote/metrics/cli.py)
def to_overlap(current_file: dict) -> Annotation: 
    """Get overlapped speech reference annotation 
  
    Parameters 
    ---------- 
    current_file : `dict` 
        File yielded by pyannote.database protocols. 
  
    Returns 
    ------- 
    overlap : `pyannote.core.Annotation` 
        Overlapped speech reference. 
        
    """ 
  
    reference = current_file["annotation"] 
    overlap = Timeline(uri=reference.uri) 
    for (s1, t1), (s2, t2) in reference.co_iter(reference): 
        l1 = reference[s1, t1] 
        l2 = reference[s2, t2] 
        if l1 == l2: 
            continue 
        overlap.add(s1 & s2) 
    return overlap.support().to_annotation() 
                

In [None]:
# evaluate multiple sets of predictions
#   (stats are saved so that they don't need to be recalculated every time)
# TODO: check if thresholds match the saved ones, recalculate (only) the missing stats if not

import pickle

def evaluate_all(prediction_files, thresholds=None, ref_suffix="_overlaps.txt",
                main_thresh=0.5,audio_id_list=None,duration=20,shift=10,
                stats_file_suffix=".stats-v2",task=None):
    # TODO: recalculate missing stats if thresholds don't match the saved ones
    
    if thresholds is None:
        thresholds = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,
              0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]
    
    stats_all = []
    
    for prediction_file in prediction_files:
        stats_file = prediction_file[2] + stats_file_suffix
        print(prediction_file)
        if os.path.exists(stats_file):
            print('Loading previously saved results.')
            results,stats,results_pf = load_saved_results(stats_file)
        else:
            print('No saved results found. Starting evaluation - this may take some time...')
            if prediction_file[0] == "w2v2":
                results,stats,results_pf = evaluate_overlaps_w2v2(
                        prediction_file[2],thresholds=thresholds,ref_suffix=ref_suffix,
                        main_thresh=main_thresh,audio_id_list=audio_id_list,
                        duration=duration,shift=shift,task=task
                )
            elif prediction_file[0] == "CNN":
                results,stats,results_pf = evaluate_overlaps_CNN(
                    prediction_file[2],frameRate=prediction_file[3],offset=prediction_file[4],
                    thresholds=thresholds,audio_id_list=audio_id_list,task=task,ref_suffix=ref_suffix)
            else:
                print("Unrecognized results format. Skipping.")
                continue
            save_results(stats_file,results,stats,results_pf)
        stats_all.append({"stats":stats, "name": prediction_file[1]})
        if results is not None:
            print(tabulate(results, headers="firstrow",tablefmt="simple"))
            print("")
        if results_pf is not None:
            print(tabulate(results_pf, headers="firstrow",tablefmt="simple"))
            print("\n")

    return stats_all

def load_saved_results(stats_file):
    
    with open(stats_file, 'rb') as file:
        results = pickle.load(file)
        stats = pickle.load(file)
        results_pf = pickle.load(file)
    
    return results,stats,results_pf
    
    
def save_results(stats_file,results,stats,results_pf):
    
    with open(stats_file, 'wb') as file:
        pickle.dump(results, file)
        pickle.dump(stats, file)
        pickle.dump(results_pf, file)


def evaluate_overlaps_w2v2(overlaps_file,thresholds=[0.5],main_thresh = 0.5,
        frameRate=50,offset=0.02,audio_id_list=None,duration=None,
        shift=None,ref_suffix="_overlaps.txt",task=None):

    hypotheses = []
    
    for threshold in thresholds: 
        hypotheses.append(
            get_overlap_hypotheses_w2v2(
                overlaps_file,threshold=threshold,audio_id_list=audio_id_list,duration=duration,shift=shift
            )
        )
 
    if task[0:4] == "AMI_":
        results,stats,results_pf = evaluate_AMI(hypotheses,thresholds,main_thresh,task)
    else:
        results,stats,results_pf = evaluate_overlaps(hypotheses,thresholds,main_thresh,ref_suffix=ref_suffix)
    return results,stats,results_pf

def evaluate_overlaps_CNN(overlaps_dir,frameRate,offset,thresholds=[0.5],main_thresh = 0.5,
                          ref_suffix="_overlaps.txt",audio_id_list=None,task=None):
    # for CNN, frameRate and offset are required
    
    hypotheses = []
    
    for threshold in thresholds: 
        hypotheses.append(
            get_overlap_hypotheses_CNN(
                overlaps_dir,frameRate,offset,threshold=threshold,audio_id_list=audio_id_list
            )
        )
        
    if task[0:4] == "AMI_":
        results,stats,results_pf = evaluate_AMI(hypotheses,thresholds,main_thresh,task)
    else:
        results,stats,results_pf = evaluate_overlaps(hypotheses,thresholds,main_thresh,ref_suffix=ref_suffix)
    return results,stats,results_pf

In [None]:
# plot precision vs recall

def make_plots(stats_all,xlims=[0,1],ylims=[0,1]):
    
    plt.figure(figsize=(6, 6))

    for stat_dict in stats_all:
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        plt.plot(stats["recall"],stats["precision"], label=name)
    
    plt.xlim(xlims)
    plt.ylim(ylims)

    plt.xlabel("recall")
    plt.ylabel("precision")
    plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.show()

    plt.figure(figsize=(6, 6))

    for stat_dict in stats_all:
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        plt.plot(thresholds,stats["precision"], label=name + " - precision")
        plt.plot(thresholds,stats["recall"], label=name + " - recall")

    plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.show()


    for stat_dict in stats_all:
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        plt.plot(thresholds,stats["precision"], label=name + " - precision")
        plt.plot(thresholds,stats["recall"], label=name + " - recall")

        plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        plt.show()
    

In [None]:
# find the optimal threshold for each set of results

# threshold where purity and coverage (or some other 2 stats) are closest to each other
def get_closest_stats(stats_all, stat1 = "precision", stat2 = "recall", stat_main = "fscore"):
    
    best_indices = []
    best_crit = []
    
    for stat_dict in stats_all:
        
        best_idx = None
        best_diff = 9999999
        
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        
        nThresh = len(thresholds)
        for iThresh in range(nThresh): 
            diff = abs(stats[stat1][iThresh] - stats[stat2][iThresh])
            if diff < best_diff:
                best_diff = diff
                best_idx = iThresh
                
        best_indices.append(best_idx)
        best_crit.append(stats[stat_main][best_idx])
        print("%s:\n    %g (thresh %g)"%(name,stats[stat_main][best_idx],thresholds[best_idx]))
    
    plt.figure(figsize=(6, 6))  
    #plt.bar([*range(1,len(best_crit)+1)],best_crit) 
    plt.plot(best_crit)

# threshold where fscore (or some other stat) is highest
def get_highest_stats(stats_all, stat_main = "fscore"):
    
    best_indices = []
    best_crit = []
    
    for stat_dict in stats_all:
        
        best_idx = None
        best_val = -9999999
        
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        
        nThresh = len(thresholds)
        for iThresh in range(nThresh): 
            val = stats[stat_main][iThresh]
            if val > best_val:
                best_val = val
                best_idx = iThresh
                
        best_indices.append(best_idx)
        best_crit.append(stats[stat_main][best_idx])
        print("%s:\n    %g (thresh %g)"%(name,stats[stat_main][best_idx],thresholds[best_idx]))
    
    plt.figure(figsize=(6, 6))  
    #plt.bar([*range(1,len(best_crit)+1)],best_crit) 
    plt.plot(best_crit)

# threshold where error_rate (or some other stat) is lowest
def get_lowest_stats(stats_all, stat_main = "error_rate"):
    
    best_indices = []
    best_crit = []
    
    for stat_dict in stats_all:
        
        best_idx = None
        best_val = 9999999
        
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        
        nThresh = len(thresholds)
        for iThresh in range(nThresh): 
            val = stats[stat_main][iThresh]
            if val < best_val:
                best_val = val
                best_idx = iThresh
                
        best_indices.append(best_idx)
        best_crit.append(stats[stat_main][best_idx])
        print("%s:\n    %g (thresh %g)"%(name,stats[stat_main][best_idx],thresholds[best_idx]))
    
    plt.figure(figsize=(6, 6))  
    #plt.bar([*range(1,len(best_crit)+1)],best_crit) 
    plt.plot(best_crit)
    

# Evaluation starts here
pick one of the options and change the paths to your own, then run the cell

## a) Evaluate OSD using custom references

In [None]:
# set paths, thresholds, etc
DATASET = 'AMI'

REF_DIR = "/path/to/directory/with/overlap/references/"
ref_suffix = "_overlaps.txt" # the expected filename format is (audio_id + ref_suffix), e.g. "EN2002a_overlaps.txt"

thresholds = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,
              0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]

mainDir_AMI_multitask = "/storage/plzen4-ntis/projects/speaker_recog/AMI/Wav2vec2Transformer_multitask_OSD-VAD-SCD/"

task_id = 1 # our multitask model outputs predictions for 1. OSD, 2. VAD, 3. SCD; we want OSD here
overlap_files_dev = []
overlap_files_eval = []

for epoch in range(1,7): # noSigmoid + base + auto weights
    overlap_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI.task%d.txt"%(epoch,task_id)])
    overlap_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI-dev.task%d.txt"%(epoch,task_id)])
overlap_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI.task%d.txt"%(task_id)])
overlap_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI-dev.task%d.txt"%(task_id)])
  
# duration and step size of audio segments (to make sure they are stitched back together correctly)
duration = 20
shift = 10

main_thresh=0.5 # results for this threshold will be printed in greater detail
    
overlap_stats_all_dev = evaluate_all(
    overlap_files_dev,thresholds=thresholds,main_thresh=main_thresh,ref_suffix=ref_suffix,duration=duration,shift=shift
)
overlap_stats_all_eval = evaluate_all(
    overlap_files_eval,thresholds=thresholds,main_thresh=main_thresh,ref_suffix=ref_suffix,duration=duration,shift=shift
)
    

In [None]:
# get the best threshold for each set of results

# a) highest F1
print("dev set:")
get_highest_stats(overlap_stats_all_dev)

print("test set:")
get_highest_stats(overlap_stats_all_eval)

# b) most similar Cov and Pur
#print("dev set:")
#get_closest_stats(overlap_stats_all_dev)
#print("test set:")
#get_closest_stats(overlap_stats_all_eval)

# plots
# make_plots(overlap_stats_all_dev)
# make_plots(overlap_stats_all_eval)


## b) Evaluate VAD using custom references

In [None]:
# set paths, thresholds, etc
DATASET = 'AMI'

REF_DIR = "/path/to/directory/with/VAD/references/"
ref_suffix = ".vad" # the expected filename format is (audio_id + ref_suffix), e.g. "EN2002a.vad"

thresholds = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,
              0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]

mainDir_AMI_multitask = "/storage/plzen4-ntis/projects/speaker_recog/AMI/Wav2vec2Transformer_multitask_OSD-VAD-SCD/"

# paths to the model outputs
VAD_files_dev = []
VAD_files_eval = []
task_id = 2 # our multitask model outputs predictions for 1. OSD, 2. VAD, 3. SCD; we want VAD here
for epoch in range(1,7): # noSigmoid + base + auto weights
    VAD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI.task%d.txt"%(epoch,task_id)])
    VAD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI-dev.task%d.txt"%(epoch,task_id)])
VAD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI.task%d.txt"%(task_id)])
VAD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI-dev.task%d.txt"%(task_id)])
  
# duration and step size of audio segments (to make sure they are stitched back together correctly)
duration = 20
shift = 10

main_thresh=0.5 # results for this threshold will be printed in greater detail
    
VAD_stats_all_dev = evaluate_all(
    VAD_files_dev,thresholds=thresholds,main_thresh=main_thresh,ref_suffix=ref_suffix,duration=duration,shift=shift
)
VAD_stats_all_eval = evaluate_all(
    VAD_files_eval,thresholds=thresholds,main_thresh=main_thresh,ref_suffix=ref_suffix,duration=duration,shift=shift
)


In [None]:
# get the best threshold for each set of results

## a) highest F1
# print("dev set:")
# get_highest_stats(VAD_stats_all_dev)
# print("test set:")
# get_highest_stats(VAD_stats_all_eval)

## b) most similar Cov and Pur
#print("dev set:")
#get_closest_stats(VAD_stats_all_dev)
#print("test set:")
#get_closest_stats(VAD_stats_all_eval)

# c) lowest error:
print("dev set:")
get_lowest_stats(VAD_stats_all_dev)
print("test set:")
get_lowest_stats(VAD_stats_all_eval)

# plots
# make_plots(VAD_stats_all_dev)
# make_plots(VAD_stats_all_eval)

## c) Evaluate AMI OSD and VAD using pyannote.db.odessa.ami
(AMI corpus only)

enabled by passing `task="AMI_OSD_test"`, `"AMI_OSD_dev"`, `"AMI_VAD_test"`, or `"AMI_VAD_dev"` to `evaluate_all()`

In [None]:
# set paths, thresholds, etc

DATASET = 'AMI'

REF_DIR = None      # not used here
ref_suffix = None   # not used here

thresholds = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,
              0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]

mainDir_AMI_multitask = "/storage/plzen4-ntis/projects/speaker_recog/AMI/Wav2vec2Transformer_multitask_OSD-VAD-SCD/"

#----
# OSD
#----

# paths to the model outputs
task_id = 1 # our multitask model outputs predictions for 1. OSD, 2. VAD, 3. SCD; we want OSD here
overlap_files_eval = []
overlap_files_dev = []
for epoch in range(1,7): # noSigmoid + base + auto weights
    overlap_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI.task%d.txt"%(epoch,task_id)])
    overlap_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI-dev.task%d.txt"%(epoch,task_id)])
overlap_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI.task%d.txt"%(task_id)])
overlap_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI-dev.task%d.txt"%(task_id)])

# duration and step size of audio segments (to make sure they are stitched back together correctly)
duration = 20
shift = 10

main_thresh=0.15 # results for this threshold will be printed in greater detail

overlap_stats_all_dev_AMI = evaluate_all(
    overlap_files_dev, thresholds=thresholds,
    task="AMI_OSD_dev", # this is the important part
    stats_file_suffix='.dev.pyannote-refs.stats-v2',
    main_thresh=main_thresh,duration=duration,shift=shift
)

overlap_stats_all_eval_AMI = evaluate_all(
    overlap_files_eval, thresholds=thresholds,
    task="AMI_OSD_test", # this is the important part
    stats_file_suffix='.eval.pyannote-refs.stats-v2',
    main_thresh=main_thresh,duration=duration,shift=shift
)

#---
# VAD
#---

# paths to the model outputs
VAD_files_dev = []
VAD_files_eval = []
task_id = 2 # our multitask model outputs predictions for 1. OSD, 2. VAD, 3. SCD; we want VAD here
for epoch in range(1,7): # noSigmoid + base + auto weights
    VAD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI.task%d.txt"%(epoch,task_id)])
    VAD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI-dev.task%d.txt"%(epoch,task_id)])
VAD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI.task%d.txt"%(task_id)])
VAD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI-dev.task%d.txt"%(task_id)])


# duration and step size of audio segments (to make sure they are stitched back together correctly)
duration = 20
shift = 10

main_thresh=0.4 # results for this threshold will be printed in greater detail

VAD_stats_all_dev_AMI = evaluate_all(
    VAD_files_dev, thresholds=thresholds,
    task="AMI_VAD_dev", # this is the important part
    stats_file_suffix='.dev.pyannote-refs.stats-v2',
    main_thresh=main_thresh,duration=duration,shift=shift
)

VAD_stats_all_eval_AMI = evaluate_all(
    VAD_files_eval, thresholds=thresholds,
    task="AMI_VAD_test", # this is the important part
    stats_file_suffix='.eval.pyannote-refs.stats-v2',
    main_thresh=main_thresh,duration=duration,shift=shift
)


In [None]:
print("overlaps:")

get_highest_stats(overlap_stats_all_dev_AMI)
get_highest_stats(overlap_stats_all_eval_AMI)

print("\nVAD:")

get_lowest_stats(VAD_stats_all_dev_AMI)
get_lowest_stats(VAD_stats_all_eval_AMI)