In [13]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)

In [22]:
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-tiny.en"
GPT2_ID = "cwestnedge/gpt2-large-pubmed"
SHARED_VOCAB = 50_257
ALPHA = 0.35
INIT_W_STEPS = 2
MAX_STEPS = 256

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

device = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", device)

Device: mps


In [35]:
# def build_dataset(pattern: str) -> Dataset:
#     def process_audio(batch):
#         batch["audio"] = [
#             librosa.load(p, sr=SR, mono=True)[0].astype(np.float32)
#             for p in batch["path"]
#         ]
#         return batch

#     paths = glob.glob(pattern)
#     ds = Dataset.from_dict({"path": paths})
#     ds = ds.map(process_audio, batched=True, batch_size=BATCH_SIZE, num_proc=4).remove_columns("path")
#     return ds

import numpy as np

def add_gaussian_noise(wav: np.ndarray, snr_db: float = 30.0) -> np.ndarray:
    if wav.size == 0:
        return wav  # empty guard

    # signal power  = mean(square)
    sig_power = np.mean(wav ** 2)
    if sig_power == 0:
        return wav

    # desired noise power
    snr_linear = 10 ** (snr_db / 10.0)
    noise_power = sig_power / snr_linear

    noise = np.random.normal(0.0, np.sqrt(noise_power), wav.shape).astype(np.float32)
    noisy = wav + noise

    # optional: renormalise to avoid clipping
    peak = np.max(np.abs(noisy))
    if peak > 1.0:
        noisy = noisy / peak

    return noisy

def build_dataset(manifest_path: str) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=BATCH_SIZE, num_proc=4)

def collate(batch):
    audio = [b["audio"] for b in batch]      # list[np.ndarray]
    refs = [b["text"]  for b in batch]      # ground-truth
    uuids = [b["uuid"]  for b in batch]      # for debugging / join
    return audio, refs, uuids

ds = build_dataset(MANIFEST)
loader = DataLoader(
    ds, 
    batch_size=BATCH_SIZE, 
    collate_fn=collate,
    num_workers=0, 
    pin_memory=False
)

Map (num_proc=4): 100%|██████████| 30/30 [00:00<00:00, 30.76 examples/s]


In [36]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(device).eval()

gpt_tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(device).eval()

EOS_ID = processor.tokenizer.eos_token_id
WHISPER_SPECIALS = set(processor.tokenizer.all_special_ids)

In [37]:
import torch.nn.functional as F

@torch.no_grad()
def fuse_whisper_gpt(wav: np.ndarray) -> str:
    """Decode a single waveform with shallow fusion."""
    feats = processor(wav, sampling_rate=SR,
                      return_tensors="pt").input_features.to(device)

    dec_ids = torch.tensor(
        [[whisper.config.decoder_start_token_id]], device=device
    )
    gpt_ids = torch.empty(1, 0, dtype=torch.long, device=device)

    for step in range(MAX_STEPS):
        w_logits = whisper(feats, decoder_input_ids=dec_ids,
                           use_cache=True).logits[:, -1] # (1, Vw)

        if step < INIT_W_STEPS: # allow <s> and task token
            next_id = w_logits.argmax(-1, keepdim=True)

        else:
            if gpt_ids.numel() == 0: # first LM step
                gpt_ids = dec_ids

            g_logits = gpt2(gpt_ids).logits[:, -1] # (1, Vg)

            w_lp = F.log_softmax(w_logits[:, :SHARED_VOCAB], dim=-1)
            g_lp = F.log_softmax(g_logits, dim=-1)
            fused = w_lp + ALPHA * g_lp  # (1, Vg)
            next_id = fused.argmax(-1, keepdim=True)

        dec_ids = torch.cat([dec_ids, next_id], dim=-1)

        # refresh GPT‑2 context (strip Whisper special tokens)
        cleaned = [t for t in dec_ids[0].tolist() if t not in WHISPER_SPECIALS]
        gpt_ids = torch.tensor([cleaned], device=device)

        if next_id.item() == EOS_ID:
            break

    return processor.batch_decode(dec_ids, skip_special_tokens=True)[0]

In [38]:
vanilla_txt, fusion_txt, gt_text = [], [], []

with torch.inference_mode():
    for wavs, refs, ids in loader: 
        # VANILLA WHISPER
        feats = processor(
            wavs, 
            sampling_rate=SR,
            return_tensors="pt", 
            padding=True
        ).input_features.to(device) # (B, 80, T_max)

        # greedy decode
        vanilla_ids = whisper.generate(feats)
        vanilla_txt.extend(
            processor.batch_decode(vanilla_ids, skip_special_tokens=True)
        )

        # SHALLOW FUSION
        for wav in wavs:
            fusion_txt.append(fuse_whisper_gpt(wav))

        gt_text.extend(refs)

In [39]:
for i, (v, f, g) in enumerate(zip(vanilla_txt, fusion_txt, gt_text)):
    print(f"--- Clip {i} ---")
    print("Whisper:", v.strip())
    print("Fusion :", f.strip())
    print("Source :", g)
    print()

--- Clip 0 ---
Whisper: The echocardiogram shows an ejection fraction of 35% with global hypo-kinesis.
Fusion : The echocardiogram shows an ejection fraction of 35% with global hypokinesis.
Source : The echocardiogram shows an ejection fraction of thirty-five percent with global hypokinesis.

--- Clip 1 ---
Whisper: Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.
Fusion : Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.
Source : Post-operative pathology confirmed a stage two-A adenocarcinoma of the sigmoid colon.

--- Clip 2 ---
Whisper: Her hemoglobin A1C has stabilized at 7.1% after switching to seem a glutide.
Fusion : Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide.
Source : Her hemoglobin A-one-C has stabilized at seven point one percent after switching to semaglutide.

--- Clip 3 ---
Whisper: Magnetic resonance imaging revealed a 3 cm demyelinating plaque in the periventricular white mat

In [None]:
import glob, numpy as np, soundfile as sf, scipy.signal as sps
from IPython.display import Audio, display

SR_TARGET = 16_000
LOWPASS_HZ = 3_000
WAV_PATH = glob.glob("../data/output/*.wav")[0]   # adjust if needed

def butter_lowpass(wav, cutoff_hz, sr=SR_TARGET):
    nyq = sr / 2
    b, a = sps.butter(6, cutoff_hz / nyq, btype="low")
    return sps.filtfilt(b, a, wav).astype(np.float32)

wav, sr = sf.read(WAV_PATH, dtype="float32")
wav = wav.mean(axis=1) if wav.ndim == 2 else wav
if sr != SR_TARGET:
    wav = sps.resample_poly(wav, SR_TARGET, sr)

lp = butter_lowpass(wav, LOWPASS_HZ)
sf.write("lowpass_only.wav", lp, SR_TARGET)

print("Original :", WAV_PATH)
print("Low-pass :", "lowpass_only.wav")

display(Audio(wav, rate=SR_TARGET, autoplay=False))
display(Audio(lp,  rate=SR_TARGET, autoplay=False))

Original : ../data/output/9a5d0da4-0949-4f67-b803-1f5b36918b5c.wav
Low-pass : lowpass_only.wav
