# Llama3 | Zero-shot Prompting

In [None]:
import os
import torch
import json
import re
import time
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" # the device to load the model onto

df = pd.read_csv('test_set_url')

tokenizer = AutoTokenizer.from_pretrained("Llama-3-70B-Instruct-ft-url")
model = AutoModelForCausalLM.from_pretrained("-Llama-3-70B-Instruct-ft-url", torch_dtype=torch.bfloat16)


def llm_generation_zero_shot(ins_text):
    messages = [
        {"role": "user", "content": "You are a medical assistant. Please classify the input report and respond in JSON format {'Classification': 'Negation error, Left/Right error, Interval Change error, Transcription error, or No error'}. Note that the input report may belong to multiple errors."},
        {"role": "assistant", "content": "Sure! Please give the input report."},
        {"role": "user", "content": ins_text}
    ]
    
    encodeds = tokenizer.apply_chat_template(messages,  add_generation_prompt=True, return_tensors="pt")
    input_ids = encodeds.to(device)
    model.to(device)
    outputs = model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=100, do_sample=True, temperature=0.6, top_p=0.9)
    response = outputs[0][input_ids.shape[-1]:]
    decoded = tokenizer.decode(response, skip_special_tokens=True)
    return decoded



results = []
for i, item in df.iterrows():
    print(i+1)
    start_index = item["Synthetic Report Processed"].lower().find("findings")
    if start_index == -1:
        start_index = 0
    # Cutting context from "Findings"
    ins_text_wo_error = item["Synthetic Report Processed"][start_index:]
    llm_result_wo_error = llm_generation_zero_shot(ins_text_wo_error)
    llm_result_wo_error = llm_result_wo_error.lower().strip()
    print('llm_result_wo_error:',llm_result_wo_error)

    neg_match_wo_error = re.findall(r'negation error', llm_result_wo_error)
    lr_match_wo_error = re.findall(r'left/right error', llm_result_wo_error)
    int_match_wo_error = re.findall(r'interval change error', llm_result_wo_error)
    tran_match_wo_error = re.findall(r'transcription error', llm_result_wo_error)

    tmp_result = []
    if neg_match_wo_error == [] and lr_match_wo_error == [] and int_match_wo_error == [] and tran_match_wo_error == []:
        tmp_result.append('no error')
    else:
        if neg_match_wo_error != []:
            tmp_result.append(neg_match_wo_error[0])
        if lr_match_wo_error != []:
            tmp_result.append(lr_match_wo_error[0])     
        if int_match_wo_error != []:
            tmp_result.append(int_match_wo_error[0]) 
        if tran_match_wo_error != []:
            tmp_result.append(tran_match_wo_error[0]) 
    results.append(tmp_result)
    print('------')
    print(tmp_result)
    print('------')


    start_index = item["Synthetic Report With Errors Processed"].lower().find("findings")
    if start_index == -1:
        start_index = 0
    # Cutting context from "Findings"
    ins_text_w_error = item["Synthetic Report With Errors Processed"][start_index:]
    llm_result_w_error = llm_generation_zero_shot(ins_text_w_error)
    llm_result_w_error = llm_result_w_error.lower().strip()
    print('llm_result_w_error:',llm_result_w_error)
    
    neg_match_w_error = re.findall(r'negation error', llm_result_w_error)
    lr_match_w_error = re.findall(r'left/right error', llm_result_w_error)
    int_match_w_error = re.findall(r'interval change error', llm_result_w_error)
    tran_match_w_error = re.findall(r'transcription error', llm_result_w_error)

    tmp_result = []
    if neg_match_w_error == [] and lr_match_w_error == [] and int_match_w_error == [] and tran_match_w_error == []:
        tmp_result.append('no error')
    else:
        if neg_match_w_error != []:
            tmp_result.append(neg_match_w_error[0])
        if lr_match_w_error != []:
            tmp_result.append(lr_match_w_error[0])     
        if int_match_w_error != []:
            tmp_result.append(int_match_w_error[0]) 
        if tran_match_w_error != []:
            tmp_result.append(tran_match_w_error[0]) 
    results.append(tmp_result)
    print('++++++')
    print(tmp_result)
    print('++++++')


