In [1]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/cafe/u/dmendo/.cache'
os.environ['HF_HOME'] = '/cafe/u/dmendo/.cache'

In [2]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [3]:
from utils import *
import openai
from tqdm import tqdm
openai.api_key = "INSERT_KEY"
import itertools
from matplotlib import pyplot as plt
import context_retriever
import re
import pandas as pd
import json
import numpy as np
import copy

import spot_utils



In [4]:
import ast

In [5]:
#duplicate from synthTL
def get_literal_node_id(node):
    return node.assert_text

def get_abstract_node_id(node,ret_var_map=False):
    str_dcmp_dict = dict((k,v.assert_text) for k,v in node.dcmp_dict.items())
    return get_abstract_node_id_from_dcmpdict(node.assert_text,str_dcmp_dict,ret_var_map=ret_var_map)

def remove_duplicate_fromlist_keeporder(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]

def get_abstract_node_id_from_dcmpdict(assert_text,dcmp_dict,ret_var_map=False):
    abs_text = assert_text
    for dcmp_var,dcmp_node_text in dcmp_dict.items():
        abs_text = abs_text.replace(dcmp_node_text,"_"+dcmp_var+"-SYMBOL_")
    return get_abstract_node_id_from_abstext(abs_text,ret_var_map=ret_var_map)

def normalize_dcmp_dict(assert_text,dcmp_dict):
    node_id,var_map = get_abstract_node_id_from_dcmpdict(assert_text,dcmp_dict,ret_var_map=True)
    new_dict = {}
    for prev_var,new_var in var_map.items():
        this_prev_var = prev_var[1:-1] #remove underscores
        this_new_var = new_var[1:-1]
        new_dict[this_new_var] = dcmp_dict[this_prev_var]
    return new_dict

def get_abstract_node_id_from_abstext(abs_text,ret_var_map=False,input_mode=False):
    if not input_mode:
        abs_var_list = remove_duplicate_fromlist_keeporder(re.findall('_[a-zA-Z0-9_]*-SYMBOL_',abs_text))
    else:
        abs_var_list = remove_duplicate_fromlist_keeporder(re.findall('_[a-zA-Z0-9_]*SYMBOL_',abs_text))
    new_abs_var_list = ["_SYMBOL"+str(i)+"_" for i in range(len(abs_var_list))]
    node_id = abs_text
    for i in range(len(abs_var_list)):
        new_abs_var = "_SYMBOL"+str(i)+"_"
        node_id = node_id.replace(abs_var_list[i],new_abs_var)
    if not ret_var_map:
        return node_id    
    else:
        var_map = dict((abs_var_list[i].replace("-SYMBOL_","_"),new_abs_var_list[i]) for i in range(len(abs_var_list)))
        return node_id, var_map
        
def get_unique_node_id(node):
    if node.parent is not None:
        return get_unique_node_id(node.parent) + ";" + node.assert_text
    else:
        return node.assert_text

def create_decompose_dict(fname):
    decompose_df = pd.read_excel(fname)
    decompose_dict = {}
    for idx,row in decompose_df.iterrows():
        dcmp_dict = json.loads(row['Decomposition'])
        dcmp_dict = dict((k,v) for k,v in dcmp_dict.items() if v != "" and v != row['Natural language'])
        dcmp_dict = normalize_dcmp_dict(row['Natural language'],dcmp_dict)
        decompose_dict[row['Natural language']] = dcmp_dict
    return decompose_dict

def create_translate_dict(fname):
    translate_df = pd.read_excel(fname)
    translate_dict = {}
    for idx,row in translate_df.iterrows():
        translate_dict[row['Natural language']] = row['LTL']
        if 'Decomposed Natural language' in row and 'Template' in row:
            assert "SYMBOL_" in row['Decomposed Natural language']
            abs_node_id,var_map = get_abstract_node_id_from_abstext(row['Decomposed Natural language'],ret_var_map=True,input_mode=True)
            cur_template = row['Template']
            for prev_var,new_var in var_map.items():
                cur_template = cur_template.replace(prev_var,new_var)
            translate_dict[abs_node_id] = cur_template
            #translate_dict[row['Decomposed Natural language']] = row['Template']
    return translate_dict

