In [1]:
import pandas as pd
import json
from llama_cpp import Llama, LlamaGrammar, LLAMA_ROPE_SCALING_TYPE_YARN
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
import os
from tqdm.notebook import tqdm
import time

%cd mediqa-oe/evaluation
from evaluate_oe import evaluate_sample
evaluate_sample
%cd ../..

with open('data/orders_data_transcript.json', 'r') as f:
    traindata = json.loads(f.read())


with open('data/test_orders_data_transcript.json', 'r') as f:
    data = json.loads(f.read())

startmodelloading_time = time.time()

try:
    llm
except NameError:
    print('Loading model')
    # llm = Llama(
    #     # model_path="/home/justin/llama/Qwen3-8B-Q5_K_M.gguf",  
    #     model_path="/home/justin/llama/Qwen3-14B-Q4_K_M.gguf",  
    #     # model_path="/home/justin/llama/Qwen3-14B-UD-Q4_K_XL.gguf",  
    #     # model_path="/home/justin/llama/Qwen3-4B-Q6_K.gguf",  
    #     draft_model=LlamaPromptLookupDecoding(),
    #     logits_all=True,
    #     n_gpu_layers=-1,
    #     flash_attn=True,
    #     n_ctx=int(4096*5.7),
    #     verbose=False,
    #     n_threads=os.cpu_count() - 2
    # )
    llm = Llama(
        model_path="Qwen3-32B-Q4_K_M.gguf",
        n_gpu_layers=-1,          # keep all layers on the 4090
        n_ctx=int(4096*4.5),               # physical KV‑cache size
        offload_kqv=True,
        rope_scaling_type=LLAMA_ROPE_SCALING_TYPE_YARN,   # or simply 2
        yarn_ext_factor=6,
        yarn_beta_fast=32,
        yarn_beta_slow=1,
        yarn_attn_factor=1,
        yarn_orig_ctx=4096,
        flash_attn=True,
        logits_all=True,
        draft_model=LlamaPromptLookupDecoding(),
        n_threads=os.cpu_count() - 4,
        verbose=False,
    )

    print(f'Loaded model {time.time()-startmodelloading_time:.4f}s.')

examplestr = \
f"""
EXAMPLE1START
Transcript:
{traindata['train'][0]['transcript']}
Desired output:
{traindata['train'][0]['expected_orders']}
EXAMPLE1END

EXAMPLE2START
Transcript:
{traindata['train'][-1]['transcript']}
Desired output:
{traindata['train'][-1]['expected_orders']}
EXAMPLE2END
""".replace('\'', '"')

system_prompt = f"""
You are an assistant for extracting medical orders from transcripts of patient doctor conversations.

Medical order extraction involves identifying and structuring various medical orders —such as medications, imaging studies, lab tests, and follow-ups— based on doctor-patient conversations. 

The conversation is given to you in the format of a list of dicts with turn_id, speaker (doctor or patient), and transcript for each turn.

You are to return a list of dicts with these keys: order_type, description, reason, and provenance. 

1. Return a list with one dict for each order from the conversation. Your output will be parsed with json.loads().
2. If there is only a single order, still return a list.
3. There are only four allowed values for the order_type: "medication", "lab", "followup", "imaging"
4. Provenance should be a list of ints relating to the turn ids where the order was made, including directly preceeding turns where the reason was mentioned.
5. Quote the reason verbatim from the text.
6. Only list *new* or repeat orders. Do not list things that the patient is to continue doing that are mentioned in passing. Do not list previous exams.
7. The above point (6.) is very important. Orders that have a "continue" status (e.g. "we will continue with xanax XXmg") are NOT considered valid orders, since we dont need to place a specific order in EHR for instance.
8. Lab orders are fine-grained, i.e. each test is one order.
9. If the doctor suggests an over the counter medication (e.g. pain killers), we also count that as an order - UNLESS the patient is already taking it (remember rule 6!)

Make sure your output is a list (i.e. starts and ends with square brackets) of dicts, separated by commas.

Examples:
{examplestr}
"""

