In [1]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

torch.manual_seed(42)
CACHE_DIR = (Path.cwd().parent / ".models").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SHALLOW_FUSION_EVAL', 'SF_EVAL', '.models']


In [2]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-small.en"
GPT2_ID = "cwestnedge/gpt2-small-pubmed"

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

  from .autonotebook import tqdm as notebook_tqdm


Device: mps


In [3]:
# fast tokenizers will show token mismatch between models and will be auto loaded when we run on colab A100 set flag to false to avoid annoyingness
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR, use_fast=False)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR, use_fast=False)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

EOS_ID = gpt2_tok.eos_token_id # 50256 (unchanged)

print("Decoder start token ID:", whisper.generation_config.decoder_start_token_id)
print("Decoder start token:", processor.decode([whisper.generation_config.decoder_start_token_id]))

PREFIX_TOK_IDS = [whisper.generation_config.decoder_start_token_id]
for position, token_id in whisper.generation_config.forced_decoder_ids:
    PREFIX_TOK_IDS.append(token_id)

print(f"\nComplete prefix: {PREFIX_TOK_IDS}")
print(f"Decoded: '{processor.decode(PREFIX_TOK_IDS)}'")
print(f"Total prefix length: {len(PREFIX_TOK_IDS)}")

Decoder start token ID: 50257
Decoder start token: <|startoftranscript|>

Complete prefix: [50257, 50362]
Decoded: '<|startoftranscript|><|notimestamps|>'
Total prefix length: 2


In [4]:
for i in range(len(gpt2_tok.get_vocab())):
    a = processor.tokenizer.decode([i])
    b = gpt2_tok.decode([i])
    if a != b:
        print(f"Token mismatch at index {i}\nwhisper token: [{a}]\n   gpt2 token: [{b}]")

In [4]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE) #.select(range(20)) #.select(range(10))

# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=list(ds.features.keys())
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)

Map (num_proc=4): 100%|██████████| 85/85 [00:03<00:00, 23.25 examples/s]
Map: 100%|██████████| 85/85 [00:02<00:00, 38.55 examples/s]


In [5]:
len(ds)

85

In [6]:
from transformers import LogitsProcessor

## this works
# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, pad_id, eos_id, alpha=0.25, warmup=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False).to(DEVICE)
#         self.pad_id = pad_id  # 50257 
#         self.eos_id = eos_id  # 50256
#         self.alpha = alpha 
#         self.warmup = warmup
#         self.step = 0 
    
#     def reset(self):
#         self.step = 0

#     @torch.inference_mode()
#     def __call__(self, input_ids, scores):
#         if self.step < self.warmup: 
#             self.step += 1 
#             return scores
#         self.step += 1 

#         # Find where valid tokens start (first token < eos_id)
#         valid_mask = input_ids < self.eos_id
        
#         # Find the first valid token position for each sequence
#         # This skips special tokens at the beginning
#         first_valid = valid_mask.long().argmax(dim=1, keepdim=True)
        
#         # Create indices for gathering
#         batch_size, seq_len = input_ids.shape
#         batch_indices = torch.arange(batch_size, device=input_ids.device).unsqueeze(1)
        
#         # Gather only the valid portion of each sequence
#         gather_indices = first_valid + torch.arange(seq_len - first_valid.max(), 
#                                                     device=input_ids.device).unsqueeze(0)
#         gather_indices = gather_indices.clamp(max=seq_len-1)
        
#         # Extract subsequences starting from first valid token
#         clean_ids = input_ids.gather(1, gather_indices)
#         clean_mask = valid_mask.gather(1, gather_indices)
        
#         # Apply your original logic on clean sequences
#         clean_ids[~clean_mask] = self.pad_id
#         attention_mask = clean_mask.long()
        
#         lm_logits = self.lm(
#             input_ids=clean_ids,
#             attention_mask=attention_mask
#         ).logits[:, -1, :]

#         lm_lp = torch.log_softmax(lm_logits, dim=-1)
        
#         fused = scores.clone()
#         fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
        
#         return fused

# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, eos_id, prefix_tokens, alpha=0.25, warmup=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False).to(DEVICE)
#         self.eos_id = eos_id
#         self.prefix_tokens = prefix_tokens  # <|startoftranscript|><|notimestamps|>
#         self.alpha = alpha 
#         self.warmup = warmup
#         self.step = 0 
    
#     def reset(self):
#         self.step = 0

#     @torch.inference_mode()
#     def __call__(self, input_ids, scores):
#         w_lp = torch.log_softmax(scores, dim=-1)
#         if self.step < self.warmup: 
#             self.step += 1 
#             return w_lp

#         prefix_len = len(self.prefix_tokens)
#         prefix_ids = torch.tensor(self.prefix_tokens, device=input_ids.device)
        
#         # make sure that the actual prefix matches expectation 
#         if self.step == self.warmup:  # (only need to check once)
#             prefix_matches = (input_ids[0, :prefix_len] == prefix_ids).all()
#             if not prefix_matches:
#                 print(f"WARNING: prefix mismatch:\nexpected {self.prefix_tokens}, got {input_ids[0, :prefix_len].tolist()}")
        
#         # make sure we dont have special tokens emitted AFTER decoder prefix
#         has_special_after_prefix = (input_ids[:, prefix_len:] > self.eos_id).any()
#         if has_special_after_prefix:
#             print(f"WARNING: special tokens found after prefix at step {self.step}")
        
#         clean_ids = input_ids[:, prefix_len:]
#         lm_logits = self.lm(input_ids=clean_ids).logits[:, -1, :]
#         lm_lp = torch.log_softmax(lm_logits, dim=-1)
        
#         # FUSION STEP 
#         # P_fused(y|x) = log P_ASR(y|x) + alpha × log P_LM(y)
#         fused = w_lp.clone()
#         fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
        
#         # optional normalization step
#         # fused -= torch.logsumexp(fused, dim=-1, keepdim=True)
#         self.step += 1
#         return fused

