In [238]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SF_EVAL', '.models', 'hfcache']


In [253]:
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-tiny.en"
GPT2_ID = "cwestnedge/gpt2-large-pubmed"
SHARED_VOCAB = 50257
ALPHA = 0.3
INIT_W_STEPS = 2
MAX_STEPS = 256

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

device = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", device)

Device: mps


In [254]:
# def build_dataset(pattern: str) -> Dataset:
#     def process_audio(batch):
#         batch["audio"] = [
#             librosa.load(p, sr=SR, mono=True)[0].astype(np.float32)
#             for p in batch["path"]
#         ]
#         return batch

#     paths = glob.glob(pattern)
#     ds = Dataset.from_dict({"path": paths})
#     ds = ds.map(process_audio, batched=True, batch_size=BATCH_SIZE, num_proc=4).remove_columns("path")
#     return ds

def build_dataset(manifest_path: str) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=BATCH_SIZE, num_proc=4)

def collate(batch):
    audio = [b["audio"] for b in batch]      # list[np.ndarray]
    refs = [b["text"]  for b in batch]      # ground-truth
    uuids = [b["uuid"]  for b in batch]      # for debugging / join
    return audio, refs, uuids

ds = build_dataset(MANIFEST)
loader = DataLoader(
    ds, 
    batch_size=BATCH_SIZE, 
    collate_fn=collate,
    num_workers=0, 
    pin_memory=False
)

Map (num_proc=4):   0%|          | 0/30 [00:00<?, ? examples/s]

Map (num_proc=4): 100%|██████████| 30/30 [00:01<00:00, 27.23 examples/s]


In [247]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(device).eval()

gpt_tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(device).eval()

EOS_ID = processor.tokenizer.eos_token_id
WHISPER_SPECIALS = set(processor.tokenizer.all_special_ids)

In [248]:
import torch.nn.functional as F

@torch.no_grad()
def fuse_whisper_gpt(wav: np.ndarray, alpha: float = 0.25) -> str:
    """Decode a single waveform with shallow fusion."""
    feats = processor(wav, sampling_rate=SR, return_tensors="pt").input_features.to(device)

    dec_ids = torch.tensor([[whisper.config.decoder_start_token_id]], device=device)
    gpt_ids = torch.empty(1, 0, dtype=torch.long, device=device)

    for step in range(MAX_STEPS):
        w_logits = whisper(
            feats, 
            decoder_input_ids=dec_ids, 
            use_cache=True
            ).logits[:, -1] # (1, Vw)

        if step < INIT_W_STEPS: # allow <s> and task token
            next_id = w_logits.argmax(-1, keepdim=True)

        else:
            if gpt_ids.numel() == 0: # first LM step
                gpt_ids = dec_ids

            g_logits = gpt2(gpt_ids).logits[:, -1] # (1, Vg)

            w_lp = F.log_softmax(w_logits[:, :SHARED_VOCAB], dim=-1)
            g_lp = F.log_softmax(g_logits, dim=-1)
            fused = w_lp + alpha * g_lp  # (1, Vg)
            next_id = fused.argmax(-1, keepdim=True)

        dec_ids = torch.cat([dec_ids, next_id], dim=-1)

        # refresh GPT‑2 context (strip Whisper special tokens)
        cleaned = [t for t in dec_ids[0].tolist() if t not in WHISPER_SPECIALS]
        gpt_ids = torch.tensor([cleaned], device=device)

        if next_id.item() == EOS_ID:
            break

    return processor.batch_decode(dec_ids, skip_special_tokens=True)[0]

In [249]:
import torch.nn.functional as F

