In [1]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SHALLOW_FUSION_EVAL', 'SF_EVAL', '.models']


In [2]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-small.en"
GPT2_ID = "cwestnedge/gpt2-small-pubmed"

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

  from .autonotebook import tqdm as notebook_tqdm


Device: mps


In [3]:
# fast tokenizers will show token mismatch between models and will be auto loaded when we run on colab A100 set flag to false to avoid annoyingness
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR, use_fast=False)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR, use_fast=False)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

if gpt2_tok.pad_token is None:
    gpt2_tok.add_special_tokens({"pad_token": "<|pad|>"})
    gpt2.resize_token_embeddings(len(gpt2_tok))

PAD_ID = gpt2_tok.pad_token_id # e.g. 50257
EOS_ID = gpt2_tok.eos_token_id # 50256 (unchanged)
SHARED_VOCAB = EOS_ID + 1

print(PAD_ID, EOS_ID, SHARED_VOCAB)
# 50257 50256 50257

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


50257 50256 50257


In [4]:
for i in range(SHARED_VOCAB):
  a = processor.tokenizer.decode([i])
  b = gpt2_tok.decode([i])
  if a != b:
    print(f"Token mismatch at index {i}\nwhisper token: {a}\n   gpt2 token: {b} ")

In [5]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE)

# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=['uuid', 'file', 'category', 'index', 'text', 'audio']
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)

Map (num_proc=4): 100%|██████████| 85/85 [00:01<00:00, 69.68 examples/s]
Map: 100%|██████████| 85/85 [00:02<00:00, 38.50 examples/s]


In [6]:
len(ds)

85

In [None]:
from transformers import LogitsProcessor

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, fusion_exclusive, pad_id, alpha=0.3, warmup_steps=3, temperature=0.05):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.fusion_excl = fusion_exclusive  # e.g. EOS_ID = 50256
        self.pad_id = pad_id
        self.alpha = alpha
        self.warmup = warmup_steps
        self.temp = temperature
        self.step = 0
        self.alpha_scale = 0.35
        self.entropy_threshold = 1.3
    
    def reset(self):
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):
        w_lp = torch.log_softmax(scores, dim=-1)

        if self.step < self.warmup: 
            self.step+=1 
            return w_lp
        self.step+=1 

        # ----DYNAMIC ALPHA------
        # w_probs = w_lp.exp()
        # # Compute entropy safely: ignore zero probabilities
        # ent_contrib = torch.where(w_probs > 0,
        #                           w_probs * w_lp,
        #                           torch.zeros_like(w_lp))
        # w_entropy = -(ent_contrib.sum(dim=-1, keepdim=True))  # shape: [B,1]
        # # Smooth gating: map entropy to [alpha, alpha*alpha_scale]
        # gate = torch.sigmoid((w_entropy - self.entropy_threshold) * 2)
        # ----DYNAMIC ALPHA------

        # replace OOV tokens with pad tokens and exclude from attention scores
        oob_mask = input_ids > self.fusion_excl
        filtered_ids = input_ids.masked_fill(oob_mask, self.pad_id)
        attn_mask = (filtered_ids != self.pad_id).long()

        lm_logits = self.lm(
            input_ids=filtered_ids, 
            attention_mask=attn_mask
        ).logits[:,-1,:] # only want logits for next token
        lm_lp = torch.log_softmax(lm_logits, dim=-1)[:, :self.fusion_excl]

        fused_slice = w_lp[:, :self.fusion_excl] + self.alpha * lm_lp
        fused_lp = w_lp.clone()
        fused_lp[:, :self.fusion_excl] = fused_slice
        
        # renormalize
        fused_lp -= torch.logsumexp(fused_lp, dim=-1, keepdim=True)
        
        return fused_lp # / self.temp
    
fusion_proc = ShallowFusion(
    lm=gpt2,
    fusion_exclusive=EOS_ID, # e.g. EOS_ID = 50256
    pad_id=PAD_ID,# <— 50257
    alpha=0.15,
    warmup_steps=4,
    temperature = 1.0
)


In [9]:
from transformers import LogitsProcessorList
from tqdm import tqdm 

fused = []
refs = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)

    with torch.no_grad():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            num_beams=1,
            do_sample=False,
            length_penalty=1.0,
        )
    decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
    fused.extend(decoded)
    fusion_proc.reset()

Decoding:   0%|          | 0/17 [00:00<?, ?it/s]Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Decoding: 100%|██████████| 17/17 [00:41<00:00,  2.43s/it]