class Node:
    def __init__(self,assert_text,parent=None,translate_fewshots=[],decompose_fewshots=[]):
        self.assert_text = assert_text
        self.translation = None
        self.translation_type = None
        self.template_translation = None
        self.parent = parent
        self.dcmp_dict = {}
        self.translate_fewshots = translate_fewshots
        self.decompose_fewshots = decompose_fewshots
    
    def set_dcmp_dict(self,dcmp_str_dict,force_new=True):
        if not force_new and dcmp_str_dict == dict((k,v.assert_text) for k,v in self.dcmp_dict.items()):
            return
        dcmp_str_dict = normalize_dcmp_dict(self.assert_text,dcmp_str_dict)
        self.dcmp_dict = {}
        for dcmp_var,dcmp_str in dcmp_str_dict.items():
            new_node = Node(assert_text=dcmp_str,
                            parent=self,
                            translate_fewshots=self.translate_fewshots,
                            decompose_fewshots=self.decompose_fewshots
                        )
            self.dcmp_dict[dcmp_var] = new_node

    def decompose(self,mode='LLM',**kwargs):
        if mode == 'LLM':
            dcmp_str_dict = decompose_LLM(self,**kwargs)
        elif mode == 'cache':
            dcmp_str_dict = get_decompose_cache(self)
        else:
            assert False, "decomposition mode not found! " + mode
        self.set_dcmp_dict(dcmp_str_dict)
        return self

    def translate(self,mode='LLM',t_type='regular',**kwargs):
        self.translation_type = t_type
        cur_prompt = None
        pred = None
        if mode == 'LLM':
            cur_output,cur_prompt,pred = translate_LLM(self,t_type=t_type,**kwargs)
        elif mode == 'cache':
            cur_output = get_translate_cache(self,t_type=t_type)
        elif mode == 'NoRun' and t_type == 'template':
            cur_output = self.template_translation
        elif mode == 'NoRun' and t_type == 'regular':
            cur_output = self.translation
        else:
            assert False, "translation mode not found! " + mode
        
        if cur_output is None:
            cur_template_translation = None
            cur_translation = None
        elif t_type == 'template':
            cur_template_translation = cur_output
            cur_translation = cur_output
            for dcmp_var,dcmp_node in self.dcmp_dict.items():
                if dcmp_node.translation is not None:
                    cur_translation = cur_translation.replace("_"+dcmp_var+"_","("+dcmp_node.translation+")")
        elif t_type == 'regular':
            cur_template_translation = None
            cur_translation = cur_output
        self.translation = cur_translation
        self.template_translation = cur_template_translation
        return cur_prompt, pred

    def check(self,DUT_formula,ret_trace=False,ret_trace_formula=False):
        if self.translation is None:
            return False, None
        elif not spot_utils.check_wellformed(self.translation):
            return False, None

        cur_conjuct = get_conjucts_for_node(self,debug=False)
        if not spot_utils.check_wellformed(cur_conjuct):
            #print("WARNING: this node is not used by the final translation!")
            return False, None
        
        if spot_utils.check_formula_contains_formula(cur_conjuct,DUT_formula,use_contains_split=True):
            return True, None
        else:
            if ret_trace:
                #print("ATTEMPTING TO GET COUNTER EXAMPLE")
                trace = spot_utils.get_counter_example(DUT_formula,cur_conjuct,ret_trace_formula=ret_trace_formula)
                #print("FINISHED")
                return False, trace
            else:
                return False, None

    def clear(self):
        self.translation = None
        self.translation_type = None
        self.template_translation = None
        self.dcmp_dict = {}

def get_root(node):
    if node.parent is None:
        return node
    else:
        return get_root(node.parent)

def get_node_id(node):
    assert False, "deprecated"
    node_id = node.assert_text
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        node_id = node_id.replace(dcmp_node.assert_text,"_"+dcmp_var+"_")
    return node_id

