In [1]:
from pyannote.core import Annotation, Segment
import pandas as pd
import pickle
from pydub import AudioSegment

In [2]:
from dotenv import load_dotenv
import os, subprocess, sys, time
from pyannote.core import Segment, Annotation
from pyannote.core.notebook import Notebook
import matplotlib.pyplot as plt
load_dotenv()
import pickle, json
import torch
from pyannote.audio import Pipeline
pretrained_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0",
                                               use_auth_token=os.getenv('HUGGINGFACE_TOKEN'))
os.environ["CUDA_VISIBLE_DEVICES"]="1"
pretrained_pipeline.to(torch.device("cuda"))

KeyboardInterrupt: 

In [None]:
ground_truth_rttm_file = '../../../AVA-AVD/dataset/rttms/2qQs3Y9OJX0_c_01.rttm'
offset=900.488
encoding_df = pd.read_pickle("../../output/video_temp/2qQs3Y9OJX0_c_01/pywork/new_encoding_df.pckl")
audio_path = "../../output/video_temp/2qQs3Y9OJX0_c_01/pywav/audio.wav"

In [None]:
from pyannote.metrics.diarization import DiarizationErrorRate
metric = DiarizationErrorRate()

In [None]:
def convert_rttm_to_diarization(rttm_file, offset):
    
    # Read RTTM file into pandas DataFrame
    rttm_df = pd.read_csv(rttm_file, sep=' ', header=None,
                      names=['temp', 'file_name', 'channel', 'start', 'duration', 'NA_1', 'NA_2', 'speaker_label', 'NA_3', 'NA_4'])
    
    rttm_df.sort_values(by="start", inplace=True)

    diarize_dict = {}

    # Iterate over RTTM rows and add segments to Pyannote annotation
    for _, row in rttm_df.iterrows():
        start_time = round(row['start'] - offset, 2)
        end_time = round(start_time + row['duration'], 2)
        label = row['speaker_label']
        if label not in diarize_dict.keys():
            diarize_dict[label] = [(start_time, end_time)]
        else:
            diarize_dict[label].append((start_time, end_time))

    return diarize_dict

def convert_diarization_output_to_pyannote(diarize_output):
    annotation = Annotation()
    
    for speaker, timelines in diarize_output.items():
        for timeline_start, timeline_end in timelines:
            annotation[Segment(timeline_start, timeline_end)] = speaker

    return annotation

def convert_pyannote_to_diarization(pyannote_output):
    
    diarize_dict = {}
    for duration,_, speaker_key in pyannote_output.itertracks(yield_label=True):
        start_time = round(duration.start, 2)
        end_time = round(duration.end, 2)
        if speaker_key in diarize_dict.keys():
            diarize_dict[speaker_key].append((start_time,end_time))
        else:
            diarize_dict[speaker_key] = [(start_time,end_time)]
            
    return diarize_dict

In [None]:
gt_1j20 = convert_rttm_to_diarization(ground_truth_rttm_file, offset)
pyannote_gt = convert_diarization_output_to_pyannote(gt_1j20)

In [None]:
gt_1j20

In [None]:
pyannote_gt

In [None]:
def convert_to_ranges(lst, frame_rate):
    ranges = []
    start = lst[0]

    threshold = round(frame_rate / 10) * 3

    for i in range(1, len(lst)):
        if lst[i] - lst[i - 1] > threshold:
            if lst[i - 1] - start > threshold:
                ranges.append(
                    (round(start / frame_rate, 2), round(lst[i - 1] / frame_rate, 2))
                )
            start = lst[i]

    # Add the last range
    if lst[-1] - start > threshold:
        ranges.append((round(start / frame_rate, 2), round(lst[-1] / frame_rate, 2)))

    return ranges

def get_final_tracks(df, frameRate):
    final_tracks = {}
    df = df.sort_values(by=["Final_Cluster", "Frame"])
    for idx in df["Final_Cluster"].unique():
        speaker_key = "SPEAKER_{:02d}".format(idx)
        final_tracks[speaker_key] = df[df["Final_Cluster"] == idx]["Frame"].to_list()

    for key in final_tracks.keys():
        final_tracks[key] = convert_to_ranges(
            final_tracks[key], frameRate
        )

    return final_tracks

In [None]:
vdo_1j20 = get_final_tracks(encoding_df, 25)

