In [1]:
import pandas as pd
import getpass
import torch
import soundfile as sf
from IPython.display import Audio
from pyannote.audio import Pipeline



In [3]:
dataset = pd.read_csv('diarized_tag_afrispeech_dialog_v1.1_49.csv')

In [4]:
HUGGINGFACE_ACCESS_TOKEN = getpass.getpass(prompt="Enter huggingface access token: ")

Enter huggingface access token: ········


In [5]:
def merge_consecutive_segments(diarization):
  merged_segments = []
  last_start, last_end, last_speaker = None, None, None

  for turn, _, speaker in diarization.itertracks(yield_label=True):
      start_time = turn.start
      end_time = turn.end

      #check if the current segment should be merged with the previous one
      if last_speaker is not None and speaker == last_speaker:
          #extend the end time of the last segment
          last_end = end_time
      else:
          #if it's not the first iteration and there's a previous segment to save
          if last_speaker is not None:
              merged_segments.append((last_start, last_end, last_speaker))
          #update last segment trackers
          last_start, last_end, last_speaker = start_time, end_time, speaker

  #add the last segment after exiting the loop
  if last_speaker is not None:
      merged_segments.append((last_start, last_end, last_speaker))

  return merged_segments

In [6]:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=HUGGINGFACE_ACCESS_TOKEN)
pipeline.to(torch.device("cuda"))


In [8]:
diarization = pipeline("/data3/mardhiyah/AfriSpeech-Dialog/data/medical/247554f8-f233-4861-bc1a-8fc327b5d5df_2b500b633e5d5ecce35433cbbb859ddc_8bW4oSXn.wav")
pred_segments = merge_consecutive_segments(diarization)

In [9]:
pred_segments