In [10]:
class NoOpLogitsProcessor(LogitsProcessor):
    """
    A simple logits processor that performs no modifications to the logits.
    """
    def __call__(self, input_ids, scores):
        # Return the scores unchanged

        w_lp = torch.log_softmax(scores, dim=-1)
        return w_lp


noop_processor = NoOpLogitsProcessor()
vanilla = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([noop_processor]),
            num_beams=1,
            do_sample=False,
            length_penalty=1.0,
        )
    decoded = processor.batch_decode(vanilla_ids, skip_special_tokens=True)
    vanilla.extend(decoded)

Decoding: 100%|██████████| 17/17 [00:26<00:00,  1.56s/it]


In [11]:
import pandas as pd 

results_df = pd.DataFrame(
    {'vanilla':[i.strip() for i in vanilla], 
     'fused':[i.strip() for i in fused], 
     'reference':ds['text']}
     )

In [12]:
from jiwer import (
    Compose,
    ToLowerCase,
    RemovePunctuation,
    RemoveMultipleSpaces,
    Strip,
    ReduceToListOfListOfWords,
    wer
)
from unidecode import unidecode
import re

# helper to handle both str and list[str]
def _map(func, x):
    return [func(t) for t in x] if isinstance(x, list) else func(x)

def remove_diacritics(x):
    return _map(unidecode, x)

def split_hyphens_and_slashes(x):
    # replace any dash or slash with a space so we never glue words together
    return _map(lambda t: re.sub(r"[-–—/]", " ", t), x)

def normalize_nums(x):
    # unify 12–16 → 12-16
    return _map(lambda t: re.sub(r"(\d)[-–—-](\d)", r"\1-\2", t), x)

transform = Compose([
    ToLowerCase(),
    remove_diacritics,
    split_hyphens_and_slashes,    # ← split first
    normalize_nums,
    RemovePunctuation(),           # now drop commas, periods, etc.
    RemoveMultipleSpaces(),
    Strip(),
    ReduceToListOfListOfWords(),   # produce [[“word”,…],…]
])

def compute_wer(ref, hyp):
    return wer(
        ref, hyp,
        reference_transform=transform,
        hypothesis_transform=transform,
    )

# rename & score
results_df = results_df.rename(columns={"gt": "reference"})
results_df["wer_base"]  = results_df.apply(
    lambda r: compute_wer(r["reference"], r["vanilla"]), axis=1
)
results_df["wer_fused"] = results_df.apply(
    lambda r: compute_wer(r["reference"], r["fused"]), axis=1
)

print("Base  WER (punct-insensitive):", results_df["wer_base"].mean())
print("Fused WER (punct-insensitive):", results_df["wer_fused"].mean())

Base  WER (punct-insensitive): 0.07469183832477919
Fused WER (punct-insensitive): 0.06996960771336627


In [13]:
results_df['diff'] = abs(results_df.wer_base - results_df.wer_fused)
top_diffs = results_df.sort_values(by='diff', ascending=False).head(10)
top_diffs

Unnamed: 0,vanilla,fused,reference,wer_base,wer_fused,diff
51,Dual energy CT characterized crystal depositio...,Dual energy CT characterized crystal depositio...,Dual-energy CT characterized crystal depositio...,0.214286,0.071429,0.142857
30,High-resolution MRI identified leptominin-geom...,High-resolution MRI identified leptomeningeal ...,High-resolution MRI identified leptomeningeal ...,0.142857,0.0,0.142857
6,Large heterogeneously enhancing mass centered ...,"Large heterogeneously enhancing mass, centered...",Large heterogeneously enhancing mass centered ...,0.0,0.136364,0.136364
69,Muscle biopsy revealed inclusion body myositis...,Muscle biopsy revealed inclusion body myositis...,Muscle biopsy revealed inclusion body myositis...,0.117647,0.0,0.117647
50,Next-generation sequencing identified pathogen...,Next-generation sequencing identified pathogen...,Next-generation sequencing identified pathogen...,0.222222,0.111111,0.111111
29,Histopathology demonstrated Rosai Dorfman dise...,Histopathology demonstrated Rosi-Dorfman disea...,Histopathology demonstrated Rosai-Dorfman dise...,0.111111,0.222222,0.111111
17,No evidence of pulmonary embolism to the subse...,No evidence of pulmonary embolism to the sub-s...,No evidence of pulmonary embolism to the subse...,0.0,0.090909,0.090909
41,High resolution CT delineated pulmonary alveol...,High resolution CT delineated pulmonary alveol...,High-resolution CT delineated pulmonary alveol...,0.0,0.083333,0.083333
40,MR neurography revealed hourglass-like constri...,MR neurography revealed hourglass-like constri...,MR neurography revealed hourglass-like constri...,0.076923,0.0,0.076923
45,Bone centigraphy exhibited the hot skull sign ...,Bone centigraphy exhibited the hot skull sign ...,Bone scintigraphy exhibited the 'hot skull' si...,0.214286,0.142857,0.071429