def get_node_assert_text(node):
    assert False, "deprecated"
    return node.assert_text

def get_node_translation(node):
    if len(node.dcmp_dict) == 0:
        return node.translation
    else:
        return node.template_translation

def find_descendant_by_id(node,targ_node_id,get_node_id_func):
    if get_node_id_func(node) == targ_node_id:
        return node
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        desc_node = find_descendant_by_id(dcmp_node,targ_node_id,get_node_id_func=get_node_id_func)
        if desc_node is not None:
            return desc_node
    return None    

def find_descendant(node,targ_node_id):
    if get_unique_node_id(node) == targ_node_id:
        return node
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        desc_node = find_descendant(dcmp_node,targ_node_id)
        if desc_node is not None:
            return desc_node
    return None

def copy_graph(node):
    new_node = Node(assert_text=node.assert_text)
    new_node.translation = node.translation
    new_node.translation_type = node.translation_type
    new_node.template_translation = node.template_translation
    new_node.translate_fewshots = node.translate_fewshots
    new_node.decompose_fewshots = node.decompose_fewshots
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        new_node.dcmp_dict[dcmp_var] = copy_graph(dcmp_node)
        new_node.dcmp_dict[dcmp_var].parent = new_node
    return new_node

def get_conjucts_for_node(node,verbose=False,ret_list=False,debug=True,omit_trivial=True,depth=None):
    tmp_dcmp_dict = node.dcmp_dict
    tmp_template_translation = node.template_translation
    tmp_translation = node.translation
    cur_root = get_root(node)
    all_translation = cur_root.translation
    
    node_identifier = "specialAP"
    assert node_identifier not in cur_root.translation, cur_root.translation
    node.dcmp_dict = {}
    node.template_translation = None
    node.translation = node_identifier

    dfs_translate(cur_root,mode='NoRun',t_type='template')
    all_conjucts = spot_utils.get_conjucts(cur_root.translation,depth=depth)
    abstract_conjucts_list = [prop for prop in all_conjucts if node_identifier in prop]
    conjucts_for_node = " && ".join(abstract_conjucts_list)
    res = conjucts_for_node.replace(node_identifier,"("+tmp_translation+")")
    #abstract_conjucts_list = [prop.replace(node_identifier,"("+tmp_translation+")") for prop in all_conjucts if node_identifier in prop]
    #res = " && ".join(abstract_conjucts_list)
    node.dcmp_dict = tmp_dcmp_dict
    node.template_translation = tmp_template_translation
    node.translation = tmp_translation
    dfs_translate(cur_root,mode='NoRun',t_type='template')
    assert not debug or spot_utils.check_equivalent(" && ".join(all_conjucts).replace(node_identifier,"("+tmp_translation+")"),all_translation)
    if not ret_list:
        return res
    else:
        conjucts_for_node_list = [clause for clause in spot_utils.get_conjucts(res) if not omit_trivial or not spot_utils.check_equivalent("1",clause)]
        return conjucts_for_node_list

def get_disjuncts_for_node(node,verbose=False,ret_list=False):
    tmp_dcmp_dict = node.dcmp_dict
    tmp_template_translation = node.template_translation
    tmp_translation = node.translation
    cur_root = get_root(node)
    all_translation = cur_root.translation
    
    node_identifier = "specialAP"
    assert node_identifier not in cur_root.translation, cur_root.translation
    node.dcmp_dict = {}
    node.template_translation = None
    node.translation = node_identifier

    dfs_translate(cur_root,mode='NoRun',t_type='template')
    all_disjuncts = spot_utils.get_disjuncts(cur_root.translation)
    abstract_disjuncts_list = [prop for prop in all_disjuncts if node_identifier in prop]
    disjuncts_for_node = " | ".join(abstract_disjuncts_list)
    res = disjuncts_for_node.replace(node_identifier,"("+tmp_translation+")")
    #abstract_disjuncts_list = [prop.replace(node_identifier,"("+tmp_translation+")") for prop in all_disjuncts if node_identifier in prop]
    #res = " && ".join(abstract_disjuncts_list)
    
    node.dcmp_dict = tmp_dcmp_dict
    node.template_translation = tmp_template_translation
    node.translation = tmp_translation
    dfs_translate(cur_root,mode='NoRun',t_type='template')
    try:
        is_equal = spot_utils.check_equivalent(" | ".join(all_disjuncts).replace(node_identifier,"("+tmp_translation+")"),all_translation)
    except:
        is_equal = True
        print("WARNING: exception thrown when verifying get_disjuncts_for_node")
    assert is_equal
    if not ret_list:
        return res
    else:
        if len(abstract_disjuncts_list) > 0:
            disjuncts_for_node_list = [clause for clause in spot_utils.get_disjuncts(res) if not spot_utils.check_equivalent("0",clause)]
            return disjuncts_for_node_list
        else:
            return []