In [None]:
vdo_1j20

In [None]:
pyannote_vdo = convert_diarization_output_to_pyannote(vdo_1j20)
pyannote_vdo

In [None]:
metric(pyannote_gt, pyannote_vdo)

In [None]:
def perform_pyannote_diarization(pretrained_model, audio_path, min_cluster_size=12):
    
    pretrained_model.instantiate({
        "clustering" : {
            "min_cluster_size": min_cluster_size
        }
    })
    pretrained_model.parameters(instantiated=True)
    diarization = pretrained_model(audio_path)
    
    return diarization

In [None]:
pyannote_ad03 = perform_pyannote_diarization(pretrained_pipeline, audio_path, min_cluster_size=3)
pyannote_ad06 = perform_pyannote_diarization(pretrained_pipeline, audio_path, min_cluster_size=6)
pyannote_ad09 = perform_pyannote_diarization(pretrained_pipeline, audio_path, min_cluster_size=9)
pyannote_ad12 = perform_pyannote_diarization(pretrained_pipeline, audio_path, min_cluster_size=12)
ad03_1j20 = convert_pyannote_to_diarization(pyannote_ad03)
ad06_1j20 = convert_pyannote_to_diarization(pyannote_ad06)
ad09_1j20 = convert_pyannote_to_diarization(pyannote_ad09)
ad12_1j20 = convert_pyannote_to_diarization(pyannote_ad12)
pyannote_ad03 = convert_diarization_output_to_pyannote(ad03_1j20)
pyannote_ad06 = convert_diarization_output_to_pyannote(ad06_1j20)
pyannote_ad09 = convert_diarization_output_to_pyannote(ad09_1j20)
pyannote_ad12 = convert_diarization_output_to_pyannote(ad12_1j20)

In [None]:
pyannote_ad03

In [None]:
ad03_1j20

In [None]:
class PyannoteAudioSegment:
    def __init__(
        self,
        segment_idx,
        audio_segment_start,
        audio_segment_end,
        ad_03_speaker,
        ad_06_speaker,
        ad_09_speaker,
        ad_12_speaker,
        has_overlap,
        vd_speaker=None,
    ):
        self.segment_idx = segment_idx
        self.audio_segment_start = audio_segment_start
        self.audio_segment_end = audio_segment_end
        self.vd_speaker = vd_speaker
        self.ad_03_speaker = ad_03_speaker
        self.ad_06_speaker = ad_06_speaker
        self.ad_09_speaker = ad_09_speaker
        self.ad_12_speaker = ad_12_speaker
        self.has_overlap = has_overlap

    def get_speaker(self, group_id):
        if group_id == "ad_03":
            return self.ad_03_speaker
        elif group_id == "ad_06":
            return self.ad_06_speaker
        elif group_id == "ad_09":
            return self.ad_09_speaker
        elif group_id == "ad_12":
            return self.ad_12_speaker
        elif group_id == "vd":
            return self.vd_speaker
        else:
            print(f"Invalid group id - {group_id}")
            return None

    def __lt__(self, other):
        return self.audio_segment_start < other.audio_segment_start

    def __gt__(self, other):
        return self.audio_segment_start > other.audio_segment_start

    def __eq__(self, other):
        return self.audio_segment_start == other.audio_segment_start

    def __le__(self, other):
        return self.audio_segment_start <= other.audio_segment_start

    def __ge__(self, other):
        return self.audio_segment_start >= other.audio_segment_start

    def __ne__(self, other):
        return self.audio_segment_start != other.audio_segment_start

    def __repr__(self) -> str:
        return f"PyannoteAudioSegment({self.audio_segment_start}, {self.audio_segment_end}, {self.ad_03_speaker}, {self.ad_06_speaker}, {self.ad_09_speaker}, {self.ad_12_speaker}, {self.vd_speaker})"


class MappingClass:
    def __init__(self):
        self.mapping_dict = {}

    def add_mapping(self, speaker_key_id, speaker_value_id, audio_segment_start, audio_segment_end):
        if speaker_value_id == "Unknown":
            return
        if speaker_key_id not in self.mapping_dict.keys():
            self.mapping_dict[speaker_key_id] = {speaker_value_id: audio_segment_end - audio_segment_start}
        else:
            if speaker_value_id not in self.mapping_dict[speaker_key_id].keys():
                self.mapping_dict[speaker_key_id][speaker_value_id] = audio_segment_end - audio_segment_start
            else:
                self.mapping_dict[speaker_key_id][speaker_value_id] += audio_segment_end - audio_segment_start

    def get_max_mapping(self, speaker_key_id):
        if speaker_key_id not in self.mapping_dict.keys():
            return None

        return max(self.mapping_dict[speaker_key_id], key=self.mapping_dict[speaker_key_id].get)