In [14]:
print_str = '''
GT:    {}
Base:  {}
Fused: {}'''

for idx, row in top_diffs.iterrows():
    row_str = print_str.format(
        row['reference'], 
        row['vanilla'], 
        row['fused']
    )
    print(row_str)


GT:    Dual-energy CT characterized crystal deposition consistent with tophaceous pseudogout involving the atlantoaxial joint.
Base:  Dual energy CT characterized crystal deposition consistent with topaceous pseudogout involving the Atlanta axial joint.
Fused: Dual energy CT characterized crystal deposition consistent with topatious pseudogout involving the atlantoaxial joint.

GT:    High-resolution MRI identified leptomeningeal melanocytosis with diffuse T1 hyperintensity along the cerebellar folia.
Base:  High-resolution MRI identified leptominin-geomelanocytosis with diffuse T1 hyperintensity along the cerebellar folia.
Fused: High-resolution MRI identified leptomeningeal melanocytosis with diffuse T1 hyperintensity along the cerebellar folia.

GT:    Large heterogeneously enhancing mass centered in the right parotid gland measuring 4.2 by 3.8 by 3.5 centimeters with areas of internal necrosis.
Base:  Large heterogeneously enhancing mass centered in the right parotid gland measuri

# BONEYARD


In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F

# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, shared_vocab, eos, alpha=0.3, warmup_steps=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False)
#         self.V = shared_vocab
#         self.eos = eos
#         self.alpha = alpha
#         self.warmup = warmup_steps
#         self.step = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         print('printing input_ids.size(), scores.size(), step, input_ids, dec_ids')
#         print(input_ids.size(), scores.size(), self.step, input_ids, processor.batch_decode(input_ids))
#         self.step+=1 

#         return scores
    
# fusion_proc = ShallowFusion(
#     lm=gpt2,
#     shared_vocab=gpt2.config.vocab_size,
#     eos=EOS_ID,
#     alpha=0.3
# )

# batch = next(iter(loader))
# feats = batch['input_features'].to(DEVICE)
# masks = batch['attention_mask'].to(DEVICE)

# with torch.no_grad():

#     out1 = whisper.generate(
#         input_features=feats,
#         attention_mask=masks,
#         logits_processor=LogitsProcessorList([fusion_proc]),
#         return_dict_in_generate=True,
#         output_scores=True,
#         num_beams=2,
#     )

#     # out2 = whisper.generate(
#     #     input_features=feats,
#     #     attention_mask=masks,
#     #     return_dict_in_generate=True,
#     #     output_scores=True,
#     #     num_beams=2
#     # )

In [None]:
batch = next(iter(loader))
feats = batch['input_features'].to(DEVICE)
masks = batch['attention_mask'].to(DEVICE)

# Generate to get sequences
with torch.no_grad():
    out = whisper.generate(
        input_features=feats,
        attention_mask=masks,
        num_beams=1,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=5
    )
# Compare step 0
decoder_ids = out.sequences[:,0:-1]  # Just the start token
with torch.no_grad():
    direct_logits = whisper(feats, decoder_input_ids=decoder_ids).logits[:, -1, :].to(DEVICE)
    direct_lp = torch.log_softmax(direct_logits, dim=-1)

gen_lp = out.scores[-1].to(DEVICE)

print(gen_lp)
print(direct_logits)


In [None]:
decoder_ids

In [None]:
# filter truly out of bounds vocab >=EOS
oob_mask = decoder_ids > EOS_ID # create mask for gpt2 OOV tokens emitted by whisper
 # replace with gpt2 pad token
filtered = decoder_ids.masked_fill(oob_mask, gpt2_tok.pad_token_id)
attention_mask = (filtered != gpt2_tok.pad_token_id).long()

with torch.no_grad():
    logits_new = gpt2(input_ids=filtered, attention_mask=attention_mask).logits[:,-1, :]

# because we dont want gpt2 to impact or determine termination just ASR model
logits_new[:,:EOS_ID-1].size()