In [52]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
print(str(CACHE_DIR).split('/')[-3:])

['SF_EVAL', '.models', 'hfcache']


In [53]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-small.en"
GPT2_ID = "cwestnedge/gpt2-medium-pubmed"

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

Device: mps


In [54]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained(GPT2_ID)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

if gpt2_tok.pad_token is None:
    gpt2_tok.add_special_tokens({"pad_token": "<|pad|>"})
    gpt2.resize_token_embeddings(len(gpt2_tok))

PAD_ID = gpt2_tok.pad_token_id # e.g. 50257
EOS_ID = gpt2_tok.eos_token_id # 50256 (unchanged)
SHARED_VOCAB = EOS_ID + 1

print(PAD_ID, EOS_ID, SHARED_VOCAB)

50257 50256 50257


In [6]:
for i in range(SHARED_VOCAB):
    if processor.tokenizer.decode([i]) != gpt2_tok.decode([i]):
            print(f"Token mismatch at index {i}")

In [7]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE)

# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=['uuid', 'file', 'category', 'index', 'text', 'audio']
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)

Map (num_proc=4): 100%|██████████| 30/30 [00:00<00:00, 31.36 examples/s]
Map: 100%|██████████| 30/30 [00:00<00:00, 47.03 examples/s]


In [47]:
from transformers import LogitsProcessor

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, shared_vocab, pad_id, alpha=0.3, warmup_steps=3, temperature=0.05):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.V = shared_vocab
        self.pad_id = pad_id
        self.alpha = alpha
        self.warmup = warmup_steps
        self.temp = temperature
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):

        if self.step < self.warmup: 
            self.step+=1 
            return scores
        self.step+=1 

        oob_mask = input_ids >= self.V
        filtered_ids = input_ids.masked_fill(oob_mask, self.pad_id)
        attn_mask = (filtered_ids != self.pad_id).long()

        lm_logits = self.lm(
            input_ids=filtered_ids, 
            attention_mask=attn_mask
        ).logits[:,-1,:] # only want logits for next token

        # lm_logits_shared = lm_logits[:, : self.V] # (B, 50257)
        lm_lp = torch.log_softmax(lm_logits, dim=-1)[:, : self.V] # (B, 50257)
        w_lp = scores # whisper emits log probs already so no need to do log_softmax

        fused_slice = w_lp[:, : self.V] + self.alpha * lm_lp
        fused_lp = w_lp.clone()
        fused_lp[:, : self.V] = fused_slice

        # two things whisper emits log probs already
        fused_lp -= torch.logsumexp(fused_lp, dim=-1, keepdim=True)
        print(self.step)
        # print("Normalized fused tensor range :", fused_lp.min().item(), fused_lp.max().item())
        # print("Whisper logits range          :", scores.min().item(), scores.max().item())
        # print("ELM logits range              :", lm_logits.min().item(), lm_logits.max().item())
        # print("ELM log_prob range            :", lm_lp.min().item(), lm_lp.max().item())
        # print()
        return fused_lp #/ self.temp

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, shared_vocab, pad_id, alpha=0.3, warmup_steps=3, temperature=0.05):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.V = shared_vocab
        self.pad_id = pad_id
        self.alpha = alpha
        self.warmup = warmup_steps
        self.temp = temperature
        self.step = 0
    
    def reset(self):
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):

        # if self.step == 0: 
        #     logsumexp  = torch.logsumexp(scores, dim=-1)
        #     prob_mass  = scores.exp().sum(dim=-1) 
        #     print("Whisper logsumexp (≈0):", logsumexp)
        #     print("Whisper prob mass  (≈1):", prob_mass)
        #     print("Whisper logits range:", scores.min().item(), scores.max().item())
    
        if self.step < self.warmup: 
            self.step+=1 
            return scores
        self.step+=1 

        oob_mask = input_ids >= self.V
        filtered_ids = input_ids.masked_fill(oob_mask, self.pad_id)
        attn_mask = (filtered_ids != self.pad_id).long()

        lm_logits = self.lm(
            input_ids=filtered_ids, 
            attention_mask=attn_mask
        ).logits[:,-1,:] # only want logits for next token

        # lm_logits_shared = lm_logits[:, : self.V] # (B, 50257)
        lm_lp = torch.log_softmax(lm_logits, dim=-1)[:, : self.V] # (B, 50257)
        w_lp = torch.log_softmax(scores, dim=-1) # whisper MAY emit log probs already so no need to do log_softmax
        # w_lp = scores # whisper emits log probs already so no need to do log_softmax

        fused_slice = w_lp[:, : self.V] + self.alpha * lm_lp
        fused_lp = w_lp.clone()
        fused_lp[:, : self.V] = fused_slice

        # two things whisper emits log probs already
        # fused_lp -= torch.logsumexp(fused_lp, dim=-1, keepdim=True)
        
        print("Normalized fused tensor range :", fused_lp.min().item(), fused_lp.max().item())
        print("Whisper logits range          :", scores.min().item(), scores.max().item())
        print("ELM logits range              :", lm_logits.min().item(), lm_logits.max().item())
        print("ELM log_prob range            :", lm_lp.min().item(), lm_lp.max().item())
        print(self.step)

        
        return fused_lp #/ self.temp
    
