In [1]:
from pyannote.audio import Pipeline
from pyannote.audio import Model
from pyannote.audio import Inference
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cdist
import pandas as pd
import wave
from scipy.io import wavfile
from tqdm import tqdm
import numpy as np
import torchaudio.transforms as T
from openai import OpenAI
import ast
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, accuracy_score
from utils import *
import json

  from .autonotebook import tqdm as notebook_tqdm


# 1. GPT-4o on Fixed-Window Transcripts

In [51]:
rolling_fragments_pred_paths = {
    case_id: f'results/rolling_fragments/predictions/LFB{case_id}/LFB{case_id} model=TextModel vad_threshold=0.3 fe-threshold=0.csv' for case_id in [1, 2, 9, 10, 18]
}
rolling_fragments_dfs = {
    case_id: pd.read_csv(path, index_col='fragment_id') for case_id, path in rolling_fragments_pred_paths.items()
}
openai_key_path = 'openai_api_key.txt'

def get_prompts(case_id, fragment_id):
    transcription = rolling_fragments_dfs[case_id].loc[fragment_id, 'transcription']
    
    system_prompt = """
You are a binary classifier that determines whether a given phrase contains delivery of feedback from a trainer to a trainee where the trainee is conducting urology surgery using the da Vinci robot. The dialogue is between two speakers, a trainer and a trainee. There can be 6 types of feedback:
                 
1. Anatomic: familiarity with anatomic structures and landmarks. i.e. 'Stay in the correct plane, between the 2 fascial layers.'
2. Procedura: pertains to timing and sequence of surgical steps. i.e. 'You can switch to the left side now.'
3. Technical: performnace of a discrete task with appropriate knowledge of factors including exposure, instruments, and traction. i.e. 'Buzz it.'
4. Praise: a positive remark. i.e. 'Good job.'
5. Criticism: a negative remark. i.e. 'It should never be like this.'
"""
    user_prompt = f"""
Classify whether the following phrase contains the delivery of feedback considering.

Format your response as follows. DO NOT DO ANY OTHER FORMATTING.:
{{'feedback': 'yes'}} if the dialogue contains feedback
{{'feedback': 'no'}} if the dialogue does not contain feedback

Phrase:
{transcription}

For example:
{{'feedback': 'yes'}}
"""
    return {
        'system': system_prompt,
        'user': user_prompt
    }
    
def detect_feetback(case_id, fragment_id, verbose=False):
    prompts = get_prompts(case_id, fragment_id)
    set_openai_key(openai_key_path)
    client = OpenAI()
    
    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": prompts['system']},
            {"role": "user", "content": prompts['user']}
        ],
        seed=42,
    )

    content = completion.choices[0].message.content

    print(content) if verbose else None
    
    try:
        classification = ast.literal_eval(content)
    except Exception as e:
        print(e)
        print(content)
        
        if 'yes' in content:
            classification = {'feedback': 'yes'}
        else:
            classification = {'feedback': 'no'}
    
    print(classification) if verbose else None
    
    return classification

In [3]:
# for each df add a new layer called gpt-4o_pred and have it be the classification from detect_feedback
for case_id, df in rolling_fragments_dfs.items():
    df['gpt-4o_pred'] = [detect_feetback(case_id, fragment_id) for fragment_id in tqdm(df.index, desc=f'Case {case_id}') if df.loc[fragment_id, 'transcription'] is not None]

Case 1: 100%|██████████| 2009/2009 [19:03<00:00,  1.76it/s]
Case 2: 100%|██████████| 919/919 [09:06<00:00,  1.68it/s]
Case 9: 100%|██████████| 1582/1582 [15:51<00:00,  1.66it/s]
Case 10: 100%|██████████| 1462/1462 [14:48<00:00,  1.65it/s]
Case 18: 100%|██████████| 1464/1464 [14:22<00:00,  1.70it/s]