@torch.no_grad()
def fuse_whisper_gpt1(wav: np.ndarray, alpha: float=0.25, variance_match: bool = False) -> str:
    """Decode a single waveform with shallow fusion."""
    feats = processor(wav, sampling_rate=SR, return_tensors="pt").input_features.to(device)

    dec_ids = torch.tensor([[whisper.config.decoder_start_token_id]], device=device)
    gpt_ids = torch.empty(1, 0, dtype=torch.long, device=device)

    for step in range(MAX_STEPS):
        w_logits = whisper(
            feats, 
            decoder_input_ids=dec_ids, 
            use_cache=True
            ).logits[:, -1] # (1, Vw)

        if step < INIT_W_STEPS: # allow <s> and task token
            next_id = w_logits.argmax(-1, keepdim=True)

        else:
            if gpt_ids.numel() == 0: # first LM step
                gpt_ids = dec_ids

            g_logits = gpt2(gpt_ids).logits[:, -1] # (1, Vg)

            w_lp = F.log_softmax(w_logits[:, :SHARED_VOCAB], dim=-1)
            g_lp = F.log_softmax(g_logits, dim=-1)

            if variance_match and alpha > 0 : 
                w_lp = w_lp/w_lp.std(unbiased=False)
                g_lp = g_lp/g_lp.std(unbiased=False).clamp_min(1e-6)

            fused = w_lp.clone()
            fused[:,:SHARED_VOCAB] += alpha * g_lp
            next_id = fused.argmax(-1, keepdim=True)

        
        dec_ids = torch.cat([dec_ids, next_id], dim=-1)

        # refresh GPT‑2 context (strip Whisper special tokens)
        cleaned = [t for t in dec_ids[0].tolist() if t not in WHISPER_SPECIALS]
        gpt_ids = torch.tensor([cleaned], device=device)

        if next_id.item() == EOS_ID:
            break

    return processor.batch_decode(dec_ids, skip_special_tokens=True)[0]

In [259]:
special_mask = torch.tensor(
    processor.tokenizer.get_special_tokens_mask(list(range(processor.tokenizer.vocab_size)),
                                        already_has_special_tokens=True),
    dtype=torch.bool, device=device
)                                 # 1 = special, incl. those < 50 257

@torch.no_grad()
def fuse_whisper_gpt_fast(wav, alpha=0.25, var_match=False):
    feats = processor(wav, sampling_rate=SR, return_tensors="pt").input_features.to(device)

    dec_ids = torch.tensor([[whisper.config.decoder_start_token_id]], device=device)
    gpt_ids = torch.empty_like(dec_ids, dtype=torch.long)[:, 0:0]

    for step in range(MAX_STEPS):
        w_logits = whisper(feats, decoder_input_ids=dec_ids, use_cache=True).logits[:, -1]

        if step < INIT_W_STEPS:
            next_id = w_logits.argmax(-1, keepdim=True)
        else:
            if gpt_ids.numel() == 0:
                # strip *all* whisper specials via tokenizer mask
                keep = ~special_mask[dec_ids[0]]
                gpt_ids = dec_ids[0, keep].unsqueeze(0)

            g_logits = gpt2(gpt_ids).logits[:, -1]

            w_lp = F.log_softmax(w_logits[:, :SHARED_VOCAB], dim=-1)
            g_lp = F.log_softmax(g_logits, dim=-1)

            if var_match and alpha:
                w_lp = w_lp / w_lp.std(unbiased=False)
                g_lp = g_lp / g_lp.std(unbiased=False).clamp_min(1e-6)

            fused = w_lp + alpha * g_lp
            next_id = fused.argmax(-1, keepdim=True)

        dec_ids = torch.cat([dec_ids, next_id], dim=-1)

        # feed GPT only if the token is still < 50 257
        if next_id.item() < SHARED_VOCAB:
            gpt_ids = torch.cat([gpt_ids, next_id], dim=-1)

        if next_id.item() == processor.tokenizer.eos_token_id:
            break

    return processor.batch_decode(dec_ids, skip_special_tokens=True)[0]

In [273]:
vanilla_txt, fusion_txt, gt_text = [], [], []

with torch.inference_mode():
    for wavs, refs, ids in loader: 
        # VANILLA WHISPER
        feats = processor(
            wavs, 
            sampling_rate=SR,
            return_tensors="pt", 
            padding=True
        ).input_features.to(device) # (B, 80, T_max)

        # greedy decode
        vanilla_ids = whisper.generate(feats)
        vanilla_txt.extend(
            processor.batch_decode(vanilla_ids, skip_special_tokens=True)
        )

        # SHALLOW FUSION
        for wav in wavs:
            # fused = fuse_whisper_gpt1(wav, alpha=0.0, variance_match=False)
            fused = fuse_whisper_gpt_fast(wav, alpha=0.3)
            # fused = fuse_whisper_gpt(wav, alpha=0.3)
            fusion_txt.append(fused)

        gt_text.extend(refs)

