In [1]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SF_EVAL', '.models', 'hfcache']


In [3]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-small.en"
GPT2_ID = "cwestnedge/gpt2-small-pubmed"
SHARED_VOCAB = 50257
ALPHA = 0.3
INIT_W_STEPS = 2
MAX_STEPS = 256

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

Device: mps


In [4]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

EOS_ID = processor.tokenizer.eos_token_id

In [5]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"],
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE)
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=['uuid', 'file', 'category', 'index', 'text', 'audio']   # keep "text" for WER
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)


Map (num_proc=4): 100%|██████████| 30/30 [00:00<00:00, 30.65 examples/s]
Map: 100%|██████████| 30/30 [00:00<00:00, 43.69 examples/s]


In [34]:
from transformers import LogitsProcessor, LogitsProcessorList
import torch.nn.functional as F

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, lm_tok, shared_vocab, special_mask, alpha=0.3):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.tok = lm_tok
        self.V = shared_vocab
        self.special_mask = special_mask
        self.alpha = alpha
        self.step = 0

    @torch.no_grad()
    def __call__(self, input_ids, scores):
        self.step += 1
        if self.step <= 2:
            return scores

        B = input_ids.size(0) # batch_size * num_beams
        device = scores.device
        mask = self.special_mask.to(device)

        for beam in input_ids:
            print(processor.batch_decode(beam))
        keep = ~mask[input_ids[0]]
        filtered = input_ids[0, keep]
        gpt_ids = filtered.unsqueeze(0) if filtered.numel() else input_ids[:, -1:].clone()
        
        # if still no lexical token, skip fusion
        if (gpt_ids < self.V).sum() == 0:
            return scores
        gpt_ids = gpt_ids.to(device)

        w_lp_full = torch.log_softmax(scores, dim=-1)
        w_lp_shared = w_lp_full[:, : self.V]

        g_logits = self.lm(gpt_ids).logits[:, -1, :]
        g_lp = torch.log_softmax(g_logits, dim=-1)

        fused_shared = w_lp_shared + self.alpha * g_lp
        fused_shared = torch.where(
            mask[: self.V].unsqueeze(0), 
            w_lp_shared, 
            fused_shared
        )
        
        fused = torch.cat([fused_shared, w_lp_full[:, self.V:]], dim=-1)
        return fused # still log‑probs; generate() is fine with that?
    

special_ids = set(processor.tokenizer.all_special_ids)
special_mask = torch.tensor(
    [i in special_ids for i in range(processor.tokenizer.vocab_size)],
    dtype=torch.bool,
    device=DEVICE,
)

fusion_proc = ShallowFusion(
    lm=gpt2, lm_tok=gpt2_tok,
    shared_vocab=gpt2.config.vocab_size,
    special_mask=special_mask,
    alpha=0.25
)

In [35]:
from transformers import GenerationConfig
from tqdm import tqdm 

fused = []
refs = []

gen_cfg = GenerationConfig(
    num_beams=2,
    do_sample=False,
    max_length=448,
)

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            generation_config=gen_cfg,
        )
    decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
    fused.extend(decoded)
    print(decoded)
    break

Decoding:   0%|          | 0/6 [00:00<?, ?it/s]

