In [1]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SF_EVAL', '.models', 'hfcache']


In [62]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-tiny.en"
GPT2_ID = "cwestnedge/gpt2-medium-pubmed"

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

Device: mps


In [None]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained(GPT2_ID)
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

if gpt2_tok.pad_token is None:
    gpt2_tok.add_special_tokens({"pad_token": "<|pad|>"})
    gpt2.resize_token_embeddings(len(gpt2_tok))

PAD_ID = gpt2_tok.pad_token_id # e.g. 50257
EOS_ID = gpt2_tok.eos_token_id # 50256 (unchanged)
SHARED_VOCAB = EOS_ID + 1

print(PAD_ID, EOS_ID, SHARED_VOCAB)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


In [None]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE)

# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=['uuid', 'file', 'category', 'index', 'text', 'audio']
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)

Map (num_proc=4): 100%|██████████| 30/30 [00:00<00:00, 37.71 examples/s]
Map: 100%|██████████| 30/30 [00:00<00:00, 46.55 examples/s]


In [None]:
from transformers import LogitsProcessor

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, shared_vocab, pad_id, alpha=0.3, warmup_steps=3, temperature=0.05):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.V = shared_vocab
        self.pad_id = pad_id
        self.alpha = alpha
        self.warmup = warmup_steps
        self.temp = temperature
        self.step = 0

    @torch.inference_mode()
    def __call__(self, input_ids, scores):

        if self.step < self.warmup: 
            self.step+=1 
            return scores
        self.step+=1 

        oob_mask = input_ids >= self.V
        filtered_ids = input_ids.masked_fill(oob_mask, self.pad_id)
        attn_mask = (filtered_ids != self.pad_id).long()

        lm_logits = self.lm(
            input_ids=filtered_ids, 
            attention_mask=attn_mask
        ).logits[:,-1,:] # only want logits for next token

        # lm_logits_shared = lm_logits[:, : self.V] # (B, 50257)
        lm_lp = torch.log_softmax(lm_logits, dim=-1)[:, : self.V] # (B, 50257)
        w_lp = scores # whisper emits log probs already so no need to do log_softmax

        fused_slice = w_lp[:, : self.V] + self.alpha * lm_lp
        fused_lp = w_lp.clone()
        fused_lp[:, : self.V] = fused_slice

        # two things whisper emits log probs already
        fused_lp -= torch.logsumexp(fused_lp, dim=-1, keepdim=True)
        
        # print("Normalized fused tensor range :", fused_lp.min().item(), fused_lp.max().item())
        # print("Whisper logits range          :", scores.min().item(), scores.max().item())
        # print("ELM logits range              :", lm_logits.min().item(), lm_logits.max().item())
        # print("ELM log_prob range            :", lm_lp.min().item(), lm_lp.max().item())
        # print()
        return fused_lp #/ self.temp
    
fusion_proc = ShallowFusion(
    lm=gpt2,
    shared_vocab=SHARED_VOCAB, # now 50 258 after adding <|pad|>
    pad_id=PAD_ID,# <— 50257
    alpha=0.35,
    warmup_steps=3,
    temperature = 0.05
)

# sanity check
with torch.no_grad():
    dummy_ids = torch.tensor([[50256, 50257, 200, 345]], device=DEVICE)
    dummy_lp  = torch.log_softmax(torch.randn(1, whisper.config.vocab_size, device=DEVICE), dim=-1)
    out = fusion_proc(dummy_ids, dummy_lp)
    print("∑exp =", torch.exp(out).sum().item())      # 1.0 ± 1e‑4
    print("range",  out.min().item(), out.max().item())  # ≈ (‑20 … 0) after /temperature

In [None]:
from transformers import LogitsProcessorList
from tqdm import tqdm 

fused = []
refs = []

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            num_beams=3,
            do_sample=False
            # generation_config=gen_cfg,
        )
    decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
    fused.extend(decoded)


Decoding: 100%|██████████| 6/6 [00:23<00:00,  3.91s/it]


∑exp = 1.000000238418579
range -28.52570915222168 -4.04352331161499


In [None]:
with torch.no_grad():
    dummy_ids = torch.tensor([[50256, 200, 345]], device=DEVICE)
    dummy_lp  = torch.log_softmax(torch.randn(1, whisper.config.vocab_size, device=DEVICE), dim=-1)
    out = fusion_proc(dummy_ids, dummy_lp)
    print("∑exp =", torch.exp(out).sum().item())      # 1.0 ± 1e‑4
    print("range",  out.min().item(), out.max().item())  # ≈ (‑20 … 0) after /temperature

∑exp = 1.000000238418579
range -17.092540740966797 -5.121554374694824