class DiarizationOutput:
    def __init__(
        self, audio_diarization_03, audio_diarization_06, audio_diarization_09, audio_diarization_12, video_diarization
    ):
        self.audio_diarization_03 = audio_diarization_03
        self.audio_diarization_06 = audio_diarization_06
        self.audio_diarization_09 = audio_diarization_09
        self.audio_diarization_12 = audio_diarization_12
        self.video_diarization = video_diarization
        self.audio_segments: list[PyannoteAudioSegment] = self.get_audio_segments()
        self.ad03_video_mapping: MappingClass = self.perform_audio_video_mapping("ad_03")
        self.ad06_video_mapping: MappingClass = self.perform_audio_video_mapping("ad_06")
        self.ad09_video_mapping: MappingClass = self.perform_audio_video_mapping("ad_09")
        self.ad12_video_mapping: MappingClass = self.perform_audio_video_mapping("ad_12")
        self.predict_unknowns()

    def get_corresponding_speaker(self, present_group_id, target_group_id, speaker_id):

        if present_group_id == "ad_03":
            groups_not_allowed = ["ad_03"]
        elif present_group_id == "ad_06":
            groups_not_allowed = ["ad_03", "ad_06"]
        elif present_group_id == "ad_09":
            groups_not_allowed = ["ad_03", "ad_06", "ad_09"]
        elif present_group_id == "ad_12":
            groups_not_allowed = ["ad_03", "ad_06", "ad_09", "ad_12"]
        elif present_group_id == "vd":
            groups_not_allowed = ["vd"]
        else:
            print(f"Invalid present group id - {present_group_id}")
            return None

        if target_group_id in groups_not_allowed:
            print(f"For present_group {present_group_id}, target_group should not be among {groups_not_allowed}")
            return None

        present_group = self.get_group_by_id(present_group_id)
        if speaker_id not in present_group.keys():
            print(f"Speaker_id {speaker_id} not in present group {present_group_id} keys - {present_group.keys()}")
            return None

        speaker_interval = present_group[speaker_id]
        target_group = self.get_group_by_id(target_group_id)

        return self.get_mapping(speaker_interval, target_group)

    def get_all_child_speakers(self, parent_group_id, child_group_id, parent_speaker_id):

        if parent_group_id == "ad_12":
            groups_not_allowed = ["ad_12"]
        elif parent_group_id == "ad_09":
            groups_not_allowed = ["ad_09", "ad_12"]
        elif parent_group_id == "ad_06":
            groups_not_allowed = ["ad_06", "ad_09", "ad_12"]
        elif parent_group_id == "ad_03":
            groups_not_allowed = ["ad_03", "ad_06", "ad_09", "ad_12"]
        elif parent_group_id == "vd":
            groups_not_allowed = ["vd"]
        else:
            print(f"Invalid parent group id - {parent_group_id}")
            return None

        if child_group_id in groups_not_allowed:
            print(f"For parent_group {child_group_id}, child_group should not be among {groups_not_allowed}")
            return None

        parent_group = self.get_group_by_id(parent_group_id)
        if parent_speaker_id not in parent_group.keys():
            print(
                f"parent_speaker_id {parent_speaker_id} not in parent_group {parent_group_id} keys - {parent_group.keys()}"
            )
            return None

        child_group = self.get_group_by_id(child_group_id)

        child_speakers = []

        for child_speaker_id in child_group.keys():
            _parent_speaker_id = self.get_corresponding_speaker(
                present_group_id=child_group_id, target_group_id=parent_group_id, speaker_id=child_speaker_id
            )

            if _parent_speaker_id == parent_speaker_id:
                child_speakers.append(child_speaker_id)

        return child_speakers

    def get_native_speakers_in_parent_group(self, child_group_id, parent_group_id, child_speaker_id):

        if child_group_id == "ad_03":
            groups_not_allowed = ["ad_03"]
        elif child_group_id == "ad_06":
            groups_not_allowed = ["ad_03", "ad_06"]
        elif child_group_id == "ad_09":
            groups_not_allowed = ["ad_03", "ad_06", "ad_09"]
        elif child_group_id == "ad_12":
            groups_not_allowed = ["ad_03", "ad_06", "ad_09", "ad_12"]
        elif child_group_id == "vd":
            groups_not_allowed = ["vd"]
        else:
            print(f"Invalid child group id - {child_group_id}")
            return None

        if parent_group_id in groups_not_allowed:
            print(f"For child_group {child_group_id}, parent_group should not be among {groups_not_allowed}")
            return None

        child_group = self.get_group_by_id(child_group_id)
        if child_speaker_id not in child_group.keys():
            print(
                f"child_speaker_id {child_speaker_id} not in child_group {child_speaker_id} keys - {child_group.keys()}"
            )
            return None

        parent_group = self.get_group_by_id(parent_group_id)

        parent_speaker_id = self.get_corresponding_speaker(
            target_group_id=parent_group_id, present_group_id=child_group_id, speaker_id=child_speaker_id
        )

        return self.get_all_child_speakers(
            parent_group_id=parent_group_id, child_group_id=child_group_id, parent_speaker_id=parent_speaker_id
        )

    def get_audio_segments(self):
        audio_segment_list = []
        segment_idx = 0
        for speaker_id in self.audio_diarization_03.keys():
            for speech_segment_start, speech_segment_end in self.audio_diarization_03[speaker_id]:
                audio_segment = PyannoteAudioSegment(
                    segment_idx=segment_idx,
                    audio_segment_start=speech_segment_start,
                    audio_segment_end=speech_segment_end,
                    ad_03_speaker=speaker_id,
                    ad_06_speaker=self.get_corresponding_speaker(
                        present_group_id="ad_03", target_group_id="ad_06", speaker_id=speaker_id
                    ),
                    ad_09_speaker=self.get_corresponding_speaker(
                        present_group_id="ad_03", target_group_id="ad_09", speaker_id=speaker_id
                    ),
                    ad_12_speaker=self.get_corresponding_speaker(
                        present_group_id="ad_03", target_group_id="ad_12", speaker_id=speaker_id
                    ),
                    vd_speaker=self.get_mapping([(speech_segment_start, speech_segment_end)], self.video_diarization),
                    has_overlap=False,
                )
                audio_segment_list.append(audio_segment)

                segment_idx += 1

        audio_segment_list.sort()

        # Check whether each audio segment has overlap with other audio segments
        for i in range(len(audio_segment_list)):
            for j in range(i + 1, len(audio_segment_list)):
                if audio_segment_list[i].audio_segment_end > audio_segment_list[j].audio_segment_start:
                    audio_segment_list[i].has_overlap = True
                    audio_segment_list[j].has_overlap = True
                    audio_segment_list[i].vd_speaker = "Unknown"
                    audio_segment_list[j].vd_speaker = "Unknown"

        return audio_segment_list

    def perform_audio_video_mapping(self, audio_group_id):
        mapping_class = MappingClass()
        for audio_segment in self.audio_segments:
            if audio_group_id == "ad_03":
                mapping_class.add_mapping(
                    speaker_key_id=audio_segment.ad_03_speaker,
                    speaker_value_id=audio_segment.vd_speaker,
                    audio_segment_start=audio_segment.audio_segment_start,
                    audio_segment_end=audio_segment.audio_segment_end,
                )
            elif audio_group_id == "ad_06":
                mapping_class.add_mapping(
                    speaker_key_id=audio_segment.ad_06_speaker,
                    speaker_value_id=audio_segment.vd_speaker,
                    audio_segment_start=audio_segment.audio_segment_start,
                    audio_segment_end=audio_segment.audio_segment_end,
                )
            elif audio_group_id == "ad_09":
                mapping_class.add_mapping(
                    speaker_key_id=audio_segment.ad_09_speaker,
                    speaker_value_id=audio_segment.vd_speaker,
                    audio_segment_start=audio_segment.audio_segment_start,
                    audio_segment_end=audio_segment.audio_segment_end,
                )
            elif audio_group_id == "ad_12":
                mapping_class.add_mapping(
                    speaker_key_id=audio_segment.ad_12_speaker,
                    speaker_value_id=audio_segment.vd_speaker,
                    audio_segment_start=audio_segment.audio_segment_start,
                    audio_segment_end=audio_segment.audio_segment_end,
                )
            else:
                print(f"Invalid audio_group_id - {audio_group_id}")
                return None

        return mapping_class

    def predict_unknowns(self):
        for audio_segment in self.audio_segments:
            if audio_segment.vd_speaker == "Unknown":
                vd_speaker = self.ad03_video_mapping.get_max_mapping(audio_segment.ad_03_speaker)
                if vd_speaker is not None:
                    audio_segment.vd_speaker = vd_speaker
                    self.ad03_video_mapping.add_mapping(
                        speaker_key_id=audio_segment.ad_03_speaker,
                        speaker_value_id=vd_speaker,
                        audio_segment_start=audio_segment.audio_segment_start,
                        audio_segment_end=audio_segment.audio_segment_end,
                    )
                else:
                    audio_segment.vd_speaker = f"Unknown_{audio_segment.ad_03_speaker}"