In [None]:
#Negation error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'negation error' in no_error_pre:
        fp += 1
    elif 'negation error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('negation error' in error_type_gs) and ('negation error' in error_type_pre):
        tp += 1
    elif ('negation error' in error_type_gs) and ('negation error' not in error_type_pre):
        fn += 1
    elif ('negation error' not in error_type_gs) and ('negation error' in error_type_pre):
        fp += 1    
    elif ('negation error' not in error_type_gs) and ('negation error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)


P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
neg_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", neg_F1)

In [None]:
#Left/right error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'left/right error' in no_error_pre:
        fp += 1
    elif 'left/right error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('left / right error' in error_type_gs) and ('left/right error' in error_type_pre):
        tp += 1
    elif ('left / right error' in error_type_gs) and ('left/right error' not in error_type_pre):
        fn += 1
    elif ('left / right error' not in error_type_gs) and ('left/right error' in error_type_pre):
        fp += 1    
    elif ('left / right error' not in error_type_gs) and ('left/right error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)


P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
lr_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", lr_F1)

In [None]:
#interval change error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'interval change error' in no_error_pre:
        fp += 1
    elif 'interval change error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('interval change error' in error_type_gs) and ('interval change error' in error_type_pre):
        tp += 1
    elif ('interval change error' in error_type_gs) and ('interval change error' not in error_type_pre):
        fn += 1
    elif ('interval change error' not in error_type_gs) and ('interval change error' in error_type_pre):
        fp += 1    
    elif ('interval change error' not in error_type_gs) and ('interval change error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)


P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
int_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", int_F1)

In [None]:
#transcription error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'transcription error' in no_error_pre:
        fp += 1
    elif 'transcription error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('transcription error' in error_type_gs) and ('transcription error' in error_type_pre):
        tp += 1
    elif ('transcription error' in error_type_gs) and ('transcription error' not in error_type_pre):
        fn += 1
    elif ('transcription error' not in error_type_gs) and ('transcription error' in error_type_pre):
        fp += 1    
    elif ('transcription error' not in error_type_gs) and ('transcription error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)


P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
tran_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", tran_F1)

In [None]:
macro_F1 = (neg_F1 + lr_F1 + int_F1 + tran_F1)/4
print(macro_F1)

# GPT4 | Zero-shot Prompting

In [None]:
import os
import json
import re
import time
import pandas as pd
from openai import AzureOpenAI

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["AZURE_OPENAI_KEY"] = "AZURE_OPENAI_KEY_number"   
os.environ["AZURE_OPENAI_ENDPOINT"] = "AZURE_OPENAI_ENDPOINT_url"

df = pd.read_csv('test_set_url')


def llm_zero_shot_generation(ins_text):
    api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
    api_key= os.getenv("AZURE_OPENAI_KEY")
    deployment_name = 'gpt-4-1106-preview'
    api_version = '2023-12-01-preview'

    client = AzureOpenAI(
        api_key=api_key,
        api_version=api_version,
        base_url=f"{api_base}/openai/deployments/{deployment_name}"
    )

    response = client.chat.completions.create(
        model=deployment_name,
        messages=[
            { "role": "system", "content": "Please classify the input report and respond in JSON format {'Classification': 'Negation error, Left/Right error, Interval Change error, Transcription error, or No error'}. Note that the input report may belong to multiple errors." },
            { "role": "user", "content": [
                {
                    "type": "text",
                    "text": ins_text
                },
            ] }
        ],
        max_tokens=1000
    )
    content = response.choices[0].message.content
    return content


results = []

for i, item in df.iterrows():
    print(i+1)
    start_index = item["Synthetic Report Processed"].lower().find("findings")
    if start_index == -1:
        start_index = 0
    # Cutting context from "Findings"
    ins_text_wo_error = item["Synthetic Report Processed"][start_index:]
    llm_result_wo_error = llm_zero_shot_generation(ins_text_wo_error).lower()
    
    neg_match_wo_error = re.findall(r'negation error', llm_result_wo_error)
    lr_match_wo_error = re.findall(r'left/right error', llm_result_wo_error)
    int_match_wo_error = re.findall(r'interval change error', llm_result_wo_error)
    tran_match_wo_error = re.findall(r'transcription error', llm_result_wo_error)
    
    tmp_result = []
    if neg_match_wo_error == [] and lr_match_wo_error == [] and int_match_wo_error == [] and tran_match_wo_error == []:
        tmp_result.append('no error')
    else:
        if neg_match_wo_error != []:
            tmp_result.append(neg_match_wo_error[0])
        if lr_match_wo_error != []:
            tmp_result.append(lr_match_wo_error[0])     
        if int_match_wo_error != []:
            tmp_result.append(int_match_wo_error[0]) 
        if tran_match_wo_error != []:
            tmp_result.append(tran_match_wo_error[0]) 
    results.append(tmp_result)
    
    start_index = item["Synthetic Report With Errors Processed"].lower().find("findings")
    if start_index == -1:
        start_index = 0
    # Cutting context from "Findings"
    ins_text_w_error = item["Synthetic Report With Errors Processed"][start_index:]
    llm_result_w_error = llm_zero_shot_generation(ins_text_w_error).lower()
    neg_match_w_error = re.findall(r'negation error', llm_result_w_error)
    lr_match_w_error = re.findall(r'left/right error', llm_result_w_error)
    int_match_w_error = re.findall(r'interval change error', llm_result_w_error)
    tran_match_w_error = re.findall(r'transcription error', llm_result_w_error)

    tmp_result = []
    if neg_match_w_error == [] and lr_match_w_error == [] and int_match_w_error == [] and tran_match_w_error == []:
        tmp_result.append('no error')
    else:
        if neg_match_w_error != []:
            tmp_result.append(neg_match_w_error[0])
        if lr_match_w_error != []:
            tmp_result.append(lr_match_w_error[0])     
        if int_match_w_error != []:
            tmp_result.append(int_match_w_error[0]) 
        if tran_match_w_error != []:
            tmp_result.append(tran_match_w_error[0]) 
    results.append(tmp_result)


In [None]:
#Negation error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'negation error' in no_error_pre:
        fp += 1
    elif 'negation error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('negation error' in error_type_gs) and ('negation error' in error_type_pre):
        tp += 1
    elif ('negation error' in error_type_gs) and ('negation error' not in error_type_pre):
        fn += 1
    elif ('negation error' not in error_type_gs) and ('negation error' in error_type_pre):
        fp += 1    
    elif ('negation error' not in error_type_gs) and ('negation error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)

P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
neg_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", neg_F1)

In [None]:
#Left/right error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'left/right error' in no_error_pre:
        fp += 1
    elif 'left/right error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('left / right error' in error_type_gs) and ('left/right error' in error_type_pre):
        tp += 1
    elif ('left / right error' in error_type_gs) and ('left/right error' not in error_type_pre):
        fn += 1
    elif ('left / right error' not in error_type_gs) and ('left/right error' in error_type_pre):
        fp += 1    
    elif ('left / right error' not in error_type_gs) and ('left/right error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)

P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
lr_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", lr_F1)

In [None]:
#interval change error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'interval change error' in no_error_pre:
        fp += 1
    elif 'interval change error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('interval change error' in error_type_gs) and ('interval change error' in error_type_pre):
        tp += 1
    elif ('interval change error' in error_type_gs) and ('interval change error' not in error_type_pre):
        fn += 1
    elif ('interval change error' not in error_type_gs) and ('interval change error' in error_type_pre):
        fp += 1    
    elif ('interval change error' not in error_type_gs) and ('interval change error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)

P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
int_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", int_F1)

In [None]:
#transcription error
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(len(df)):
    error_type_gs = df['Error Type(s) Processed'].iloc[i]
    no_error_pre = results[2*i]
    error_type_pre = results[2*i+1]

    if 'transcription error' in no_error_pre:
        fp += 1
    elif 'transcription error' not in no_error_pre:
        tn += 1
    else:
        print("ERROR1!!!", i)
    
    if ('transcription error' in error_type_gs) and ('transcription error' in error_type_pre):
        tp += 1
    elif ('transcription error' in error_type_gs) and ('transcription error' not in error_type_pre):
        fn += 1
    elif ('transcription error' not in error_type_gs) and ('transcription error' in error_type_pre):
        fp += 1    
    elif ('transcription error' not in error_type_gs) and ('transcription error' not in error_type_pre):
        tn += 1
    else:
        print("ERROR2!!!", i)

P = tp/(tp+fp)
R = tp/(tp+fn)
F1 = 2*P*R/(P+R)
tran_F1 = round(F1, 3)

print("P:", P)
print("R:", R)
print("F1:", F1)
print("F1:", tran_F1)

In [None]:
macro_F1 = (neg_F1 + lr_F1 + int_F1 + tran_F1)/4
print(macro_F1)