In [274]:
for i, (g, w, f) in enumerate(zip(gt_text, vanilla_txt, fusion_txt)):
    print(f"--- Clip {i} ---")
    print("Source :", g.strip())
    print("Whisper:", w.strip())
    print("Fusion :", f.strip())
    print()

--- Clip 0 ---
Source : The echocardiogram shows an ejection fraction of thirty-five percent with global hypokinesis.
Whisper: The echocardiogram shows an ejection fraction of 35% with global hypo-kinesis.
Fusion : The echocardiogram shows an ejection fraction of 35% with global hypokinesis.

--- Clip 1 ---
Source : Post-operative pathology confirmed a stage two-A adenocarcinoma of the sigmoid colon.
Whisper: Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.
Fusion : Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.

--- Clip 2 ---
Source : Her hemoglobin A-one-C has stabilized at seven point one percent after switching to semaglutide.
Whisper: Her hemoglobin A1C has stabilized at 7.1% after switching to seem a glutide.
Fusion : Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.

--- Clip 3 ---
Source : Magnetic resonance imaging revealed a three-centimeter demyelinating plaque in the periventricul

### BONEYARD

In [252]:
def detect_shared_prefix(wtok, gtok, hint=50257):
    # fast: compare the bytes of the first `hint` tokens
    for i in range(hint):
        if wtok.convert_ids_to_tokens(i) != gtok.convert_ids_to_tokens(i):
            raise ValueError(
                f"Token {i} differs between Whisper and GPT‑2 – "
                "prefix assumption broken. Use the full w2g map."
            )
    return hint

SHARED_VOCAB = detect_shared_prefix(processor.tokenizer, gpt_tok)  # ➜ 50257
SHARED_VOCAB

50257

In [None]:
special_mask = torch.tensor(
    processor.tokenizer.get_special_tokens_mask(list(range(processor.tokenizer.vocab_size)),
                                        already_has_special_tokens=True),
    dtype=torch.bool, device=device
)                                 # 1 = special, incl. those < 50 257

@torch.no_grad()
def fuse_whisper_gpt_fast(wav, alpha=0.25, var_match=False):
    feats = processor(wav, sampling_rate=SR,
                      return_tensors="pt").input_features.to(device)

    dec_ids = torch.tensor([[whisper.config.decoder_start_token_id]],
                           device=device)
    gpt_ids = torch.empty_like(dec_ids, dtype=torch.long)[:, 0:0]

    for step in range(MAX_STEPS):
        w_logits = whisper(feats, decoder_input_ids=dec_ids,
                           use_cache=True).logits[:, -1]

        if step < INIT_W_STEPS:
            next_id = w_logits.argmax(-1, keepdim=True)
        else:
            if gpt_ids.numel() == 0:
                # strip *all* whisper specials via tokenizer mask
                keep = ~special_mask[dec_ids[0]]
                gpt_ids = dec_ids[0, keep].unsqueeze(0)

            g_logits = gpt2(gpt_ids).logits[:, -1]

            w_lp = F.log_softmax(w_logits[:, :SHARED_VOCAB], dim=-1)
            g_lp = F.log_softmax(g_logits, dim=-1)

            if var_match and alpha:
                w_lp = w_lp / w_lp.std(unbiased=False)
                g_lp = g_lp / g_lp.std(unbiased=False).clamp_min(1e-6)

            fused = w_lp + alpha * g_lp
            next_id = fused.argmax(-1, keepdim=True)

        dec_ids = torch.cat([dec_ids, next_id], dim=-1)

        # feed GPT only if the token is still < 50 257
        if next_id.item() < SHARED_VOCAB:
            gpt_ids = torch.cat([gpt_ids, next_id], dim=-1)

        if next_id.item() == processor.tokenizer.eos_token_id:
            break

    return processor.batch_decode(dec_ids, skip_special_tokens=True)[0]