class NoOpLogitsProcessor(LogitsProcessor):
    """
    A simple logits processor that performs no modifications to the logits.
    """
    def __call__(self, input_ids, scores):
        # Return the scores unchanged

        w_lp = torch.log_softmax(scores, dim=-1)
        return w_lp
    

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, eos_id, prefix_tokens, alpha=0.25, warmup=3):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False).to(DEVICE)
        self.eos_id = eos_id
        self.prefix_tokens = prefix_tokens  # <|startoftranscript|><|notimestamps|>
        self.alpha = alpha 
        self.warmup = warmup
        self.step = 0 
        self.use_mask = False
    
    def reset(self):
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):
        w_lp = torch.log_softmax(scores, dim=-1)
        if self.step < self.warmup: 
            self.step += 1 
            return w_lp

        prefix_len = len(self.prefix_tokens)
        prefix_ids = torch.tensor(self.prefix_tokens, device=input_ids.device)

        
        # make sure that the actual prefix matches expectation 
        if self.step == self.warmup:  # (only need to check once)
            prefix_matches = (input_ids[0, :prefix_len] == prefix_ids).all()
            if not prefix_matches:
                print(f"WARNING: prefix mismatch:\nexpected {self.prefix_tokens}, got {input_ids[0, :prefix_len].tolist()}")
        
        # make sure we dont have special tokens emitted AFTER decoder prefix
        has_special_after_prefix = (input_ids[:, prefix_len:] > self.eos_id).any()
        if has_special_after_prefix:
            print(f"WARNING: special tokens found after prefix at step {self.step}")
        
        clean_ids = input_ids[:, prefix_len:]
        lm_logits = self.lm(input_ids=clean_ids).logits[:, -1, :]
        lm_lp = torch.log_softmax(lm_logits, dim=-1)

        # FUSION STEP 
        # P_fused(y|x) = log P_ASR(y|x) + [HAS_EOS_MASK] * alpha * log P_LM(y)
        fused = w_lp.clone()
        if self.use_mask: 
            # this condition is to help exclude terminated sequences from getting revived by our lm
            has_eos = (input_ids == self.eos_id).any(dim=1)
            fusion_mask = (~has_eos).float().unsqueeze(-1) 
            fused[:, :self.eos_id] += fusion_mask*self.alpha * lm_lp[:, :self.eos_id]
        else: 
            fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
        
        # print(input_ids)
        # optional normalization step
        # fused -= torch.logsumexp(fused, dim=-1, keepdim=True)
        self.step += 1
        return fused
  
fusion_proc = ShallowFusion(
    lm=gpt2,
    eos_id=EOS_ID,
    prefix_tokens=PREFIX_TOK_IDS,
    alpha=0.1,
    warmup=2
)
fusion_proc.use_mask = True

In [7]:
from transformers import LogitsProcessorList
from tqdm import tqdm 

fused = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)

    with torch.inference_mode():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            num_beams=2,
            return_timestamps=False,
            return_token_timestamps=False,
            # do_sample=False,
            # length_penalty=1.1,
            # repetition_penalty=1.1,
            # max_length=100,  # Safety limit
        )

    decoded = processor.batch_decode(
        fused_ids, 
        skip_special_tokens=True, 
        # output_word_offsets=True
        )
    
    fused.extend(decoded)
    fusion_proc.reset()

Decoding:   0%|          | 0/17 [00:00<?, ?it/s]Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Decoding: 100%|██████████| 17/17 [00:51<00:00,  3.02s/it]


In [8]:
noop_processor = NoOpLogitsProcessor()
vanilla = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([noop_processor]),
            num_beams=2,
            do_sample=False,
            return_timestamps=False,
            return_token_timestamps=False
            # do_sample=False,
            # length_penalty=1.1,
            # repetition_penalty=1.1,
            # max_length=100,  # Safety limit
        )

    decoded = processor.batch_decode(
        vanilla_ids, 
        skip_special_tokens=True, 
        # output_word_offsets=True
        )
    
    vanilla.extend(decoded)


Decoding: 100%|██████████| 17/17 [00:34<00:00,  2.01s/it]


In [9]:
import pandas as pd 

results_df = pd.DataFrame(
    {
        "vanilla":[i.strip() for i in vanilla], 
        "fused":[i.strip() for i in fused], 
        "reference":ds['text'],
    }
)

In [15]:
from jiwer import (
    Compose,
    ToLowerCase,
    RemovePunctuation,
    RemoveMultipleSpaces,
    Strip,
    ReduceToListOfListOfWords,
    wer
)
from unidecode import unidecode
import re

# helper to handle both str and list[str]
def _map(func, x):
    return [func(t) for t in x] if isinstance(x, list) else func(x)

def remove_diacritics(x):
    return _map(unidecode, x)

def split_hyphens_and_slashes(x):
    # replace any dash or slash with a space so we never glue words together
    return _map(lambda t: re.sub(r"[-–—/]", " ", t), x)

def normalize_nums(x):
    # unify 12–16 → 12-16
    return _map(lambda t: re.sub(r"(\d)[-–—-](\d)", r"\1-\2", t), x)

transform = Compose([
    ToLowerCase(),
    remove_diacritics,
    split_hyphens_and_slashes,    # ← split first
    normalize_nums,
    RemovePunctuation(),           # now drop commas, periods, etc.
    RemoveMultipleSpaces(),
    Strip(),
    ReduceToListOfListOfWords(),   # produce [[“word”,…],…]
])

def compute_wer(ref, hyp):
    return wer(
        ref, hyp,
        reference_transform=transform,
        hypothesis_transform=transform,
    )

# rename & score
results_df = results_df.rename(columns={"gt": "reference"})
results_df["wer_base"]  = results_df.apply(
    lambda r: compute_wer(r["reference"], r["vanilla"]), axis=1
)
results_df["wer_fused"] = results_df.apply(
    lambda r: compute_wer(r["reference"], r["fused"]), axis=1
)

print("Base  WER (punct-insensitive):", results_df["wer_base"].mean())
print("Fused WER (punct-insensitive):", results_df["wer_fused"].mean())

Base  WER (punct-insensitive): 0.08241219037403308
Fused WER (punct-insensitive): 0.07002387319332835


In [13]:
results_df['diff'] = abs(results_df.wer_base - results_df.wer_fused)
top_diffs = results_df.sort_values(by='diff', ascending=False)
top_diffs = top_diffs[top_diffs['diff']>0]

In [14]:
print_str = '''
GT:    {}
Base:  {}
Fused: {}
base err - fused err: {}'''

