In [2]:
import pandas as pd
import json
from llama_cpp import Llama, LlamaGrammar
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
import os
from tqdm.notebook import tqdm
import time

%cd mediqa-oe/evaluation
from evaluate_oe import evaluate_sample
evaluate_sample
%cd ../..

with open('data/orders_data_transcript.json', 'r') as f:
    data = json.loads(f.read())

allids = [_['id'] for _ in data['dev']]

dev_extract = data['dev'][:12]

dev_extract_orders = {
    d['id']: d['expected_orders'] for d in dev_extract
}
with open('data/dev_extract_orders.json', 'w') as f:
    json.dump(dev_extract_orders, f)

startmodelloading_time = time.time()

try:
    llm
except NameError:
    print('Loading model')
    llm = Llama(
        # model_path="/home/justin/llama/Qwen3-8B-Q5_K_M.gguf",  
        # model_path="/home/justin/llama/Qwen3-14B-Q4_K_M.gguf",  
        # model_path="/home/justin/llama/Qwen3-14B-UD-Q4_K_XL.gguf",  
        model_path="/home/justin/llama/Qwen3-14B-UD-Q5_K_XL.gguf",  
        # model_path="/home/justin/llama/Qwen3-4B-Q6_K.gguf",  
        draft_model=LlamaPromptLookupDecoding(),
        logits_all=True,
        n_gpu_layers=-1,
        flash_attn=True,
        n_ctx=int(4096*5.7),
        verbose=False,
        n_threads=os.cpu_count() - 2
    )
    print(f'Loaded model {time.time()-startmodelloading_time:.4f}s.')

examplestr = \
f"""
EXAMPLE1START
Transcript:
{data['train'][0]['transcript']}
Desired output:
{data['train'][0]['expected_orders']}
EXAMPLE1END

EXAMPLE2START
Transcript:
{data['train'][-1]['transcript']}
Desired output:
{data['train'][-1]['expected_orders']}
EXAMPLE2END
""".replace('\'', '"')

system_prompt = f"""
You are an assistant for extracting medical orders from transcripts of patient doctor conversations.

Medical order extraction involves identifying and structuring various medical orders —such as medications, imaging studies, lab tests, and follow-ups— based on doctor-patient conversations. 

The conversation is given to you in the format of a list of dicts with turn_id, speaker (doctor or patient), and transcript for each turn.

You are to return a list of dicts with these keys: order_type, description, reason, and provenance. 

1. Return a list with one dict for each order from the conversation. Your output will be parsed with json.loads().
2. If there is only a single order, still return a list.
3. There are only four allowed values for the order_type: "medication", "lab", "followup", "imaging"
4. Provenance should be a list of ints relating to the turn ids where the order was made, including directly preceeding turns where the reason was mentioned.
5. Quote the reason verbatim from the text.
6. Only list *new* or repeat orders. Do not list things that the patient is to continue doing that are mentioned in passing. Do not list previous exams.
7. The above point (6.) is very important. Orders that have a "continue" status (e.g. "we will continue with xanax XXmg") are NOT considered valid orders, since we dont need to place a specific order in EHR for instance.
8. Lab orders are fine-grained, i.e. each test is one order.
9. If the doctor suggests an over the counter medication (e.g. pain killers), we also count that as an order - UNLESS the patient is already taking it (remember rule 6!)

Make sure your output is a list (i.e. starts and ends with square brackets) of dicts, separated by commas.

Examples:
{examplestr}
"""