tensor([False, False, False,  ..., False, False,  True], device='mps:0')


In [223]:
np.array(processor.tokenizer.all_special_ids)

array([50256, 50257, 50258, 50259, 50260, 50261, 50262, 50263, 50264,
       50265, 50266, 50267, 50268, 50269, 50270, 50271, 50272, 50273,
       50274, 50275, 50276, 50277, 50278, 50279, 50280, 50281, 50282,
       50283, 50284, 50285, 50286, 50287, 50288, 50289, 50290, 50291,
       50292, 50293, 50294, 50295, 50296, 50297, 50298, 50299, 50300,
       50301, 50302, 50303, 50304, 50305, 50306, 50307, 50308, 50309,
       50310, 50311, 50312, 50313, 50314, 50315, 50316, 50317, 50318,
       50319, 50320, 50321, 50322, 50323, 50324, 50325, 50326, 50327,
       50328, 50329, 50330, 50331, 50332, 50333, 50334, 50335, 50336,
       50337, 50338, 50339, 50340, 50341, 50342, 50343, 50344, 50345,
       50346, 50347, 50348, 50349, 50350, 50351, 50352, 50353, 50354,
       50355, 50356, 50357, 50358, 50359, 50360, 50361, 50362])

In [None]:
# from transformers import LogitsProcessor
# import torch.nn.functional as F

# class GPT2FusionProcessor(LogitsProcessor):
#     def __init__(self, alpha, gpt2, shared_vocab, specials, variance_match):
#         self.alpha         = alpha
#         self.gpt2          = gpt2
#         self.shared_vocab  = shared_vocab
#         self.specials      = specials
#         self.variance      = variance_match

#     def __call__(self, input_ids, scores):
#         """
#         `scores` is the logit vector Whisper just produced for each batch item.
#         We add α·LM_log_probs on the shared slice.
#         """
#         if self.alpha == 0:
#             return scores  # vanilla path – identical to Whisper

#         # build GPT‑2 context (strip specials) for each item in batch
#         new_scores = scores.clone()
#         for b, ids in enumerate(input_ids):
#             ctx = [t for t in ids.tolist() if t not in self.specials] or ids.tolist()
#             g_logits = self.gpt2(torch.tensor([ctx], device=scores.device)
#                                  ).logits[:, -1, :]  # (1, Vg)

#             w_lp = F.log_softmax(scores[b], -1)
#             g_lp = F.log_softmax(g_logits[0], -1)

#             if self.variance:
#                 w_lp = w_lp / w_lp.std(unbiased=False)
#                 g_lp = g_lp / g_lp.std(unbiased=False).clamp_min(1e-6)

#             w_lp[:self.shared_vocab] += self.alpha * g_lp
#             new_scores[b] = w_lp

#         return new_scores

In [48]:
# import glob, numpy as np, soundfile as sf, scipy.signal as sps
# from IPython.display import Audio, display

# SR_TARGET = 16_000
# LOWPASS_HZ = 3_000
# WAV_PATH = glob.glob("../data/output/*.wav")[0]   # adjust if needed

# def butter_lowpass(wav, cutoff_hz, sr=SR_TARGET):
#     nyq = sr / 2
#     b, a = sps.butter(6, cutoff_hz / nyq, btype="low")
#     return sps.filtfilt(b, a, wav).astype(np.float32)

# wav, sr = sf.read(WAV_PATH, dtype="float32")
# wav = wav.mean(axis=1) if wav.ndim == 2 else wav
# if sr != SR_TARGET:
#     wav = sps.resample_poly(wav, SR_TARGET, sr)

# lp = butter_lowpass(wav, LOWPASS_HZ)
# sf.write("lowpass_only.wav", lp, SR_TARGET)

# print("Original :", WAV_PATH)
# print("Low-pass :", "lowpass_only.wav")

# display(Audio(wav, rate=SR_TARGET, autoplay=False))
# display(Audio(lp,  rate=SR_TARGET, autoplay=False))