In [1]:
import os
import json
import time
import tiktoken
from openai import AzureOpenAI
from requests.exceptions import ReadTimeout
import concurrent.futures

In [2]:
def extract_infos(json_file_path) -> dict:
    with open(json_file_path, 'r') as file:
        return json.load(file)


def init_client(infos: dict):
    client = AzureOpenAI(
    azure_endpoint = infos['azure_endpoint'],
    api_key = infos['api_key'],
    api_version = infos['api_version']
    )

    return client

In [3]:
def num_tokens_from_string(string: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model("gpt-4")
    num_tokens = len(encoding.encode(string))

    return num_tokens

def message(role, content) -> dict:
    return {"role": role, "content": content}

def read_file(absolute_path):
    with open(absolute_path) as file:
        return file.read()

def load_tables_from_json(json_file):
    with open(json_file, 'r') as file:
        data = json.load(file)
    return data

def check_processed_tables(json_file_path: str, tables_directory_path: str):
    data = load_tables_from_json(json_file_path)

    for file_name in os.listdir(tables_directory_path):
        if file_name.endswith('.txt'):
            parts = file_name.split('_')
            if len(parts) == 2:
                article_id = parts[0]
                table_index = int(parts[1].split('.')[0])
            else:
                continue

            if article_id in data:
                article = data[article_id]
                if 0 <= table_index < len(article):
                    article[table_index]['processed'] = True

    with open(json_file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

def build_messages(file_name, messages_file_path, html_table, output_prompts_folder):
    content_system_1 = read_file(messages_file_path['system_1'])
    content_user_1 = read_file(messages_file_path['user_1'])
    content_assistant = read_file(messages_file_path['assistant'])
    content_user_2 = read_file(messages_file_path['user_2']) + '\n' + html_table
    content_system_2 = read_file(messages_file_path['system_2'])

    messages_dict = [
        message("system", content_system_1),
        message("user", content_user_1),
        message("assistant", content_assistant),
        message("user", content_user_2),
        message("system", content_system_2)
    ]

    # save prompt for replication purposes
    file_name_txt = file_name + '.txt'
    with open(os.path.join(output_prompts_folder, file_name_txt), "w") as text_file:
        text_file.write(json.dumps(messages_dict))
    print(f"\t Saved prompt at: {os.path.join(output_prompts_folder, file_name_txt)}")

    # number of input tokens
    input_tokens = num_tokens_from_string(content_system_1 + content_user_1 + content_assistant + content_user_2 + content_system_2)

    return messages_dict, input_tokens

In [4]:
def send_request(client, prompt: dict, max_tokens = 16000):

    start_time = time.time()

    with client.chat.completions.with_streaming_response.create(
        model="gpt-4-32k", # model = "deployment_name".
        max_tokens = 6000,
        temperature = 0,
        stream=True,
        messages = prompt,
    ) as response:
        # print(response.headers.get("X-My-Header"))
        answer = ''
        current_answer = ''
        output_tokens = 0
        stream = ''

        for line in response.iter_lines():

            stream += line + '\n'

            if len(line) > 0:
                output_tokens += 1
                line = line.replace('data: ', '')
                if line == '[DONE]':
                    break
                json_line = json.loads(line)
                if len(json_line['choices']) > 0 and  json_line['choices'][0] != None and json_line['choices'][0]['delta'] != None and len(json_line['choices'][0]['delta']) > 0 and json_line['choices'][0]['delta']['content'] != None:
                    current_token = json_line['choices'][0]['delta']['content']
                    # answer += json_line['choices'][0]['delta']['content']
                    answer += current_token
                    current_answer += current_token
                    if '\n' in current_token:
                        print(current_answer)
                        current_answer = ''
    request_time = time.time() - start_time

    return answer, output_tokens, request_time, stream

def save_answer_and_stats(answer, input_tokens, output_tokens, request_time, stream, file_name, output_answers_folder):
    file_name_txt = file_name + '.txt'
    with open(os.path.join(output_answers_folder, file_name_txt), "w") as text_file:
        text_file.write(answer.encode('ascii', 'ignore').decode())
    print(f"\t Saved answer at: {os.path.join(output_answers_folder, file_name_txt)}")

    '''
    data_dict = {"file_name": file_name, "input_tokens": input_tokens, "output_tokens": output_tokens, "request_time": request_time, "stream": stream}
    print(data_dict)
    df = pd.DataFrame([data_dict])
    with pd.ExcelWriter(os.path.join(output_stats_folder, stats_file), engine='openpyxl', if_sheet_exists="overlay", mode='a') as writer:
        df.to_excel(writer, sheet_name='main', startrow=writer.sheets['main'].max_row, index=False, header=False)

    print(f"\t Saved stats at: {os.path.join(output_stats_folder, stats_file)}")
    '''
    return

In [5]:
def extract_claims(client, article_table, file_name, messages_file_paths, output_prompts_folder, output_answers_folder):
    table_html = article_table['table'].encode('ascii', 'ignore').decode()

    prompt, input_tokens = build_messages(file_name, messages_file_paths, table_html, output_prompts_folder)
    print(f"Sending request for: [{file_name}]")

    for attempt in range(2):
        try:
            answer, output_tokens, request_time, stream = send_request(client, prompt)
            break
        except ReadTimeout:
            print(f"ReadTimeout occurred. Retrying... Attempt")
    else:
        print("All retry attempts failed. Handle the error or raise it again.")
        return

    save_answer_and_stats(answer, input_tokens, output_tokens, request_time, stream, file_name, output_answers_folder)


def run(connection_data: dict, messages_file_paths: dict, articles_tables: dict, output_prompts_folder, output_answers_folder, num_threads):
    clients = [init_client(connection_data) for _ in range(num_threads)]

    progress = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
        for article_id, article_tables in articles_tables.items():
            for index, article_table in enumerate(article_tables):
                if article_table['processed']:
                    continue

                executor.submit(
                    extract_claims,
                    clients[progress % num_threads],
                    article_table,
                    f"{article_id}_{index}",
                    messages_file_paths,
                    output_prompts_folder,
                    output_answers_folder
                )

                progress += 1

    for client in clients:
        client.close()

    return

In [6]:
def check_path(path):
    if not os.path.exists(path):
        os.makedirs(path)

connection_infos = extract_infos('private.json')

msgs_base_path = 'messages/CS'

msgs_file_paths = {
    'system_1':  f'{msgs_base_path}/system_1.txt',
    'system_2':  f'{msgs_base_path}/system_2.txt',
    'user_1':    f'{msgs_base_path}/user_1.txt',
    'user_2':    f'{msgs_base_path}/user_2.txt',
    'assistant': f'{msgs_base_path}/assistant.txt'
}

answers_folder = 'experiments/answers/test/13'
prompts_folder = 'experiments/prompts/test/13'

check_path(answers_folder)
check_path(prompts_folder)

In [None]:
# output_stats_folder = 'stats/ER/'
# stats_file = 'stats_1.xlsx'

tables_file_path = 'experiments/extracted_tables/extraction_test.json'

tables = load_tables_from_json(tables_file_path)
check_processed_tables(tables_file_path, answers_folder)
run(connection_infos, msgs_file_paths, tables, prompts_folder, answers_folder, 5)

	 Saved prompt at: experiments/prompts/test/12\2301.04770_0.txt
	 Saved prompt at: experiments/prompts/test/12\2301.04770_1.txt
	 Saved prompt at: experiments/prompts/test/12\2307.01231_2.txt
	 Saved prompt at: experiments/prompts/test/12\2307.01231_5.txt
	 Saved prompt at: experiments/prompts/test/12\2301.02962_1.txt
Sending request for: [2301.04770_1]
Sending request for: [2301.04770_0]Sending request for: [2307.01231_5]Sending request for: [2307.01231_2]


Sending request for: [2301.02962_1]
<{<Models, RoBERTa>, <Dataset, DBLP>}, F1 Score, 95.81>

<{<Models, RoBERTa>, <Dataset, iTunes>}, F1 Score, 72.73>

<{<Data Type, Structured>, <Dataset, Amazon-Google>, <Domain, software>, <Size, 11,460>, <# Positive, 1,167>, <# Attr., 3>}>

<{<Data set, RLdata>, <EPregime, PY>}, Precision, 0.896>

<{<Data set, RLdata>, <EPregime, PY>}, Recall, 0.961>

<{<Models, RoBERTa>, <Dataset, Amazon>}, F1 Score, 61.76>

<{<Models, RoBERTa>, <Dataset, Google>}, F1 Score, 73.12>

<{<Models, RoBERTa>, <Datas