data_to_use = data['dev'][:]
all_preds = {}
all_metrics = {}
metrics_per_sample = {}
for idx, curr_sample in enumerate(tqdm(data_to_use)):
    curr_sample = data_to_use[idx]
    curr_transcript = curr_sample['transcript']
    
    user_prompt = f"Please process this transcript. Reply only with the valid response in the desired format.\nTRANSCRIPT START{curr_transcript}\nTRANSCRIPT END. Please follow all the rules and instructions carefully"
    prompt=f"""
    <|im_start|>system{system_prompt}<|im_end|>
    <|im_start|>user\n{user_prompt} \\nothink<|im_end|>
    <|im_start|>assistant\n<think>\n</think>\n
    """
    start = time.time()
    response = llm.create_completion(
        prompt=prompt,
        max_tokens=int(4096*0.33),
        temperature=0., top_k=1, top_p=1.0,
        stop=[]
    )
    curr_time = time.time()-start
    print(f'{idx=} took {curr_time:.4f}s {curr_sample["id"]=}')
    
    pred = response['choices'][0]['text']
    if "</think>" in pred:
        thought_trace = pred.split('</think>')[0]
        print(thought_trace)
        pred = pred.split('</think>')[-1]
    try:
        try:
            pred = json.loads(pred)
        except:
            pred = json.loads('['+pred+']')
            print(f'{idx=} was fixed with adding square brackets')
        if isinstance(pred, dict):
            print(f'{idx=} was a dict, we made it a list')
            pred = [pred]
    except:
        print(f'{idx} is not proper json.\n{pred}')
        pred = []

    all_preds[curr_sample['id']] = pred

    t = {curr_sample['id']: curr_sample['expected_orders']}
    p = {curr_sample['id']: pred}
    metrics = evaluate_sample(t, p)
    metricstr = f"Desc F1:{metrics['description']['Rouge1_f1']:.3f} ({metrics['description']['Rouge1_recall']:.3f}/{metrics['description']['Rouge1_precision']:.3f})"
    metricstr += f"  Reas F1:{metrics['reason']['Rouge1_f1']:.3f} ({metrics['reason']['Rouge1_recall']:.3f}/{metrics['reason']['Rouge1_precision']:.3f})"
    metricstr += f"  OTyp F1:{metrics['order_type']['Strict_f1']:.3f} ({metrics['order_type']['Strict_recall']:.3f}/{metrics['order_type']['Strict_precision']:.3f})"
    metricstr += f"  Prov F1:{metrics['provenance']['MultiLabel_f1']:.3f} ({metrics['provenance']['MultiLabel_recall']:.3f}/{metrics['provenance']['MultiLabel_precision']:.3f})"
    
    print(metricstr)
    all_metrics[curr_sample['id']] = metrics
    metrics_per_sample[curr_sample['id']] = [metrics['description']['Rouge1_f1'], 
                                             metrics['reason']['Rouge1_f1'],
                                             metrics['order_type']['Strict_f1'],
                                             metrics['provenance']['MultiLabel_f1'],
                                             curr_time,
                                             len(pred),
                                             len(curr_sample['expected_orders']),
                                            ]
    print()

print('Done')
# 'dev_ALL_qw14BUDQ4KXL_v5dot6_ex_specdec'
# with open('data/dev_ALL_qw8BQ5KM_v5dot6_ex_specdec_v2c.json', 'w') as f:
#     json.dump(all_preds, f)
with open('data/dev_ALL_qw14BXLQ5_v5dot6_ex_specdec_v2c.json', 'w') as f:
    json.dump(all_preds, f)

import pandas as pd
df = pd.DataFrame(metrics_per_sample).T
df.columns = ['descF1', 'reasonF1', 'ordertypeF1', 'provF1', 'processingtime', 'len_pred', 'len_target']
df['avg'] = df[['descF1', 'reasonF1', 'ordertypeF1', 'provF1']].mean(axis=1)
print(df.mean(axis=0).round(3))
df.round(3)


/mnt/g/ProjectsOverflow/MEDIQA/mediqa-oe/evaluation
/mnt/g/ProjectsOverflow/MEDIQA
Loading model


llama_context: n_ctx_per_seq (23347) < n_ctx_train (40960) -- the full capacity of the model will not be utilized


Loaded model 3.4683s.


  0%|          | 0/100 [00:00<?, ?it/s]

idx=0 took 57.3408s curr_sample["id"]='acibench_D2N182_virtscribe_clef_taskC_test3'
Desc F1:0.190 (0.250/0.154)  Reas F1:0.267 (0.500/0.182)  OTyp F1:0.500 (0.500/0.500)  Prov F1:0.333 (0.333/0.333)

idx=1 took 21.3361s curr_sample["id"]='acibench_D2N174_virtassist_clef_taskC_test3'
Desc F1:0.909 (0.833/1.000)  Reas F1:0.877 (0.833/0.926)  OTyp F1:1.000 (1.000/1.000)  Prov F1:0.638 (0.611/0.667)

