In [None]:
import json
from pathlib import Path
from itertools import combinations
from typing import Any
import dataclasses

from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset, Audio
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, pipeline, Pipeline,
    WhisperProcessor, WhisperForConditionalGeneration
)
import pysrt
from IPython.display import clear_output
import IPython.display
import librosa

from asr.asr import (
    initialize_model_for_speech_segmentation,
    initialize_model_for_speech_classification,
    initialize_model_for_speech_recognition,
    transcribe
)
from asr.lm import SequenceScore
from asr.comparison import TokenizedText, MultipleTextsAlignment, filter_correction_suggestions
from asr.whisper_scores import whisper_pipeline_transcribe_with_word_scores

In [None]:
dataset = (
    load_dataset('dangrebenkin/long_audio_youtube_lectures')
    .cast_column('audio', Audio(sampling_rate=16_000))
    ['train']
)

In [None]:
sample = dataset[2]
waveform = sample['audio']['array']
sample['name']

In [None]:
segmenter = initialize_model_for_speech_segmentation('ru', 'bond005/wav2vec2-large-ru-golos')
whisper_pipeline = initialize_model_for_speech_recognition('ru', 'openai/whisper-large-v3')

results = transcribe(
    waveform,
    segmenter=segmenter,
    voice_activity_detector=lambda audio: [{'score': 1, 'label': 'Speech'}],
    asr=lambda audio: {'text': 'none'},
    min_segment_size=1,
    max_segment_size=20,
)

tokenized_segments = []
scores_per_word = []

for segment in tqdm(results):
    waveform_segment = waveform[int(segment.start * 16_000):int(segment.end * 16_000)]
    tokenized_text_for_segment, _, scores_for_segment = (
        whisper_pipeline_transcribe_with_word_scores(waveform_segment, whisper_pipeline)
    )
    tokenized_segments.append(tokenized_text_for_segment)
    scores_per_word += scores_for_segment

tokenized_text = TokenizedText.concatenate(tokenized_segments)

In [None]:
from transformers.models.whisper.tokenization_whisper import bytes_to_unicode


feature_extractor = whisper_pipeline.feature_extractor
tokenizer = whisper_pipeline.tokenizer
model = whisper_pipeline.model
generate_kwargs = whisper_pipeline._forward_params

inputs = feature_extractor(
    waveform_segment,
    return_tensors='pt',
    sampling_rate=16_000,
).to(model.device, model.dtype)
result = model.generate(
    **inputs,
    **generate_kwargs,
    return_dict_in_generate=True,
    return_token_timestamps=True,
)

# convert token ids and logits to numpy
token_ids = result['sequences'][0].cpu().numpy()
logits = torch.nn.functional.log_softmax(torch.stack(result['scores']), dim=-1).cpu().numpy()

# skip start special tokens to align with logits
token_ids = token_ids[-len(logits):]

# skip all special tokens
is_special = np.array([id in tokenizer.all_special_ids for id in token_ids])
token_ids = token_ids[~is_special]
logits = logits[~is_special]

score_per_token = np.array([float(l[0, token_id]) for token_id, l in zip(token_ids, logits)])

# reproducing whisper bpe decoding
byte_decoder = {v: k for k, v in bytes_to_unicode().items()}
bytes_list_per_token = [
    [byte_decoder[x] for x in bytes_str]
    for bytes_str in tokenizer.convert_ids_to_tokens(token_ids)
]

# searching for token positions in the text
token_end_positions = []
for i in range(len(bytes_list_per_token)):
    concatenated_bytes = sum(bytes_list_per_token[:i + 1], [])
    try:
        text = bytearray(concatenated_bytes).decode('utf-8', errors='strict')
        token_end_positions.append(len(text))
    except UnicodeDecodeError:
        token_end_positions.append(None)  # not a full utf-8 charachter

assert text == tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)

# cleaning up tokenization spaces, shifting token_end_positions
# (see .clean_up_tokenization() in PreTrainedTokenizerBase)
if tokenizer.clean_up_tokenization_spaces:
    for replace_from in [" .", " ?", " !", " ,", " ' ", " n't", " 'm", " 's", " 've", " 're"]:
        replace_to = replace_from.strip()
        while (start_pos := text.find(replace_from)) != -1:
            delta_len = len(replace_to) - len(replace_from)
            text = text[:start_pos] + replace_to + text[start_pos + len(replace_from):]
            token_end_positions = [
                (
                    token_end_pos
                    if token_end_pos <= start_pos
                    else token_end_pos + delta_len
                )
                for token_end_pos in token_end_positions
            ]

    assert text == tokenizer.decode(token_ids)

# tokenizing the text
tokenized_text = TokenizedText.from_text(text)

# matching words and tokens
tokens_range_per_word = []
for word in tokenized_text.get_words():
    first_token_idx = None  # first token of the word, inclusive
    for token_idx, token_end_pos in enumerate(token_end_positions):
        if token_end_pos is None:
            continue
        if token_end_pos > word.start and first_token_idx is None:
            first_token_idx = token_idx
        if token_end_pos >= word.stop:
            break
    tokens_range_per_word.append((first_token_idx, token_idx + 1))