#                     vd_speaker = self.ad06_video_mapping.get_max_mapping(audio_segment.ad_06_speaker)
#                     if vd_speaker is not None:
#                         audio_segment.vd_speaker = vd_speaker
#                         self.ad06_video_mapping.add_mapping(
#                             speaker_key_id=audio_segment.ad_06_speaker,
#                             speaker_value_id=vd_speaker,
#                             audio_segment_start=audio_segment.audio_segment_start,
#                             audio_segment_end=audio_segment.audio_segment_end,
#                         )
#                     else:
#                         audio_segment.vd_speaker = f"Unknown_{audio_segment.ad_06_speaker}"
#                         vd_speaker = self.ad09_video_mapping.get_max_mapping(audio_segment.ad_09_speaker)
#                         if vd_speaker is not None:
#                             audio_segment.vd_speaker = vd_speaker
#                             self.ad09_video_mapping.add_mapping(
#                                 speaker_key_id=audio_segment.ad_09_speaker,
#                                 speaker_value_id=vd_speaker,
#                                 audio_segment_start=audio_segment.audio_segment_start,
#                                 audio_segment_end=audio_segment.audio_segment_end,
#                             )
#                         else:
#                             audio_segment.vd_speaker = f"Unknown_{audio_segment.ad_09_speaker}"
#                             vd_speaker = self.ad12_video_mapping.get_max_mapping(audio_segment.ad_12_speaker)
#                             if vd_speaker is not None:
#                                 audio_segment.vd_speaker = vd_speaker
#                                 self.ad12_video_mapping.add_mapping(
#                                     speaker_key_id=audio_segment.ad_12_speaker,
#                                     speaker_value_id=vd_speaker,
#                                     audio_segment_start=audio_segment.audio_segment_start,
#                                     audio_segment_end=audio_segment.audio_segment_end,
#                                 )
#                             else:
#                                 audio_segment.vd_speaker = f"Unknown_{audio_segment.ad_12_speaker}"

    def get_diarization_output(self):
        diarize_output = {}
        for audio_segment in self.audio_segments:
            if audio_segment.vd_speaker not in diarize_output.keys():
                diarize_output[audio_segment.vd_speaker] = []
            diarize_output[audio_segment.vd_speaker].append(
                (audio_segment.audio_segment_start, audio_segment.audio_segment_end)
            )
        return diarize_output

    def get_group_by_id(self, group_id):
        if group_id == "ad_03":
            return self.audio_diarization_03
        if group_id == "ad_06":
            return self.audio_diarization_06
        if group_id == "ad_09":
            return self.audio_diarization_09
        if group_id == "ad_12":
            return self.audio_diarization_12
        if group_id == "vd":
            return self.video_diarization

    @staticmethod
    def get_mapping(speaker_interval, target_group):
        max_overlap = 0
        max_overlap_speaker = "Unknown"
        for speaker_id in target_group.keys():
            result_overlap, _ = find_overlap(
                speaker_interval,
                target_group[speaker_id],
            )
            if result_overlap > max_overlap:
                max_overlap = result_overlap
                max_overlap_speaker = speaker_id

        return max_overlap_speaker