def get_all_ancestors(node,inclusive=True):
    if inclusive:
        cur_list = [node]
    else:
        cur_list = []
    if node.parent is not None:
        cur_list += get_all_ancestors(node.parent,inclusive=True)
    return cur_list

def get_all_descendants(node,inclusive=True):
    if inclusive:
        cur_list = [node]
    else:
        cur_list = []
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        cur_list += get_all_descendants(dcmp_node,inclusive=True)
    return cur_list

def dfs_decompose(node,mode='LLM',max_depth=None,**kwargs):
    if max_depth is not None and max_depth == 0:
        return node
    node.decompose(mode,**kwargs)
    if max_depth is not None:
        next_max_depth = max_depth-1
    else:
        next_max_depth = None
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        dfs_decompose(dcmp_node,mode=mode,max_depth=next_max_depth,**kwargs)
    return node

def dfs_translate(node,mode='LLM',t_type='regular',**kwargs):
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        dfs_translate(dcmp_node,mode=mode,t_type=t_type,**kwargs)
    if len(node.dcmp_dict.items()) == 0:
        node.translate(mode,t_type='regular',**kwargs)
    else:
        node.translate(mode,t_type=t_type,**kwargs)

def get_translate_cache(node,t_type='regular'):
    if t_type == 'regular':
        assert_text = node.assert_text
    elif t_type == 'template':
        abs_node_id = get_abstract_node_id(node)
        assert_text = abs_node_id
        assert assert_text in translate_dict, assert_text
    else:
        assert False
    if assert_text in translate_dict:
        res = translate_dict[assert_text]
    else:
        print("WARNING:",assert_text,"not in cached translations!")
        res = None
    return res

def get_decompose_cache(node):
    if node.assert_text in decompose_dict:
        dcmp_dict = decompose_dict[node.assert_text]
        res = dict((k,v) for k,v in dcmp_dict.items() if v != "" and v != node.assert_text)
    else:
        print("WARNING:",node.assert_text,"not in cached decompositions!")
        res = {}
    return res

def get_max_depth(node):
    cur_max = 0
    for dcmp_var,dcmp_node in node.dcmp_dict.items():
        cur_max = max(cur_max,get_max_depth(dcmp_node))
    return 1 + cur_max

def get_nodes_by_depth(cur_graph,depth,include_leaves=False):
    if include_leaves and (depth==0 or len(cur_graph.dcmp_dict) == 0):
        return [cur_graph]
    elif not include_leaves and depth==0:
        return [cur_graph]
    else:
        res_list = []
        for dcmp_var,dcmp_node in cur_graph.dcmp_dict.items():
            res_list += get_nodes_by_depth(dcmp_node,depth-1,include_leaves=include_leaves)
        return res_list

In [6]:
import Levenshtein
import spot
def helper_normalize_symbols(formula):
    map = {
            "&&":"&",
            "||" : "|",
            "<->" : "=",
            "->" : ">",
          }
    for k,v in map.items():
        formula = formula.replace(k,v)
    return "".join(formula.split())
    