data_to_use = data['test']
all_preds = {}
all_metrics = {}
metrics_per_sample = {}
failed_idxs = []
for idx, curr_sample in enumerate(tqdm(data_to_use)):
    curr_sample = data_to_use[idx]
    curr_transcript = curr_sample['transcript']
    
    user_prompt = f"Please process this transcript. Reply only with the valid response in the desired format.\nTRANSCRIPT START{curr_transcript}\nTRANSCRIPT END. Please follow all the rules and instructions carefully"
    prompt=f"""
    <|im_start|>system{system_prompt}<|im_end|>
    <|im_start|>user\n{user_prompt} \\nothink<|im_end|>
    <|im_start|>assistant\n<think>\n</think>\n
    """
    start = time.time()
    try:
        response = llm.create_completion(
            prompt=prompt,
            max_tokens=int(4096*0.5),
            temperature=0., top_k=1, top_p=1.0,
            stop=[]
        )
    except ValueError as e:
        print(e)
        failed_idxs.append(idx)
    curr_time = time.time()-start
    
    pred = response['choices'][0]['text']
    if "</think>" in pred:
        thought_trace = pred.split('</think>')[0]
        print(thought_trace)
        pred = pred.split('</think>')[-1]
    try:
        try:
            pred = json.loads(pred)
        except:
            pred = json.loads('['+pred+']')
            print(f'{idx=} was fixed with adding square brackets')
        if isinstance(pred, dict):
            print(f'{idx=} was a dict, we made it a list')
            pred = [pred]
    except:
        print(f'{idx} is not proper json.\n{pred}')
        pred = []

    all_preds[curr_sample['id']] = pred
    print(f'{idx=} took {curr_time:.4f}s {curr_sample["id"]=} {len(pred)=}')
    # print()

print('Done')

with open('data/test_qw32B_v5dot6_ex_specdec_v2c.json', 'w') as f:
    json.dump(all_preds, f)

/home/justinenglemann/llm/mediqa-oe/evaluation
/home/justinenglemann/llm
Loading model


llama_context: n_ctx_per_seq (18432) < n_ctx_train (40960) -- the full capacity of the model will not be utilized
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility


Loaded model 2.4340s.


  0%|          | 0/100 [00:00<?, ?it/s]

idx=0 took 21.6742s curr_sample["id"]='primock57_1_5' len(pred)=1
idx=1 took 3.5893s curr_sample["id"]='acibench_D2N188_aci_clef_taskC_test3' len(pred)=2
idx=2 took 5.1326s curr_sample["id"]='primock57_2_1' len(pred)=1
idx=3 took 2.8207s curr_sample["id"]='acibench_D2N079_aci_valid' len(pred)=2
idx=4 took 3.3141s curr_sample["id"]='acibench_D2N069_virtassist_valid' len(pred)=2
idx=5 took 3.3752s curr_sample["id"]='acibench_D2N036_aci_train' len(pred)=3
idx=6 took 2.7577s curr_sample["id"]='acibench_D2N202_aci_clef_taskC_test3' len(pred)=2
idx=7 took 2.5224s curr_sample["id"]='acibench_D2N172_virtassist_clef_taskC_test3' len(pred)=1
idx=8 took 3.9195s curr_sample["id"]='acibench_D2N033_aci_train' len(pred)=4
idx=9 took 3.4064s curr_sample["id"]='acibench_D2N194_aci_clef_taskC_test3' len(pred)=2
idx=10 took 2.6786s curr_sample["id"]='acibench_D2N013_virtassist_train' len(pred)=2
idx=11 took 3.7879s curr_sample["id"]='acibench_D2N071_virtassist_valid' len(pred)=2
idx=12 took 4.7267s curr_

In [3]:
# fallback for conversations that don't fit into vram

with open('data/test_qw14B_v5dot6_ex_specdec_v2c.json', 'r') as f:
    preds14B = json.loads(f.read())

for idx in failed_idxs:
    sample_id = data_to_use[idx]['id']
    print(idx, sample_id)
    all_preds[sample_id] = preds14B[sample_id]
    
with open('data/test_qw32B_v5dot6_ex_specdec_v2c_fixed.json', 'w') as f:
    json.dump(all_preds, f)

37 primock57_1_11
57 primock57_1_7