for idx, row in top_diffs.iterrows():
    row_str = print_str.format(
        row['reference'], 
        row['vanilla'], 
        row['fused'],
        row['wer_base'] - row['wer_fused']
    )
    print(row_str)


GT:    The urinary bladder is distended with smooth wall thickening. Bilateral hydroureteronephrosis present. Foley catheter placement recommended.
Base:  The urinary bladder is distended with smooth wall thickening. Bilateral hydrouriterinophrosis present. Foley catheter placement recommended.
Fused: The urinary bladder is distended with smooth wall thickening.
base err - fused err: -0.375

GT:    Endoscopic retrograde cholangiopancreatography revealed type I choledochal cyst with anomalous pancreaticobiliary junction.
Base:  Endoscopic retrograde colangiopancreatography revealed type 1 collodocal cyst with anomalous pancreatic obiliary junction.
Fused: Endoscopic retrograde cholangiopancreatography revealed type 1 collodocal cyst with anomalous pancreaticobiliary junction.
base err - fused err: 0.25

GT:    Multiparametric MRI revealed infiltrating ductal carcinoma with peritumoral lymphovascular invasion and axillary lymphadenopathy measuring 2.3 cm.
Base:  Multi-parametric MRI rev

# BONEYARD


In [None]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

In [None]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-small.en"
GPT2_ID = "cwestnedge/gpt2-small-pubmed"

CACHE_DIR = (Path.cwd().parent / ".models").resolve()
MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

# fast tokenizers will show token mismatch between models and will be auto loaded when we run on colab A100 set flag to false to avoid annoyingness
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR, use_fast=False)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR, use_fast=False)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

EOS_ID = gpt2_tok.eos_token_id # 50256 (unchanged)

print("Decoder start token ID:", whisper.generation_config.decoder_start_token_id)
print("Decoder start token:", processor.decode([whisper.generation_config.decoder_start_token_id]))

PREFIX_TOK_IDS = [whisper.generation_config.decoder_start_token_id]
for position, token_id in whisper.generation_config.forced_decoder_ids:
    PREFIX_TOK_IDS.append(token_id)

print(f"\nComplete prefix: {PREFIX_TOK_IDS}")
print(f"Decoded: '{processor.decode(PREFIX_TOK_IDS)}'")
print(f"Total prefix length: {len(PREFIX_TOK_IDS)}")

def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE).select([21,22,23,24,25])

# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=2, 
    batched=True,
    remove_columns=list(ds.features.keys())
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=2, shuffle=False)

for i in range(len(gpt2_tok.get_vocab())):
    a = processor.tokenizer.decode([i])
    b = gpt2_tok.decode([i])
    if a != b:
        print(f"Token mismatch at index {i}\nwhisper token: {a}\n   gpt2 token: {b} ")


In [None]:
# print("Decoder start token ID:", whisper.config.decoder_start_token_id)
# print("BOS token ID:", processor.tokenizer.bos_token_id)
# print("Suppress tokens:", whisper.config.suppress_tokens if hasattr(whisper.config, 'suppress_tokens') else None)

In [None]:
from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F

class HelloWorldLP(LogitsProcessor):
    def __init__(self, alpha=0.0, warmup=2):
        super().__init__()
        self.alpha = alpha
        self.step  = 0
        self.warmup = warmup

    def reset(self): self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):
        w_lp = torch.log_softmax(scores, dim=-1)
        if self.step < self.warmup: 
            self.step+=1 
            return w_lp
        self.step+=1

        return scores

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, eos_id, alpha=0.25, warmup=3):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False).to(DEVICE)
        self.eos_id = eos_id # should be 50256
        self.alpha = alpha 
        self.warmup = warmup
        self.step = 0 
    
    def reset(self):
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):
        if self.step < self.warmup: 
            self.step += 1 
            return scores  # Return RAW SCORES, not w_lp!
        
        self.step += 1 

        whisper_start_ids = [50257, 50362]
        lm_input_ids = input_ids[:, len(whisper_start_ids):]

        lm_logits = self.lm(
            input_ids=lm_input_ids,
        ).logits[:, -1, :]  # just want next token logits

        # Convert to log probs
        w_lp = torch.log_softmax(scores, dim=-1)
        lm_lp = torch.log_softmax(lm_logits, dim=-1)

        # Create fusion mask: shape [batch_size, 1]
        # 1.0 for sequences that DON'T want EOS, 0.0 for those that do
        
        # Apply fusion
        fused = w_lp.clone()
        fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
        
        return fused  # RETU

class FixedShallowFusion(LogitsProcessor):
    def __init__(self, lm, eos_id, prefix_tokens, alpha=0.25, warmup=2):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.eos_id = eos_id
        self.prefix_tokens = prefix_tokens
        self.alpha = alpha
        self.warmup = warmup
        self.step = 0
    
    def reset(self):
        self.step = 0
    
    @torch.inference_mode()
    def __call__(self, input_ids, scores):
        w_lp = torch.log_softmax(scores, dim=-1)
        
        if self.step < self.warmup:
            self.step += 1
            return w_lp
        
        prefix_len = len(self.prefix_tokens)
        
        if input_ids.shape[1] <= prefix_len:
            self.step += 1
            return w_lp
        
        # CRITICAL FIX: Remove EOS tokens from sequence
        clean_ids = input_ids[:, prefix_len:]
        
        # Remove EOS tokens - truncate at first EOS
        for i in range(clean_ids.shape[0]):
            seq = clean_ids[i]
            eos_mask = seq == self.eos_id
            if eos_mask.any():
                first_eos = eos_mask.nonzero(as_tuple=True)[0][0].item()
                clean_ids = clean_ids[:, :first_eos]
                break
        
        if clean_ids.shape[1] == 0:
            self.step += 1
            return w_lp
        
        # Now GPT-2 gets clean input without EOS
        lm_logits = self.lm(input_ids=clean_ids).logits[:, -1, :]
        lm_lp = torch.log_softmax(lm_logits, dim=-1)
        
        # Simple fusion
        fused = w_lp.clone()
        fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
        
        self.step += 1
        return fused
    
    
hw_proc = HelloWorldLP(warmup=2)

fusion_proc = ShallowFusion(
    lm=gpt2,
    eos_id=50256,
    alpha=0.15,
    warmup=2
)