def get_levenshtein_distance_nonformed(formula_a,formula_b):
    try:
        f_a = str(spot.formula(formula_a))
        f_b = str(spot.formula(formula_b))
    except:
        f_a = formula_a
        f_b = formula_b
    var_set = set()
    var_set.update(set(re.findall('_[a-zA-Z0-9_]*_',formula_a)))
    var_set.update(set(re.findall('_[a-zA-Z0-9_]*_',formula_b)))
    i = 0
    repl_list = list(set(string.ascii_lowercase+string.ascii_uppercase)-var_set)
    assert len(repl_list) > 0
    assert len(var_set) <= len(string.ascii_lowercase+string.ascii_uppercase)
    for var in var_set:
        f_a = f_a.replace(var,repl_list[i])
        f_b = f_b.replace(var,repl_list[i])
        i += 1
    f_a = helper_normalize_symbols(f_a)
    f_b = helper_normalize_symbols(f_b)
    return Levenshtein.distance(f_a,f_b)

In [7]:
def read_nl2spec_fewshots(fname):
    res = []
    decompose_fewshots = pd.read_excel(fname).to_dict('records')
    for entry in decompose_fewshots:
        subtranslation_list = ast.literal_eval(entry["sub-translations"])
        for subtranslation in subtranslation_list:
            #print(subtranslation)
            assert spot_utils.check_wellformed(subtranslation[-1])
        assert spot_utils.check_wellformed(entry["LTL"])
        cur_entry = entry.copy()
        cur_entry["sub-translations"] = [(sub[0],str(spot.formula(sub[1]))) for sub in subtranslation_list]
        res.append(cur_entry)
    return res
decompose_fewshots = read_nl2spec_fewshots('nl2spec_fewshots.xlsx')

In [8]:
nl2spec_prefix_prompt = \
"""
You are a Linear Temporal Logic (LTL) expert. Your answers always need to follow the following output format. 
Decompose the following natural language sentences into phrases that can be independently translated to an LTL formula.
Remember that X means "next", U means "until", G means "globally", F means "finally", which means GF means "infinitely often".
The formula should only contain atomic propositions or operators ||, &&, !, ->, <->, X, U, G, F.

The following explain each field of the translation:
Natural language: the natural language phrase to be decomposed into sub-translations and translated to LTL. The Natural language field must contain all of the input natural language text for the current translation.
sub-translations: a list of tuples which map sub-phrases of the Natural language and to the phrase's corresponding formalization in LTL.
LTL: the LTL formula which captures the intended formalization of the entire Natural language. The sub-translations should be used to compose the full LTL formula. Note that the LTL formula should be well-formed.
"""

nl2spec_prefix_prompt += "\nAccount for the following example translations:\n"
for ex in decompose_fewshots:
    nl2spec_prefix_prompt += "\n" +  str(ex) + "\n"

In [74]:
#test_str = open("specs/amba_master_godhal.txt").read()
#test_str = open("specs/amba_slave_godhal.txt").read()
test_str = open("specs/amba_arbiter_godhal.txt").read()
#test_str = "if the memory is empty and a read transfer is attempted, then the slave shall send an ERROR response."
#test_str = 'G1 When the slave is not selected by the decoder, HREADY signal shall be high.\nG2 When the slave is not selected by the decoder, HRESP shall be OKAY.\nG3 When no AHB transaction is taking place, HRESP shall be OKAY.\nG4 RD and WR signal cannot be high simultaneously.\nG5 If memory is full and write transfer is attempted, then the slave shall send an ERROR response. Similarly, if the memory is empty and a read transfer is attempted, then the slave shall send an ERROR response.\nG6 When slave is involved in a transfer, HWRITE is used to decide values of WR and RD.\nG7 When slave is involved in any transfer, signal HADDR is used to decide ADDR.\nG8 When slave is involved in write transfer, signal HWDATA is used to decide DI.\nG9 When slave is involved in read transfer, signal DO is used to decide HRDATA.'