fusion_proc = ShallowFusion(
    lm=gpt2,
    shared_vocab=SHARED_VOCAB, # now 50 258 after adding <|pad|>
    pad_id=PAD_ID,# <— 50257
    alpha=0.34,
    warmup_steps=4,
    temperature = 0.05
)


In [48]:
from transformers import LogitsProcessorList
from tqdm import tqdm 

fused = []
refs = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)

    with torch.no_grad():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            num_beams=1,
            do_sample=False,
            # generation_config=gen_cfg,
        )
    decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
    fused.extend(decoded)
    fusion_proc.reset()
    # break
    # print(f"{'-'*20} BATCH_{idx} {'-'*20}")


Decoding:   0%|          | 0/6 [00:00<?, ?it/s]

Normalized fused tensor range : -inf -1.9351720809936523
Whisper logits range          : -inf 40.020538330078125
ELM logits range              : -1.6271870136260986 106.66032409667969
ELM log_prob range            : -49.70732879638672 -0.42529767751693726
5
Normalized fused tensor range : -inf -2.651327133178711
Whisper logits range          : -inf 31.015838623046875
ELM logits range              : -1.859573245048523 100.27620697021484
ELM log_prob range            : -48.40901184082031 -0.36179062724113464
6
Normalized fused tensor range : -inf -2.418224811553955
Whisper logits range          : -inf 34.643035888671875
ELM logits range              : 11.719165802001953 98.54611206054688
ELM log_prob range            : -47.51453399658203 -0.16678456962108612
7
Normalized fused tensor range : -inf -2.133643865585327
Whisper logits range          : -inf 33.80644226074219
ELM logits range              : 22.612577438354492 100.30415344238281
ELM log_prob range            : -52.33537292480469

Decoding:  17%|█▋        | 1/6 [00:01<00:07,  1.54s/it]

Normalized fused tensor range : -inf -3.3472652435302734
Whisper logits range          : -inf 32.60832977294922
ELM logits range              : 11.01267147064209 93.56316375732422
ELM log_prob range            : -48.50297546386719 -0.3649786710739136
5
Normalized fused tensor range : -inf -1.7170156240463257
Whisper logits range          : -inf 24.240680694580078
ELM logits range              : 3.9479129314422607 98.84561157226562
ELM log_prob range            : -47.80802917480469 -0.3705364465713501
6
Normalized fused tensor range : -inf -3.82226300239563
Whisper logits range          : -inf 18.46592140197754
ELM logits range              : 5.164938449859619 96.26339721679688
ELM log_prob range            : -48.677978515625 -0.47918859124183655
7
Normalized fused tensor range : -inf -1.488022804260254
Whisper logits range          : -inf 32.815147399902344
ELM logits range              : 6.850248336791992 92.77094268798828
ELM log_prob range            : -48.92655944824219 -0.13163393

Decoding:  33%|███▎      | 2/6 [00:03<00:06,  1.70s/it]

Normalized fused tensor range : -inf -0.33612510561943054
Whisper logits range          : -inf 19.406047821044922
ELM logits range              : 59.203094482421875 111.93832397460938
ELM log_prob range            : -48.54405212402344 -0.9667118191719055
28
Normalized fused tensor range : -inf -2.0789601802825928
Whisper logits range          : -inf 40.17632293701172
ELM logits range              : -1.00044846534729 80.90894317626953
ELM log_prob range            : -46.38007736206055 -0.07840103656053543
5
Normalized fused tensor range : -inf -2.491705894470215
Whisper logits range          : -inf 32.95386505126953
ELM logits range              : 4.79795503616333 80.26943969726562
ELM log_prob range            : -51.31562805175781 -0.03855551779270172
6
Normalized fused tensor range : -inf -1.78688645362854
Whisper logits range          : -inf 38.48463439941406
ELM logits range              : 20.14548110961914 79.6475601196289
ELM log_prob range            : -50.17220687866211 -0.12347