batch = next(iter(loader))
feats = batch['input_features'].to(DEVICE)
masks = batch['attention_mask'].to(DEVICE)

with torch.no_grad():

    out1 = whisper.generate(
        input_features=feats,
        attention_mask=masks,
        logits_processor=LogitsProcessorList([fusion_proc]),
        return_dict_in_generate=True,
        output_scores=True,
        return_timestamps=False,
        num_beams=2,
        # max_new_tokens= 10
    )
    fusion_proc.reset()
    # out2 = whisper.generate(
    #     input_features=feats,
    #     attention_mask=masks,
    #     return_dict_in_generate=True,
    #     output_scores=True,
    # )
    
# print((out1.scores[-1] != out2.scores[-1]).sum().sum())

In [None]:
sequences = [torch.tensor([[50257, 50362,  6930,   329],
        [50257, 50362,  6930,    13],
        [50257, 50362,   383,  7976],
        [50257, 50362,   262,  7976],
        [50257, 50362,  6952,   345],
        [50257, 50362,  1318,   318],
        [50257, 50362,  6952,   345],
        [50257, 50362,  6952,   921],
        [50257, 50362,   329,   262],
        [50257, 50362,   329,   345]], device='mps:0'),
torch.tensor([[50257, 50362,  6930,   329,  4964],
        [50257, 50362,  6930,   329,   262],
        [50257, 50362,   383,  7976,  2436],
        [50257, 50362,   262,  7976,  2436],
        [50257, 50362,  1318,   318,   257],
        [50257, 50362,  6952,   345,    13],
        [50257, 50362,  6952,   345,    13],
        [50257, 50362,  6952,   345,   329],
        [50257, 50362,   329,   345,    13],
        [50257, 50362,   329,   262,  1306]], device='mps:0')]


prefix_len = len(PREFIX_TOK_IDS)
gpt2_ids = sequences[0][:,prefix_len:]

with torch.inference_mode():
    lm_scores = gpt2(
        input_ids = gpt2_ids
    ).logits[:,-1,:]


lm_lp = torch.log_softmax(lm_scores, dim=-1)
lm_lp

In [None]:
    # clean_ids = input_ids[:, prefix_len:]
    # lm_logits = self.lm(input_ids=clean_ids).logits[:, -1, :]
    # lm_lp = torch.log_softmax(lm_logits, dim=-1)

    # # FUSION STEP 
    # # P_fused(y|x) = log P_ASR(y|x) + [HAS_EOS_MASK] * alpha * log P_LM(y)
    # fused = w_lp.clone()
    # if self.use_mask: 
    #     # this condition is to help exclude terminated sequences from getting revived by our lm
    #     has_eos = (input_ids == self.eos_id).any(dim=1)
    #     fusion_mask = (~has_eos).float().unsqueeze(-1) 
    #     fused[:, :self.eos_id] += fusion_mask*self.alpha * lm_lp[:, :self.eos_id]
    # else: 
    #     fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
    
    # print(input_ids)

In [None]:
# Check what generation strategy was used
print(f"Do sample: {whisper.config.do_sample if hasattr(whisper.config, 'do_sample') else 'N/A'}")
print(f"Temperature: {whisper.config.temperature if hasattr(whisper.config, 'temperature') else 'N/A'}")
print(f"Num beams: {whisper.config.num_beams if hasattr(whisper.config, 'num_beams') else 'N/A'}")

batch = next(iter(loader))
feats = batch['input_features'].to(DEVICE)
masks = batch['attention_mask'].to(DEVICE)

# Force greedy decoding (argmax) to match your analysis
out1_greedy = whisper.generate(
    input_features=feats,
    attention_mask=masks,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=False,  # Force greedy/argmax
    num_beams=1,      # No beam search
    return_timestamps=False
    # temperature=1.0,  # No temperature scaling
)

In [None]:
prefix_len = out1_greedy.sequences.shape[1] - len(out1_greedy.scores)
print(f"Prefix length: {prefix_len}")

t = 30
input_ids_at_t = out1_greedy.sequences[:, :prefix_len + t].clone()
scores_at_t = out1_greedy.scores[t]

in_scope_ids = input_ids_at_t[:, 2:]

with torch.inference_mode():
    gpt2_scores = gpt2(
        input_ids=in_scope_ids.to(DEVICE),
    ).logits[:, -1, :]

g_lp = torch.log_softmax(gpt2_scores, dim=-1)
w_lp = torch.log_softmax(scores_at_t, dim=-1)

# Apply fusion (using alpha=0.3 as in your code)
alpha = 0.3

fused = scores_at_t.clone()
fused[:, :EOS_ID] +=  alpha * g_lp[:, :EOS_ID]

# Get next tokens with different strategies
next_token_fused = fused.argmax(dim=-1)
next_token_raw = scores_at_t.argmax(dim=-1)
next_token_gpt2 = g_lp.argmax(dim=-1)

# Decode for comparison
print("\nToken choices:")
print(f"Raw ASR:{processor.decode(next_token_raw[0].item())}")
print(f"GPT-2  :{processor.decode(next_token_gpt2[0].item())}")
print(f"Fusion :{processor.decode(next_token_fused[0].item())}")

# Show the actual sequences
actual_next = out1_greedy.sequences[:, prefix_len + t]
print(f"Actual next token in sequence: {processor.decode(actual_next[0].item())}")

# Compare full sequences
inputs_with_raw = torch.cat([input_ids_at_t, next_token_raw.unsqueeze(1)], dim=-1)
inputs_with_fused = torch.cat([input_ids_at_t, next_token_fused.unsqueeze(1)], dim=-1)
inputs_actual = out1_greedy.sequences[:, :prefix_len + t + 1]

print("\nFull sequences:")
print(f"Context: {processor.batch_decode(input_ids_at_t)[0]}")
print(f"With raw ASR: {processor.batch_decode(inputs_with_raw)[0]}")
print(f"With fusion: {processor.batch_decode(inputs_with_fused)[0]}")
print(f"Actual sequence: {processor.batch_decode(inputs_actual)[0]}")

print(f"\nNote: Step {t} was pure ASR during generation (no LogitsProcessor used)")

In [None]:
# PROPER FUSION TESTING - During Active Generation (Not After EOS)