In [75]:
decompose_dict = create_decompose_dict('rawcontext_decomposition-arbiter_v3.xlsx')
translate_dict = create_translate_dict('rawcontext_decomposition-arbiter_v3.xlsx')
#decompose_dict = create_decompose_dict('rawcontext_decomposition-master.xlsx')
#translate_dict = create_translate_dict('rawcontext_decomposition-master.xlsx')
#decompose_dict = create_decompose_dict('rawcontext_decomposition-slave.xlsx')
#translate_dict = create_translate_dict('rawcontext_decomposition-slave.xlsx')

In [76]:
%%time
formula_DUT = translate_dict[test_str]
cur_graph = Node(test_str)
dfs_decompose(cur_graph,mode='cache')
dfs_translate(cur_graph,mode='cache',t_type='template')
#assert spot_utils.check_equivalent(cur_graph.translation,formula_DUT)

CPU times: user 3.32 ms, sys: 0 ns, total: 3.32 ms
Wall time: 3.19 ms


In [77]:
get_max_depth(cur_graph)

6

In [78]:
len(get_all_descendants(cur_graph))

63

In [79]:
len(get_nodes_by_depth(cur_graph,depth=get_max_depth(cur_graph)-1,include_leaves=True))

41

In [15]:
cur_var_list = spot_utils.get_variables(formula_DUT)

In [16]:
def query_nl2spec(nl_phrase,cur_subtranslations,prefix_prompt,num_try=1):
    total_new_subtranslations = []
    new_translation = ""
    
    cur_prompt = prefix_prompt
    cur_prompt += "\nThe LTL translation may only contain a subset of the following variables: "+str(cur_var_list)
    cur_prompt += "\nProvide the LTL formula and sub-translations for the following natural language phrase (the translation must only use LTL operators and the format must be in JSON as shown in previous examples):\n"
    
    if len(cur_subtranslations) > 0:
        cur_obj = {"Natural language":nl_phrase,"sub-translations":list(cur_subtranslations)}
    else:
        cur_obj = {"Natural language":nl_phrase}
    cur_prompt += "\n" + str(cur_obj)
    for i in range(num_try):
        response = get_inference_response(cur_prompt,
                                          model="gpt-4-0125-preview"
                                          #model="gpt-3.5-turbo-0125"
                                         )
        pred = response["choices"][0]["message"]["content"]
        
        #print(cur_prompt)
        #print(pred)
        try:
            json_pred = json.loads(pred)
        except:
            pass
        try:
            new_subtranslations = json_pred["sub-translations"]
            new_subtranslations = [tuple(entry) for entry in new_subtranslations if len(entry) == 2]
        except:
            new_subtranslations = []
        total_new_subtranslations += new_subtranslations
        try:
            if new_translation == "":
                new_translation = json_pred["LTL"]
        except:
            pass
    return total_new_subtranslations, new_translation

In [17]:
def query_oracle_new_subtranslations(subtranslations,correct_subtranslations,bad_subtranslations):
    new_good_subtranslations = []
    for entry in subtranslations:
        if entry not in correct_subtranslations and entry not in bad_subtranslations:
                print("sub-translation found:")
                print(entry[0])
                print()
                print("translation:")
                print(entry[1])
                print()
                
                response = ""
                while "y" not in response and "n" not in response:
                    try:
                        response = input("is this a correct translation? y/n")
                    except:
                        response = ""
                        print("failed to get input")
                if "y" in response:
                    correct_subtranslations.add(entry)
                    new_good_subtranslations.append(entry)
                elif "n" in response:
                    bad_subtranslations.add(entry)
                else:
                    assert False
    return new_good_subtranslations

In [20]:
num_try = 3

oracle_translation_set = set()
all_translation_set = set()