In [122]:
import pickle
# pickle.dump(rolling_fragments_dfs, open('results/rolling_fragments/predictions/rolling_fragments_dfs_with_gpt-4o_pred.pkl', 'wb'))
rolling_fragments_dfs = pickle.load(open('results/rolling_fragments/predictions/rolling_fragments_dfs_with_gpt-4o_pred.pkl', 'rb'))

In [129]:
true_labels_paths = {
    case_id: f'results/rolling_fragments/true_labels/LFB{case_id} model=TextModel vad_threshold=0.3 fe-threshold=0.csv' for case_id in [1, 2, 9, 10, 18]
}
true_labels_dfs = {
    case_id: pd.read_csv(path) for case_id, path in true_labels_paths.items()
}

In [126]:
def get_metrics(case_id: int, true_labels_dfs: dict[int, pd.DataFrame], rolling_fragments_dfs: dict[int, pd.DataFrame]) -> dict:
    rolling_fragment_df = rolling_fragments_dfs[case_id].reset_index().dropna()
    rolling_fragment_df['secs'] = rolling_fragment_df['start_time'].apply(lambda x: int(x.split(':')[-1]) + 60*int(x.split(':')[-2]) + 60*60*int(x.split(':')[-3]))
    true_labels_df = true_labels_dfs[case_id].reset_index().dropna()
    
    df = pd.merge(true_labels_df, rolling_fragment_df, on='secs')
    df['gpt-4o_pred'] = df['gpt-4o_pred'].apply(lambda x: 1. if 'yes' in x['feedback'] else 0.)
    y_true = df['fb_instance'].values
    y_pred_gpt_4o = df['gpt-4o_pred'].values
    y_pred_bert = df['pred'].values
    
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred_gpt_4o, average='binary')
    roc_auc = roc_auc_score(y_true, y_pred_gpt_4o)
    accuracy = accuracy_score(y_true, y_pred_gpt_4o)
    gpt4o_metrics = {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'accuracy': accuracy
    }
    
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred_bert, average='binary')
    roc_auc = roc_auc_score(y_true, y_pred_bert)
    accuracy = accuracy_score(y_true, y_pred_bert)
    bert_metrics = {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'accuracy': accuracy
    }
    return {
        'gpt-4o': gpt4o_metrics,
        'bert': bert_metrics
    }

metrics = {case_id: get_metrics(case_id, true_labels_dfs, rolling_fragments_dfs) for case_id in [1, 2, 9, 10, 18]}

# Average metrics
gpt4o_metrics = {k: (np.mean([v['gpt-4o'][k] for v in metrics.values()]), np.std([v['gpt-4o'][k] for v in metrics.values()])) for k in metrics[1]['gpt-4o'].keys()}
print('GPT-4o Metrics')
print(json.dumps(gpt4o_metrics, indent=4))

GPT-4o Metrics
{
    "precision": [
        0.5971557453987111,
        0.1621322716641781
    ],
    "recall": [
        0.6177625257897931,
        0.059592782229525874
    ],
    "f1": [
        0.6021951935632186,
        0.1102710011858039
    ],
    "roc_auc": [
        0.6865911150959215,
        0.029165390909575643
    ],
    "accuracy": [
        0.7176570425240889,
        0.04314773925405801
    ]
}


# 2. BERT on Dialogue Reconstruction

In [1]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertForSequenceClassification
from transformers import set_seed, TrainingArguments, Trainer, BertForSequenceClassification
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from tqdm import tqdm
import numpy as np

from models import TextModel
from models.dataset import TextDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_state = 42
set_seed(random_state)
auxes = {
    'dialogue': 'dialogue',
    'hallucination removal': 'reduced hallucinations',
    'trainee/trainer id': 'all phrases'
}
paths = {}
case_ids = [1, 2, 6, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 26, 28, 29, 33]
test_case_ids = [1, 2, 9, 10, 18]
for case_id in case_ids:
    paths[case_id] = {}
    for k, aux in auxes.items():
        paths[case_id][k] = f"./results/extract_dialogue/aligned_fb_detection/LFB{case_id}_full '{aux}'.csv"