def test_fusion_during_generation():
    """Test fusion at the right time - during active generation"""
    
    print("=== TESTING FUSION DURING ACTIVE GENERATION ===")
    
    # Get a fresh generation
    batch = next(iter(loader))
    feats = batch['input_features'][:1].to(DEVICE)
    masks = batch['attention_mask'][:1].to(DEVICE)
    
    # Generate with return_dict to get intermediate scores
    with torch.no_grad():
        result = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            return_dict_in_generate=True,
            output_scores=True,
            num_beams=1,
            do_sample=False,
            max_new_tokens=20
        )
    
    prefix_len = result.sequences.shape[1] - len(result.scores)
    print(f"Prefix length: {prefix_len}")
    print(f"Total generation steps: {len(result.scores)}")
    
    # Test fusion at different steps DURING generation
    test_steps = [2, 5, 8, 10] if len(result.scores) > 10 else list(range(min(len(result.scores), 5)))
    
    for step in test_steps:
        if step >= len(result.scores):
            continue
            
        print(f"\n--- TESTING STEP {step} (Active Generation) ---")
        
        # Get the context at this step
        input_ids_at_step = result.sequences[:, :prefix_len + step]
        scores_at_step = result.scores[step]
        
        # Extract clean context for GPT-2
        clean_context = input_ids_at_step[:, prefix_len:]
        
        print(f"Context: '{processor.decode(clean_context[0])}'")
        print(f"Contains EOS: {(clean_context == EOS_ID).any().item()}")
        
        # Only test if no EOS in context (active generation)
        if not (clean_context == EOS_ID).any():
            # Get predictions
            with torch.no_grad():
                gpt2_logits = gpt2(clean_context).logits[:, -1, :]
            
            w_lp = torch.log_softmax(scores_at_step, dim=-1)
            g_lp = torch.log_softmax(gpt2_logits, dim=-1)
            
            # Test different alpha values
            alphas = [0.2, 0.5, 0.8]
            
            whisper_pred = w_lp.argmax(dim=-1)[0].item()
            gpt2_pred = g_lp.argmax(dim=-1)[0].item()
            
            print(f"  Whisper wants: '{processor.decode(whisper_pred)}'")
            print(f"  GPT-2 wants: '{processor.decode(gpt2_pred)}'")
            
            for alpha in alphas:
                fused_scores = w_lp.clone()
                fused_scores[:, :EOS_ID] += alpha * g_lp[:, :EOS_ID]
                fused_pred = fused_scores.argmax(dim=-1)[0].item()
                
                if fused_pred != whisper_pred:
                    print(f"  α={alpha}: '{processor.decode(fused_pred)}' 🔄 CHANGED!")
                else:
                    print(f"  α={alpha}: '{processor.decode(fused_pred)}' ➡️ Same")
        else:
            print(f"  Skipping - sequence already ended")

def compare_fusion_effectiveness():
    """Compare vanilla vs fusion with proper alpha"""
    
    print("\n" + "="*60)
    print("=== COMPARING VANILLA VS FUSION (Proper Alpha) ===")
    
    batch = next(iter(loader))
    feats = batch['input_features'][:1].to(DEVICE)
    masks = batch['attention_mask'][:1].to(DEVICE)
    
    # Vanilla generation
    with torch.no_grad():
        vanilla_result = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            num_beams=1,
            do_sample=False,
            max_new_tokens=30,
            return_timestamps=False
        )
    
    # Fusion with higher alpha
    fusion_proc = ShallowFusion(
        lm=gpt2,
        eos_id=EOS_ID,
        prefix_tokens=PREFIX_TOK_IDS,
        alpha=0.7,  # Higher alpha to see effect
        warmup=2
    )
    
    with torch.no_grad():
        fusion_result = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            num_beams=1,
            do_sample=False,
            max_new_tokens=30,
            return_timestamps=False
        )
    
    # Compare results
    vanilla_text = processor.decode(vanilla_result[0], skip_special_tokens=True)
    fusion_text = processor.decode(fusion_result[0], skip_special_tokens=True)
    
    print(f"Reference: {ds['text'][0]}")
    print(f"Vanilla:   {vanilla_text}")
    print(f"Fusion α=0.7: {fusion_text}")
    print(f"Different: {vanilla_text != fusion_text}")
    
    if vanilla_text != fusion_text:
        print("✅ SUCCESS: Fusion with proper alpha shows differences!")
        
        # Show character-level diff
        from difflib import unified_diff
        diff = list(unified_diff(
            vanilla_text.split(), 
            fusion_text.split(),
            fromfile='vanilla',
            tofile='fusion',
            lineterm=''
        ))
        if diff:
            print("\nWord-level differences:")
            for line in diff:
                print(f"  {line}")
    else:
        print("⚠️ Still identical - try even higher alpha")
    
    return vanilla_text, fusion_text

# RUN THE PROPER TESTS
test_fusion_during_generation()
compare_fusion_effectiveness()

print("\n🎯 KEY INSIGHTS:")
print("1. Test fusion DURING generation, not after EOS")
print("2. GPT-2 predicting 'The' after EOS is normal - it doesn't understand stopping context")
print("3. Your fusion correctly chooses to end rather than continue inappropriately")
print("4. Use higher alpha (0.5-0.8) to see LM influence during active generation")

In [None]:
processor.tokenizer.encode('<|startoftranscript|><|notimestamps|>')

In [None]:
out1_greedy.sequences
oob_mask = input_ids_at_t >= EOS_ID  # 50256
filtered_ids = input_ids_at_t.masked_fill(oob_mask, PAD_ID)  # 50257
attention_mask = (filtered_ids != PAD_ID).long()

In [None]:
# Check what generation strategy was used
print(f"Do sample: {whisper.config.do_sample if hasattr(whisper.config, 'do_sample') else 'N/A'}")
print(f"Temperature: {whisper.config.temperature if hasattr(whisper.config, 'temperature') else 'N/A'}")
print(f"Num beams: {whisper.config.num_beams if hasattr(whisper.config, 'num_beams') else 'N/A'}")

batch = next(iter(loader))
feats = batch['input_features'].to(DEVICE)
masks = batch['attention_mask'].to(DEVICE)

# Force greedy decoding (argmax) to match your analysis
out1_greedy = whisper.generate(
    input_features=feats,
    attention_mask=masks,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=False,  # Force greedy/argmax
    num_beams=1,      # No beam search
)

