In [22]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SHALLOW_FUSION_EVAL', 'SF_EVAL', '.models']


In [57]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-tiny.en"
GPT2_ID = "cwestnedge/gpt2-medium-pubmed"

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

Device: mps


In [58]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

if gpt2_tok.pad_token is None:
    gpt2_tok.add_special_tokens({"pad_token": "<|pad|>"})
    gpt2.resize_token_embeddings(len(gpt2_tok))

PAD_ID = gpt2_tok.pad_token_id # e.g. 50257
EOS_ID = gpt2_tok.eos_token_id # 50256 (unchanged)
SHARED_VOCAB = EOS_ID + 1

print(PAD_ID, EOS_ID, SHARED_VOCAB)
# 50257 50256 50257

50257 50256 50257


In [59]:
# for i in range(PAD_ID):
#     if processor.tokenizer.decode([i]) != gpt2_tok.decode([i]):
#             print(f"Token mismatch at index {i}")

In [60]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE)

# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=['uuid', 'file', 'category', 'index', 'text', 'audio']
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)

Map (num_proc=4): 100%|██████████| 50/50 [00:01<00:00, 27.27 examples/s]
Map: 100%|██████████| 50/50 [00:01<00:00, 47.97 examples/s]


In [61]:
from transformers import LogitsProcessor

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, fusion_exclusive, pad_id, alpha=0.3, warmup_steps=3, temperature=0.05):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.fusion_excl = fusion_exclusive  # e.g. EOS_ID = 50256
        self.pad_id = pad_id
        self.alpha = alpha
        self.warmup = warmup_steps
        self.temp = temperature
        self.step = 0
        self.alpha_scale = 0.35
        self.entropy_threshold = 1.5
    
    def reset(self):
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):
        w_lp = torch.log_softmax(scores, dim=-1)

        if self.step < self.warmup: 
            self.step+=1 
            return scores
        self.step+=1 

        # ----DYNAMIC ALPHA------
        w_probs = w_lp.exp()
        # Compute entropy safely: ignore zero probabilities
        ent_contrib = torch.where(w_probs > 0,
                                  w_probs * w_lp,
                                  torch.zeros_like(w_lp))
        w_entropy = -(ent_contrib.sum(dim=-1, keepdim=True))  # shape: [B,1]
        # Smooth gating: map entropy to [alpha, alpha*alpha_scale]
        gate = torch.sigmoid((w_entropy - self.entropy_threshold) * 2)
        alpha_dynamic = self.alpha * (1 + (self.alpha_scale - 1) * gate)
        # print(alpha_dynamic)
        # ----DYNAMIC ALPHA------

        oob_mask = input_ids > self.fusion_excl
        filtered_ids = input_ids.masked_fill(oob_mask, self.pad_id)
        attn_mask = (filtered_ids != self.pad_id).long()

        lm_logits = self.lm(
            input_ids=filtered_ids, 
            attention_mask=attn_mask
        ).logits[:,-1,:] # only want logits for next token
        lm_lp = torch.log_softmax(lm_logits, dim=-1)[:, :self.fusion_excl]

        fused_slice = w_lp[:, :self.fusion_excl] + alpha_dynamic * lm_lp
        fused_lp = w_lp.clone()
        fused_lp[:, :self.fusion_excl] = fused_slice
        
        # renormalize
        fused_lp -= torch.logsumexp(fused_lp, dim=-1, keepdim=True)
        
        # print("Normalized fused tensor range :", fused_lp.min().item(), fused_lp.max().item())
        # print("Whisper logits range          :", scores.min().item(), scores.max().item())
        # print("ELM logits range              :", lm_logits.min().item(), lm_logits.max().item())
        # print("ELM log_prob range            :", lm_lp.min().item(), lm_lp.max().item())
        # print(self.step)
        
        return fused_lp # / self.temp
    
fusion_proc = ShallowFusion(
    lm=gpt2,
    fusion_exclusive= EOS_ID, # e.g. EOS_ID = 50256
    pad_id=PAD_ID,# <— 50257
    alpha=0.3,
    warmup_steps=4,
    temperature = 1.0
)


In [64]:
from transformers import LogitsProcessorList
from tqdm import tqdm 

fused = []
refs = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)

    with torch.no_grad():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            num_beams=1,
            do_sample=False,
            # generation_config=gen_cfg,
        )
    decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
    fused.extend(decoded)
    fusion_proc.reset()
    # break
    # print(f"{'-'*20} BATCH_{idx} {'-'*20}")