idx=2 took 18.9744s curr_sample["id"]='primock57_5_12'
Desc F1:0.077 (1.000/0.040)  Reas F1:0.167 (1.000/0.091)  OTyp F1:1.000 (1.000/1.000)  Prov F1:0.400 (0.500/0.333)

idx=3 took 24.1653s curr_sample["id"]='acibench_D2N140_virtscribe_clinicalnlp_taskC_test2'
Desc F1:0.541 (0.870/0.393)  Reas F1:0.091 (0.500/0.050)  OTyp F1:0.667 (1.000/0.500)  Prov F1:0.462 (0.600/0.375)

idx=4 took 18.4539s curr_sample["id"]='acibench_D2N067_aci_train'
Desc F1:0.800 (1.000/0.667)  Reas F1:0.000 (0.000/0.000)  OTyp F1:0.800 (1.000/0.667)  Prov F1:0.706 (0.750/0.667)

idx=5 took 51.2579s curr_sample["id"]='p

KeyboardInterrupt: 

In [2]:
import pandas as pd
import json
from llama_cpp import Llama, LlamaGrammar
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
import os
from tqdm.notebook import tqdm
import time
import matplotlib.pyplot as plt

%cd mediqa-oe/evaluation
from evaluate_oe import evaluate_sample
evaluate_sample
%cd ../..

with open('data/orders_data_transcript.json', 'r') as f:
    data = json.loads(f.read())

allids = [_['id'] for _ in data['dev']]

startmodelloading_time = time.time()

mname = "/home/justin/llama/Qwen3-14B-UD-Q4_K_XL.gguf"
try:
    llm
except NameError:
    print('Loading model')
    llm = Llama(
        model_path=mname,  
        draft_model=LlamaPromptLookupDecoding(),
        logits_all=True,
        n_gpu_layers=-1,
        flash_attn=True,
        # n_ctx=int(4096*5.7),
        # n_ctx=int(4096*7),
        n_ctx=int(4096*7),
        verbose=False,
        n_threads=os.cpu_count() - 2
    )
    print(f'Loaded model {time.time()-startmodelloading_time:.4f}s.')

def shorten_transcript(s):
    s = str(s)
    s = s.replace('turn_id', 'turn').replace('DOCTOR', 'DOC').replace('PATIENT', 'PAT').replace('transcript', 'txt')
    return s


train_examples_predicted_turns = {
    'acibench_D2N074_virtscribe_valid': [10, 14],
    'primock57_2_6': [184, 185, 186, 189, 190, 180, 181, 183, 191, 192],
    'acibench_D2N025_virtscribe_train': [80, 88, 90, 96, 133],
    'primock57_1_2': [80, 88, 90, 96, 133],    
}

def get_salient_transcript(sample):
    expected_output = sample['expected_orders']
    transcript = sample['transcript']
    salient_turns = [__ for _ in expected_output for __ in _['provenance']]
    salient_turns += train_examples_predicted_turns[sample['id']]
    salient_turns += [_-3 for _ in salient_turns]
    salient_turns += [_-2 for _ in salient_turns]
    salient_turns += [_-1 for _ in salient_turns]
    salient_turns += [_+1 for _ in salient_turns]
    salient_turns += [_+2 for _ in salient_turns]
    salient_turns += [_+3 for _ in salient_turns]
    salient_turns = sorted(set(salient_turns))
    salient_transcript = [transcript[idx] for idx in salient_turns if idx<len(transcript)]
    return salient_transcript

examplestr = \
f"""
EXAMPLE1START
Transcript:
{shorten_transcript(data['train'][0]['transcript'])}
Desired output:
{data['train'][0]['expected_orders']}
EXAMPLE1END

EXAMPLE2START
Transcript:
{shorten_transcript(data['train'][-1]['transcript'])}
Desired output:
{data['train'][-1]['expected_orders']}
EXAMPLE2END
""".replace('\'', '"')

def get_train_sample_from_idname(idname):
    return [_ for _ in data['train'] if idname in _['id']][0]

example_transcripts = [get_salient_transcript(get_train_sample_from_idname(idname)) for idname in train_examples_predicted_turns]
example_transcripts = [shorten_transcript(_) for _ in example_transcripts]
example_outputs = [get_train_sample_from_idname(idname)['expected_orders'] for idname in train_examples_predicted_turns]