# First, determine the prefix length
prefix_len = out1_greedy.sequences.shape[1] - len(out1_greedy.scores)
print(f"Prefix length: {prefix_len}")

# Choose which step to analyze
t = 10

# Get the correct input sequence that was used to generate scores[t]
input_ids_at_t = out1_greedy.sequences[:, :prefix_len + t].clone()
scores_at_t = out1_greedy.scores[t]

print(f"Analyzing step {t}:")
print(f"Input shape: {input_ids_at_t.shape}")
print(f"This input was used to generate token at position {prefix_len + t}")

# NEW: Properly handle special tokens for GPT-2
# Extract only valid GPT-2 tokens (< EOS_ID) for each sequence
batch_size = input_ids_at_t.shape[0]
gpt2_inputs = []

for i in range(batch_size):
    # Get only tokens that are valid for GPT-2 (< 50256)
    valid_mask = input_ids_at_t[i] < EOS_ID
    valid_tokens = input_ids_at_t[i][valid_mask]
    gpt2_inputs.append(valid_tokens)

# Pad sequences to same length for batching
max_len = max(len(seq) for seq in gpt2_inputs)
filtered_ids = torch.full((batch_size, max_len), PAD_ID, device=DEVICE)
attention_mask = torch.zeros((batch_size, max_len), device=DEVICE)

for i, seq in enumerate(gpt2_inputs):
    filtered_ids[i, :len(seq)] = seq
    attention_mask[i, :len(seq)] = 1

print(f"GPT-2 input (first sequence): {processor.decode(gpt2_inputs[0])}")
print(f"Valid token IDs: {gpt2_inputs[0].tolist()}")

with torch.inference_mode():
    gpt2_scores = gpt2(
        input_ids=filtered_ids.to(DEVICE),
        attention_mask=attention_mask.to(DEVICE)
    ).logits

# Get log probabilities
g_lp = torch.log_softmax(gpt2_scores[:, -1, :], dim=-1)
w_lp = torch.log_softmax(scores_at_t, dim=-1)

# Apply fusion
alpha = 0.1
fused = scores_at_t.clone()
fused[:, :EOS_ID] += alpha * g_lp[:, :EOS_ID]

# Get next tokens with different strategie


In [None]:
# Get next tokens with different strategies
next_token_fused = fused.argmax(dim=-1)
next_token_raw = scores_at_t.argmax(dim=-1)
next_token_gpt2 = g_lp.argmax(dim=-1)

# Decode for comparison
print("\nToken choices:")
print(f"Raw ASR:{processor.decode(next_token_raw[0].item())}")
print(f"GPT-2  :{processor.decode(next_token_gpt2[0].item())}")
print(f"Fusion :{processor.decode(next_token_fused[0].item())}")

# Show the actual sequences
actual_next = out1_greedy.sequences[:, prefix_len + t]
print(f"Actual next token in sequence: {processor.decode(actual_next[0].item())}")

# Compare full sequences
inputs_with_raw = torch.cat([input_ids_at_t, next_token_raw.unsqueeze(1)], dim=-1)
inputs_with_fused = torch.cat([input_ids_at_t, next_token_fused.unsqueeze(1)], dim=-1)
inputs_actual = out1_greedy.sequences[:, :prefix_len + t + 1]

print("\nFull sequences:")
print(f"Context: {processor.batch_decode(input_ids_at_t)[0]}")
print(f"With raw ASR: {processor.batch_decode(inputs_with_raw)[0]}")
print(f"With fusion: {processor.batch_decode(inputs_with_fused)[0]}")
print(f"Actual sequence: {processor.batch_decode(inputs_actual)[0]}")

print(f"\nNote: Step {t} was pure ASR during generation (no LogitsProcessor used)")

In [None]:
gpt2_token_id = next_token_gpt2[0].item()
print(f"GPT-2 predicted token ID: {gpt2_token_id}")

# Try different decode methods
print(f"Decode with processor: '{processor.decode(gpt2_token_id)}'")
print(f"Decode with tokenizer: '{processor.tokenizer.decode([gpt2_token_id])}'")
print(f"Token string: '{processor.tokenizer.convert_ids_to_tokens([gpt2_token_id])[0]}'")

# Check if it's a space or special character
if gpt2_token_id < 50257:
    token = processor.tokenizer.convert_ids_to_tokens([gpt2_token_id])[0]
    print(f"Token repr: {repr(token)}")  # This will show \n, \t, spaces etc
    print(f"Token bytes: {token.encode('utf-8')}")

In [None]:
test_sequence

In [None]:
# Check what GPT-2 actually received as input
print(f"Input to GPT-2: {processor.decode(filtered_ids[0])}")
print(f"Raw token IDs: {filtered_ids[0].tolist()}")

# Let's manually test GPT-2 with clear medical context
test_sequence = gpt2_tok.encode("The patient had pain in the abdomen and pel", return_tensors="pt").to(DEVICE)
with torch.no_grad():
    test_output = gpt2(test_sequence).logits[0, -1, :]
    test_probs = torch.softmax(test_output, dim=0)
    top5 = test_probs.topk(5)
    
print("\nGPT-2 predictions for 'abdomen and pel':")
for prob, idx in zip(top5.values, top5.indices):
    if idx < 50257:
        print(f"  {processor.decode(idx)}: {prob:.3f}")

# Also check if the model weights look reasonable
print(f"\nGPT-2 weight stats:")
print(f"Mean: {gpt2.lm_head.weight.mean().item():.4f}")
print(f"Std: {gpt2.lm_head.weight.std().item():.4f}")

# Test a few medical terms
medical_tests = [
    "The patient's hep",  # -> hepatic/hepatitis
    "The cardiac cath",   # -> catheterization  
    "Diagnosed with pneum" # -> pneumonia
]