def find_overlap(intervals1, intervals2):
    overlap = 0
    total_duration1 = 0
    total_duration2 = 0

    # Calculate the total duration of intervals in intervals1
    for start, end in intervals1:
        total_duration1 += end - start

    # Calculate the total duration of intervals in intervals2 and find the overlap
    for start, end in intervals2:
        total_duration2 += end - start
        for s1, e1 in intervals1:
            common_start = max(s1, start)
            common_end = min(e1, end)
            if common_start < common_end:
                overlap += common_end - common_start

    # Calculate the percentage of overlap with respect to intervals1
    percentage_overlap1 = (overlap / total_duration1) * 100

    # Calculate the percentage of overlap with respect to intervals2
    percentage_overlap2 = (overlap / total_duration2) * 100

    return percentage_overlap1, percentage_overlap2


diarize_output = DiarizationOutput(
    audio_diarization_03=ad03_1j20,
    audio_diarization_06=ad06_1j20,
    audio_diarization_09=ad09_1j20,
    audio_diarization_12=ad12_1j20,
    video_diarization=vdo_1j20,
)


final_diarization_output = diarize_output.get_diarization_output()

In [None]:
pyannote_combined = convert_diarization_output_to_pyannote(final_diarization_output)
pyannote_combined

In [None]:
gt_1j20