Decoding:  50%|█████     | 3/6 [00:05<00:05,  1.81s/it]

Normalized fused tensor range : -inf -1.4962034225463867
Whisper logits range          : -inf 34.24433135986328
ELM logits range              : 41.0994873046875 109.47005462646484
ELM log_prob range            : -51.608001708984375 -0.6767575740814209
30
Normalized fused tensor range : -inf -0.13740627467632294
Whisper logits range          : -inf 19.15380096435547
ELM logits range              : 48.542572021484375 108.31451416015625
ELM log_prob range            : -47.28108596801758 -0.38919511437416077
31
Normalized fused tensor range : -inf -3.318338632583618
Whisper logits range          : -inf 27.030517578125
ELM logits range              : 8.977217674255371 91.42965698242188
ELM log_prob range            : -51.25576400756836 -0.11485785990953445
5
Normalized fused tensor range : -inf -1.857423186302185
Whisper logits range          : -inf 33.765907287597656
ELM logits range              : 15.3555326461792 84.29269409179688
ELM log_prob range            : -50.1291618347168 -0.0495

Decoding:  67%|██████▋   | 4/6 [00:06<00:03,  1.73s/it]

Normalized fused tensor range : -inf -0.7353101372718811
Whisper logits range          : -inf 36.73586654663086
ELM logits range              : 17.859764099121094 112.14183807373047
ELM log_prob range            : -50.44358825683594 -0.2706511318683624
26
Normalized fused tensor range : -inf -0.23918099701404572
Whisper logits range          : -inf 35.99046325683594
ELM logits range              : 27.458290100097656 117.62235260009766
ELM log_prob range            : -46.609493255615234 -0.6591063737869263
27
Normalized fused tensor range : -inf -0.17940889298915863
Whisper logits range          : -inf 17.884384155273438
ELM logits range              : 65.79255676269531 116.16921997070312
ELM log_prob range            : -46.35974884033203 -0.5107797384262085
28
Normalized fused tensor range : -inf -1.6928091049194336
Whisper logits range          : -inf 40.68667221069336
ELM logits range              : 9.298093795776367 84.22431945800781
ELM log_prob range            : -46.7899780273437

Decoding:  83%|████████▎ | 5/6 [00:08<00:01,  1.78s/it]

Normalized fused tensor range : -inf -0.3733363747596741
Whisper logits range          : -inf 34.37839126586914
ELM logits range              : 5.947858810424805 90.58150482177734
ELM log_prob range            : -49.65442657470703 -0.02327992208302021
5
Normalized fused tensor range : -inf -3.952346086502075
Whisper logits range          : -inf 34.306251525878906
ELM logits range              : 11.750798225402832 92.72699737548828
ELM log_prob range            : -51.214115142822266 -0.020344629883766174
6
Normalized fused tensor range : -inf -2.8211357593536377
Whisper logits range          : -inf 36.67088317871094
ELM logits range              : 8.596049308776855 100.23770904541016
ELM log_prob range            : -49.35661315917969 -0.23527118563652039
7
Normalized fused tensor range : -inf -1.7603093385696411
Whisper logits range          : -inf 37.5537109375
ELM logits range              : 12.630240440368652 90.07604217529297
ELM log_prob range            : -51.30717849731445 -0.010

Decoding: 100%|██████████| 6/6 [00:11<00:00,  1.98s/it]

Normalized fused tensor range : -inf -0.6077675819396973
Whisper logits range          : -inf 24.178359985351562
ELM logits range              : 51.659603118896484 108.32563018798828
ELM log_prob range            : -44.24417495727539 -1.7745331525802612
40





In [50]:
vanilla = []
for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            num_beams=1,
            do_sample=False
        )
    decoded = processor.batch_decode(vanilla_ids, skip_special_tokens=True)
    vanilla.extend(decoded)

Decoding: 100%|██████████| 6/6 [00:07<00:00,  1.19s/it]