correct_subtranslations = set()
bad_subtranslations = set()
cur_subtranslations = []
for iter in tqdm(range(get_max_depth(cur_graph))):
    prev_subtranslations = cur_subtranslations.copy()
    
    abs_node_set = set()
    for i in range(get_max_depth(cur_graph)-iter):
        abs_node_set.update(get_nodes_by_depth(cur_graph,depth=i,include_leaves=True))
    cur_best_subtranslations = []
    for node in abs_node_set:
        cur_best_subtranslations.append((get_abstract_node_id(node),str(spot.formula(get_node_translation(node)))))
        for dcmp_node in node.dcmp_dict.values():
            if dcmp_node not in abs_node_set:
                cur_best_subtranslations.append( (dcmp_node.assert_text,str(spot.formula(dcmp_node.translation))) )
    if iter > 0:
        for entry in cur_best_subtranslations:
            if entry not in all_translation_set:
                oracle_translation_set.add(entry)
    #print(len(abs_node_set))
    #print(len(cur_best_subtranslations))
    print(iter,len(oracle_translation_set))
    cur_subtranslations = list(correct_subtranslations) + cur_best_subtranslations
    all_translation_set.update(set(cur_subtranslations))
    new_subtranslations, new_translation = query_nl2spec(test_str,cur_subtranslations,nl2spec_prefix_prompt,num_try=num_try)
    all_translation_set.update(set(new_subtranslations))
    all_translation_set.add((test_str,new_translation))
    if spot_utils.check_wellformed(new_translation) and spot_utils.check_equivalent(new_translation,formula_DUT,use_contains_split=True):
        break

  0%|                                                     | 0/5 [00:00<?, ?it/s]

0 0


 20%|████████▊                                   | 1/5 [06:52<27:29, 412.26s/it]

1 0


 40%|█████████████████▌                          | 2/5 [12:43<18:48, 376.09s/it]

2 2
OpenAI API returned an API Error: HTTP code 502 from API (<html>
<head><title>502 Bad Gateway</title></head>
<body>
<center><h1>502 Bad Gateway</h1></center>
<hr><center>cloudflare</center>
</body>
</html>
)


 60%|██████████████████████████▍                 | 3/5 [20:02<13:29, 404.84s/it]

3 15


 80%|███████████████████████████████████▏        | 4/5 [29:39<07:52, 472.82s/it]

4 17


100%|████████████████████████████████████████████| 5/5 [35:39<00:00, 427.88s/it]


In [21]:
if not (spot_utils.check_wellformed(new_translation) and spot_utils.check_equivalent(new_translation,formula_DUT)):
    oracle_translation_set.add((test_str,formula_DUT))
    all_translation_set.add((test_str,formula_DUT))
    print("LLM fail!")

LLM fail!


In [97]:
get_levenshtein_distance_nonformed("0",str(spot.formula(formula_DUT)))

490

In [92]:
exp_name = "godhalarbiter_gpt4"
#save_nl2spec_data(exp_name,all_translation_set,oracle_translation_set)
all_translation_set,oracle_translation_set = load_nl2spec_data(exp_name)

In [93]:
len(oracle_translation_set)

23

In [94]:
len(all_translation_set)

94

In [95]:
res_list = []
for entry in all_translation_set:
    res_list.append(get_levenshtein_distance_nonformed("0",entry[-1]))
print(np.sum(res_list))
print(np.mean(res_list))
print(np.std(res_list))
print(np.max(res_list))

4254
45.255319148936174
110.97651046938704
490


In [96]:
res_list = []
for entry in oracle_translation_set:
    res_list.append(get_levenshtein_distance_nonformed("0",entry[-1]))
print(np.sum(res_list))
print(np.mean(res_list))
print(np.std(res_list))
print(np.max(res_list))

1374
59.73913043478261
113.34833474978008
490


In [27]:
def save_nl2spec_data(exp_name,all_translation_set,all_oracle_translation_set):
    with open(exp_name+'_nl2spec_alltranslation.json', 'w') as f:
        json.dump(list(all_translation_set), f)
    with open(exp_name+'_nl2spec_oracletranslation.json', 'w') as f:
        json.dump(list(oracle_translation_set), f)
def load_nl2spec_data(exp_name):
    all_translation_set = json.load(open(exp_name+'_nl2spec_alltranslation.json', 'r'))
    oracle_translation_set = json.load(open(exp_name+'_nl2spec_oracletranslation.json', 'r'))
    return all_translation_set,oracle_translation_set