In [None]:
final_diarization_output

In [None]:
pyannote_gt

In [None]:
metric(pyannote_gt, pyannote_combined)

In [None]:
temp_ground_truth_rttm_file = '../../../AVA-AVD/dataset/rttms/1j20qq1JyX4_c_01.rttm'
temp_diarization = convert_rttm_to_diarization(temp_ground_truth_rttm_file, 900.488)
convert_diarization_output_to_pyannote(temp_diarization)

In [None]:
diarize_output.get_all_child_speakers("vd", "ad_03", "SPEAKER_00")

In [None]:
def create_annotation_plot(
    diarization_output,
    save_path,
    video_name,
    video_duration,
    plot_name="diarization",
    offset = 0
):
    custom_diarization = Annotation()

    for speaker_key in diarization_output.keys():
        for timeline in diarization_output[speaker_key]:
            custom_diarization[Segment(timeline[0], timeline[1])] = speaker_key

    # Create a figure
    fig, ax = plt.subplots(figsize=(10, 2))

    # Plot the custom diarization result
    nb = Notebook()
    nb.plot_annotation(custom_diarization, ax, legend=True)

    # Customize the plot
    ax.set_xlabel("Time")
    ax.set_yticks([])  # To hide the y-axis
    ax.set_xlim(offset, video_duration + offset)

    # Save the figure
    saveFileName = os.path.join(save_path, f"{video_name}_{plot_name}.png")
    fig.savefig(saveFileName, bbox_inches="tight")
    # Close the figure
    ax.clear()
    plt.close(fig)
    
# def convert_diarization_output_to_pyannote(diarize_output, video_name, offset):
#     annotation = Annotation()
    
#     _diarization_output = diarize_output[video_name]
#     diarize_dict = {}
    
#     for speaker, timelines in _diarization_output.items():
#         for timeline_start, timeline_end in timelines:
#             annotation[Segment(timeline_start + offset, timeline_end + offset)] = speaker
#             if speaker not in diarize_dict.keys():
#                 diarize_dict[speaker] = [(timeline_start + offset, timeline_end + offset)]
#             else:
#                 diarize_dict[speaker].append((timeline_start + offset, timeline_end + offset))

#     return annotation, diarize_dict

# def _convert_diarization_output_to_pyannote(diarize_output):
#     annotation = Annotation()
    
#     for speaker, timelines in diarize_output.items():
#         for timeline_start, timeline_end in timelines:
#             annotation[Segment(timeline_start, timeline_end)] = speaker

#     return annotation

In [None]:
offset = 899.993
videoDuration = len(AudioSegment.from_file(audio_path)) / 1000

audio_path = "../../output/video_temp/1j20qq1JyX4_c_01/pywav/audio.wav"
pretrained_pipeline.instantiate({
    "clustering" : {
        "min_cluster_size": 3
    }
})
pretrained_pipeline.parameters(instantiated=True)
diarization = pretrained_pipeline(audio_path)
ado_1j20_03 = {}
for duration,_, speaker_key in diarization.itertracks(yield_label=True):
    if speaker_key in ado_1j20_03.keys():
        ado_1j20_03[speaker_key].append((duration.start + offset,duration.end + offset))
    else:
        ado_1j20_03[speaker_key] = [(duration.start + offset,duration.end + offset)]
        