examplestr = "".join([f"EXAMPLE{idx+1}START\nTranscript:\n{transcript}\nExpected output:\n{output}\n" for idx, (transcript, output) in enumerate(zip(example_transcripts, example_outputs))]).replace('\'', '"')

system_prompt = f"""
You are an assistant for extracting medical orders from transcripts of patient doctor conversations.

Medical order extraction involves identifying and structuring various medical orders —such as medications, imaging studies, lab tests, and follow-ups— based on doctor-patient conversations. 

The conversation is given to you in the format of a list of dicts with turn, speaker (doctor or patient), and txt for each turn.

You are to return a list of dicts with these keys: order_type, description, reason, and provenance. 

1. Return a list with one dict for each order from the conversation. Your output will be parsed with json.loads().
2. If there is only a single order, still return a list.
3. There are only four allowed values for the order_type: "medication", "lab", "followup", "imaging"
4. Provenance should be a list of ints relating to the turn ids that ground the order. This is the turn where the order was made, and maybe include directly preceeding turns where the reason was mentioned.
5. Quote the reason verbatim from the text.
6. Only list *new* or repeat orders. Do not list things that the patient is to continue doing that are mentioned in passing. Do not list previous exams.
7. The above point (6.) is very important. Orders that have a "continue" status (e.g. "we will continue with xanax XXmg") are NOT considered valid orders, since we dont need to place a specific order in EHR for instance.
8. Lab orders are fine-grained, i.e. each test is one order.
9. If the doctor suggests an over the counter medication (e.g. pain killers), we also count that as an order - UNLESS the patient is already taking it (remember rule 6!)
10. Only count orders that are actually made, not ones that are conditional (if your pain persists, maybe we...).

Make sure your output is a list (i.e. starts and ends with square brackets) of dicts, separated by commas.

Examples:
{examplestr}
"""


data_to_use = [_ for _ in data['dev'] if _['id'] not in train_examples_predicted_turns]#data['train'][:]
all_preds = {}
all_metrics = {}
metrics_per_sample = {}
for idx, curr_sample in enumerate(tqdm(data_to_use)):
    curr_sample = data_to_use[idx]
    curr_transcript = shorten_transcript(curr_sample['transcript'])
    
    user_prompt = f"Please process this transcript.\nTRANSCRIPTSTART\n{curr_transcript}\nTRANSCRIPTEND\nPlease follow all the rules and instructions carefully. Reply only with the valid response in the desired format."
    prompt=f"""
    <|im_start|>system{system_prompt}<|im_end|>
    <|im_start|>user\n{user_prompt} \\nothink<|im_end|>
    <|im_start|>assistant\n<think>\n</think>\n["""
    start = time.time()
    response = llm.create_completion(
        prompt=prompt,
        max_tokens=int(4096*0.5),
        temperature=0., top_k=1, top_p=1.0,
        stop=[]
    )
    curr_time = time.time()-start
    
    pred = '['+response['choices'][0]['text']
    if "</think>" in pred:
        thought_trace = pred.split('</think>')[0]
        print(thought_trace)
        pred = pred.split('</think>')[-1]
    try:
        try:
            pred = json.loads(pred)
        except:
            pred = json.loads('['+pred+']')
            print(f'{idx=} was fixed with adding square brackets')
        if isinstance(pred, dict):
            print(f'{idx=} was a dict, we made it a list')
            pred = [pred]
    except:
        print(f'{idx} is not proper json.\n{pred}')
        pred = []

    all_preds[curr_sample['id']] = pred

    t = {curr_sample['id']: curr_sample['expected_orders']}
    p = {curr_sample['id']: pred}
    metrics = evaluate_sample(t, p)
    metricstr = f"DescF1:{metrics['description']['Rouge1_f1']:.3f} ({metrics['description']['Rouge1_recall']:.3f}/{metrics['description']['Rouge1_precision']:.3f})"
    metricstr += f"  ReasF1:{metrics['reason']['Rouge1_f1']:.3f} ({metrics['reason']['Rouge1_recall']:.3f}/{metrics['reason']['Rouge1_precision']:.3f})"
    metricstr += f"  OTypF1:{metrics['order_type']['Strict_f1']:.3f} ({metrics['order_type']['Strict_recall']:.3f}/{metrics['order_type']['Strict_precision']:.3f})"
    metricstr += f"  ProvF1:{metrics['provenance']['MultiLabel_f1']:.3f} ({metrics['provenance']['MultiLabel_recall']:.3f}/{metrics['provenance']['MultiLabel_precision']:.3f})"

    num_tokens = response['usage']['total_tokens']
    num_targets = len(curr_sample['expected_orders'])
    num_preds = len(pred)
    print(f'{idx=}: {curr_time:.4f}s #targets={num_targets} #preds={num_preds} #tokens={num_tokens} {curr_sample["id"]}')
    
    print(metricstr)
    all_metrics[curr_sample['id']] = metrics
    metrics_per_sample[curr_sample['id']] = [metrics['description']['Rouge1_f1'], 
                                             metrics['reason']['Rouge1_f1'],
                                             metrics['order_type']['Strict_f1'],
                                             metrics['provenance']['MultiLabel_f1'],
                                             curr_time,
                                             num_preds,
                                             num_targets,
                                             num_tokens,
                                            ]
    # print()