for test in medical_tests:
    tokens = processor.tokenizer.encode(test, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        out = gpt2(tokens).logits[0, -1, :]
    next_token = processor.decode(out.argmax().item())
    print(f"\n'{test}' -> '{next_token}'")

In [None]:
class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, pad_id, eos_id, alpha=0.25, warmup=3):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False).to(DEVICE)
        self.pad_id = pad_id # should be 50257 
        self.eos_id = eos_id # should be 50256
        self.alpha = alpha 
        self.warmup = warmup
        self.step = 0 
    
    def reset(self):
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):

        if self.step < self.warmup: 
            self.step+=1 
            return scores
        self.step+=1 

        oov_mask = input_ids >= self.eos_id # gpt2 and whispers EOS token
        padded_input_ids = input_ids.masked_fill(oov_mask, self.pad_id) # PAD_ID
        attention_mask = (padded_input_ids != self.pad_id).long()
        
        lm_logits = self.lm(
            input_ids=padded_input_ids,
            attention_mask=attention_mask
        ).logits[:,-1,:] # just want next token logits

        lm_lp = torch.log_softmax(lm_logits, dim=-1)
        
        fused = scores.clone()
        fused[:, :self.eos_id] += self.alpha * lm_lp[:, :self.eos_id]
        # optional normalization step
        # fused -= torch.logsumexp(fused, dim=-1, keepdim=True)
        return fused

alpha=0.3
warmup = 1
fusion_proc = ShallowFusion(gpt2, PAD_ID, EOS_ID, alpha=alpha, warmup=warmup)

out = whisper.generate(
    input_features=feats,
    attention_mask=masks,
    logits_processor=[fusion_proc],
    return_dict_in_generate=True,
    output_scores=True,
)

prefix_len = out.sequences.shape[1] - len(out.scores)  # Should be 2
pure_asr_step0 = out.scores[0]  # This is pure ASR (no fusion)

input_for_step0 = out.sequences[:, :prefix_len]

oov = input_for_step0 >= EOS_ID
gpt2_inp = input_for_step0.masked_fill(oov, PAD_ID)
attn = (gpt2_inp != PAD_ID).long()

lm_logits = gpt2(gpt2_inp, attention_mask=attn).logits[:, -1]
lm_logp = torch.log_softmax(lm_logits, dim=-1)

fused_manual = pure_asr_step0.clone()
fused_manual[:, :EOS_ID] += alpha * lm_logp[:, :EOS_ID]

# Now run AGAIN with warmup=0 to get actual fused scores at step 0
fusion_proc2 = ShallowFusion(gpt2, PAD_ID, EOS_ID, alpha=alpha, warmup=0)
out2 = whisper.generate(
    input_features=feats,
    attention_mask=masks,
    logits_processor=[fusion_proc2],
    return_dict_in_generate=True,
    output_scores=True,
)

# Compare!
fused_actual = out2.scores[0]
diff = (fused_actual - fused_manual).abs()
valid_mask = ~torch.isnan(diff)
print(f"Max diff: {diff[valid_mask].max().item():.6e}")

In [None]:
(out3.scores[t] != fused).sum().sum()

In [None]:
inputs_fused

In [None]:
b

In [None]:
c

In [None]:
t = 5
input_ids_at_t = out1.sequences[:,:t].clone()
scores_at_t = out1.scores[t-1]
texts = processor.batch_decode(
    input_ids_at_t,
    skip_special_tokens=True,
)
# texts = [t.strip() for t in texts] # this doesnt seem to make a difference
gpt2_inputs = gpt2_tok(
    texts,
    return_tensors="pt",
    padding=True,           # pads to longest in batch
    truncation=False,       # adjust as you like
).input_ids.to(DEVICE)

with torch.inference_mode():
    gpt2_scores = gpt2(
        input_ids = gpt2_inputs,
        attention_mask = torch.ones_like(gpt2_inputs).to(DEVICE)
    ).logits[:,-1]

g_lp = torch.log_softmax(gpt2_scores, dim=-1)
w_lp = torch.log_softmax(scores_at_t, dim=-1)

fused = w_lp.clone()
fused[:, :g_lp.size(1)] += 0 * g_lp
next_token = fused.argmax(dim=-1).unsqueeze(1)
next_token_raw = scores_at_t.argmax(dim=-1).unsqueeze(1)

inputs_raw = torch.cat([input_ids_at_t, next_token_raw], dim=-1)
inputs_fused = torch.cat([input_ids_at_t, next_token], dim=-1)
# inputs_fused -= torch.logsumexp(inputs_fused, dim=-1, keepdim=True)

a = processor.batch_decode(input_ids_at_t)
b = processor.batch_decode(inputs_fused)
c = processor.batch_decode(inputs_raw)
d = processor.batch_decode(out1.sequences[:,:t+1])


### ------------ TESTING END ---------------

In [None]:
# # =============================================================
# #  One-cell evaluation (uses your original jiwer transform)
# #  -------------------------------------------------------------
# #  Metrics per model:
# #    • Global WER                --> same as your old script
# #    • Medical-Term Recall (MTR) --> fraction of terms perfectly present
# #    • Medical-Term-only WER     --> WER on tokens that belong to terms
# #
# #  Expects a DataFrame `results_df` with columns:
# #        reference, vanilla, fused, medical_terms
# #  where medical_terms is list[str]  (or a string repr like "['a','b']").
# # =============================================================

# import re, ast, itertools, pandas as pd
# from jiwer import (
#     Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces,
#     Strip, ReduceToListOfListOfWords, wer
# )
# from unidecode import unidecode


# # ---------- helper to handle both str & list[str] -------------------
# def _map(func, x):
#     return [func(t) for t in x] if isinstance(x, list) else func(x)

# def remove_diacritics(x):
#     return _map(unidecode, x)

# def split_hyphens_and_slashes(x):
#     return _map(lambda t: re.sub(r"[-–—/]", " ", t), x)

# def normalize_nums(x):
#     return _map(lambda t: re.sub(r"(\d)[-–—-](\d)", r"\1-\2", t), x)

# # ---------- your original jiwer transform ---------------------------
# transform = Compose([
#     ToLowerCase(),
#     remove_diacritics,
#     split_hyphens_and_slashes,
#     normalize_nums,
#     RemovePunctuation(),
#     RemoveMultipleSpaces(),
#     Strip(),
#     ReduceToListOfListOfWords(),   # -> [["word", ...], ...]
# ])

# def compute_wer(ref, hyp):
#     return wer(
#         ref, hyp,
#         reference_transform=transform,
#         hypothesis_transform=transform,
#     )

# # ---------- lightweight normaliser for term metrics -----------------
# _punc_rx   = re.compile(r"[^\w\s]")
# _range_rx  = re.compile(r"(\d)[-–—-](\d)")
# _split_rx  = re.compile(r"[-–—/]")

