In [1]:
import json

from program import Program
from lm_poller import LMPoller
from text_preprocessing import extract_triplets, extract_relations
from lm_response_evaluator import extract_code, get_response_similarity, evaluate_response

In [2]:
#read json file from ../res/webnlg/train.json
with open('../res/webnlg/train.json', 'r', encoding='utf-8') as f:
    data = json.load(f)
    
# print(data['data'][0]['in'])
# print(len(data['data']))

In [3]:
data_dict = {}
for i in range(len(data['data'])):
    sample = data['data'][i]
    in_data = data['data'][i]['in']
    relations = extract_relations(in_data)
    relations = tuple(relations)

    if relations in data_dict:
        data_dict[relations].append(sample)
    else:
        data_dict[relations] = [sample]

In [4]:
keys = list(data_dict.keys())
example_id = -1
# print(f'{keys[example_id]}:\n {data_dict[keys[example_id]]}')

In [5]:
from logging import getLogger

logger = getLogger('lm_poller')

def log_error(message):
    logger.error(message)

MAX_LLM_FIX_QUYERY = 5

In [6]:
program_gen = Program('../out/program')
lm = LMPoller()

for i, key in enumerate(keys[20:22]):
    relations = set(key)
    program_gen.add_rule_if_stmt(set(key))
    #get sample for the key
    sample = data_dict[key][0]
    sample_X = sample['in']
    reference_text = sample['out']
    triplets = extract_triplets(sample_X)
    
    # print(f'key: {key}, triplets: {triplets}, relations: {relations}')
    
    response = lm.query_lm(triplets, reference_text, relations)
    with open(f'../res/lama_responses/response_{i}.txt', 'w') as f:
        f.write(response)
    exctracted_code = extract_code(response, relations)
    with open(f'../res/lama_responses/code_{i}.py', 'w') as f:
        f.write(exctracted_code)
    
    output, errors = evaluate_response(triplets ,exctracted_code, reference_text, relations)
    print(f'Errors: {errors}')
    fix_query_count = 0
    print(f'similarity: {get_response_similarity(output, reference_text)}')
    print(f'output: {output}\nreference: {reference_text}\n\n')
    while (errors is not None or get_response_similarity(output, reference_text) < 0.5) and fix_query_count < MAX_LLM_FIX_QUYERY:
        response = lm.fix_query(output, errors)
        exctracted_code = extract_code(response, relations)
        output, errors = evaluate_response(triplets, exctracted_code, reference_text)
        fix_query_count += 1
        
    if fix_query_count < MAX_LLM_FIX_QUYERY:
        program_gen.add_rule(exctracted_code)
    else:
        log_error(f'Failed to generate rule for {key} after {MAX_LLM_FIX_QUYERY} attempts. Skipping...')
        

program_gen.add_print_stmt()
program_gen.write_program()

Errors: None
similarity: 1.0
output: Alpena County Regional Airport owner is Alpena County, Michigan.
reference: Alpena County Regional Airport owner is Alpena County, Michigan.


Errors: None
similarity: 0.7532467532467533
output: The 1st runway length metres of the Amsterdam Airport Schiphol is 3800 metres.
reference: The length of the first runway at Amsterdam Airport Schiphol is 3800 metres.


