In [13]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)

In [22]:
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-tiny.en"
GPT2_ID = "cwestnedge/gpt2-large-pubmed"
SHARED_VOCAB = 50_257
ALPHA = 0.35
INIT_W_STEPS = 2
MAX_STEPS = 256

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

device = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", device)

Device: mps


In [17]:
# def build_dataset(pattern: str) -> Dataset:
#     def process_audio(batch):
#         batch["audio"] = [
#             librosa.load(p, sr=SR, mono=True)[0].astype(np.float32)
#             for p in batch["path"]
#         ]
#         return batch

#     paths = glob.glob(pattern)
#     ds = Dataset.from_dict({"path": paths})
#     ds = ds.map(process_audio, batched=True, batch_size=BATCH_SIZE, num_proc=4).remove_columns("path")
#     return ds

def build_dataset(manifest_path: str) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    # 2) load audio and attach as new column
    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=BATCH_SIZE, num_proc=4)

def collate(batch):
    audio = [b["audio"] for b in batch]      # list[np.ndarray]
    refs = [b["text"]  for b in batch]      # ground-truth
    uuids = [b["uuid"]  for b in batch]      # for debugging / join
    return audio, refs, uuids

ds = build_dataset(MANIFEST)
loader = DataLoader(
    ds, 
    batch_size=BATCH_SIZE, 
    collate_fn=collate,
    num_workers=0, 
    pin_memory=False
)

Map (num_proc=4): 100%|██████████| 30/30 [00:00<00:00, 35.26 examples/s]


In [18]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(device).eval()

gpt_tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(device).eval()

EOS_ID = processor.tokenizer.eos_token_id
WHISPER_SPECIALS = set(processor.tokenizer.all_special_ids)

In [23]:
import torch.nn.functional as F

@torch.no_grad()
def fuse_whisper_gpt(wav: np.ndarray) -> str:
    """Decode a single waveform with shallow fusion."""
    feats = processor(wav, sampling_rate=SR,
                      return_tensors="pt").input_features.to(device)

    dec_ids = torch.tensor(
        [[whisper.config.decoder_start_token_id]], device=device
    )
    gpt_ids = torch.empty(1, 0, dtype=torch.long, device=device)

    for step in range(MAX_STEPS):
        w_logits = whisper(feats, decoder_input_ids=dec_ids,
                           use_cache=True).logits[:, -1] # (1, Vw)

        if step < INIT_W_STEPS: # allow <s> and task token
            next_id = w_logits.argmax(-1, keepdim=True)

        else:
            if gpt_ids.numel() == 0: # first LM step
                gpt_ids = dec_ids

            g_logits = gpt2(gpt_ids).logits[:, -1] # (1, Vg)

            w_lp = F.log_softmax(w_logits[:, :SHARED_VOCAB], dim=-1)
            g_lp = F.log_softmax(g_logits, dim=-1)
            fused = w_lp + ALPHA * g_lp  # (1, Vg)
            next_id = fused.argmax(-1, keepdim=True)

        dec_ids = torch.cat([dec_ids, next_id], dim=-1)

        # refresh GPT‑2 context (strip Whisper special tokens)
        cleaned = [t for t in dec_ids[0].tolist() if t not in WHISPER_SPECIALS]
        gpt_ids = torch.tensor([cleaned], device=device)

        if next_id.item() == EOS_ID:
            break

    return processor.batch_decode(dec_ids, skip_special_tokens=True)[0]

In [24]:
vanilla_txt, fusion_txt, gt_text = [], [], []

with torch.inference_mode():
    for wavs, refs, ids in loader: 
        # VANILLA WHISPER
        feats = processor(
            wavs, 
            sampling_rate=SR,
            return_tensors="pt", 
            padding=True
        ).input_features.to(device) # (B, 80, T_max)

        # greedy decode
        vanilla_ids = whisper.generate(feats)
        vanilla_txt.extend(
            processor.batch_decode(vanilla_ids, skip_special_tokens=True)
        )

        # SHALLOW FUSION
        for wav in wavs:
            fusion_txt.append(fuse_whisper_gpt(wav))

        gt_text.extend(refs)

In [25]:
for i, (v, f, g) in enumerate(zip(vanilla_txt, fusion_txt, gt_text)):
    print(f"--- Clip {i} ---")
    print("Source :", g)
    print("Whisper:", v.strip())
    print("Fusion :", f.strip(), "\n")

--- Clip 0 ---
Source : The echocardiogram shows an ejection fraction of thirty-five percent with global hypokinesis.
Whisper: The echocardiogram shows an ejection fraction of 35% with global hypo-kinesis.
Fusion : The echocardiogram shows an ejection fraction of 35% with global hypokinesis. 

--- Clip 1 ---
Source : Post-operative pathology confirmed a stage two-A adenocarcinoma of the sigmoid colon.
Whisper: Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.
Fusion : Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon. 

--- Clip 2 ---
Source : Her hemoglobin A-one-C has stabilized at seven point one percent after switching to semaglutide.
Whisper: Her hemoglobin A1C has stabilized at 7.1% after switching to seem a glutide.
Fusion : Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide. 

--- Clip 3 ---
Source : Magnetic resonance imaging revealed a three-centimeter demyelinating plaque in the periventri