print('Done')
mname_short = mname.split('/')[-1].split('.gguf')[0].replace('Qwen','qw')
savestr = f'dev_ALL_{mname_short}_shortenedexs'
with open(f'data/{savestr}.json', 'w') as f:
    json.dump(all_preds, f)

import pandas as pd
df = pd.DataFrame(metrics_per_sample).T
df.columns = ['descF1', 'reasonF1', 'ordertypeF1', 'provF1', 'processingtime', 'len_pred', 'len_target', 'total_tokens']
df['avg'] = df[['descF1', 'reasonF1', 'ordertypeF1', 'provF1']].mean(axis=1)
print(df.mean(axis=0).round(3))
df.round(3)


/mnt/g/ProjectsOverflow/MEDIQA/mediqa-oe/evaluation
/mnt/g/ProjectsOverflow/MEDIQA


  0%|          | 0/100 [00:00<?, ?it/s]

idx=0: 13.1717s #targets=2 #preds=2 #tokens=13023 acibench_D2N182_virtscribe_clef_taskC_test3
DescF1:0.190 (0.250/0.154)  ReasF1:0.107 (0.375/0.062)  OTypF1:0.500 (0.500/0.500)  ProvF1:0.250 (0.333/0.200)
idx=1: 5.5012s #targets=3 #preds=3 #tokens=10253 acibench_D2N174_virtassist_clef_taskC_test3
DescF1:0.909 (0.833/1.000)  ReasF1:0.333 (0.333/0.333)  OTypF1:1.000 (1.000/1.000)  ProvF1:0.407 (0.611/0.306)
idx=2: 4.9969s #targets=1 #preds=2 #tokens=10336 primock57_5_12
DescF1:0.667 (1.000/0.500)  ReasF1:0.057 (1.000/0.029)  OTypF1:0.667 (1.000/0.500)  ProvF1:0.500 (0.500/0.500)
idx=3: 7.3529s #targets=2 #preds=4 #tokens=11598 acibench_D2N140_virtscribe_clinicalnlp_taskC_test2
DescF1:0.553 (0.717/0.450)  ReasF1:0.021 (0.500/0.011)  OTypF1:0.667 (1.000/0.500)  ProvF1:0.154 (0.200/0.125)
idx=4: 5.0602s #targets=2 #preds=4 #tokens=10292 acibench_D2N067_aci_train
DescF1:0.667 (1.000/0.500)  ReasF1:0.000 (0.000/0.000)  OTypF1:0.667 (1.000/0.500)  ProvF1:0.600 (0.750/0.500)
idx=5: 6.0899s #tar

Unnamed: 0,descF1,reasonF1,ordertypeF1,provF1,processingtime,len_pred,len_target,total_tokens,avg
acibench_D2N182_virtscribe_clef_taskC_test3,0.190,0.107,0.500,0.250,13.172,2.0,2.0,13023.0,0.262
acibench_D2N174_virtassist_clef_taskC_test3,0.909,0.333,1.000,0.407,5.501,3.0,3.0,10253.0,0.662
primock57_5_12,0.667,0.057,0.667,0.500,4.997,2.0,1.0,10336.0,0.473
acibench_D2N140_virtscribe_clinicalnlp_taskC_test2,0.553,0.021,0.667,0.154,7.353,4.0,2.0,11598.0,0.349
acibench_D2N067_aci_train,0.667,0.000,0.667,0.600,5.060,4.0,2.0,10292.0,0.483
...,...,...,...,...,...,...,...,...,...
acibench_D2N125_aci_clinicalnlp_taskB_test1,0.667,0.000,0.667,0.222,3.103,1.0,2.0,9250.0,0.389
acibench_D2N199_aci_clef_taskC_test3,0.892,0.200,0.909,0.909,6.053,5.0,6.0,10573.0,0.728
primock57_5_5,0.857,0.963,1.000,0.857,4.718,1.0,1.0,10985.0,0.919
primock57_5_11,0.200,0.000,0.500,0.333,8.173,3.0,1.0,13075.0,0.258