pretrained_pipeline.instantiate({
    "clustering" : {
        "min_cluster_size": 6
    }
})
pretrained_pipeline.parameters(instantiated=True)
diarization = pretrained_pipeline(audio_path)
ado_1j20_06 = {}
for duration,_, speaker_key in diarization.itertracks(yield_label=True):
    if speaker_key in ado_1j20_06.keys():
        ado_1j20_06[speaker_key].append((duration.start + offset,duration.end + offset))
    else:
        ado_1j20_06[speaker_key] = [(duration.start + offset,duration.end + offset)]
        
pretrained_pipeline.instantiate({
    "clustering" : {
        "min_cluster_size": 9
    }
})
pretrained_pipeline.parameters(instantiated=True)
diarization = pretrained_pipeline(audio_path)
ado_1j20_09 = {}
for duration,_, speaker_key in diarization.itertracks(yield_label=True):
    if speaker_key in ado_1j20_09.keys():
        ado_1j20_09[speaker_key].append((duration.start + offset,duration.end + offset))
    else:
        ado_1j20_09[speaker_key] = [(duration.start + offset,duration.end + offset)]

        
pretrained_pipeline.instantiate({
    "clustering" : {
        "min_cluster_size": 12
    }
})
pretrained_pipeline.parameters(instantiated=True)
diarization = pretrained_pipeline(audio_path)
ado_1j20_12 = {}
for duration,_, speaker_key in diarization.itertracks(yield_label=True):
    if speaker_key in ado_1j20_12.keys():
        ado_1j20_12[speaker_key].append((duration.start + offset,duration.end + offset))
    else:
        ado_1j20_12[speaker_key] = [(duration.start + offset,duration.end + offset)]

In [None]:
pretrained_pipeline.instantiate({
    "clustering" : {
        "min_cluster_size": 1
    }
})
pretrained_pipeline.parameters(instantiated=True)
diarization = pretrained_pipeline(audio_path)
ado_1j20_02 = {}
for duration,_, speaker_key in diarization.itertracks(yield_label=True):
    if speaker_key in ado_1j20_02.keys():
        ado_1j20_02[speaker_key].append((duration.start + offset,duration.end + offset))
    else:
        ado_1j20_02[speaker_key] = [(duration.start + offset,duration.end + offset)]
create_annotation_plot(ado_1j20_02, ".", "1j20_02", videoDuration, offset=offset)

In [None]:
create_annotation_plot(ado_1j20_03, ".", "1j20_03", videoDuration, offset=offset)
create_annotation_plot(ado_1j20_06, ".", "1j20_06", videoDuration, offset=offset)
create_annotation_plot(ado_1j20_09, ".", "1j20_09", videoDuration, offset=offset)
create_annotation_plot(ado_1j20_12, ".", "1j20_12", videoDuration, offset=offset)

In [None]:
vdo = pickle.load(open("../../output/run_output/video_diarization_output_AVA_AVD.pckl", "rb"))
pyannote_vdo, vdo_1j20 = convert_diarization_output_to_pyannote(vdo, "1j20qq1JyX4_c_01", offset)

In [None]:
create_annotation_plot(vdo_1j20, ".", "vdo_1j20", videoDuration, offset=offset)

In [None]:
diarization = pretrained_pipeline(audio_path)
ado_1j20_06 = {}
for duration,_, speaker_key in diarization.itertracks(yield_label=True):
    if speaker_key in ado_1j20_06.keys():
        ado_1j20_06[speaker_key].append((duration.start,duration.end))
    else:
        ado_1j20_06[speaker_key] = [(duration.start,duration.end)]

In [None]:
vdo_1j20

In [None]:
ado_1j20_09.keys()

In [None]:
vdo_1j20

In [None]:
videoDuration

In [None]:
va_mapping

In [None]:
vdo_1j20

In [None]:
pyanote_03 = _convert_diarization_output_to_pyannote(ado_1j20_03)
pyanote_06 = _convert_diarization_output_to_pyannote(ado_1j20_06)
pyanote_09 = _convert_diarization_output_to_pyannote(ado_1j20_09)
pyanote_12 = _convert_diarization_output_to_pyannote(ado_1j20_12)

In [None]:
vdo_1j20_annotation = convert_diarization_output_to_pyannote(vdo_1j20, 899.993)