[(0.70596875, 7.06784375, 'SPEAKER_01'),
 (7.06784375, 9.632843750000003, 'SPEAKER_00'),
 (9.66659375, 10.392218750000001, 'SPEAKER_01'),
 (10.45971875, 12.214718750000003, 'SPEAKER_00'),
 (12.29909375, 16.90596875, 'SPEAKER_01'),
 (17.17596875, 19.90971875, 'SPEAKER_00'),
 (19.90971875, 20.550968750000003, 'SPEAKER_01'),
 (20.027843750000002, 25.191593750000003, 'SPEAKER_00'),
 (25.44471875, 32.17784375, 'SPEAKER_01'),
 (32.59971875, 39.41721875, 'SPEAKER_00'),
 (39.41721875, 41.57721875, 'SPEAKER_01'),
 (41.13846875, 42.64034375, 'SPEAKER_00'),
 (42.336593750000006, 54.250343750000006, 'SPEAKER_01'),
 (54.55409375, 60.35909375000001, 'SPEAKER_00'),
 (60.51096875, 67.73346875, 'SPEAKER_01'),
 (68.77971875, 71.54721875, 'SPEAKER_00'),
 (71.54721875, 72.45846875000001, 'SPEAKER_01'),
 (72.03659375000001, 73.11659375, 'SPEAKER_00'),
 (73.74096875000001, 78.21284375, 'SPEAKER_01'),
 (78.75284375000001, 86.48159375, 'SPEAKER_00'),
 (80.55846875, 91.79721875, 'SPEAKER_01'),
 (92.59034375, 9

In [11]:
def get_audio_duration(file_path):
    audio_data, sample_rate = sf.read(file_path)
    duration = len(audio_data) / sample_rate
    return duration


In [12]:
get_audio_duration('/data3/mardhiyah/AfriSpeech-Dialog/data/medical/247554f8-f233-4861-bc1a-8fc327b5d5df_2b500b633e5d5ecce35433cbbb859ddc_8bW4oSXn.wav')

160.57199546485262

In [14]:
# dataset['path'] = dataset['path'].apply(lambda x: '/data3/mardhiyah/AfriSpeech-Dialog'+x)
dataset.head()

Unnamed: 0,audio_id,path,transcript,age_group,gender,accent,domain,country,audio_duration,doctor_summary,all_timestamps
0,,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,"Speaker 1: My name is justina, I’m a medical d...",26-40,Male,Hausa,OSCE-Doctor-Patient,NG,240.0,Joel Solomon is a 45-year-old male who has bee...,False
1,f9bd30b2-684d-470e-b65c-ced156da0bbb,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,"\r\n\r\n00:04:08\r\n[Speaker 1]: Hi, I am Dr. ...",26-40,Female,Yoruba,OSCE-Doctor-Patient,NG,235.932993,"Habemuko, a 35 year old, presents with rectal ...",False
2,78fb7265-0373-49d5-b933-c28c95cf3eea,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:01:06\r\n[Speaker 1]: Are you there?\r\n00:...,26-40,Male,Yoruba,OSCE-Doctor-Patient,NG,906.433991,"Mr. Ade, a 35 year old unmarried Yoruba male t...",True
3,,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,[Speaker 1]: My name is dorathy and i am conse...,26-40,Male,Hausa,OSCE-Doctor-Patient,NG,240.0,"Chidinma Okunmi, a 25 year old Igbo student, p...",False
4,,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,[Speaker 1]: Hello. How are you?\r\n[Speaker 2...,26-40,Male,Hausa,OSCE-Doctor-Patient,NG,240.0,Mr. Emeka is a 65 year old male who presents w...,False


In [31]:
diarize_dataset = dataset[dataset['all_timestamps']==True].reset_index()

In [17]:
# len(diarize_dataset)

In [37]:
from tqdm import tqdm

diarize_dataset['pred_segments'] = ''
for i, path_ in tqdm(enumerate(diarize_dataset['path']), total=len(diarize_dataset['path']), desc="Processing"):
    diarization = pipeline(path_)
    pred_segments = merge_consecutive_segments(diarization)
    
    diarize_dataset.at[i, 'pred_segments'] = pred_segments


Processing:  30%|████████▍                   | 10/33 [30:48<1:10:51, 184.83s/it]


ValueError: File /data3/mardhiyah/AfriSpeech-Dialogdata/non-medical/30a5b6a2-8370-4dac-bc7f-e4ae7c1d8562_a4a6daf501c5da95d8f941189e8b6856_NowwEqPI.wav does not exist

In [41]:
medical_diarized_ds = diarize_dataset[diarize_dataset['domain']!='OSCE-Doctor-Patient'].reset_index()
medical_diarized_ds['path'] = medical_diarized_ds['path'].apply(lambda x: str(x).replace('Dialogdata', 'Dialog/data'))

In [42]:
medical_diarized_ds['pred_segments'] = ''
for i, path_ in tqdm(enumerate(medical_diarized_ds['path']), total=len(medical_diarized_ds['path']), desc="Processing"):
    diarization = pipeline(path_)
    pred_segments = merge_consecutive_segments(diarization)
    
    medical_diarized_ds.at[i, 'pred_segments'] = pred_segments

Processing: 100%|████████████████████████████| 23/23 [2:03:16<00:00, 321.60s/it]


In [43]:
osce_diarized_ds = diarize_dataset[diarize_dataset['domain']=='OSCE-Doctor-Patient'].reset_index()

In [44]:
#join the dataframes back together vertically
result = pd.concat([osce_diarized_ds, medical_diarized_ds], axis=0, ignore_index=True).reset_index()

In [54]:
# result.drop(columns=['level_0'], inplace=True)

In [55]:
#confirm 1, 26, 28 are the incomplete diarization and then remove them
# print(result.iloc[28]['transcript'])
final_diarized_ds = result.drop(index=[1,26,28]).reset_index(drop=True)

In [72]:
def convert_time_to_seconds(timestamp):
    # Split the timestamp into minutes, seconds, and milliseconds
    minutes, seconds, milliseconds = map(float, timestamp.split(':'))
    # Convert the time to seconds (including fractional part from milliseconds)
    total_seconds = minutes * 60 + seconds + milliseconds / 1000
    return total_seconds


def extract_segments(transcript):
    # Regular expression to match the timestamp and speaker tag
    timestamp_pattern = r'(\d{2}:\d{2}:\d{2})'
    speaker_pattern = r'\[([^\]]+)\]'
    
    lines = transcript.strip().splitlines()
    segments = []
    
    start_time = None
    speaker_tag = None
    
    for i in range(len(lines)):
        if re.match(timestamp_pattern, lines[i]):  # Line is a timestamp
            if start_time and speaker_tag:
                # If we have both start and speaker, the current timestamp is the end time
                end_time = convert_time_to_seconds(lines[i])
                segments.append((start_time, end_time, speaker_tag))
                start_time = None
                speaker_tag = None
            # Set the new start time, converting to seconds
            start_time = convert_time_to_seconds(lines[i])
        elif re.match(speaker_pattern, lines[i]):  # Line contains a speaker tag
            speaker_tag = re.findall(speaker_pattern, lines[i])[0]
    
    return segments


In [78]:
#ensure new line before speaker tags
# final_diarized_ds['transcript'] = final_diarized_ds['transcript'].apply(lambda x: str(x).replace('[', '\r\n['))
final_diarized_ds['ref_segments'] = final_diarized_ds['transcript'].apply(lambda x: extract_segments(x))


In [82]:
final_diarized_ds

Unnamed: 0,index,audio_id,path,transcript,age_group,gender,accent,domain,country,audio_duration,doctor_summary,all_timestamps,pred_segments,ref_segments
0,2,78fb7265-0373-49d5-b933-c28c95cf3eea,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:01:06\r\n\r\n[Speaker 1]: Are you there?\r\...,26-40,Male,Yoruba,OSCE-Doctor-Patient,NG,906.433991,"Mr. Ade, a 35 year old unmarried Yoruba male t...",True,"[(0.03096875, 1.44846875, SPEAKER_01), (3.9459...","[(1.006, 1.027, Speaker 1), (4.0, 8.038, Speak..."
1,6,6d5f1e84-dbc5-42d3-898e-bfe9c2cfe166,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,"00:00:06\r\n\r\n[Speaker 1]: Good evening, I'm...",26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,204.622993,Johnny Dobra is a 35-year-old male who present...,True,"[(0.03096875, 2.61284375, SPEAKER_00), (2.6128...","[(0.006, 2.028, Speaker 1), (2.049, 6.009, Spe..."
2,11,ab8fd39d-52bc-48fa-bae9-ff636627f1e5,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:00:08\r\n\r\n[Speaker 1]: My name is Nwachu...,26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,343.26,"The patient, a 45 year old female accontant, p...",True,"[(0.03096875, 2.7478437500000004, SPEAKER_01),...","[(0.008, 2.04, Speaker 1), (2.049, 5.023, Spea..."
3,12,5129fd8c-7b8c-4d05-a03a-196bcae4deff,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:00:50\r\n\r\n[Speaker 1]: My name is Chiama...,26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,337.587982,"John Peter, a 45-year-old construction worker,...",True,"[(0.7565937500000001, 3.87846875, SPEAKER_01),...","[(0.05, 3.044, Speaker 1), (4.016, 6.035, Spea..."
4,13,58efe9ef-6c76-4d17-b1b2-f397552a0c0e,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:01:13\r\n\r\n[Speaker 1]: Good day. I am Do...,26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,213.635986,"Peter, a 45-year-old male, presented to the ho...",True,"[(1.0434687500000002, 32.025968750000004, SPEA...","[(1.013, 5.006, Speaker 1), (5.028, 6.028, Spe..."
5,16,b6f48ed7-e629-4e50-af63-1c9067c34ef2,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,"00:00:35\r\n\r\n[Speaker 1]: Hello, how are yo...",26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,220.120998,"The patient, Chike, is a 35-year-old male pres...",True,"[(0.03096875, 219.11909375000002, SPEAKER_00)]","[(0.035, 1.022, Speaker 1), (2.016, 2.053, Spe..."
6,17,b9ffbd7f-d6a5-4511-a7aa-b1e0b52dfc03,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,"00:00:08\r\n\r\n[Speaker 1]: Hi, my name is Dr...",26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,359.890998,"\r\nMrs. Oge, a 26-year-old woman, presented w...",True,"[(0.03096875, 4.334093750000001, SPEAKER_01), ...","[(0.008, 4.011, Speaker 1), (5.01, 9.045, Spea..."
7,18,f2905988-70f5-42a3-86db-19066b25bc48,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:05:34\r\n\r\n[Speaker 1]: Good afternoon. M...,26-40,Male,Yoruba,OSCE-Doctor-Patient,NG,240.36898,"Mrs. Ayike, a 45-year-old female, presented wi...",True,"[(3.9797187500000004, 11.82659375, SPEAKER_01)...","[(5.034, 11.059, Speaker 1), (12.042, 15.038, ..."
8,19,9d1962f3-e879-452d-b57b-cf4a7f133cd8,/data3/mardhiyah/AfriSpeech-Dialog/data/medica...,00:00:04\r\n\r\n[Speaker 1]: my name is Precio...,26-40,Female,Urhobo,OSCE-Doctor-Patient,NG,444.86,"Gift Joseph, a 32-year-old Igala Christian far...",True,"[(0.03096875, 15.454718750000001, SPEAKER_00),...","[(0.004, 3.025, Speaker 1), (3.037, 6.025, Spe..."
9,20,a71d3ac2-bd1f-4a67-bbff-c9d8d678ae10,/data3/mardhiyah/AfriSpeech-Dialog/data/non-me...,00:00:52\r\n\r\n[Speaker 1]: My name is Abulor...,26-40,Female,Urhobo,Chit-Chat-NG,NG,1205.075986,Not Applicable,True,"[(0.8072187500000001, 4.11471875, SPEAKER_00),...","[(0.052, 4.006, Speaker 1), (4.052, 9.022, Spe..."


Calculate Metrics

In [83]:
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate

def create_pyannote_annotation(segments_list):
    annotation = Annotation()
    for start, end, speaker_tag in segments_list:
        segment = Segment(start, end)
        annotation[segment] = speaker_tag    
    return annotation

der_metric = DiarizationErrorRate()

In [85]:
for i, text in tqdm(enumerate(final_diarized_ds['transcript']), total=len(final_diarized_ds['transcript']), desc="Processing"):
    ref_annotation = create_pyannote_annotation(final_diarized_ds.iloc[i]['ref_segments'])
    pred_annotation = create_pyannote_annotation(final_diarized_ds.iloc[i]['pred_segments'])
    der = der_metric(ref_annotation, pred_annotation)
    print(f"DER: {100 * der:.2f}%")
#get abs value for whole dataset
ds_der = abs(der_metric)
print(f"Absolute DER for dataset: {100 * ds_der:.2f}%")

Processing:   0%|                                        | 0/30 [00:00<?, ?it/s]

DER: 22.31%
DER: 41.62%


Processing:   7%|██▏                             | 2/30 [00:00<00:01, 17.78it/s]

DER: 29.10%
DER: 46.78%
DER: 44.84%


Processing:  17%|█████▎                          | 5/30 [00:00<00:01, 23.89it/s]

DER: 40.92%
DER: 16.47%
DER: 21.34%
DER: 40.81%
DER: 5.27%
DER: 33.09%


Processing:  37%|███████████▎                   | 11/30 [00:00<00:00, 37.34it/s]

DER: 48.95%
DER: 4.27%
DER: 8.38%
DER: 7.48%
DER: 44.95%
DER: 7.35%
DER: 15.80%
DER: 27.63%
DER: 1.47%
DER: 10.45%


Processing:  70%|█████████████████████▋         | 21/30 [00:00<00:00, 60.43it/s]

DER: 51.84%
DER: 2.56%
DER: 5.71%
DER: 23.69%
DER: 1.95%
DER: 7.51%
DER: 37.14%
DER: 3.25%
DER: 48.28%


Processing: 100%|███████████████████████████████| 30/30 [00:00<00:00, 56.13it/s]


Absolute DER for dataset: 21.30%


In [86]:
final_diarized_ds.to_csv('pyannote_diarization_der_0.2130_30.csv', index=False)