In [3]:
idx=0: 11.2263s #targets=2 #preds=2 #tokens=13025 acibench_D2N182_virtscribe_clef_taskC_test3
DescF1:0.200 (0.250/0.167)  ReasF1:0.107 (0.375/0.062)  OTypF1:0.500 (0.500/0.500)  ProvF1:0.250 (0.333/0.200)
idx=1: 3.7126s #targets=3 #preds=3 #tokens=10253 acibench_D2N174_virtassist_clef_taskC_test3
DescF1:0.909 (0.833/1.000)  ReasF1:0.333 (0.333/0.333)  OTypF1:1.000 (1.000/1.000)  ProvF1:0.407 (0.611/0.306)
idx=2: 2.7383s #targets=1 #preds=2 #tokens=10306 primock57_5_12
DescF1:0.667 (1.000/0.500)  ReasF1:0.087 (1.000/0.045)  OTypF1:0.667 (1.000/0.500)  ProvF1:0.500 (0.500/0.500)
idx=3: 4.8973s #targets=2 #preds=4 #tokens=11598 acibench_D2N140_virtscribe_clinicalnlp_taskC_test2
DescF1:0.553 (0.717/0.450)  ReasF1:0.021 (0.500/0.011)  OTypF1:0.667 (1.000/0.500)  ProvF1:0.336 (0.800/0.212)
idx=4: 3.3289s #targets=2 #preds=4 #tokens=10299 acibench_D2N067_aci_train

SyntaxError: invalid decimal literal (1412554130.py, line 1)

In [None]:
import pandas as pd
import json
from llama_cpp import Llama, LlamaGrammar
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
import os
from tqdm.notebook import tqdm
import time
import matplotlib.pyplot as plt

%cd mediqa-oe/evaluation
from evaluate_oe import evaluate_sample
evaluate_sample
%cd ../..

with open('data/orders_data_transcript.json', 'r') as f:
    data = json.loads(f.read())

allids = [_['id'] for _ in data['dev']]

dev_extract = data['dev'][:12]

dev_extract_orders = {
    d['id']: d['expected_orders'] for d in dev_extract
}
with open('data/dev_extract_orders.json', 'w') as f:
    json.dump(dev_extract_orders, f)

startmodelloading_time = time.time()

mname = "/home/justin/llama/Qwen3-8B-UD-Q5_K_XL.gguf"
mname = "/home/justin/llama/Qwen3-14B-UD-Q4_K_XL.gguf"
try:
    llm
except NameError:
    print('Loading model')
    llm = Llama(
        # model_path="/home/justin/llama/Qwen3-8B-Q5_K_M.gguf",  
        # lora_path='llama_8b_unsloth_mediqa_oe_adapter_model.safetensors',
        model_path=mname,  
        # lora_path='unsloth.Q4_K_M.gguf',
        # model_path="/home/justin/llama/Qwen3-14B-Q4_K_M.gguf",  
        # model_path="/home/justin/llama/Qwen3-14B-UD-Q4_K_XL.gguf",  
        # model_path="/home/justin/llama/Qwen3-4B-Q6_K.gguf",  
        draft_model=LlamaPromptLookupDecoding(),
        logits_all=True,
        n_gpu_layers=-1,
        flash_attn=True,
        # n_ctx=int(4096*5.7),
        # n_ctx=int(4096*7),
        n_ctx=int(4096*7),
        verbose=False,
        n_threads=os.cpu_count() - 2
    )
    print(f'Loaded model {time.time()-startmodelloading_time:.4f}s.')