In [55]:
import pandas as pd 
results_df = pd.DataFrame({'vanilla':vanilla, 'fused':fused, 'gt':ds['text']})

print_str = '''
GT:    {}
Base:  {}
Fused: {}'''
for idx, row in results_df.iterrows():
    row_str = print_str.format(
        row['gt'], 
        row['vanilla'].strip(), 
        row['fused'].strip()
    )
    print(row_str)



GT:    The echocardiogram shows an ejection fraction of thirty-five percent with global hypokinesis.
Base:  The echocardiogram shows an ejection fraction of 35% with global hypokinesis.
Fused: The echocardiogram shows an ejection fraction of 35% with global hypokinesis.

GT:    Post-operative pathology confirmed a stage two-A adenocarcinoma of the sigmoid colon.
Base:  Postoperative pathology confirmed a stage 2a adenocarcinoma of the sigmoid colon.
Fused: Postoperative pathology confirmed a stage 2a adenocarcinoma of the sigmoid colon.

GT:    Her hemoglobin A-one-C has stabilized at seven point one percent after switching to semaglutide.
Base:  Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.
Fused: Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.

GT:    Magnetic resonance imaging revealed a three-centimeter demyelinating plaque in the periventricular white matter.
Base:  Magnetic resonance imaging revealed a 3 cm demyelinating plaq

# BONEYARD


In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F

# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, shared_vocab, eos, alpha=0.3, warmup_steps=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False)
#         self.V = shared_vocab
#         self.eos = eos
#         self.alpha = alpha
#         self.warmup = warmup_steps
#         self.step = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         print('printing input_ids.size(), scores.size(), step, input_ids, dec_ids')
#         print(input_ids.size(), scores.size(), self.step, input_ids, processor.batch_decode(input_ids))
#         self.step+=1 

#         return scores
    
# fusion_proc = ShallowFusion(
#     lm=gpt2,
#     shared_vocab=gpt2.config.vocab_size,
#     eos=EOS_ID,
#     alpha=0.3
# )

# batch = next(iter(loader))
# feats = batch['input_features'].to(DEVICE)
# masks = batch['attention_mask'].to(DEVICE)

# with torch.no_grad():

#     out1 = whisper.generate(
#         input_features=feats,
#         attention_mask=masks,
#         logits_processor=LogitsProcessorList([fusion_proc]),
#         return_dict_in_generate=True,
#         output_scores=True,
#         num_beams=2,
#     )

#     # out2 = whisper.generate(
#     #     input_features=feats,
#     #     attention_mask=masks,
#     #     return_dict_in_generate=True,
#     #     output_scores=True,
#     #     num_beams=2
#     # )

In [None]:
# steps = [
# torch.tensor([[50257, 50362],
#         [50257, 50362],
#         [50257, 50362]], device='mps:0'),
# torch.tensor([[50257, 50362,   383],
#         [50257, 50362,  2947],
#         [50257, 50362,  2332]], device='mps:0'),
# torch.tensor([[50257, 50362,   383,   304],
#         [50257, 50362,  2947,    12],
#         [50257, 50362,  2332, 16869]], device='mps:0'),
# torch.tensor([[50257, 50362,   383,   304,   354],
#         [50257, 50362,  2947,    12, 27173],
#         [50257, 50362,  2332, 16869, 49835]], device='mps:0')
# ]

# idx = 3
# with torch.no_grad():
#         dec_ids = steps[idx]
#         logits_a = whisper(feats, decoder_input_ids=dec_ids).logits[:, -1, :] 
#         out = whisper.generate(
#                 input_features=feats,
#                 attention_mask=masks,
#                 num_beams=1,
#                 do_sample=False,
#                 return_dict_in_generate=True,
#                 output_scores=True,
#                 max_new_tokens=len(steps)
#                 )
#         logits_b = out.scores[idx]

# print(logits_a)
# print(logits_b)

In [None]:
# oob_mask = dec_ids >= EOS_ID # create mask for gpt2 OOV tokens emitted by whisper
# pad_token = gpt2_tok.eos_token_id # replace with gpt2 pad token
# filtered = dec_ids.masked_fill(oob_mask, pad_token)
# attention_mask = (filtered != pad_token).long()

# with torch.no_grad():
#     logits_new = gpt2(input_ids=filtered, attention_mask=attention_mask).logits[:,-1, :]

# logits_new.size()