# def _normalise(text: str) -> str:
#     text = unidecode(text.lower())
#     text = _range_rx.sub(r"\1-\2", text)
#     text = _split_rx.sub(" ", text)
#     text = _punc_rx.sub(" ", text)
#     return re.sub(r"\s+", " ", text).strip()

# def _term_recall(row, hyp_text):
#     hyp_norm = _normalise(hyp_text)
#     hits = sum(1 for t in row["medical_terms"] if _normalise(t) in hyp_norm)
#     return hits / len(row["medical_terms"])

# def _extract_term_tokens(row, text):
#     tokens = _normalise(text).split()
#     keep   = [False] * len(tokens)
#     for term in row["medical_terms"]:
#         ttoks = _normalise(term).split()
#         for i in range(len(tokens) - len(ttoks) + 1):
#             if tokens[i:i+len(ttoks)] == ttoks:
#                 for j in range(i, i+len(ttoks)):
#                     keep[j] = True
#     return " ".join(tok for tok, flag in zip(tokens, keep) if flag)

# # ---------- main evaluation routine ---------------------------------
# def evaluate(df: pd.DataFrame) -> None:
#     # ensure list[str] in medical_terms
#     if isinstance(df["medical_terms"].iloc[0], str):
#         df["medical_terms"] = df["medical_terms"].apply(ast.literal_eval)

#     # Global WER (your existing metric)
#     df["wer_vanilla"] = df.apply(
#         lambda r: compute_wer(r["reference"], r["vanilla"]), axis=1)
#     df["wer_fused"]   = df.apply(
#         lambda r: compute_wer(r["reference"], r["fused"]), axis=1)

#     # Medical-Term Recall
#     df["mtr_vanilla"] = df.apply(
#         lambda r: _term_recall(r, r["vanilla"]), axis=1)
#     df["mtr_fused"]   = df.apply(
#         lambda r: _term_recall(r, r["fused"]), axis=1)

#     # Medical-Term-only WER
#     df["mtwer_vanilla"] = df.apply(
#         lambda r: wer(
#             _extract_term_tokens(r, r["reference"]),
#             _extract_term_tokens(r, r["vanilla"]),
#             reference_transform=transform,
#             hypothesis_transform=transform), axis=1)
#     df["mtwer_fused"]   = df.apply(
#         lambda r: wer(
#             _extract_term_tokens(r, r["reference"]),
#             _extract_term_tokens(r, r["fused"]),
#             reference_transform=transform,
#             hypothesis_transform=transform), axis=1)

#     # -------- summary printout --------------------------------------
#     print("\n=== Global WER ===")
#     print(f"  vanilla : {df['wer_vanilla'].mean():.4f}")
#     print(f"  fused   : {df['wer_fused'].mean():.4f}")

#     print("\n=== Medical-Term Recall ===")
#     print(f"  vanilla : {df['mtr_vanilla'].mean():.4f}")
#     print(f"  fused   : {df['mtr_fused'].mean():.4f}")

#     print("\n=== Medical-Term-only WER ===")
#     print(f"  vanilla : {df['mtwer_vanilla'].mean():.4f}")
#     print(f"  fused   : {df['mtwer_fused'].mean():.4f}")

# # ---------- run on your DataFrame -----------------------------------

# import pandas as pd 

# results_df = pd.DataFrame(
#     {
#         "vanilla":[i.strip() for i in vanilla], 
#         "fused":[i.strip() for i in fused], 
#         "reference":ds['text'],
#         "medical_terms":ds['medical_terms']
#     }
# )

# evaluate(results_df.copy())

In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F

# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, shared_vocab, eos, alpha=0.3, warmup_steps=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False)
#         self.V = shared_vocab
#         self.eos = eos
#         self.alpha = alpha
#         self.warmup = warmup_steps
#         self.step = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         print('printing input_ids.size(), scores.size(), step, input_ids, dec_ids')
#         print(input_ids.size(), scores.size(), self.step, input_ids, processor.batch_decode(input_ids))
#         self.step+=1 

#         return scores
    
# fusion_proc = ShallowFusion(
#     lm=gpt2,
#     shared_vocab=gpt2.config.vocab_size,
#     eos=EOS_ID,
#     alpha=0.3
# )

# batch = next(iter(loader))
# feats = batch['input_features'].to(DEVICE)
# masks = batch['attention_mask'].to(DEVICE)

# with torch.no_grad():

#     out1 = whisper.generate(
#         input_features=feats,
#         attention_mask=masks,
#         logits_processor=LogitsProcessorList([fusion_proc]),
#         return_dict_in_generate=True,
#         output_scores=True,
#         num_beams=2,
#     )

#     # out2 = whisper.generate(
#     #     input_features=feats,
#     #     attention_mask=masks,
#     #     return_dict_in_generate=True,
#     #     output_scores=True,
#     #     num_beams=2
#     # )

In [None]:
batch = next(iter(loader))
feats = batch['input_features'].to(DEVICE)
masks = batch['attention_mask'].to(DEVICE)

# Generate to get sequences
with torch.no_grad():
    out = whisper.generate(
        input_features=feats,
        attention_mask=masks,
        num_beams=1,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=5
    )
# Compare step 0
decoder_ids = out.sequences[:,0:-1]  # Just the start token
with torch.no_grad():
    direct_logits = whisper(feats, decoder_input_ids=decoder_ids).logits[:, -1, :].to(DEVICE)
    direct_lp = torch.log_softmax(direct_logits, dim=-1)

gen_lp = out.scores[-1].to(DEVICE)

print(gen_lp)
print(direct_logits)


In [None]:
decoder_ids

In [None]:
# filter truly out of bounds vocab >=EOS
oob_mask = decoder_ids > EOS_ID # create mask for gpt2 OOV tokens emitted by whisper
 # replace with gpt2 pad token
filtered = decoder_ids.masked_fill(oob_mask, gpt2_tok.pad_token_id)
attention_mask = (filtered != gpt2_tok.pad_token_id).long()

with torch.no_grad():
    logits_new = gpt2(input_ids=filtered, attention_mask=attention_mask).logits[:,-1, :]

# because we dont want gpt2 to impact or determine termination just ASR model
logits_new[:,:EOS_ID-1].size()