def shorten_transcript(s):
    s = str(s)
    s = s.replace('turn_id', 'turn').replace('DOCTOR', 'DOC').replace('PATIENT', 'PAT').replace('transcript', 'txt')
    return s


train_examples_predicted_turns = {
    'acibench_D2N074_virtscribe_valid': [10, 14],
    'primock57_2_6': [184, 185, 186, 189, 190, 180, 181, 183, 191, 192],
    'acibench_D2N025_virtscribe_train': [80, 88, 90, 96, 133],
    'primock57_1_2': [80, 88, 90, 96, 133],    
}

def get_salient_transcript(sample):
    expected_output = sample['expected_orders']
    transcript = sample['transcript']
    salient_turns = [__ for _ in expected_output for __ in _['provenance']]
    salient_turns += train_examples_predicted_turns[sample['id']]
    salient_turns += [_-3 for _ in salient_turns]
    salient_turns += [_-2 for _ in salient_turns]
    salient_turns += [_-1 for _ in salient_turns]
    salient_turns += [_+1 for _ in salient_turns]
    salient_turns += [_+2 for _ in salient_turns]
    salient_turns += [_+3 for _ in salient_turns]
    salient_turns = sorted(set(salient_turns))
    salient_transcript = [transcript[idx] for idx in salient_turns if idx<len(transcript)]
    return salient_transcript

examplestr = \
f"""
EXAMPLE1START
Transcript:
{shorten_transcript(data['train'][0]['transcript'])}
Desired output:
{data['train'][0]['expected_orders']}
EXAMPLE1END

EXAMPLE2START
Transcript:
{shorten_transcript(data['train'][-1]['transcript'])}
Desired output:
{data['train'][-1]['expected_orders']}
EXAMPLE2END
""".replace('\'', '"')

def get_train_sample_from_idname(idname):
    return [_ for _ in data['train'] if idname in _['id']][0]

example_transcripts = [get_salient_transcript(get_train_sample_from_idname(idname)) for idname in train_examples_predicted_turns]
example_transcripts = [shorten_transcript(_) for _ in example_transcripts]
example_outputs = [get_train_sample_from_idname(idname)['expected_orders'] for idname in train_examples_predicted_turns]

examplestr = "".join([f"EXAMPLE{idx+1}START\nTranscript:\n{transcript}\nExpected output:\n{output}\n" for idx, (transcript, output) in enumerate(zip(example_transcripts, example_outputs))]).replace('\'', '"')

system_prompt = f"""
You are an assistant for extracting medical orders from transcripts of patient doctor conversations.

Medical order extraction involves identifying and structuring various medical orders —such as medications, imaging studies, lab tests, and follow-ups— based on doctor-patient conversations. 

The conversation is given to you in the format of a list of dicts with turn, speaker (doctor or patient), and txt for each turn.

You are to return a list of dicts with these keys: order_type, description, reason, and provenance. 

1. Return a list with one dict for each order from the conversation. Your output will be parsed with json.loads().
2. If there is only a single order, still return a list.
3. There are only four allowed values for the order_type: "medication", "lab", "followup", "imaging"
4. Provenance should be a list of ints relating to the turn ids that ground the order. This is the turn where the order was made, and maybe include directly preceeding turns where the reason was mentioned.
5. Quote the reason verbatim from the text.
6. Only list *new* or repeat orders. Do not list things that the patient is to continue doing that are mentioned in passing. Do not list previous exams.
7. The above point (6.) is very important. Orders that have a "continue" status (e.g. "we will continue with xanax XXmg") are NOT considered valid orders, since we dont need to place a specific order in EHR for instance.
8. Lab orders are fine-grained, i.e. each test is one order.
9. If the doctor suggests an over the counter medication (e.g. pain killers), we also count that as an order - UNLESS the patient is already taking it (remember rule 6!)
10. Only count orders that are actually made, not ones that are conditional (if your pain persists, maybe we...).

Make sure your output is a list (i.e. starts and ends with square brackets) of dicts, separated by commas.

Examples:
{examplestr}
"""

#You are scored using these metrics: description_Rouge1_f1, reason_Rouge1_f1, order_type_Strict_f1 and provenance_MultiLabel_f1