['<|startoftranscript|>', '<|notimestamps|>', ' The', ' e']
['<|startoftranscript|>', '<|notimestamps|>', ' the', ' e']
['<|startoftranscript|>', '<|notimestamps|>', ' Post', 'operative']
['<|startoftranscript|>', '<|notimestamps|>', ' Post', '-']
['<|startoftranscript|>', '<|notimestamps|>', ' Her', ' hem']
['<|startoftranscript|>', '<|notimestamps|>', ' Her', ' ha']
['<|startoftranscript|>', '<|notimestamps|>', ' Magnetic', ' resonance']
['<|startoftranscript|>', '<|notimestamps|>', ' Magnetic', ' Reson']
['<|startoftranscript|>', '<|notimestamps|>', ' We', ' started']
['<|startoftranscript|>', '<|notimestamps|>', ' we', ' started']
['<|startoftranscript|>', '<|notimestamps|>', ' The', ' e', 'ch']
['<|startoftranscript|>', '<|notimestamps|>', ' the', ' e', 'ch']
['<|startoftranscript|>', '<|notimestamps|>', ' Post', '-', 'operative']
['<|startoftranscript|>', '<|notimestamps|>', ' Post', 'operative', ' pathology']
['<|startoftranscript|>', '<|notimestamps|>', ' Her', ' hem', 'oglobin

Decoding:   0%|          | 0/6 [00:07<?, ?it/s]

['<|startoftranscript|>', '<|notimestamps|>', ' Magnetic', ' Reson', 'ance', ' Imaging', ' revealed', ' a', ' 3', ' cent', 'imeter', ' dem', 'y', 'el', 'inating', ' plaque', ' in', ' the', ' per', 'iv', 'entric', 'ular', ' white', ' matter', '.', ' (']
['<|startoftranscript|>', '<|notimestamps|>', ' We', ' started', ' ce', 'ft', 'ri', 'ax', 'one', ' 2', ' grams', ',', ' given', ' intraven', 'ously', ',', ' every', ' 24', ' hours', ' for', ' suspected', ' bacterial', ' men', 'ing', 'itis', '.']
['<|startoftranscript|>', '<|notimestamps|>', ' We', ' started', ' ce', 'ft', 'ri', 'ax', 'one', ' two', ' grams', ',', ' given', ' intraven', 'ously', ',', ' every', ' 24', ' hours', ' for', ' suspected', ' bacterial', ' men', 'ing', 'itis', '.']
[' The echocardiogram shows an ejection fraction of 35% with global hypokinesis.', ' Post-operative pathology confirmed a stage 2a adenocarcinoma of the sigmoid colon.', ' Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.', ' Mag




In [None]:
vanilla = []
for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            generation_config=gen_cfg
        )
    decoded = processor.batch_decode(vanilla_ids, skip_special_tokens=True)
    vanilla.extend(decoded)

In [None]:
import pandas as pd 
print(pd.DataFrame({'vanilla':vanilla, 'fused':fused, 'gt':ds['text']}).to_markdown())

# BONEYARD


In [None]:
# class ShallowFusion(LogitsProcessor):
#     def __init__(self, lm, tok, shared_vocab, special_mask, alpha=0.25):
#         super().__init__()
#         self.lm   = lm.eval().requires_grad_(False)
#         self.tok  = tok
#         self.V    = shared_vocab
#         self.mask = special_mask
#         self.alpha = alpha
#         self.step  = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         self.step += 1
#         if self.step <= 2:
#             return scores                    # warm‑up

#         B = input_ids.size(0)                # batch*beams
#         dev = scores.device
#         mask = self.mask.to(dev)

#         # —— build GPT context for **each** beam independently ——
#         gpt_ctx = []
#         for seq in input_ids:                # loop over beams
#             keep = ~mask[seq]
#             filtered = seq[keep]
#             ctx = filtered if filtered.numel() else seq[-1:]
#             gpt_ctx.append(ctx)

#         # pad to equal length & stack
#         gpt_ids = torch.nn.utils.rnn.pad_sequence(
#             gpt_ctx, batch_first=True,
#             padding_value=self.tok.eos_token_id
#         ).to(dev)                            # (B, L_ctx)

#         # —— LM logits for every beam ——
#         g_logits = self.lm(gpt_ids).logits[:, -1, :]   # (B, V)

#         # —— Whisper & GPT log‑probs ——
#         w_lp_full   = torch.log_softmax(scores, dim=-1)     # (B, V_whisper)
#         w_lp_shared = w_lp_full[:, : self.V]                # (B, V)

#         g_lp = torch.log_softmax(g_logits, dim=-1)          # (B, V)

#         fused = w_lp_shared + self.alpha * g_lp             # (B, V)
#         fused = torch.where(
#             mask[: self.V].unsqueeze(0),    # keep Whisper specials intact
#             w_lp_shared,
#             fused,
#         )

#         return torch.cat([fused, w_lp_full[:, self.V:]], dim=-1)
    
# special_ids = set(processor.tokenizer.all_special_ids)
# special_mask = torch.tensor(
#     [i in special_ids for i in range(processor.tokenizer.vocab_size)],
#     dtype=torch.bool,
#     device=DEVICE,
# )

# fusion_proc = ShallowFusion(
#     lm=gpt2, tok=gpt2_tok,
#     shared_vocab=gpt2.config.vocab_size,
#     special_mask=special_mask,
#     alpha=0.3
# )

# from transformers import GenerationConfig
# from tqdm import tqdm 

# fused = []
# refs = []

# gen_cfg = GenerationConfig(
#     num_beams=4,
#     do_sample=False,
#     max_length=448,
# )

# for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
#     feats = batch['input_features'].to(DEVICE)
#     masks = batch['attention_mask'].to(DEVICE)
#     with torch.no_grad():
#         fused_ids = whisper.generate(
#             input_features=feats,
#             attention_mask=masks,
#             logits_processor=LogitsProcessorList([fusion_proc]),
#             generation_config=gen_cfg,
#         )
#     decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
#     fused.extend(decoded)
#     print(decoded)
#     break

In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F
# from tqdm import tqdm

# class HelloWorldProcessor(LogitsProcessor):
#     """
#     A toy processor that adds +5.0 to the logit for token ID 7 at every step.
#     """
#     def __init__(self, elm, elm_tokenizer, shared_vocab, special_mask, alpha):
#         self.elm = elm
#         self.elm_tokenizer = elm_tokenizer
#         self.shared_vocab = shared_vocab
#         self.special_mask = special_mask
#         self.alpha = alpha
#         self.counter = 0 

#     def __call__(self, input_ids, scores):

#         self.counter +=1
#         if self.counter > 2:

#             keep = ~self.special_mask[input_ids[0]]
#             filtered = input_ids[0, keep]
            
#             if filtered.numel() == 0: 
#                 elm_ids = input_ids[:, -1:].clone()

#             else:
#                 elm_ids = filtered.unsqueeze(0)
#             with torch.no_grad():
#                 w_lp_full = F.log_softmax(scores, dim=-1)
#                 w_lp_shared = w_lp_full[:, :SHARED_VOCAB]

#                 elm_logits = self.elm(elm_ids).logits[:, -1, :]
#                 elm_lp = F.log_softmax(elm_logits, dim=-1)

#                 fused_shared = w_lp_shared + self.alpha * elm_lp
#                 fused_logits = torch.cat([
#                     fused_shared,
#                     w_lp_full[:, SHARED_VOCAB:],
#                 ], dim=-1)
#         else:
#             fused_logits = scores.clone()
#         return fused_logits
    
# special_ids = set(processor.tokenizer.all_special_ids)
# special_mask = torch.tensor(
#     [i in special_ids for i in range(processor.tokenizer.vocab_size)],
#     dtype=torch.bool,
#     device=DEVICE,
# )

# b = next(iter(loader))
# feats = b['input_features'].to(DEVICE)
# mask = b['attention_mask'].to(DEVICE)

# hello_proc = HelloWorldProcessor(
#     elm=gpt2, 
#     elm_tokenizer=gpt2_tok, 
#     shared_vocab=SHARED_VOCAB, 
#     special_mask=special_mask,
#     alpha=0.3
#     )

# lp_list = LogitsProcessorList([hello_proc])

# all_outputs = []
# all_references = []
# for batch_idx, batch in enumerate(tqdm(loader, desc="Processing batches")):
#     feats = batch['input_features'].to(DEVICE)
#     masks = batch['attention_mask'].to(DEVICE)
#     with torch.no_grad():
#         out_ids = whisper.generate(
#             input_features=feats,
#             attention_mask=mask,
#             logits_processor=lp_list,
#             max_new_tokens=50,
#             num_beams=3,
#             no_repeat_ngram_size=3,
#         )
#     decoded = processor.batch_decode(out_ids, skip_special_tokens=True)
#     all_outputs.extend(decoded)

In [None]:
## if i wanted to keep uuid in batch id have to use collator and do the following... 

# def collate_fn(batch):
#     uuids = [item["uuid"] for item in batch]
#     feats = torch.stack([
#         torch.as_tensor(item["input_features"], dtype=torch.float32)
#         for item in batch
#     ], dim=0)
#     masks = torch.stack([
#         torch.as_tensor(item["attention_mask"], dtype=torch.long)
#         for item in batch
#     ], dim=0)  # shape (B, T)

#     return {
#         "uuid":           uuids,
#         "input_features": feats,
#         "attention_mask": masks,
#     }

# loader = DataLoader(
#     ds,
#     collate_fn=collate_fn,
#     batch_size=BATCH_SIZE,
#     shuffle=False,   # or True if you want
# )