In [None]:
vanilla = []
for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            num_beams=1,
            do_sample=False
        )
    decoded = processor.batch_decode(vanilla_ids, skip_special_tokens=True)
    vanilla.extend(decoded)

Decoding: 100%|██████████| 6/6 [00:07<00:00,  1.29s/it]


In [None]:
import pandas as pd 
results_df = pd.DataFrame({'vanilla':vanilla, 'fused':fused, 'gt':ds['text']})

for idx, row in results_df.iterrows():
    row_str = f"GT:    {row['gt']}\nBase:  {row['vanilla'].strip()}\nFused: {row['fused'].strip()}"
    print(row_str)
    print()

GT:    The echocardiogram shows an ejection fraction of thirty-five percent with global hypokinesis.
Base:  The echocardiogram shows an ejection fraction of 35% with global hypokinesis.
Fused: The echocardiogram shows an ejection fraction of 35% with global hypokinesis.

GT:    Post-operative pathology confirmed a stage two-A adenocarcinoma of the sigmoid colon.
Base:  Postoperative pathology confirmed a stage 2a adenocarcinoma of the sigmoid colon.
Fused: Post-operative pathology confirmed a stage 2a adenocarcinoma of the sigmoid colon.

GT:    Her hemoglobin A-one-C has stabilized at seven point one percent after switching to semaglutide.
Base:  Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.
Fused: Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.

GT:    Magnetic resonance imaging revealed a three-centimeter demyelinating plaque in the periventricular white matter.
Base:  Magnetic resonance imaging revealed a 3 cm demyelinating plaq

# BONEYARD


In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F

# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, shared_vocab, eos, alpha=0.3, warmup_steps=3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False)
#         self.V = shared_vocab
#         self.eos = eos
#         self.alpha = alpha
#         self.warmup = warmup_steps
#         self.step = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         print('printing input_ids.size(), scores.size(), step, input_ids, dec_ids')
#         print(input_ids.size(), scores.size(), self.step, input_ids, processor.batch_decode(input_ids))
#         self.step+=1 

#         return scores
    
# fusion_proc = ShallowFusion(
#     lm=gpt2,
#     shared_vocab=gpt2.config.vocab_size,
#     eos=EOS_ID,
#     alpha=0.3
# )

# batch = next(iter(loader))
# feats = batch['input_features'].to(DEVICE)
# masks = batch['attention_mask'].to(DEVICE)

# with torch.no_grad():

#     out1 = whisper.generate(
#         input_features=feats,
#         attention_mask=masks,
#         logits_processor=LogitsProcessorList([fusion_proc]),
#         return_dict_in_generate=True,
#         output_scores=True,
#         num_beams=2,
#     )

#     # out2 = whisper.generate(
#     #     input_features=feats,
#     #     attention_mask=masks,
#     #     return_dict_in_generate=True,
#     #     output_scores=True,
#     #     num_beams=2
#     # )

In [None]:
# steps = [
# torch.tensor([[50257, 50362],
#         [50257, 50362],
#         [50257, 50362]], device='mps:0'),
# torch.tensor([[50257, 50362,   383],
#         [50257, 50362,  2947],
#         [50257, 50362,  2332]], device='mps:0'),
# torch.tensor([[50257, 50362,   383,   304],
#         [50257, 50362,  2947,    12],
#         [50257, 50362,  2332, 16869]], device='mps:0'),
# torch.tensor([[50257, 50362,   383,   304,   354],
#         [50257, 50362,  2947,    12, 27173],
#         [50257, 50362,  2332, 16869, 49835]], device='mps:0')
# ]

# idx = 3
# with torch.no_grad():
#         dec_ids = steps[idx]
#         logits_a = whisper(feats, decoder_input_ids=dec_ids).logits[:, -1, :] 
#         out = whisper.generate(
#                 input_features=feats,
#                 attention_mask=masks,
#                 num_beams=1,
#                 do_sample=False,
#                 return_dict_in_generate=True,
#                 output_scores=True,
#                 max_new_tokens=len(steps)
#                 )
#         logits_b = out.scores[idx]

# print(logits_a)
# print(logits_b)

In [None]:
# oob_mask = dec_ids >= EOS_ID # create mask for gpt2 OOV tokens emitted by whisper
# pad_token = gpt2_tok.eos_token_id # replace with gpt2 pad token
# filtered = dec_ids.masked_fill(oob_mask, pad_token)
# attention_mask = (filtered != pad_token).long()

# with torch.no_grad():
#     logits_new = gpt2(input_ids=filtered, attention_mask=attention_mask).logits[:,-1, :]

# logits_new.size()