In [None]:
def convert_rttm_to_pyannote(rttm_file, offset):
    # Read RTTM file into pandas DataFrame
    rttm_df = pd.read_csv(rttm_file, sep=' ', header=None,
                          names=['temp', 'file_name', 'channel', 'start', 'duration', 'NA_1', 'NA_2', 'speaker_label', 'NA_3', 'NA_4'])

    # Initialize Pyannote annotation
    annotation = Annotation()
    diarize_dict = {}

    # Iterate over RTTM rows and add segments to Pyannote annotation
    for _, row in rttm_df.iterrows():
        start_time = row['start']
        end_time = start_time + row['duration']
        label = row['speaker_label']
        annotation[Segment(start_time, end_time)] = label
        if label not in diarize_dict.keys():
            diarize_dict[label] = [(start_time, end_time)]
        else:
            diarize_dict[label].append((start_time, end_time))

    return annotation, diarize_dict

In [None]:
# Example usage
rttm_file = '../../../AVA-AVD/dataset/rttms/1j20qq1JyX4_c_01.rttm'
ground_truth, diarize_dict = convert_rttm_to_pyannote(rttm_file, offset=899.993)

In [None]:
create_annotation_plot(diarize_dict, ".", "1j20_gt", videoDuration, offset=offset)

In [None]:
with open('1j20_pyannote_vdo.rttm','w') as f:
    pyannote_vdo.write_rttm(f)

In [None]:
diarize_dict

In [None]:
ado

In [None]:
# Example usage
rttm_file = '../../../AVA-AVD/save/token/avaavd/rttms/1j20qq1JyX4_c_01.rttm'
predicted, predicted_dict = convert_rttm_to_pyannote(rttm_file, offset=899.993)

In [None]:
create_annotation_plot(predicted_dict, ".", "1j20_predicted", videoDuration, offset=offset)

In [None]:
from pyannote.metrics.diarization import DiarizationErrorRate
metric = DiarizationErrorRate()

In [None]:
metric(ground_truth, predicted)

In [None]:
metric(ground_truth, pyanote_1)

In [None]:
ado_1j20_03

In [None]:
def get_mapping(diarization_result_1, diarization_result_2):
    result_mapping = {}
    unknown_speaker_count = 0
    for speaker_result_2 in diarization_result_2.keys():
        max_overlap = 0
        max_overlap_speaker = None
        for speaker_result_1 in diarization_result_1.keys():
            result_overlap, _ = find_overlap(
                diarization_result_1[speaker_result_1],
                diarization_result_2[speaker_result_2],
            )
            if result_overlap > max_overlap:
                max_overlap = result_overlap
                max_overlap_speaker = speaker_result_1

        if max_overlap_speaker is None or max_overlap < 10:
            result_mapping[speaker_result_2] = f"Unknown_{unknown_speaker_count}"
            unknown_speaker_count += 1
        else:
            result_mapping[speaker_result_2] = max_overlap_speaker

    return result_mapping

def get_audio_video_mapping(final_video_output, final_audio_output):

    final_audio_video_mapping = get_mapping(
            final_video_output, final_audio_output
        )
    
    final_video_audio_mapping = get_mapping(
        final_audio_output, final_video_output
    )

    return final_audio_video_mapping, final_video_audio_mapping

In [None]:
au_vi_mapping, vi_au_mapping = get_audio_video_mapping(ado_1j20_03, vdo_1j20)

In [None]:
au_vi_mapping

In [None]:
vi_au_mapping

In [None]:
vdo_1j20

In [None]:
dia_output = DiarizationOutput(
    audio_diarization_03=ado_1j20_03,
    audio_diarization_06=ado_1j20_06,
    audio_diarization_09=ado_1j20_09,
    audio_diarization_12=ado_1j20_12,
    video_diarization=vdo_1j20
)

In [None]:
dia_output.get_corresponding_speaker(present_group_id="ad_09", target_group_id="ad_12", speaker_id="SPEAKER_04")

In [None]:
dia_output.get_all_child_speakers(parent_group_id="ad_12", child_group_id="ad_09", parent_speaker_id="SPEAKER_03")

In [None]:
dia_output.get_native_speakers_in_parent_group(parent_group_id="ad_12", child_group_id="ad_09", child_speaker_id="SPEAKER_06")