tokens_per_word = [
    [
        bytearray(b).decode('utf-8', errors='replace')
        for b in bytes_list_per_token[start_token_idx:end_token_idx]
    ]
    for start_token_idx, end_token_idx in tokens_range_per_word
]

token_scores_per_word = [
    list(score_per_token[start_token_idx:end_token_idx])
    for start_token_idx, end_token_idx in tokens_range_per_word
]

In [None]:
bytearray(sum(bytes_list_per_token, [])).decode('utf-8', errors='strict')

In [None]:
text

In [None]:
tokenizer.decode(token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)

In [None]:
output_dir = Path('/home/oleg/pisets_test_results_with_scores')
output_dir.mkdir(parents=True, exist_ok=True)

filepath = output_dir / f'{sample["name"]} Pisets WhisperV3 no-VAD (segments 1s-20s) with scores.json'

with open(filepath, 'w') as f:
    json.dump({
        'tokenized_text': dataclasses.asdict(tokenized_text),
        'scores_per_word': scores_per_word,
    }, f, ensure_ascii=False)

In [None]:
!cat "/home/oleg/pisets_test_results_with_scores/savvateev Pisets WhisperV3 no-VAD (segments 1s-20s) with scores.json"

In [None]:
def get_all_subsets(elements: list[Any]):
    """
    Returns all subsets of a list.
    ```
    get_all_subsets([1, 2, 3])
    >>> [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
    ```
    """
    return sum((
        [list(x) for x in combinations(elements, r)]
        for r in range(len(elements) + 1)
    ), [])

base = transcriptions['galore']['whisperV3_long_segments_ru']
additional = transcriptions['galore']['w2v2_golos_lm']
truth = transcriptions['galore']['truth']

MultipleTextsAlignment.from_strings(truth, base).wer()

In [None]:
is_uncertain = MultipleTextsAlignment.from_strings(base, additional).get_uncertainty_mask()
print('Uncertain words ratio', is_uncertain.mean())
MultipleTextsAlignment.from_strings(truth, base).wer(uncertainty_mask=is_uncertain)

In [None]:
alignment = MultipleTextsAlignment.from_strings(base, additional)
orig_indices_to_resolve = filter_correction_suggestions(alignment, skip_word_form_change=False)
indices_to_resolve = orig_indices_to_resolve.copy()
indices_accepted = []

# print(alignment.substitute(show_in_braces=indices_to_resolve))

depth = 2

context_before = 100
context_after = 100

while len(indices_to_resolve):
    print(f'{len(indices_to_resolve)} indices remaining')

    indices = indices_to_resolve[:depth]

    variants: list[list[int]] = get_all_subsets(indices)

    scores = {}

    for indices_to_consider in get_all_subsets(indices):
        text = alignment.substitute(replace=indices_accepted + indices_to_consider)

        start_idx = alignment.matches[indices[0]].char_start1
        end_idx = alignment.matches[indices[-1]].char_end1 + len(text) - len(alignment.text1.text)

        start_idx -= context_before
        end_idx += context_after

        start_idx = np.clip(start_idx, 0, len(text))
        end_idx = np.clip(end_idx, 0, len(text))

        text = text[start_idx:end_idx]

        scores[tuple(indices_to_consider)] = {
            'score': sequence_score(text),
            'text' : text
        }

    print([x['score'] for x in scores.values()])

    best_option = max(scores, key=lambda k: scores[k]['score'])

    should_accept_index = indices[0] in best_option

    if should_accept_index:
        indices_accepted.append(indices[0])
    
    indices_to_resolve = indices_to_resolve[1:]

In [None]:
corrected = alignment.substitute(replace=indices_accepted)

In [None]:
MultipleTextsAlignment.from_strings(truth, corrected).wer()

In [None]:
is_uncertain = MultipleTextsAlignment.from_strings(base, corrected).get_uncertainty_mask()
print('Uncertain words ratio', is_uncertain.mean())
MultipleTextsAlignment.from_strings(truth, base).wer(uncertainty_mask=is_uncertain)

In [None]:
alignment = MultipleTextsAlignment.from_strings(base, additional)

print(alignment.substitute(
    show_in_braces=[i for i, op in enumerate(alignment.matches) if not op.is_equal]
))

In [None]:
alignment = MultipleTextsAlignment.from_strings(truth, base)

print(alignment.substitute(
    show_in_braces=[i for i, op in enumerate(alignment.matches) if not op.is_equal]
))

In [None]:
alignment = MultipleTextsAlignment.from_strings(base, corrected)

print(alignment.substitute(
    show_in_braces=filter_correction_suggestions(alignment, skip_word_form_change=False)
))

In [None]:
# print(alignment.substitute(
#     show_in_braces=orig_indices_to_resolve,
#     pref_second=indices_accepted,
#     pref_first=set(orig_indices_to_resolve) - set(indices_accepted),
# ))

In [None]:
'''
I have two speech recognition models (the first model is usually better) and compare their predictions. In the following text, the disagreement between models is highlighted in braces.

- {aaa|bbb} means that the second model wants to replace "aaa" with "bbb"
- {+xx} means that the second model wants to insert "xx" into the first model predictions
- {yy} means that the second model wants to remove "yy" from the first model predictions

Based on linguistic knowledge and common sense, please resolve the disagreement and write the final transcription without braces.

The text:
'''