dfs = {case_id: {k: pd.read_csv(v) for k, v in paths[case_id].items()} for case_id in case_ids}

In [16]:
def parse_context_dialogue(context_dialogue):
    parsed = ' '.join([x[x.index('[')+1:x.index(']')].replace("'", '') for x in context_dialogue.split('\n')[1:-1]])
    # parsed = context_dialogue
    return parsed

In [30]:
def compute_metrics(eval_pred):
    if isinstance(eval_pred, dict):
        labels = eval_pred['label_ids']
        preds = eval_pred['predictions'].argmax(-1)
    else:
        labels = eval_pred.label_ids
        preds = eval_pred.predictions.argmax(-1)

    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    roc_auc = roc_auc_score(labels, preds)
    
    metrics =  {'accuracy': accuracy,
                'roc_auc': roc_auc,
                'precision': precision,
                'recall': recall,
                'f1': f1}
    return metrics

def train(
    output_dir: str,
    epochs: int,
    batch_size: int,
    warmup_steps: int,
    weight_decay: float,
    eval_save_strategy: str,
    save_steps: int,
    eval_steps: int,
    metric_for_best_model: str,
    report_to: str,
    seed: int,
    lr_scheduler_type: str,
    lr_scheduler_kwargs: dict,
    model: nn.Module,
    train_dataset: torch.utils.data.Dataset,
    eval_dataset: torch.utils.data.Dataset,
):
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        warmup_steps=warmup_steps,
        weight_decay=weight_decay,
        eval_strategy=eval_save_strategy,
        save_strategy=eval_save_strategy,
        save_steps=save_steps,
        eval_steps=eval_steps,
        load_best_model_at_end=True,
        metric_for_best_model=metric_for_best_model,
        report_to=report_to,
        seed=seed,
        lr_scheduler_type=lr_scheduler_type,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
        save_total_limit=5,
        remove_unused_columns=False
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics
    )
    trainer.train()
    
    return trainer.model

def evaluate(model: TextModel, test_dataset: TextDataset, batch_size: int):
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
    model.eval()
    all_preds = []
    all_labels = []
    for batch in tqdm(test_loader):
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(model.device)
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        token_type_ids = batch['token_type_ids']
        labels = torch.tensor([1*x for x in batch['label']]).to(model.device)
        outputs = model.forward(input_ids, attention_mask, token_type_ids, labels=labels)
        all_preds.extend(outputs.logits.cpu().detach().numpy())
        all_labels.extend(labels.cpu().detach().numpy())
    metrics = compute_metrics({'label_ids': all_labels, 'predictions': np.array(all_preds)})
    return metrics

### 2.1 Dialogue

In [31]:
all_aligned_fb_detection_df = pd.concat((dfs[case_id]['dialogue'] for case_id in case_ids if case_id not in test_case_ids)).sample(frac=1, random_state=random_state)
all_aligned_fb_detection_df = all_aligned_fb_detection_df[['full_clip_path', 'context_dialogue', 'true_fb_instance']]

train_df = all_aligned_fb_detection_df.iloc[:int(0.8*len(all_aligned_fb_detection_df))].reset_index(drop=True)
val_df = all_aligned_fb_detection_df.iloc[int(0.8*len(all_aligned_fb_detection_df)):].reset_index(drop=True)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = TextDataset(
    transcriptions_df=train_df,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)
val_dataset = TextDataset(
    transcriptions_df=val_df,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)

