In [36]:
import os, glob, librosa, numpy as np, torch, json
from pathlib import Path

CACHE_DIR = (Path.cwd().parent / ".models" / "hfcache").resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HOME"] = str(CACHE_DIR)
print(str(CACHE_DIR).split('/')[-3:])

['SF_EVAL', '.models', 'hfcache']


In [77]:
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    GPT2LMHeadModel, AutoTokenizer
)

SR = 16_000
BATCH_SIZE = 5
WHISPER_ID = "openai/whisper-tiny.en"
GPT2_ID = "cwestnedge/gpt2-medium-pubmed"
SHARED_VOCAB = 50257
ALPHA = 0.3
INIT_W_STEPS = 2
MAX_STEPS = 256

MANIFEST = "../data/output/manifest.jsonl"
AUDIO_DIR = "../data/output"  

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
print("Device:", DEVICE)

Device: mps


In [78]:
processor = WhisperProcessor.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR)
whisper = WhisperForConditionalGeneration.from_pretrained(WHISPER_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

gpt2_tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
gpt2 = GPT2LMHeadModel.from_pretrained(GPT2_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()

EOS_ID = processor.tokenizer.eos_token_id

In [79]:
def build_dataset(manifest_path: str, batch_size: int, num_proc: int = 4) -> Dataset:
    with open(manifest_path, encoding="utf-8") as f:
        rows = [json.loads(line) for line in f]

    ds = Dataset.from_list(rows)

    def add_audio(batch):
        batch["audio"] = [
            librosa.load(f"{AUDIO_DIR}/{fname}", sr=SR, mono=True)[0].astype(np.float32)
            for fname in batch["file"]
        ]
        return batch

    return ds.map(add_audio, batched=True, batch_size=batch_size, num_proc=num_proc)

def encode_audio(batch):
    # batch["audio"] is List[np.ndarray], each at its natural length
    feats = processor.feature_extractor(
        batch["audio"], # for whatever reason processor doesnt support PT tensors so numpy array or list for now.
        sampling_rate=SR,
        padding="max_length",
        truncation=True, 
        max_length=processor.feature_extractor.n_samples,  # n_samples = chunk_length * sampling_rate
        return_attention_mask=True,
        return_tensors="pt" 
    )

    #  input_features : Tensor (B, T_max, 80)
    #  attention_mask : Tensor (B, T_max)
    batch["input_features"] = feats.input_features
    batch["attention_mask"] = feats.attention_mask
    return batch

ds = build_dataset(MANIFEST, batch_size=BATCH_SIZE)
# choosing NOT to overwrite ds with removed fields so we can eval on text field later,
# could also create a collator and pass fields we care about through, but seems like 
# too much extra code tbh, indices will still match if we dont shuffle
ds_processed = ds.map(
    encode_audio, 
    batch_size=BATCH_SIZE, 
    batched=True,
    remove_columns=['uuid', 'file', 'category', 'index', 'text', 'audio']
    )

ds_processed.set_format(type="torch", columns=["input_features","attention_mask"])
loader = DataLoader(ds_processed, batch_size=BATCH_SIZE, shuffle=False)

Map (num_proc=4): 100%|██████████| 30/30 [00:01<00:00, 15.32 examples/s]
Map: 100%|██████████| 30/30 [00:00<00:00, 44.37 examples/s]


In [84]:
from transformers import LogitsProcessor, LogitsProcessorList
import torch.nn.functional as F

# Building a fusion class the goal of which is to easily hook into huggingface generate method
# hopefully this will allow us to lean more into native HF functionality

class ShallowFusion(LogitsProcessor):
    def __init__(self, lm, lm_tok, shared_vocab, special_mask, alpha=0.3):
        super().__init__()
        self.lm = lm.eval().requires_grad_(False)
        self.tok = lm_tok
        self.V = shared_vocab
        self.special_mask = special_mask
        self.alpha = alpha
        self.step = 0

    @torch.no_grad()
    def __call__(self, input_ids, scores):
        self.step += 1
        if self.step <= 2:
            return scores

        B = input_ids.size(0) # batch_size * num_beams NEW
        device = scores.device # this started as lazy but now i think its smart
        mask = self.special_mask.to(device)

        # ---- NEW ----
        elm_ctx= []
        for seq in input_ids: # iterate over our beams
            # print(processor.batch_decode(seq), self.step)
            keep = ~mask[seq]
            filtered = seq[keep]
            ctx = filtered if filtered.numel() else seq[-1:]
            elm_ctx.append(ctx)

        elm_ids = torch.nn.utils.rnn.pad_sequence(
            elm_ctx,batch_first=True,
            padding_value=self.tok.eos_token_id
        ).to(device)

        # if still no lexical token, skip fusion
        if (elm_ids < self.V).sum() == 0:
            return scores

        w_lp_full = torch.log_softmax(scores, dim=-1)
        w_lp_shared = w_lp_full[:, :self.V]

        elm_logits = self.lm(elm_ids).logits[:, -1, :]
        elm_lp = torch.log_softmax(elm_logits, dim=-1)

        fused_shared = w_lp_shared + self.alpha * elm_lp
        fused_shared = torch.where(
            mask[: self.V].unsqueeze(0), 
            w_lp_shared, 
            fused_shared
        )
        
        fused_ext = torch.cat([fused_shared, w_lp_full[:, self.V:]], dim=-1)
        return fused_ext # still log‑probs; generate() is fine with that?
    
special_ids = set(processor.tokenizer.all_special_ids)
special_mask = torch.tensor(
    [i in special_ids for i in range(processor.tokenizer.vocab_size)],
    dtype=torch.bool,
    device=DEVICE,
)

fusion_proc = ShallowFusion(
    lm=gpt2, lm_tok=gpt2_tok,
    shared_vocab=gpt2.config.vocab_size,
    special_mask=special_mask,
    alpha=0.35
)

In [None]:
from transformers import GenerationConfig
from tqdm import tqdm 

fused = []
refs = []

gen_cfg = GenerationConfig(
    num_beams=3,
    do_sample=True,  # Deterministic beam search
    max_length=448,
    repetition_penalty=1.2,  # 1.0 = no penalty, >1.0 = discourage repetition
    no_repeat_ngram_size=3,  # Prevent 3-gram repetitions
    length_penalty=1.0,      # 1.0 = neutral, <1.0 = shorter, >1.0 = longer
    # early_stopping=True,   
)

for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        fused_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            logits_processor=LogitsProcessorList([fusion_proc]),
            generation_config=gen_cfg,
        )
    decoded = processor.batch_decode(fused_ids, skip_special_tokens=True)
    fused.extend(decoded)

Decoding:   0%|          | 0/6 [00:00<?, ?it/s]

In [86]:
vanilla = []
for idx, batch in enumerate(tqdm(loader, total=len(loader), desc="Decoding")):
    feats = batch['input_features'].to(DEVICE)
    masks = batch['attention_mask'].to(DEVICE)
    with torch.no_grad():
        vanilla_ids = whisper.generate(
            input_features=feats,
            attention_mask=masks,
            generation_config=gen_cfg
        )
    decoded = processor.batch_decode(vanilla_ids, skip_special_tokens=True)
    vanilla.extend(decoded)

Decoding: 100%|██████████| 6/6 [00:12<00:00,  2.05s/it]


In [87]:
import pandas as pd 
print(pd.DataFrame({'vanilla':vanilla, 'fused':fused, 'gt':ds['text']}).to_markdown())

|    | vanilla                                                                                                                                                             | fused                                                                                                                                                                    | gt                                                                                                                                                                                       |
|---:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------

|    | vanilla                                                                                                                                                             | fused                                                                                                                                                                    | gt                                                                                                                                                                                       |
|---:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|  0 | The echocardiogram shows an ejection fraction of 35% with global hypochinesis.                                                                                      | The echocardiogram shows an ejection fraction of 35% with global hypokinesis.                                                                                            | The echocardiogram shows an ejection fraction of thirty-five percent with global hypokinesis.                                                                                            |
|  1 | Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.                                                                                  | Post-operative pathology confirmed a stage 2A adenocarcinoma of the sigmoid colon.                                                                                       | Post-operative pathology confirmed a stage two-A adenocarcinoma of the sigmoid colon.                                                                                                    |
|  2 | Her hemoglobin A1C has stabilized at 7.1% after switching to seem a glutide                                                                                         | Her hemoglobin A1c has stabilized at 7.1% after switching to semaglutide                                                                                                 | Her hemoglobin A-one-C has stabilized at seven point one percent after switching to semaglutide.                                                                                         |
|  3 | Magnetic resonance imaging revealed a 3 cm demyelinating plaque in the periventricular white matter.                                                                | Magnetic resonance imaging revealed a 3 cm demyelinating plaque in the periventricular white matter.                                                                     | Magnetic resonance imaging revealed a three-centimeter demyelinating plaque in the periventricular white matter.                                                                         |
|  4 | We started Seftrac Zone 2 grams given intravenously every 24 hours for suspected bacterial meningitis.                                                              | We started Seftrac's own two grams, given intravenously every 24 hours for suspected bacterial meningitis.                                                               | We started ceftriaxone, two grams given intravenously every twenty-four hours, for suspected bacterial meningitis.                                                                       |
|  5 | His BNP is 1,240 picograms per milliliter consistent with decompensated heart failure.                                                                              | His BNP is 1,240 picograms per milliliter. Consistent with decompensated heart failure                                                                                   | His B-N-P is one thousand two hundred forty picograms per milliliter, consistent with decompensated heart failure.                                                                       |
|  6 | She is allergic to fluoroquineolones and develops Steven's Johnson syndrome after taking Cybrofloxysin.                                                             | She is allergic to fluoroquineolones and develops Stevens-Johnson syndrome after taking Cyprofloxacin.                                                                   | She is allergic to fluoroquinolones and developed Stevens-Johnson syndrome after taking ciprofloxacin.                                                                                   |
|  7 | give 0.4 milligrams of sublingual nitroglycerin as needed if chest pain is not relieved by rest                                                                     | give 0.4 mg of sublingual nitroglycerin as needed, if chest pain is not relieved by rest                                                                                 | Give zero point four milligrams of sublingual nitroglycerin as needed if chest pain is not relieved by rest.                                                                             |
|  8 | We documented a mollum potty class three airway before intubation.                                                                                                  | We documented a mollum potty class three-airway before intubation.                                                                                                       | We documented a Mallampati class three airway before intubation.                                                                                                                         |
|  9 | Colonoscopy identified a sesile serrated lesion in the transverse colon removed with a cold snare.                                                                  | Colonoscopy identified a sessile serrated lesion in the transverse colon, removed with a cold snare.                                                                     | Colonoscopy identified a sessile serrated lesion in the transverse colon, removed with a cold snare.                                                                                     |
| 10 | The infants' app-gar scores were 8.9 at 1 and 5 minutes, respectively                                                                                               | The Infants' App-Gar scores were 8.9 at 1 and 5 minutes, respectively                                                                                                    | The infant’s Apgar scores were eight and nine at one and five minutes, respectively.                                                                                                     |
| 11 | Start methyl-prednisolone, one gram intravenously each day for acute optic neuritis.                                                                                | Start methyl-prednisolone, one gram intravenously each day for acute optic neuritis.                                                                                     | Start methylprednisolone, one gram intravenously each day, for acute optic neuritis.                                                                                                     |
| 12 | Artarial Blood Gas shows a PIO 2 of 55 millimeters of mercury on room air indicating moderate hypoxemia.                                                            | Artarial blood gas shows a PIO2 of 55 mm of mercury on room air, indicating moderate hypoxemia.                                                                          | Arterial blood gas shows a P-A-O-two of fifty-five millimeters of mercury on room air, indicating moderate hypoxemia.                                                                    |
| 13 | Current procedural terminology code 93306 for trans-thoracic echocardiography was partially denied because modifiers were missing.                                  | Current procedural terminology code 93306. For trans-thoracic echocardiography, was partially denied because modifiers were missing                                      | Current Procedural Terminology code nine three three zero six for transthoracic echocardiography was partially denied because modifiers were missing.                                    |
| 14 | The members deductible was met, so the $320 coin insurance on HCPCS code J1740 should be waived.                                                                    | The members' deductible was met, so the $320 coinsurance on HCPCS code J1740 should be waived.                                                                           | The member’s deductible was met, so the three-hundred-twenty-dollar coinsurance on H-C-P-C-S code J seventeen forty should be waived.                                                    |
| 15 | The patient presents with intermittent angina and a positive troponin 1 of 0.36 nanograms per milliliter                                                            | The patient presents with intermittent angina and a positive troponin 1 of 0.36 nanograms per milliliter                                                                 | The patient presents with intermittent angina and a positive troponin I of zero point three six nanograms per milliliter.                                                                |
| 16 | International classification of diseases 10th revision code J45.909 was flagged as unspecified asthma, a documentation update was requested                         | International Classification of Diseases, 10th revision code J45.909 was flagged as unspecified asthma. A documentation update was requested                             | International Classification of Diseases, tenth revision, code J forty-five point nine zero nine was flagged as unspecified asthma; a documentation update was requested.                |
| 17 | Doppler ultrasound detected a non-compressible femoral vein consistent with deep venous thrombosis.                                                                 | Doppler ultrasound detected a non-compressible femoral vein, consistent with deep venous thrombosis.                                                                     | Doppler ultrasound detected a non-compressible femoral vein, consistent with deep venous thrombosis.                                                                                     |
| 18 | The pharmacy rejected the GOP-1 authorization because the prior authorization expired on June 30, 2025.                                                             | The pharmacy rejected the GOP-1 authorization because the prior authorization expired on June 30, 2025.                                                                  | The pharmacy rejected the G-L-P-one authorization because the prior authorization expired on June thirtieth, twenty twenty-five.                                                         |
| 19 | Diagnosis-related group 330 payment was reduced due to a coding discrepancy with a secondary diagnosis of hypocalemia.                                              | Diagnosis-related group 330 payment was reduced due to a coding discrepancy with the secondary diagnosis of hypokalemia.                                                 | Diagnosis-related group three-thirty payment was reduced due to a coding discrepancy with a secondary diagnosis of hypokalemia.                                                          |
| 20 | the EOB shows a coordination of benefits adjustment after the Medicare crossover.                                                                                   | the EOB shows a coordination of benefits adjustment after the Medicare crossover.                                                                                        | The E-O-B shows a coordination-of-benefits adjustment after the Medicare crossover.                                                                                                      |
| 21 | We need operative notes to support current procedural terminology code 29881 for arthroscopic medial menisectomy.                                                   | We need operative notes to support current procedural terminology code.                                                                                                  | We need operative notes to support Current Procedural Terminology code two nine eight eight one for arthroscopic medial meniscectomy.                                                    |
| 22 | Modifier 25 was omitted on the Evaluation and Management Service, causing bundling with the injection.                                                              | Modifier 25 was omitted on the evaluation and management service, causing bundling with the injection.                                                                   | Modifier twenty-five was omitted on the evaluation and management service, causing bundling with the injection.                                                                          |
| 23 | Denial Code CO197 cites missing pre-certification for the Lumbar Laminesctomy.                                                                                      | Denial Code C0197 cites missing pre-certification for the Lumbar Laminectomy.                                                                                            | Denial code C-O one ninety-seven cites missing pre-certification for the lumbar laminectomy.                                                                                             |
| 24 | The Appeals Team requests a radiology report to validate current procedural terminology code 74177 for computed tomography of the abdomen and pelvis with contrast. | The Appeals Team requests a radiology report to validate current procedural terminology, code 7-4-1-7-7 for computed tomography of the abdomen and pelvis with contrast. | The appeals team requests a radiology report to validate Current Procedural Terminology code seven four one seven seven for computed tomography of the abdomen and pelvis with contrast. |
| 25 | and out of network penalty applied because NPI1215983746 lacks a single case agreement.                                                                             | and out of network penalty, applied because NPI1-215-983746 lacks a single case agreement.                                                                               | An out-of-network penalty applied because N-P-I one two one five nine eight three seven four six lacks a single-case agreement.                                                          |
| 26 | The ambulance claimed build A0428, but mileage code A0425 was missing triggering partial payment.                                                                   | The ambulance claimed "Build A-0428" but mileage code A-0525 was missing, triggering partial payment.                                                                    | The ambulance claim billed A zero four two eight, but mileage code A zero four two five was missing, triggering partial payment.                                                         |
| 27 | The prosthetic hip device falls under HCPCS code L8699, which is not covered without a K modifier.                                                                  | The prosthetic hip device falls under HCPCS code L8-699, which is not covered without a K-modifier.                                                                      | The prosthetic hip device falls under H-C-P-C-S code L eight six nine nine, which is not covered without a K modifier.                                                                   |
| 28 | The high cost threshold was exceeded. Specialty medication coded J3490 requires national drug code submission                                                       | The high-cost threshold was exceeded. Specialty medication coded J3490 requires national drug code submission                                                            | The high-cost threshold was exceeded; specialty medication coded J three four nine zero requires National Drug Code submission.                                                          |
| 29 | Claim edits routed, current procedural terminology 96372 to incidental when billed with 99214 on the same date of service.                                          | Claimed edits routed, current procedural terminology 96372 to incidental when billed with 99214 on the same date of service.                                             | Claim edits routed Current Procedural Terminology nine six three seven two to incidental when billed with nine nine two one four on the same date of service.                            |

# BONEYARD


In [None]:
# class ShallowFusionV1(LogitsProcessor):
#     def __init__(self, lm, lm_tok, shared_vocab, special_mask, alpha=0.3):
#         super().__init__()
#         self.lm = lm.eval().requires_grad_(False)
#         self.tok = lm_tok
#         self.V = shared_vocab
#         self.special_mask = special_mask
#         self.alpha = alpha
#         self.step = 0

#     @torch.no_grad()
#     def __call__(self, input_ids, scores):
#         self.step += 1
#         if self.step <= 2:
#             return scores

#         B = input_ids.size(0) # batch_size * num_beams
#         device = scores.device
#         mask = self.special_mask.to(device)

#         keep = ~mask[input_ids[0]]
#         filtered = input_ids[0, keep]
#         gpt_ids = filtered.unsqueeze(0) if filtered.numel() else input_ids[:, -1:].clone()
        
#         # if still no lexical token, skip fusion
#         if (gpt_ids < self.V).sum() == 0:
#             return scores
#         gpt_ids = gpt_ids.to(device)

#         w_lp_full = torch.log_softmax(scores, dim=-1)
#         w_lp_shared = w_lp_full[:, : self.V]

#         g_logits = self.lm(gpt_ids).logits[:, -1, :]
#         g_lp = torch.log_softmax(g_logits, dim=-1)

#         fused_shared = w_lp_shared + self.alpha * g_lp
#         fused_shared = torch.where(
#             mask[: self.V].unsqueeze(0), 
#             w_lp_shared, 
#             fused_shared
#         )
        
#         fused = torch.cat([fused_shared, w_lp_full[:, self.V:]], dim=-1)
#         return fused # still log‑probs; generate() is fine with that?

In [None]:
# from transformers import LogitsProcessor, LogitsProcessorList
# import torch.nn.functional as F
# from tqdm import tqdm

# class HelloWorldProcessor(LogitsProcessor):
#     """
#     A toy processor that adds +5.0 to the logit for token ID 7 at every step.
#     """
#     def __init__(self, elm, elm_tokenizer, shared_vocab, special_mask, alpha):
#         self.elm = elm
#         self.elm_tokenizer = elm_tokenizer
#         self.shared_vocab = shared_vocab
#         self.special_mask = special_mask
#         self.alpha = alpha
#         self.counter = 0 

#     def __call__(self, input_ids, scores):

#         self.counter +=1
#         if self.counter > 2:

#             keep = ~self.special_mask[input_ids[0]]
#             filtered = input_ids[0, keep]
            
#             if filtered.numel() == 0: 
#                 elm_ids = input_ids[:, -1:].clone()

#             else:
#                 elm_ids = filtered.unsqueeze(0)
#             with torch.no_grad():
#                 w_lp_full = F.log_softmax(scores, dim=-1)
#                 w_lp_shared = w_lp_full[:, :SHARED_VOCAB]

#                 elm_logits = self.elm(elm_ids).logits[:, -1, :]
#                 elm_lp = F.log_softmax(elm_logits, dim=-1)

#                 fused_shared = w_lp_shared + self.alpha * elm_lp
#                 fused_logits = torch.cat([
#                     fused_shared,
#                     w_lp_full[:, SHARED_VOCAB:],
#                 ], dim=-1)
#         else:
#             fused_logits = scores.clone()
#         return fused_logits
    
# special_ids = set(processor.tokenizer.all_special_ids)
# special_mask = torch.tensor(
#     [i in special_ids for i in range(processor.tokenizer.vocab_size)],
#     dtype=torch.bool,
#     device=DEVICE,
# )

# b = next(iter(loader))
# feats = b['input_features'].to(DEVICE)
# mask = b['attention_mask'].to(DEVICE)

# hello_proc = HelloWorldProcessor(
#     elm=gpt2, 
#     elm_tokenizer=gpt2_tok, 
#     shared_vocab=SHARED_VOCAB, 
#     special_mask=special_mask,
#     alpha=0.3
#     )

# lp_list = LogitsProcessorList([hello_proc])

# all_outputs = []
# all_references = []
# for batch_idx, batch in enumerate(tqdm(loader, desc="Processing batches")):
#     feats = batch['input_features'].to(DEVICE)
#     masks = batch['attention_mask'].to(DEVICE)
#     with torch.no_grad():
#         out_ids = whisper.generate(
#             input_features=feats,
#             attention_mask=mask,
#             logits_processor=lp_list,
#             max_new_tokens=50,
#             num_beams=3,
#             no_repeat_ngram_size=3,
#         )
#     decoded = processor.batch_decode(out_ids, skip_special_tokens=True)
#     all_outputs.extend(decoded)

In [None]:
## if i wanted to keep uuid in batch id have to use collator and do the following... 

# def collate_fn(batch):
#     uuids = [item["uuid"] for item in batch]
#     feats = torch.stack([
#         torch.as_tensor(item["input_features"], dtype=torch.float32)
#         for item in batch
#     ], dim=0)
#     masks = torch.stack([
#         torch.as_tensor(item["attention_mask"], dtype=torch.long)
#         for item in batch
#     ], dim=0)  # shape (B, T)

#     return {
#         "uuid":           uuids,
#         "input_features": feats,
#         "attention_mask": masks,
#     }

# loader = DataLoader(
#     ds,
#     collate_fn=collate_fn,
#     batch_size=BATCH_SIZE,
#     shuffle=False,   # or True if you want
# )