# Evaluate Speaker Change Detection (SCD)

evaluate SCD predictions obtained via wav2vec2_audioFrameClassification.py or wav2vec2_audioFrameClassification_multitask.py

evaluation is done via pyannote.metrics (https://pyannote.github.io/)

evaluation on AMI corpus supports pyannote.db.odessa.ami (https://github.com/pyannote/pyannote-db-odessa-ami) - 
use task="AMI_SCD_dev" (ODESSA/AMI development set) or task="AMI_SCD_test" (ODESSA/AMI test set)

other datasets require reference annotations in RTTM format (one RTTM file per audio)

### How to use this notebook:
- run all cells until you reach "Evaluation starts here"
- pick one of the options for eval (custom references or pyannote.db)
- change the paths and settings there and run the cells

### Changelog:
2023-03-09:
  - fixed a mistake in the way audio chunks are concatenated in get_SCD_predictions_w2v2()  
    (this affected the reported results, but the difference was minor - around the 4th significant digit in most cases)
  - fixed a few edge cases that previously threw errors


In [None]:
from pyannote.core import Annotation, Timeline, Segment
from pyannote.metrics.segmentation import (
    SegmentationCoverage,
    SegmentationPurity,
    SegmentationPurityCoverageFMeasure,
    SegmentationPrecision,
    SegmentationRecall
)

import os
import re
import numpy as np

from matplotlib import pyplot as plt

from tabulate import tabulate

In [None]:
# loads SCD references from RTTM format (for corpora other than AMI, or if using custom RTTM files)

def get_SCD_reference(ref_file,frameRate=100,offset=0):
    
    reference = Annotation()
    
    maxTime = 0

    with open(ref_file,'r') as file:
        for line in file:
            lineType,uri,_,startTime,dur,_,_,spk,_ = line.split(" ",8)
            startTime = float(startTime)
            endTime = startTime + float(dur)
            if endTime > maxTime:
                maxTime = endTime
            if lineType == "SPEAKER":
                reference[Segment(startTime, endTime)] = spk

    uem = Segment(0,endTime)
     
    return reference,uem

In [None]:
# converts filenames of audio segments into unique identifiers of the original audio files

def get_audio_id(filename):
    # obtains the ID of an audio file by repeatedly removing specific substrings from the end
    #  (and only the end, to avoid false matches)
    #  (e.g. "IS1009c.Mix-Headset_t700-710.wav" -> "IS1009c")
    basename = os.path.basename(filename)
    
    suffixesToRemove = ['.wav','.mp4','.dereverb','.denoise','.OMEETING',
                        '.OHALLWAY','.OOFFICE','.meeting','.booth','.office',
                        '_OMEETING','_OHALLWAY','_OOFFICE','_meeting','_booth','_office',
                        '.Mix-Headset'] 
    
    regexToRemove = ['(\.|_)\d+kHz',
                    '_t[0-9]+(\.[0-9]+)?\-[0-9]+(\.[0-9]+)?',
                    '_\d{5}']
    
    found = True
    while found:
        found = False
        for suffix in suffixesToRemove:
            if basename.endswith(suffix):
                basename = basename[:-(len(suffix))]
                found = True
        for rgx in regexToRemove:
            match = re.search(rgx+"$",basename) # +"$" ensures we only find a match at the *end* 
            if match is not None:
                basename = basename[:match.start()]
                found = True
    
    return basename

In [None]:
# get SCD hypotheses for each threshold, in pyannote's preferred format

from scipy.signal import find_peaks

def get_annotation_from_labels(labels,frameRate,offset=0.02,threshold=0.5):
    num_values = len(labels)
    
    annotation = Annotation()
        
    # find peaks
    peak_indices,_ = find_peaks(labels,height=threshold)
    #print("Peaks:")
    #print(peak_indices)
    
    changes = []
    for peak_idx in peak_indices:
        changes.append(peak_idx / frameRate + offset)
    changes.append(num_values / frameRate + offset)
    
    #print("Changes:")
    #print(changes)
    
    startTime = 0
    for endTime in changes:
        annotation[Segment(startTime, endTime)] = 'segment'
        startTime = endTime
        
    return annotation

def get_SCD_hypotheses_w2v2(labels_all,frameRate=50,threshold=0.5,offset=0.02):
    # turns predicted labels into segment hypotheses in a format usable by pyannote
    
    hypotheses = {}
    for audio_id in labels_all:
        hypotheses[audio_id] = get_annotation_from_labels(labels_all[audio_id],
                frameRate=frameRate,offset=offset,threshold=threshold)
    
    return hypotheses

def get_SCD_predictions_w2v2(test_file,frameRate=50,audio_id_list=None,duration=None,shift=None):
    # reads predicted labels from w2v2 output, stitches them back together
    
    predictions_all = []
    durations = {}
    
    with open(test_file,'r') as file:
        for line in file:
            line = line.strip()
            if len(line) == 0:
                continue
            path,values_str = line.split(",")
            audio_filename = os.path.basename(path)
            
            audio_id = get_audio_id(audio_filename)
            if (audio_id_list is not None) and not (audio_id in audio_id_list):
                continue # skip files that are not on the list
            
            match = re.search('_t[0-9]+(\.[0-9]+)?\-[0-9]+(\.[0-9]+)?',audio_filename)
            if match is not None:
                str = match.group()
                startTime,endTime = str[2:].split('-')
                startTime = float(startTime)
                endTime = float(endTime)
            else:
                # TODO: do not require start and end times if there's only one segment per audio file
                raise ValueError('Could not determine the start time and end time of audio chunks')
                #startTime = None
                #endTime = None

            idx1 = values_str.find('[')
            idx2 = values_str.find(']')
            
            values = values_str[idx1+1:idx2].split(" ")
            values = list(map(float, values))
            
            predictions_all.append({
                "path":path,
                "values": values,
                "id": audio_id,
                "startTime": startTime,
                "endTime": endTime
            })
            
            if (audio_id not in durations) or (endTime < durations[audio_id]):
                durations[audio_id] = endTime
            
    
    labels_all = {}
    for prediction in predictions_all:
        audio_id = prediction["id"]
        
        startTime = prediction["startTime"]
        endTime = prediction["endTime"]
                    
        if audio_id not in labels_all:
            if startTime is None:
                # Note: "+ [0]" adds the one missing label (w2v outputs only 999 labels for 20s audio)
                labels_all[audio_id] = prediction["values"] + [0]
            else:
                labels_all[audio_id] = [0] * round(durations[audio_id] * frameRate)
        
        if endTime is not None:

            startIdx = round(startTime * frameRate)
            endIdx = round(endTime * frameRate)

            if duration is not None and shift is not None:
                overlap = duration - shift # this "overlap" refers to overlap between audio segments
                halfoverlap_frames = round(shift * frameRate / 2)
                startIdx = startIdx + halfoverlap_frames

                if startIdx >= endIdx: # this can happen in the very last segment
                    #print("start time %f, end time %f"%(startTime,endTime))
                    continue # just skip this file completely

                newLabels = prediction["values"][halfoverlap_frames:]

                if endIdx - startIdx != len(newLabels):
                    # w2v labels are one value shorter than time*frameRate => off by one will probably happen every time...
                    #  ... and the last segment of each conversation will obviously have a much bigger difference
                    endIdx = startIdx + len(newLabels)
                labels_all[audio_id][startIdx:endIdx] = newLabels
            else:
                if len(prediction["values"]) != endIdx - startIdx:
                    endIdx = startIdx + len(prediction["values"])
                labels_all[audio_id][startIdx:endIdx] = prediction["values"]
        else:
            # currently cannot happen - either both are None or neither is
            startIdx = round(startTime * frameRate)
            labels_all[audio_id][startIdx:] = prediction["values"]

            
    return labels_all
            

In [None]:
# calculate coverage, purity, etc for all files and thresholds from a specific set of hypotheses
#  (version with custom RTTM references)
def evaluate_SCD(hypotheses,thresholds,main_thresh,pr_tolerance=None,pc_tolerance=None):
    
    nThresh = len(thresholds)
    
    if pr_tolerance is None:
        segmentationPrecision = SegmentationPrecision()
        segmentationRecall = SegmentationRecall()
    else:
        segmentationPrecision = SegmentationPrecision(tolerance=pr_tolerance)
        segmentationRecall = SegmentationRecall(tolerance=pr_tolerance)
        
    if pc_tolerance is None:
        segmentationCoverage = SegmentationCoverage()
        segmentationPurity = SegmentationPurity()
        segmentationFMeasure = SegmentationPurityCoverageFMeasure()
    else:
        segmentationCoverage = SegmentationCoverage(tolerance=pc_tolerance)
        segmentationPurity = SegmentationPurity(tolerance=pc_tolerance)
        segmentationFMeasure = SegmentationPurityCoverageFMeasure(tolerance=pc_tolerance)
    
    results = [["threshold","coverage","purity","precision","recall","fscore"]]
    results_pf = [["audio_id","coverage","purity","precision","recall","fscore"]]
    reference_all = None
    
    
    cov_dur_total = [0] * nThresh
    pur_dur_total = [0] * nThresh
    all_dur_total = [0] * nThresh
    
    fscore_avg = [0] * nThresh
    
    precision_avg = [0] * nThresh
    prec_matches_total = [0] * nThresh
    prec_bounds_total = [0] * nThresh
    
    recall_avg = [0] * nThresh
    rec_matches_total = [0] * nThresh
    rec_bounds_total = [0] * nThresh
    
    purity_avg = [0] * nThresh
    coverage_avg = [0] * nThresh
    

    precision_total = [0] * nThresh
    recall_total = [0] * nThresh
    fscore_total = [0] * nThresh
    coverage_total = [0] * nThresh
    purity_total = [0] * nThresh
    
    nfiles = 0

    for audio_id in hypotheses[0].keys():
        #print("\n")
        print(audio_id)
        
        ref_file = os.path.join(REF_DIR, audio_id + RTTM_SUFFIX)
        if not os.path.exists(ref_file):
            print('File ''%s'' not found. Skipping.'%ref_file)
            continue
        
        nfiles = nfiles + 1
        
        reference,uem = get_SCD_reference(ref_file,frameRate=100,offset=0)
        
        for iThresh in range(nThresh): 

            hypothesis = hypotheses[iThresh][audio_id]
            
            # Coverage
            coverage_details = segmentationCoverage(reference, hypothesis, detailed=True, uem=uem)
            coverage = coverage_details["segmentation coverage"]
            cov_dur_total[iThresh] = cov_dur_total[iThresh] + coverage_details["intersection duration"]
            all_dur_total[iThresh] = all_dur_total[iThresh] + coverage_details["total duration"]
            
            coverage_avg[iThresh] = coverage_avg[iThresh] + coverage
            
            # Purity
            purity_details = segmentationPurity(reference, hypothesis, detailed=True, uem=uem)
            purity = purity_details["segmentation purity"]
            pur_dur_total[iThresh] = pur_dur_total[iThresh] + purity_details["intersection duration"]
            
            purity_avg[iThresh] = purity_avg[iThresh] + purity
            
            # F-score
            fscore_details = segmentationFMeasure(reference, hypothesis, detailed=True, uem=uem)
            fscore = fscore_details["segmentation F[purity|coverage]"]
            fscore_avg[iThresh] = fscore_avg[iThresh] + fscore


            # Precision
            precision_details = segmentationPrecision(reference,hypothesis,detailed=True,uem=uem)
            precision = precision_details["segmentation precision"]
            prec_matches_total[iThresh] = prec_matches_total[iThresh] + precision_details["number of matches"]
            prec_bounds_total[iThresh] = prec_bounds_total[iThresh] + precision_details["number of boundaries"]
            precision_avg[iThresh] = precision_avg[iThresh] + precision

            # Recall

            recall_details = segmentationRecall(reference,hypothesis,detailed=True,uem=uem)
            recall = recall_details["segmentation recall"]
            rec_matches_total[iThresh] = rec_matches_total[iThresh] + recall_details["number of matches"]
            rec_bounds_total[iThresh] = rec_bounds_total[iThresh] + recall_details["number of boundaries"]
            recall_avg[iThresh] = recall_avg[iThresh] + recall

            if thresholds[iThresh] == main_thresh:
                results_pf.append([audio_id,coverage,purity,precision,recall,fscore])

    
    for iThresh in range(nThresh): 
        
        fscore_avg[iThresh] = fscore_avg[iThresh] / nfiles
        precision_avg[iThresh] = precision_avg[iThresh] / nfiles
        recall_avg[iThresh] = recall_avg[iThresh] / nfiles
        purity_avg[iThresh] = purity_avg[iThresh] / nfiles
        coverage_avg[iThresh] = coverage_avg[iThresh] / nfiles
        
        if prec_bounds_total[iThresh] == 0:
            precision_total[iThresh] = 1
        else:
            precision_total[iThresh] = prec_matches_total[iThresh] / prec_bounds_total[iThresh]
        
        if rec_bounds_total[iThresh] == 0:
            recall_total[iThresh] = 1
        else:
            recall_total[iThresh] = rec_matches_total[iThresh] / rec_bounds_total[iThresh]
        
        if all_dur_total[iThresh] == 0:
            coverage_total[iThresh] = 1
            purity_total[iThresh] = 1
        else: 
            coverage_total[iThresh] = cov_dur_total[iThresh] / all_dur_total[iThresh]
            purity_total[iThresh] = pur_dur_total[iThresh] / all_dur_total[iThresh]
            
        if (coverage_total[iThresh] + purity_total[iThresh]) == 0:
            fscore_total[iThresh] = 0
        else:
            fscore_total[iThresh] = 2 * (coverage_total[iThresh] * purity_total[iThresh]) / (coverage_total[iThresh] + purity_total[iThresh])
        
        if thresholds[iThresh] == main_thresh:
            results_pf.append(["AVG",coverage_avg[iThresh],purity_avg[iThresh],precision_avg[iThresh],recall_avg[iThresh],fscore_avg[iThresh]])
            results_pf.append(["TOTAL",coverage_total[iThresh],purity_total[iThresh],precision_total[iThresh],recall_total[iThresh],fscore_total[iThresh]])

        results.append([thresholds[iThresh],coverage_total[iThresh],purity_total[iThresh],precision_total[iThresh],recall_total[iThresh],fscore_total[iThresh]])
    
    stats = {}
    stats["thresholds"] = thresholds
    stats["coverage"] = coverage_total
    stats["purity"] = purity_total
    stats["precision"] = precision_total
    stats["recall"] = recall_total
    stats["fscore"] = fscore_total
    
    return results, stats, results_pf
    

In [None]:
# calculate coverage, purity, etc for all files and thresholds from a specific set of hypotheses
#  (version using pyannote.db.odessa.ami, AMI corpus only)

from pyannote.database import get_protocol, FileFinder

def evaluate_SCD_AMI(hypotheses,thresholds,main_thresh,task):
    
    print('Evaluation with pyannote refs')
    
    #preprocessors = {'audio': FileFinder()}
    protocol = get_protocol('AMI.SpeakerDiarization.MixHeadset')
        
    segmentationCoverage = SegmentationCoverage()
    segmentationPurity = SegmentationPurity()
    segmentationFMeasure = SegmentationPurityCoverageFMeasure()
    
    coverage_total = []
    purity_total = []
    fscore_total = []
    
    results = [["threshold","coverage","purity","fscore"]]
    
    nThresh = len(thresholds)
    for iThresh in range(nThresh): 
        if thresholds[iThresh] == main_thresh: 
            display = True
            print("Results for threshold %g:"%main_thresh)
        else:
            display = False

        if task[-4:] == "_dev":
            AMI_set = protocol.development()
        else:
            AMI_set = protocol.test()
            
        # iterate over each file of the test/dev set
        for test_file in AMI_set:
            
            audio_id = get_audio_id(test_file["uri"])
            #print(audio_id)
            
            hypothesis = hypotheses[iThresh][audio_id]
            
            # evaluate hypothesis
            reference = test_file['annotation']
            
            uem = test_file['annotated']
            segmentationCoverage(reference, hypothesis, uem=uem)
            segmentationPurity(reference, hypothesis, uem=uem)
            segmentationFMeasure(reference, hypothesis, uem=uem)
            
        cov = segmentationCoverage.report(display=display)
        pur = segmentationPurity.report(display=display)
        F = segmentationFMeasure.report(display=display)
        
        cov = cov.values.tolist()
        coverage_total.append(cov[-1][0])
        
        pur = pur.values.tolist()
        purity_total.append(pur[-1][0])
        
        F = F.values.tolist()
        fscore_total.append(F[-1][0])
        
        results.append(
            [thresholds[iThresh],coverage_total[iThresh],purity_total[iThresh],fscore_total[iThresh]]
        )
        
        segmentationCoverage.reset()
        segmentationPurity.reset()
        segmentationFMeasure.reset()
        
    stats = {}
    stats["thresholds"] = thresholds
    stats["coverage"] = coverage_total
    stats["purity"] = purity_total
    stats["fscore"] = fscore_total
        
    return results,stats,None


In [None]:
# evaluate multiple sets of predictions
#   (stats are saved so that they don't need to be recalculated every time)
# TODO: check if thresholds match the saved ones, recalculate (only) the missing stats if not

import pickle

def evaluate_all(prediction_files, thresholds=None,
                 main_thresh=0.25,audio_id_list=None,duration=20,shift=10,task=None,
                 stats_file_suffix='.stats'):
    # TODO: recalculate missing stats if thresholds don't match the saved ones
    
    if thresholds is None:
        thresholds = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,
              0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]
    
    stats_all = []
    
    for prediction_file in prediction_files:
        stats_file = prediction_file[2] + stats_file_suffix
        print(prediction_file)
        if os.path.exists(stats_file):
            print('Loading previously saved results.')
            results,stats,results_pf = load_saved_results(stats_file)
        else:
            print('No saved results found. Starting evaluation - this may take some time...')
            if prediction_file[0] == "w2v2":
                results,stats,results_pf = evaluate_SCD_w2v2(prediction_file[2],thresholds=thresholds,
                                                    main_thresh=main_thresh,audio_id_list=audio_id_list,
                                                    duration=duration,shift=shift,task=task)
            else:
                print("Unrecognized results format ''%s''. Skipping."%prediction_file[0])
                continue
            save_results(stats_file,results,stats,results_pf)
        stats_all.append({"stats":stats, "name": prediction_file[1]})
        if results is not None:
            print(tabulate(results, headers="firstrow",tablefmt="simple"))
            print("")
        if results_pf is not None:
            print(tabulate(results_pf, headers="firstrow",tablefmt="simple"))
            print("\n")

    return stats_all

def load_saved_results(stats_file):
    
    with open(stats_file, 'rb') as file:
        results = pickle.load(file)
        stats = pickle.load(file)
        results_pf = pickle.load(file)
    
    return results,stats,results_pf
    
    
def save_results(stats_file,results,stats,results_pf):
    
    with open(stats_file, 'wb') as file:
        pickle.dump(results, file)
        pickle.dump(stats, file)
        pickle.dump(results_pf, file)
        
def evaluate_SCD_w2v2(SCD_file,thresholds=[0.5],main_thresh = 0.5,
        frameRate=50,offset=0.02,audio_id_list=None,duration=None,
        shift=None,pr_tolerance=None,pc_tolerance=None,task=None):
    
    hypotheses = []
    
    labels_all = get_SCD_predictions_w2v2(
        SCD_file,frameRate=frameRate,audio_id_list=audio_id_list,duration=duration,shift=shift
    )
    
    for threshold in thresholds: 
        hypotheses.append(get_SCD_hypotheses_w2v2(labels_all,frameRate=frameRate,threshold=threshold,offset=offset))

    if task is not None and task[0:4] == "AMI_":
        results,stats,results_pf = evaluate_SCD_AMI(hypotheses,thresholds,main_thresh,task)
    else:
        results,stats,results_pf = evaluate_SCD(hypotheses,thresholds,main_thresh,pr_tolerance=pr_tolerance,pc_tolerance=pc_tolerance)
    return results,stats,results_pf


In [None]:
# plot coverage vs purity

def make_plots(stats_all,xlims=[0,1],ylims=[0,1]):
    
    plt.figure(figsize=(6, 6))

    for stat_dict in stats_all:
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        plt.plot(stats["coverage"],stats["purity"], label=name)
    
    plt.xlim(xlims)
    plt.ylim(ylims)

    plt.xlabel("coverage")
    plt.ylabel("purity")
    plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.show()

    plt.figure(figsize=(6, 6))

    for stat_dict in stats_all:
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        plt.plot(thresholds,stats["purity"], label=name + " - purity")
        plt.plot(thresholds,stats["coverage"], label=name + " - coverage")

    plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.show()


    for stat_dict in stats_all:
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        plt.plot(thresholds,stats["purity"], label=name + " - purity")
        plt.plot(thresholds,stats["coverage"], label=name + " - coverage")

        plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        plt.show()

In [None]:
# find the optimal threshold for each set of results

# threshold where purity and coverage (or some other 2 stats) are closest to each other
def get_closest_stats(stats_all, stat1 = "purity", stat2 = "coverage", stat_main = "fscore"):
    
    best_indices = []
    best_crit = []
    
    for stat_dict in stats_all:
        
        best_idx = None
        best_diff = 9999999
        
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        
        nThresh = len(thresholds)
        for iThresh in range(nThresh): 
            diff = abs(stats[stat1][iThresh] - stats[stat2][iThresh])
            if diff < best_diff:
                best_diff = diff
                best_idx = iThresh
                
        best_indices.append(best_idx)
        best_crit.append(stats[stat_main][best_idx])
        print("%s:\n    %g (thresh %g)"%(name,stats[stat_main][best_idx],thresholds[best_idx]))
        
    plt.figure(figsize=(6, 6))    
    #plt.bar([*range(1,len(best_crit)+1)],best_crit) 
    plt.plot(best_crit)
    
# threshold where fscore (or some other stat) is highest
def get_highest_stats(stats_all, stat_main = "fscore"):
    
    best_indices = []
    best_crit = []
    
    for stat_dict in stats_all:
        
        best_idx = None
        best_val = -9999999
        
        stats = stat_dict["stats"]
        name = stat_dict["name"]
        thresholds = stats["thresholds"]
        
        nThresh = len(thresholds)
        for iThresh in range(nThresh): 
            val = stats[stat_main][iThresh]
            if val > best_val:
                best_val = val
                best_idx = iThresh
                
        best_indices.append(best_idx)
        best_crit.append(stats[stat_main][best_idx])
        print("%s:\n    %g (thresh %g)"%(name,stats[stat_main][best_idx],thresholds[best_idx]))
    
    plt.figure(figsize=(6, 6))  
    #plt.bar([*range(1,len(best_crit)+1)],best_crit) 
    plt.plot(best_crit)
    

# Evaluation starts here
pick one of the options and change the paths to your own, then run the cell

## a) Evaluate SCD using custom RTTM references

In [None]:
# set paths, thresholds, etc
DATASET = 'AMI'

REF_DIR = "/path/to/directory/with/RTTM/files/"
RTTM_SUFFIX = ".rttm" # the expected filename format is (audio_id + RTTM_SUFFIX), e.g. "EN2002a.rttm"

thresholds = [0,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]

audio_id_list = None # use everything in the model outputs

# audio_id_list = ["EN2002a","EN2002b","EN2002c","EN2002d",
#                  "ES2004a","ES2004b","ES2004c","ES2004d",
#                  "ES2014a","ES2014b","ES2014c","ES2014d",
#                  "IS1009a","IS1009b","IS1009c","IS1009d",
#                  "TS3003a","TS3003b","TS3003c","TS3003d",
#                  "TS3007a","TS3007b","TS3007c","TS3007d"]

mainDir_AMI_multitask = "/storage/plzen4-ntis/projects/speaker_recog/AMI/Wav2vec2Transformer_multitask_OSD-VAD-SCD/"

# paths to the model outputs
task_id = 3 # our multitask model outputs predictions for 1. OSD, 2. VAD, 3. SCD; we want SCD here
SCD_files_eval = []
SCD_files_dev = []
for epoch in range(1,7): # noSigmoid + base + auto weights
    SCD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI.task%d.txt"%(epoch,task_id)])
    SCD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI-dev.task%d.txt"%(epoch,task_id)])
SCD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI.task%d.txt"%(task_id)])
SCD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI-dev.task%d.txt"%(task_id)])

# duration and step size of audio segments (to make sure they are stitched back together correctly)
duration = 20
shift = 10

main_thresh=0.25 # results for this threshold will be printed in greater detail

# run evaluation
stats_all_dev = evaluate_all(SCD_files_dev, thresholds=thresholds,main_thresh=main_thresh,audio_id_list=None,duration=duration,shift=shift)
stats_all_eval = evaluate_all(SCD_files_eval, thresholds=thresholds,main_thresh=main_thresh,audio_id_list=None,duration=duration,shift=shift)


In [None]:
# get the best threshold for each set of results

# a) highest F1
print("dev set:")
get_highest_stats(stats_all_dev)

print("test set:")
get_highest_stats(stats_all_eval)

# b) most similar Cov and Pur
#print("dev set:")
#get_closest_stats(stats_all_dev)
#print("test set:")
#get_closest_stats(stats_all_eval)

# plots
# make_plots(stats_all_dev)
# make_plots(stats_all_eval)

## b) Evaluate AMI SCD using pyannote.db.odessa.ami
(AMI corpus only)

enabled by passing `task="AMI_SCD_test"` or `task="AMI_SCD_dev"` to `evaluate_all()`

In [None]:
# set paths, thresholds, etc

DATASET = 'AMI'

REF_DIR = None      # not used here
ref_suffix = None   # not used here

thresholds = [0,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1]

mainDir_AMI_multitask = "/storage/plzen4-ntis/projects/speaker_recog/AMI/Wav2vec2Transformer_multitask_OSD-VAD-SCD/"

# paths to the model outputs
task_id = 3 # our multitask model outputs predictions for 1. OSD, 2. VAD, 3. SCD; we want SCD here
SCD_files_eval = []
SCD_files_dev = []
for epoch in range(1,7): # noSigmoid + base + auto weights
    SCD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI.task%d.txt"%(epoch,task_id)])
    SCD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch %d"%epoch,mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/epoch%d/output_AMI-dev.task%d.txt"%(epoch,task_id)])