In [34]:
params_model = {
    'text_model': 'bert-base-uncased',
    # 'text_model': 'results/extract_dialogue/bert/checkpoint-4500',
    # 'config': 'bert-base-uncased',
    'class_weights': None,
    'num_classes': 2
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextModel(params_model, device)



Initializing Text Model!


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


No class weighting!


In [35]:
model = train(
    output_dir='results/extract_dialogue/bert/dialogue',
    epochs=3,
    batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    eval_save_strategy='steps',
    save_steps=500,
    eval_steps=500,
    metric_for_best_model='eval_f1',
    report_to='none',
    seed=random_state,
    lr_scheduler_type='linear',
    lr_scheduler_kwargs=None,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

Step,Training Loss,Validation Loss,Accuracy,Roc Auc,Precision,Recall,F1
500,0.5101,0.556685,0.722953,0.712605,0.425904,0.693916,0.527838
1000,0.5096,0.450168,0.818413,0.709008,0.611364,0.511407,0.556936
1500,0.4601,0.406383,0.832414,0.726149,0.651972,0.534221,0.587252
2000,0.4118,0.44432,0.832414,0.665852,0.758893,0.365019,0.49294
2500,0.3769,0.487823,0.834535,0.724127,0.663462,0.524715,0.585987
3000,0.3296,0.421886,0.845566,0.719032,0.728814,0.490494,0.586364
3500,0.3024,0.441688,0.852779,0.735869,0.739946,0.524715,0.614016


In [36]:
test_aligned_fb_detection = pd.concat((dfs[case_id]['dialogue'] for case_id in case_ids if test_case_ids)).sample(frac=1, random_state=random_state)
test_aligned_fb_detection = test_aligned_fb_detection[['full_clip_path', 'context_dialogue', 'true_fb_instance']].reset_index(drop=True)

test_dataset = TextDataset(
    transcriptions_df=test_aligned_fb_detection,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)

In [37]:
bert_metrics = evaluate(model, test_dataset, batch_size=32)

100%|██████████| 481/481 [00:26<00:00, 18.47it/s]


In [38]:
bert_metrics

{'accuracy': 0.8782280621869512,
 'roc_auc': 0.7832625105090458,
 'precision': 0.8544061302681992,
 'recall': 0.5991402471789361,
 'f1': 0.704358812381554}

### 2.2 +Reduced Hallucinations

In [39]:
all_aligned_fb_detection_df = pd.concat((dfs[case_id]['hallucination removal'] for case_id in case_ids if case_id not in test_case_ids)).sample(frac=1, random_state=random_state)
all_aligned_fb_detection_df = all_aligned_fb_detection_df[['full_clip_path', 'context_dialogue', 'true_fb_instance']]

train_df = all_aligned_fb_detection_df.iloc[:int(0.8*len(all_aligned_fb_detection_df))].reset_index(drop=True)
val_df = all_aligned_fb_detection_df.iloc[int(0.8*len(all_aligned_fb_detection_df)):].reset_index(drop=True)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = TextDataset(
    transcriptions_df=train_df,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)
val_dataset = TextDataset(
    transcriptions_df=val_df,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)

In [40]:
params_model = {
    'text_model': 'bert-base-uncased',
    # 'text_model': 'results/extract_dialogue/bert/checkpoint-4500',
    # 'config': 'bert-base-uncased',
    'class_weights': None,
    'num_classes': 2
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextModel(params_model, device)



Initializing Text Model!


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


No class weighting!


In [41]:
model = train(
    output_dir='results/extract_dialogue/bert/hallucination_removal',
    epochs=3,
    batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    eval_save_strategy='steps',
    save_steps=500,
    eval_steps=500,
    metric_for_best_model='eval_f1',
    report_to='none',
    seed=random_state,
    lr_scheduler_type='linear',
    lr_scheduler_kwargs=None,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

Step,Training Loss,Validation Loss,Accuracy,Roc Auc,Precision,Recall,F1
500,0.5672,0.54856,0.752922,0.73733,0.641593,0.683962,0.6621


In [42]:
test_aligned_fb_detection = pd.concat((dfs[case_id]['hallucination removal'] for case_id in case_ids if test_case_ids)).sample(frac=1, random_state=random_state)
test_aligned_fb_detection = test_aligned_fb_detection[['full_clip_path', 'context_dialogue', 'true_fb_instance']].reset_index(drop=True)

test_dataset = TextDataset(
    transcriptions_df=test_aligned_fb_detection,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)

In [43]:
bert_metrics = evaluate(model, test_dataset, batch_size=32)
bert_metrics

100%|██████████| 137/137 [00:11<00:00, 12.00it/s]


{'accuracy': 0.8059496567505721,
 'roc_auc': 0.796465840825861,
 'precision': 0.7803602556653109,
 'recall': 0.740761169332598,
 'f1': 0.7600452744765138}

### 2.3 +Trainee/Trainer ID

In [47]:
all_aligned_fb_detection_df = pd.concat((dfs[case_id]['trainee/trainer id'] for case_id in case_ids if case_id not in test_case_ids)).sample(frac=1, random_state=random_state)
all_aligned_fb_detection_df = all_aligned_fb_detection_df[['full_clip_path', 'context_dialogue', 'true_fb_instance']]

train_df = all_aligned_fb_detection_df.iloc[:int(0.8*len(all_aligned_fb_detection_df))].reset_index(drop=True)
val_df = all_aligned_fb_detection_df.iloc[int(0.8*len(all_aligned_fb_detection_df)):].reset_index(drop=True)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = TextDataset(
    transcriptions_df=train_df,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)
val_dataset = TextDataset(
    transcriptions_df=val_df,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)

In [48]:
params_model = {
    'text_model': 'bert-base-uncased',
    # 'text_model': 'results/extract_dialogue/bert/checkpoint-4500',
    # 'config': 'bert-base-uncased',
    'class_weights': None,
    'num_classes': 2
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextModel(params_model, device)



Initializing Text Model!


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


No class weighting!


In [49]:
model = train(
    output_dir='results/extract_dialogue/bert/trainee-trainer-id',
    epochs=3,
    batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    eval_save_strategy='steps',
    save_steps=500,
    eval_steps=500,
    metric_for_best_model='eval_f1',
    report_to='none',
    seed=random_state,
    lr_scheduler_type='linear',
    lr_scheduler_kwargs=None,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

Step,Training Loss,Validation Loss,Accuracy,Roc Auc,Precision,Recall,F1
500,0.4623,0.337381,0.841402,0.823325,0.679144,0.783951,0.727794


In [50]:
test_aligned_fb_detection = pd.concat((dfs[case_id]['hallucination removal'] for case_id in case_ids if test_case_ids)).sample(frac=1, random_state=random_state)
test_aligned_fb_detection = test_aligned_fb_detection[['full_clip_path', 'context_dialogue', 'true_fb_instance']].reset_index(drop=True)

test_dataset = TextDataset(
    transcriptions_df=test_aligned_fb_detection,
    tokenizer=tokenizer,
    text_col='context_dialogue',
    label_col='true_fb_instance',
    file_col='full_clip_path'
)

In [51]:
bert_metrics = evaluate(model, test_dataset, batch_size=32)
bert_metrics

100%|██████████| 137/137 [00:11<00:00, 12.17it/s]


{'accuracy': 0.7453089244851259,
 'roc_auc': 0.714074210051639,
 'precision': 0.7859477124183006,
 'recall': 0.5306122448979592,
 'f1': 0.633519920974646}

# 3. Other Hallucination Techniques

In [1]:
from utils import *
import torch
from models import ExtractDialogueModel
from utils import whisper_transcribe
from utils import set_openai_key

In [3]:
openai_key_path = 'openai_api_key.txt'

# Run second transcriptions
for case_id in [2, 9, 10, 18]:
    device = torch.device("cuda")
    params_extract_dialogue = {
        'speaker_diarization_model': 'pyannote/speaker-diarization-3.1',
        'speaker_embedding_model': 'pyannote/embedding',
        'hf_token_path': 'huggingface_token.txt',
        'openai_key_path': openai_key_path, 
        'transcribe_fn': whisper_transcribe,
        'full_audio_path': f'../../full_audios/LFB{case_id}_full.wav',
        'interval': 180,
        'console_times_path': '../../annotations/console_times/combined_console_times_secs.csv',
        'fb_annot_path': '../../clips_no_wiggle/fbk_cuts_no_wiggle_0_4210.csv',
        'vad_activity_path': f'../../full_VADs/LFB{case_id}_full_activity.csv',
        'diarizations_save_path': f'results/extract_dialogue/diarizations/LFB{case_id}_full.csv',
        'transcriptions_save_path': f'results/extract_dialogue/transcriptions/LFB{case_id}_full_2.csv',
        'identifications_save_path': f'results/extract_dialogue/dfifications/LFB{case_id}_full_2.csv',
        'fb_detection_save_path': f"results/extract_dialogue/fb_detection/LFB{case_id}_full_2 'dialogue' thresh={0.8}.csv",
        'audio_clips_dir': 'results/extract_dialogue/audio_clips',
        'trainer_anchors_dir': 'results/extract_dialogue/anchors/trainer',
        'trainee_anchors_dir': 'results/extract_dialogue/anchors/trainee',
        'tmp_dir': 'tmp',
        'seed': 42,
        'min_n_speakers': 2,
        'max_n_speakers': 2,
        'embedding_dist_thresh': 0.8
    }
    set_openai_key(openai_key_path)
    model = ExtractDialogueModel(params_extract_dialogue, device)

    model.full_diarization(load_saved=True)
    model.full_transcription(load_saved=False)

Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 698/698 [14:24<00:00,  1.24s/it]
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 1466/1466 [30:51<00:00,  1.26s/it] 
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 1050/1050 [20:18<00:00,  1.16s/it]
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 1038/1038 [21:04<00:00,  1.22s/it]


In [16]:
def get_hallucinations(df1, df2, verbose=False):
    df1.replace(np.nan, '', inplace=True)
    df2.replace(np.nan, '', inplace=True)
    
    hallucinations = np.zeros(len(df1), dtype=bool)
    for i in range(len(df1)):
        if df1.iloc[i].transcription != df2.iloc[i].transcription:
            hallucinations[i] = True
            if verbose:
                print(f"df 1: {df1.iloc[i].transcription}")
                print(f"df 2: {df2.iloc[i].transcription}")
        if df1.iloc[i].transcription == '' and df2.iloc[i].transcription == '':
            hallucinations[i] = True
            if verbose:
                print(f"df 1: {df1.iloc[i].transcription}")
                print(f"df 2: {df2.iloc[i].transcription}")
    return hallucinations

transcriptions_df1 = {case_id: pd.read_csv(f'results/extract_dialogue/transcriptions/LFB{case_id}_full.csv') for case_id in [1, 2, 9, 10, 18]}
transcriptions_df2 = {case_id: pd.read_csv(f'results/extract_dialogue/transcriptions/LFB{case_id}_full_2.csv') for case_id in [1, 2, 9, 10, 18]}
hallucinations = {case_id: get_hallucinations(transcriptions_df1[case_id], transcriptions_df2[case_id]) for case_id in [1, 2, 9, 10, 18]}
final_transcriptions_df = {case_id: transcriptions_df1[case_id].copy().iloc[~hallucinations[case_id]] for case_id in [1, 2, 9, 10, 18]}

for case_id in [1, 2, 9, 10, 18]:
    final_transcriptions_df[case_id].to_csv(f'results/extract_dialogue/transcriptions/LFB{case_id}_full_-reduced_hallucinations_base.csv', index=False)

In [19]:
openai_key_path = 'openai_api_key.txt'
all_metrics = {}
# Run second transcriptions
for case_id in [1, 2, 9, 10, 18]:
    device = torch.device("cuda")
    params_extract_dialogue = {
        'speaker_diarization_model': 'pyannote/speaker-diarization-3.1',
        'speaker_embedding_model': 'pyannote/embedding',
        'hf_token_path': 'huggingface_token.txt',
        'openai_key_path': openai_key_path, 
        'transcribe_fn': whisper_transcribe,
        'full_audio_path': f'../../full_audios/LFB{case_id}_full.wav',
        'interval': 180,
        'console_times_path': '../../annotations/console_times/combined_console_times_secs.csv',
        'fb_annot_path': '../../clips_no_wiggle/fbk_cuts_no_wiggle_0_4210.csv',
        'vad_activity_path': f'../../full_VADs/LFB{case_id}_full_activity.csv',
        'diarizations_save_path': f'results/extract_dialogue/diarizations/LFB{case_id}_full.csv',
        'transcriptions_save_path': f'results/extract_dialogue/transcriptions/LFB{case_id}_full_-reduced_hallucinations_base.csv',
        'identifications_save_path': f'results/extract_dialogue/identifications/LFB{case_id}_full_-reduced_hallucinations_base.csv',
        'fb_detection_save_path': f"results/extract_dialogue/fb_detection/LFB{case_id}_full_-reduced_hallucinations_base 'dialogue' thresh={0.8}.csv",
        'audio_clips_dir': 'results/extract_dialogue/audio_clips',
        'trainer_anchors_dir': 'results/extract_dialogue/anchors/trainer',
        'trainee_anchors_dir': 'results/extract_dialogue/anchors/trainee',
        'tmp_dir': 'tmp',
        'seed': 42,
        'min_n_speakers': 2,
        'max_n_speakers': 2,
        'embedding_dist_thresh': 0.8
    }
    set_openai_key(openai_key_path)
    model = ExtractDialogueModel(params_extract_dialogue, device)

    model.full_diarization(load_saved=True)
    model.full_transcription(load_saved=True)
    model.full_identification(load_saved=False)
    model.full_fb_detection(load_saved=False, aux='dialogue')
    model.full_aligned_fb_detection(load_saved=False)
    all_metrics[case_id] = model.evaluate(weighting='binary', model_type='fb')

Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 1507/1507 [00:04<00:00, 372.37it/s]



Number of contexts: 1502
aux: dialogue


100%|██████████| 1497/1497 [15:50<00:00,  1.57it/s]
100%|██████████| 1497/1497 [03:00<00:00,  8.31it/s] 
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 622/622 [00:01<00:00, 412.78it/s]



Number of contexts: 617
aux: dialogue


100%|██████████| 612/612 [06:40<00:00,  1.53it/s]
100%|██████████| 612/612 [01:09<00:00,  8.80it/s]
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 1269/1269 [00:03<00:00, 395.50it/s]



Number of contexts: 1264
aux: dialogue


100%|██████████| 1259/1259 [13:32<00:00,  1.55it/s]
100%|██████████| 1259/1259 [01:32<00:00, 13.65it/s]
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 968/968 [00:02<00:00, 424.36it/s]



Number of contexts: 963
aux: dialogue


100%|██████████| 958/958 [10:21<00:00,  1.54it/s]
100%|██████████| 958/958 [01:48<00:00,  8.82it/s]
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


100%|██████████| 878/878 [00:02<00:00, 427.71it/s]



Number of contexts: 873
aux: dialogue


100%|██████████| 868/868 [09:24<00:00,  1.54it/s]
100%|██████████| 868/868 [01:10<00:00, 12.28it/s]


In [41]:
openai_key_path = 'openai_api_key.txt'
all_metrics = {}
# Run second transcriptions
for case_id in [1, 2, 9, 10, 18]:
    device = torch.device("cuda")
    params_extract_dialogue = {
        'speaker_diarization_model': 'pyannote/speaker-diarization-3.1',
        'speaker_embedding_model': 'pyannote/embedding',
        'hf_token_path': 'huggingface_token.txt',
        'openai_key_path': openai_key_path, 
        'transcribe_fn': whisper_transcribe,
        'full_audio_path': f'../../full_audios/LFB{case_id}_full.wav',
        'interval': 180,
        'console_times_path': '../../annotations/console_times/combined_console_times_secs.csv',
        'fb_annot_path': '../../clips_no_wiggle/fbk_cuts_no_wiggle_0_4210.csv',
        'vad_activity_path': f'../../full_VADs/LFB{case_id}_full_activity.csv',
        'diarizations_save_path': f'results/extract_dialogue/diarizations/LFB{case_id}_full.csv',
        'transcriptions_save_path': f'results/extract_dialogue/transcriptions/LFB{case_id}_full_-reduced_hallucinations_base.csv',
        'identifications_save_path': f'results/extract_dialogue/identifications/LFB{case_id}_full_-reduced_hallucinations_base.csv',
        'fb_detection_save_path': f"results/extract_dialogue/fb_detection/LFB{case_id}_full_-reduced_hallucinations_base 'dialogue' thresh={0.8}.csv",
        'audio_clips_dir': 'results/extract_dialogue/audio_clips',
        'trainer_anchors_dir': 'results/extract_dialogue/anchors/trainer',
        'trainee_anchors_dir': 'results/extract_dialogue/anchors/trainee',
        'tmp_dir': 'tmp',
        'seed': 42,
        'min_n_speakers': 2,
        'max_n_speakers': 2,
        'embedding_dist_thresh': 0.8
    }
    set_openai_key(openai_key_path)
    model = ExtractDialogueModel(params_extract_dialogue, device)

    model.full_diarization(load_saved=True)
    model.full_transcription(load_saved=True)
    model.full_identification(load_saved=True)
    model.full_fb_detection(load_saved=True, aux='dialogue')
    model.full_aligned_fb_detection(load_saved=True)
    all_metrics[case_id] = model.evaluate(weighting='binary', model_type='fb')

Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`
Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--embedding/snapshots/4db4899737a38b2d618bbd74350915aa10293cb2/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.
Model was trained with pyannote.audio 0.0.1, yours is 3.3.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.4.0+cu118. Bad things might happen unless you revert torch to 1.x.


In [42]:
# Average metrics
avg_metrics = {k: (np.mean([v[k] for v in all_metrics.values()]), np.std([v[k] for v in all_metrics.values()])) for k in all_metrics[1].keys()}
print('Metrics')
print(json.dumps(avg_metrics, indent=4))

Metrics
{
    "precision": [
        0.568697884891402,
        0.06462143686647753
    ],
    "recall": [
        0.6128425566418169,
        0.0808942812293493
    ],
    "f1": [
        0.5879814215561103,
        0.06591532017344578
    ],
    "roc_auc": [
        0.68542848891328,
        0.04795825307630654
    ],
    "accuracy": [
        0.7234343853815178,
        0.0724305793607616
    ]
}


In [43]:
all_metrics

{1: {'precision': 0.6550802139037433,
  'recall': 0.6805555555555556,
  'f1': 0.667574931880109,
  'roc_auc': 0.6554640241961159,
  'accuracy': 0.6558533145275035},
 2: {'precision': 0.6387434554973822,
  'recall': 0.648936170212766,
  'f1': 0.6437994722955145,
  'roc_auc': 0.6902268399701963,
  'accuracy': 0.6966292134831461},
 9: {'precision': 0.5095541401273885,
  'recall': 0.45714285714285713,
  'f1': 0.4819277108433735,
  'roc_auc': 0.651417119954194,
  'accuracy': 0.744807121661721},
 10: {'precision': 0.5081967213114754,
  'recall': 0.6138613861386139,
  'f1': 0.5560538116591929,
  'roc_auc': 0.653084539223153,
  'accuracy': 0.6655405405405406},
 18: {'precision': 0.5319148936170213,
  'recall': 0.6637168141592921,
  'f1': 0.5905511811023622,
  'roc_auc': 0.7769499212227408,
  'accuracy': 0.8543417366946778}}