In [None]:
! pip install torch transformers datasets tqdm lxml librosa

In [15]:
import torch
import librosa
import torch.nn.functional as F
from transformers import WhisperProcessor, WhisperForConditionalGeneration, GPT2LMHeadModel, AutoTokenizer
import math
import json

### Load ASR model and audio
- the correct text we want to predict is "The patient exhibits signs of bradykinesia, a common symptom in Parkinson's"

In [16]:
# Load your audio file and resample it to 16 kHz (Whisper's expected rate)
audio_file = "../assets_folder/my_procedure.m4a"
audio_array, sampling_rate = librosa.load(audio_file, sr=16000)

# Initialize the processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
decoder = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

# Prepare the input features from the audio array
input_features = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features

# Begin the decoding process with the decoder start token
decoder_input_ids = torch.tensor([[decoder.config.decoder_start_token_id]])
decoder.eval();

  audio_array, sampling_rate = librosa.load(audio_file, sr=16000)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


### Load in LM expert

In [4]:
gpt2_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("cwestnedge/gpt2-medium-pubmed")
gpt2_model.eval(); # put GPT2 in eval mode

### Run shallow fusion
Token prediction is implemented in stages. Early on we rely solely on the ASR model and then gradually introduce the language model as more context becomes available. For example:

$$
\textbf{Token Selection at Step }t:\quad
F(x,t)
=
\begin{cases}
\displaystyle
\arg\max_{y_t}\;\log P_{\text{ASR}}(y_t \!\mid\! x,\; y_{<t}),
& t < \text{initial\_steps},\\[0.75em]
\displaystyle
\arg\max_{y_t}\;\Bigl[
\log P_{\text{ASR}}(y_t \!\mid\! x,\; y_{<t})
\;+\;
\lambda\,\log P_{\text{LM}}(y_t \!\mid\! y_{<t})
\Bigr],
& t \ge \text{initial\_steps}.
\end{cases}
$$

This piecewise approach allows the system to build confidence from the raw audio transcription initially before incorporating the domain expert corrections, namely because our starting point should be conditionalized on something observed e.g. we gotta start somewhere. 

In [5]:
# max_new_tokens = 100   # Maximum tokens to generate (you can adjust this)
initial_steps = 2     # Use pure Whisper prediction for the first few tokens
num_steps = 100        # For demonstration, we run 6 iterations (could be max_new_tokens)
alpha = 0.25

# Start with Whisper's decoder start token.
# Load whisper's special tokens so we can filter out for GPT2.

decoder_input_ids = torch.tensor([[decoder.config.decoder_start_token_id]], device=input_features.device)
whisper_special_ids = processor.tokenizer.all_special_ids

# Initialize a GPT2 input sequence (will be updated after each step).
# Initially, filter the decoder_input_ids to remove Whisper's special tokens.

gpt_input_ids = [token for token in decoder_input_ids[0].tolist() if token not in whisper_special_ids]
if not gpt_input_ids:
    gpt_input_ids = decoder_input_ids[0].tolist()
gpt_input_tensor = torch.tensor([gpt_input_ids], device=decoder_input_ids.device)


whisper_data = []
gpt2_data = []
step_tokens = []
with torch.no_grad():
    for step in range(num_steps):
        decoder_outputs = decoder(input_features, decoder_input_ids=decoder_input_ids, use_cache=True)
        decoder_logits = decoder_outputs.logits[:, -1, :]  # shape: [batch, whisper_vocab_size]

        if step < initial_steps:
            # for the first few steps we need to use whisper's decoder prediction b/c we gotta start somewhere
            # the reason initial steps is < 2 is because the first two tokens are for whisper are special tokens
            next_token = decoder_logits.argmax(dim=-1, keepdim=True)
        else:
            # if we're past the initial generation steps from decoder we can now leverage gpt2
            # which excludes those special tokens froms whisper
            gpt2_outputs = gpt2_model(gpt_input_tensor)
            gpt2_logits = gpt2_outputs.logits[:, -1, :]  # shape: [batch, gpt2_vocab_size]

            # compute log-probabilities for both models
            whisper_log_probs = F.log_softmax(decoder_logits, dim=-1)
            gpt2_log_probs = F.log_softmax(gpt2_logits, dim=-1)

            # create index for shared domain assuming index 0 to len(gpt2 vocab size) is a subset of whisper's.
            shared_domain_idx = gpt2_log_probs.shape[-1]  # e.g. 50257

            # restrict whisper's log-probs to the GPT2 vocab
            fused_logits = whisper_log_probs[:, :shared_domain_idx] + alpha * gpt2_log_probs

            # this is how to do it if we dont care about the difference in vocabularies
            next_token = fused_logits.argmax(dim=-1, keepdim=True)

            # this is so we dont miss anything we excluded in the fusion for final prediction
            # it will need to be calibrated so its on the same scale as the fused logits though
            # whisper_only_logits = whisper_log_probs[:,shared_domain_idx:]
            # fused_and_extended_logits = torch.cat([fused_logits, whisper_only_logits], dim=-1)
            # next_token = fused_and_extended_logits.argmax(dim=-1, keepdim=True)

        # append the chosen token to Whisper's full decoder input.
        decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)

        # update the GPT2 input sequence by filtering out Whisper special tokens.
        new_decoder_ids = decoder_input_ids[0].tolist()
        gpt_input_ids = [token for token in new_decoder_ids if token not in whisper_special_ids]
        if not gpt_input_ids:
            gpt_input_ids = new_decoder_ids  # fallback in the unlikely event all tokens are special
        gpt_input_tensor = torch.tensor([gpt_input_ids], device=decoder_input_ids.device)

        # stop if the end-of-sequence token is generated.
        if next_token.item() == processor.tokenizer.eos_token_id:
            break

# Finally, decode the full Whisper sequence (special tokens will be handled as needed).
final_output = processor.batch_decode(decoder_input_ids, skip_special_tokens=True)
# print("Final output:", final_output)


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


### predicted output with whisper/ASR model only

In [13]:
whisper_only_predicted_ids = decoder.generate(input_features)
whisper_only_transcription = processor.batch_decode(whisper_only_predicted_ids, skip_special_tokens=True)
print("Out of box whisper prediction:")
print(whisper_only_transcription[0].strip())

Out of box whisper prediction:
Hi, my name is Collins and I'm calling about a procedure that I had related to tetrology if below. I wanted to check on the status and make sure that the claim is being processed. It was for trans-cathodar pulmonary valve replacement on October at Northwestern Medicine.


### predicted output when leveraging shallow fusion + LM expert

In [14]:
# Finally, decode the full Whisper sequence (special tokens will be handled as needed).
final_output = processor.batch_decode(decoder_input_ids, skip_special_tokens=True)
print("Prediction with shallow fusion:")
print(final_output[0].strip())

Prediction with shallow fusion:
Hi, my name is Collins and I'm calling about a procedure that I had related to tetralogy of Fallot. I wanted to check on the status and make sure that the claim is being processed. It was for transcatheter pulmonary valve replacement on October at Northwestern Medicine.