SCD_files_eval.append(["w2v2","AMI multi noSigmoid autoWeights - eval epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI.task%d.txt"%(task_id)])
SCD_files_dev.append(["w2v2","AMI multi noSigmoid autoWeights - dev epoch 7",mainDir_AMI_multitask + "AMITrain_fuzzy0.2_22-09-22_20s10s_7epochs_noSigmoid_yesWeights_seed1234/output_AMI-dev.task%d.txt"%(task_id)])

# duration and step size of audio segments (to make sure they are stitched back together correctly)
duration = 20
shift = 10

main_thresh=0.25 # results for this threshold will be printed in greater detail

# run evaluation
 
stats_all_eval_AMI = evaluate_all(
    SCD_files_eval, thresholds=thresholds,
    task="AMI_SCD_test", # this is the important part
    stats_file_suffix='.pyannote-refs.thresh0.35.stats',
    main_thresh=0.35,duration=duration,shift=shift
)

stats_all_dev_AMI = evaluate_all(
    SCD_files_dev, thresholds=thresholds,
    task="AMI_SCD_dev", # this is the important part
    stats_file_suffix='.pyannote-refs.thresh0.35.stats',
    main_thresh=0.35,duration=duration,shift=shift
)


In [None]:
# get the best threshold for each set of results

# a) highest F1
print("dev set:")
get_highest_stats(stats_all_dev_AMI)

print("test set:")
get_highest_stats(stats_all_eval_AMI)

# b) most similar Cov and Pur
#print("dev set:")
#get_closest_stats(stats_all_dev)
#print("test set:")
#get_closest_stats(stats_all_eval)

# plots
# make_plots(stats_all_dev)
# make_plots(stats_all_eval)