Decoding: 100%|██████████| 10/10 [00:21<00:00,  2.13s/it]


In [65]:
vanilla = []
for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            num_beams=1,
            do_sample=False
        )
    decoded = processor.batch_decode(vanilla_ids, skip_special_tokens=True)
    vanilla.extend(decoded)

Decoding: 100%|██████████| 10/10 [00:10<00:00,  1.08s/it]


In [66]:
import pandas as pd 

print_str = '''
GT:    {}
Base:  {}
Fused: {}'''

results_df = pd.DataFrame({'vanilla':vanilla, 'fused':fused, 'gt':ds['text']})

for idx, row in results_df.iterrows():
    row_str = print_str.format(
        row['gt'], 
        row['vanilla'].strip(), 
        row['fused'].strip()
    )
    print(row_str)



GT:    There is evidence of bilateral ground-glass opacities consistent with atypical viral pneumonia.
Base:  There is evidence of bilateral ground glass opacities consistent with atypical viral pneumonia.
Fused: There is evidence of bilateral ground-glass opacities consistent with atypical viral pneumonia.

GT:    Non-contrast CT of the kidneys shows a 7 mm calculus in the proximal left ureter with mild hydronephrosis.
Base:  Non-contrast CT of the kidneys shows a 7-millimeter calculus in the proximal left-year eater with mild hydro nephrosis.
Fused: Non-contrast CT of the kidneys shows a 7-millimeter calculus in the proximal left-year-eater with mild hydro nephrosis.

GT:    CT angiography demonstrates a filling defect in the right main pulmonary artery suggestive of acute pulmonary embolism.
Base:  CT angiography demonstrates a filling defect in the right main pulmonary artery suggestive of acute pulmonary embolism.
Fused: CT angiography demonstrates a filling defect in the right m

# BONEYARD


In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F

# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, shared_vocab, eos, alpha=0.3, warmup_steps=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False)
#         self.V = shared_vocab
#         self.eos = eos
#         self.alpha = alpha
#         self.warmup = warmup_steps
#         self.step = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         print('printing input_ids.size(), scores.size(), step, input_ids, dec_ids')
#         print(input_ids.size(), scores.size(), self.step, input_ids, processor.batch_decode(input_ids))
#         self.step+=1 

#         return scores
    
# fusion_proc = ShallowFusion(
#     lm=gpt2,
#     shared_vocab=gpt2.config.vocab_size,
#     eos=EOS_ID,
#     alpha=0.3
# )

# batch = next(iter(loader))
# feats = batch['input_features'].to(DEVICE)
# masks = batch['attention_mask'].to(DEVICE)

# with torch.no_grad():

#     out1 = whisper.generate(
#         input_features=feats,
#         attention_mask=masks,
#         logits_processor=LogitsProcessorList([fusion_proc]),
#         return_dict_in_generate=True,
#         output_scores=True,
#         num_beams=2,
#     )

#     # out2 = whisper.generate(
#     #     input_features=feats,
#     #     attention_mask=masks,
#     #     return_dict_in_generate=True,
#     #     output_scores=True,
#     #     num_beams=2
#     # )

In [None]:
batch = next(iter(loader))
feats = batch['input_features'].to(DEVICE)
masks = batch['attention_mask'].to(DEVICE)

# Generate to get sequences
with torch.no_grad():
    out = whisper.generate(
        input_features=feats,
        attention_mask=masks,
        num_beams=1,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=5
    )
# Compare step 0
decoder_ids = out.sequences[:,0:-1]  # Just the start token
with torch.no_grad():
    direct_logits = whisper(feats, decoder_input_ids=decoder_ids).logits[:, -1, :].to(DEVICE)
    direct_lp = torch.log_softmax(direct_logits, dim=-1)

gen_lp = out.scores[-1].to(DEVICE)

print(gen_lp)
print(direct_logits)


In [None]:
decoder_ids

In [None]:
# filter truly out of bounds vocab >=EOS
oob_mask = decoder_ids > EOS_ID # create mask for gpt2 OOV tokens emitted by whisper
 # replace with gpt2 pad token
filtered = decoder_ids.masked_fill(oob_mask, gpt2_tok.pad_token_id)
attention_mask = (filtered != gpt2_tok.pad_token_id).long()

with torch.no_grad():
    logits_new = gpt2(input_ids=filtered, attention_mask=attention_mask).logits[:,-1, :]

# because we dont want gpt2 to impact or determine termination just ASR model
logits_new[:,:EOS_ID-1].size()