In [1]:
import json

from program import Program
from lm_poller import LMPoller
from text_preprocessing import extract_triplets, extract_relations
from lm_response_evaluator import extract_code, get_response_similarity, evaluate_response

In [2]:
#read json file from ../res/webnlg/train.json
with open('../res/webnlg/train.json', 'r', encoding='utf-8') as f:
    data = json.load(f)
    
print(data['data'][0]['in'])
print(len(data['data']))

Aarhus Airport | city served | Aarhus, Denmark
35426


In [3]:
# group all data with same input so in is a key and value is set of all possible outputs
data_dict = {}
for i in range(len(data['data'])):
    in_data = data['data'][i]['in']
    out_data = data['data'][i]['out']
    if in_data in data_dict:
        data_dict[in_data].add(out_data)
    else:
        data_dict[in_data] = {out_data}

In [4]:
keys = list(data_dict.keys())
example_id = 0
print(f'{keys[example_id]}:\n {data_dict[keys[example_id]]}')

Aarhus Airport | city served | Aarhus, Denmark:
 {'Aarhus Airport serves the city of Aarhus, Denmark.', 'The Aarhus is the airport of Aarhus, Denmark.'}


In [5]:
program_gen = Program('../res/output')
lm = LMPoller()

In [6]:
from logging import getLogger

logger = getLogger('lm_poller')

def log_error(message):
    logger.error(message)

MAX_LLM_FIX_QUYERY = 5

# TODO: data_dict keys should be relations not input text
for i, key in enumerate(keys[20:22]):
    # print(key.split('▸'))
    relations = extract_relations(key)
    program_gen.add_rule_if_stmt(relations)
    triplets = extract_triplets(key)
    reference_text = next(iter(data_dict[key]))
    
    response = lm.query_lm(triplets, reference_text, relations)
    with open(f'../res/lama_responses/response_{i}.txt', 'w') as f:
        f.write(response)
    exctracted_code = extract_code(response)
    with open(f'../res/lama_responses/code_{i}.py', 'w') as f:
        f.write(exctracted_code)
    # exctracted_code = extract_code(response)
    
    output, errors = evaluate_response(triplets ,exctracted_code, reference_text)
    print(f'Output: {output}, Errors: {errors}')
    fix_query_count = 0
    print(f'similarity: {get_response_similarity(output, reference_text)}')
    print(f'\noutput: {output}\n\nreference: {reference_text}')
    while (errors is not None or get_response_similarity(output, reference_text) < 0.5) and fix_query_count < MAX_LLM_FIX_QUYERY:
        lm.fix_query(output, errors)
        response = lm.fix_query(output, errors, reference_text)
        exctracted_code = extract_code(response)
        output, errors = evaluate_response(exctracted_code, reference_text)
        fix_query_count += 1
        
    if fix_query_count < MAX_LLM_FIX_QUYERY:
        program_gen.add_rule(exctracted_code)
    else:
        log_error(f'Failed to generate rule for {key} after {MAX_LLM_FIX_QUYERY} attempts. Skipping...')
        

program_gen.add_print_stmt()
program_gen.write_program()

Output: The runway length of Abilene Regional Airport is 1121.0., Errors: None
similarity: 1.0

output: The runway length of Abilene Regional Airport is 1121.0.

reference: The runway length of Abilene Regional Airport is 1121.0.
Output: The runway length of Abilene Regional Airport is 2194.0., Errors: None
similarity: 1.0

output: The runway length of Abilene Regional Airport is 2194.0.

reference: The runway length of Abilene Regional Airport is 2194.0.
