In [29]:
import os
os.chdir('/Users/markjos/projects/malachor5')
import kenlm

In [17]:
sasoc_3gram_path = '/Users/markjos/projects/malachor5/data/SASOC/sasoc_3gram.arpa'

In [14]:
model = kenlm.LanguageModel(sasoc_3gram_path)
print('{0}-gram model'.format(model.order))

# sentence = 'language modeling is fun .'
# sentence = 'uyabona into ekhona ne'
sentence = 'kahle kahle'
print(sentence)
print(model.score(sentence, eos=True))

3-gram model
kahle kahle
-3.4717769622802734


Loading the LM will be faster if you build a binary file.
Reading /Users/markjos/projects/malachor5/data/SASOC/sasoc_3gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


In [10]:
words = ['<s>'] + sentence.split() + ['</s>']
for i, (prob, length, oov) in enumerate(model.full_scores(sentence)):
    print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i+2-length:i+2])))
    if oov:
        print('\t"{0}" is an OOV'.format(words[i+1]))

-2.771185874938965 2: <s> kahle
-0.2829204201698303 3: <s> kahle kahle
-5.055927753448486 1: potato
	"potato" is an OOV
-0.9160206913948059 1: </s>


In [11]:
# Find out-of-vocabulary words
for w in words:
    if not w in model:
        print('"{0}" is an OOV'.format(w))

"potato" is an OOV


In [30]:
import torch
from transformers import LogitsProcessor
from pyctcdecode import build_ctcdecoder
import sys
sys.path.append('scripts')
from longform import load_and_resample


torchvision is not available - cannot save figures


In [117]:
class LanguageModelRescorer(LogitsProcessor):
    def __init__(self, tokenizer, lm_path, alpha=0.5):
        super().__init__()
        self.alpha = alpha  # Weight for LM fusion
        self.tokenizer = tokenizer
        self.lm = kenlm.LanguageModel(lm_path)

    def __call__(self, input_ids, scores):
        """Modify logits using LM-based rescoring."""
        # Decode the current sequence
        text_hypothesis = self.tokenizer.decode(input_ids[0].tolist())

        # Get LM score
        lm_score = self.lm.score(text_hypothesis)
        # print(scores.shape)
        # print(lm_score)

        # Convert LM score into a logit adjustment
        lm_adjustment = torch.tensor(lm_score, device=scores.device) * self.alpha
        
        # Apply the LM bias (shallow fusion)
        scores = scores *(1-self.alpha) + lm_adjustment

        return scores

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration, LogitsProcessorList

# Load Whisper model and processor
model_id = "openai/whisper-tiny"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

Loading the LM will be faster if you build a binary file.
Reading /Users/markjos/projects/malachor5/data/SASOC/sasoc_3gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


In [126]:
# Prepare input
# audio_path = '/Users/markjos/projects/malachor5/data/SASOC/audio/AKHONA_12-10-31_24.wav'
audio_path = '/Users/markjos/projects/malachor5/data/SASOC/audio/BASH_12-11-23_128.wav'
audio_input = load_and_resample(audio_path, flatten=True)
input_features = processor(audio_input, return_tensors="pt", sampling_rate=16_000).input_features

In [77]:
input_features.shape

torch.Size([1, 80, 3000])

In [110]:
prompt=processor.tokenizer.get_decoder_prompt_ids()
prompt=torch.tensor(prompt)
output = model.generate(input_features)
decoded_text = processor.batch_decode(output, skip_special_tokens=False, num_beams=3)
decoded_text

language=None	init_tokens=tensor([[50258, 50259, 50359, 50363]])
decoder_input_ids=tensor([[50258, 50259, 50359, 50363]])


["<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Okay, so I don't feel like having sex nor lo-seater in our relationship is a trouble."]

In [132]:
lm_rescorer = LanguageModelRescorer(processor.tokenizer, sasoc_3gram_path, alpha=2)
logits_processor = LogitsProcessorList([lm_rescorer])
output = model.generate(input_features, logits_processor=logits_processor, num_beams=3)
decoded_text = processor.batch_decode(output, skip_special_tokens=False)
decoded_text

Loading the LM will be faster if you build a binary file.
Reading /Users/markjos/projects/malachor5/data/SASOC/sasoc_3gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


language=None	init_tokens=tensor([[50258, 50264, 50359, 50363]])
decoder_input_ids=tensor([[50258, 50264, 50359, 50363]])


KeyboardInterrupt: 