In [None]:
data_to_use = [_ for _ in data['dev'] if _['id'] not in train_examples_predicted_turns]#data['train'][:]
all_preds = {}
all_metrics = {}
metrics_per_sample = {}
for idx, curr_sample in enumerate(tqdm(data_to_use)):
    curr_sample = data_to_use[idx]
    curr_transcript = shorten_transcript(curr_sample['transcript'])
    
    user_prompt = f"Please process this transcript.\nTRANSCRIPTSTART\n{curr_transcript}\nTRANSCRIPTEND\nPlease follow all the rules and instructions carefully. Reply only with the valid response in the desired format."
    prompt=f"""
    <|im_start|>system{system_prompt}<|im_end|>
    <|im_start|>user\n{user_prompt} \\nothink<|im_end|>
    <|im_start|>assistant\n<think>\n</think>\n["""
    start = time.time()
    response = llm.create_completion(
        prompt=prompt,
        max_tokens=int(4096*0.5),
        temperature=0., top_k=1, top_p=1.0,
        stop=[]
    )
    curr_time = time.time()-start
    
    pred = '['+response['choices'][0]['text']
    if "</think>" in pred:
        thought_trace = pred.split('</think>')[0]
        print(thought_trace)
        pred = pred.split('</think>')[-1]
    try:
        try:
            pred = json.loads(pred)
        except:
            pred = json.loads('['+pred+']')
            print(f'{idx=} was fixed with adding square brackets')
        if isinstance(pred, dict):
            print(f'{idx=} was a dict, we made it a list')
            pred = [pred]
    except:
        print(f'{idx} is not proper json.\n{pred}')
        pred = []

    all_preds[curr_sample['id']] = pred

    t = {curr_sample['id']: curr_sample['expected_orders']}
    p = {curr_sample['id']: pred}
    metrics = evaluate_sample(t, p)
    metricstr = f"DescF1:{metrics['description']['Rouge1_f1']:.3f} ({metrics['description']['Rouge1_recall']:.3f}/{metrics['description']['Rouge1_precision']:.3f})"
    metricstr += f"  ReasF1:{metrics['reason']['Rouge1_f1']:.3f} ({metrics['reason']['Rouge1_recall']:.3f}/{metrics['reason']['Rouge1_precision']:.3f})"
    metricstr += f"  OTypF1:{metrics['order_type']['Strict_f1']:.3f} ({metrics['order_type']['Strict_recall']:.3f}/{metrics['order_type']['Strict_precision']:.3f})"
    metricstr += f"  ProvF1:{metrics['provenance']['MultiLabel_f1']:.3f} ({metrics['provenance']['MultiLabel_recall']:.3f}/{metrics['provenance']['MultiLabel_precision']:.3f})"

    num_tokens = response['usage']['total_tokens']
    num_targets = len(curr_sample['expected_orders'])
    num_preds = len(pred)
    print(f'{idx=}: {curr_time:.4f}s #targets={num_targets} #preds={num_preds} #tokens={num_tokens} {curr_sample["id"]}')
    
    print(metricstr)
    all_metrics[curr_sample['id']] = metrics
    metrics_per_sample[curr_sample['id']] = [metrics['description']['Rouge1_f1'], 
                                             metrics['reason']['Rouge1_f1'],
                                             metrics['order_type']['Strict_f1'],
                                             metrics['provenance']['MultiLabel_f1'],
                                            
                                             curr_time,
                                             num_preds,
                                             num_targets,
                                             num_tokens,
                                            ]
    # print()

print('Done')
mname_short = mname.split('/')[-1].split('.gguf')[0].replace('Qwen','qw')
savestr = f'dev_ALL_{mname_short}_shortenedexs'
# # 'dev_ALL_qw14BUDQ4KXL_v5dot6_ex_specdec'
# with open('data/dev_ALL_qw8BQ5KM_v5dot7noex.json', 'w') as f:
#     json.dump(all_preds, f)
with open(f'data/{savestr}.json', 'w') as f:
    json.dump(all_preds, f)

import pandas as pd
df = pd.DataFrame(metrics_per_sample).T
df.columns = ['descF1', 'reasonF1', 'ordertypeF1', 'provF1', 'processingtime', 'len_pred', 'len_target', 'total_tokens']
df['avg'] = df[['descF1', 'reasonF1', 'ordertypeF1', 'provF1']].mean(axis=1)
print(df.mean(axis=0).round(3))
df.round(3)

In [None]:
savestr