In [1]:
from copy import deepcopy
from utils import find_reverse, random_choose, parse_response, strip_end
import numpy as np
import json
import os

In [2]:
string_match_APIs = [
    'Global Email V4%%Global Email V4',
    'BART%%Advisory information',
    'JAK_API%%Ben 10',
    'colegiosantaana%%Disciplina-1',
    'MikuAPI%%getRandomImage',
    'Handball Data%%Daily Match List-Scheduled',
    'Token API%%generate',
    'Numbers Translator%%Numbers Translator',
    'Password Generator API%%Password of length 50',
    'Places%%Geographic coordinates by placename',
    'pizzaallapala%%Get Producto Promo',
    'Password Generator API%%Base',
    'JAK_API%%Brawl Stars',
    'siteDomain%%language list',
    'Fluximmo%%get_portail_api',
    'NumbersToLetters%%Convertir cantidad a letra Moneda MXN Ingles',
    'thailand%%thai4',
    'colegiosantaana%%Mensajes-1',
    'Soccer Data%%Tournament List',
    'Marvel Vs Capcom 2%%All Characters',
    'JAK_API%%Miraculous',
    'NumbersToLetters%%Convertir cantidad a letra Moneda MXN Español',
    'F1 Latest News%%GET recent F1 news from all sources',
    'get_today_date',
    'Numbers%%Get math fact',
    'forecast_weather',
    'Car database%%Makes',
    'colegiosantaana%%Evaluaciones-1',
    'JAK_API%%Genshin Impact',
    'Football Dolphin%%Head to head statistics',
    'add_date',
]

In [3]:
from my_llm import chat_my, visualize_messages, get_chat_completion_my
model_ckpts = 'gpt-3.5-turbo-16k-0613'

In [4]:
def eval_pred_file(file_name, key_output='model_output', is_parsed=False, visualize=False):

    with open(file_name, "r", encoding='utf-8') as f:
        dataset = json.load(f)

    for gt_api in dataset:

        examples = dataset[gt_api]
        for ii in range(len(examples)):
            item = examples[ii]
        
            item['no_call'] = 0
            if is_parsed:
                parsed = item['parsed_result']
            else:
                res = item[key_output].strip()
                parsed = parse_response(res, API_name_list=list(dataset.keys()), api_descriptions="XXX", proc_toolken=True, ground_API=True)
            
            if parsed['finish']:
                item['err'] = 0
                item['no_call'] = 1
                examples[ii] = item
                continue
            
            if not parsed['parse_successful']:
                item['err'] = 1
                examples[ii] = item
                continue

            try:
                json.loads(parsed['action_input'])
            except:
                item['err'] = 1
                examples[ii] = item
                continue
                
            item['err'] = 0
            
            if parsed['action'] != gt_api:
                item['api_match'] = 0
            else:
                item['api_match'] = 1
                
                gt_action_input = item['action_input']
                model_action_input = parsed['action_input']

                gt_dict = json.loads(gt_action_input)                
                model_dict = json.loads(model_action_input)

                # check semantic correctenss based on API call
                if gt_api in string_match_APIs:
                    # check via string matching
                    string_same = True

                    for key, val in gt_dict.items():
                        if key in model_dict and str(model_dict[key]).strip().lower() == str(val).strip().lower():
                            pass
                        else:
                            string_same = False
                            break
                    item['args_correct'] = int(string_same)

                else:
                    # check the correctness via ChatGPT
                    messages = [
                        {"role": "system", "content": "You are a helpful assistant."}
                    ]

                    msg = "Your task is to judge whether an API call is correct with respect to the given ground truth API call. Note that the API call doesn't have to be exactly the" \
                    " same as the ground truth; it only needs to be semantically correct. It should not miss any important details in the arguments.\n\n" \
                    "The ground truth API call is:\nAPI name: {}\nAPI arguments: {}\n\n" \
                    "The API call that you need to verify the correctness is:\nAPI name: {}\nAPI arguments: {}\n\n" \
                    "Now say your judgment. Your response should always start with \"Yes.\" or \"No.\" indicating whether it's correct.\nYour response:"

                    jud = chat_my(messages, msg.format(gt_api, json.dumps(gt_dict), gt_api, json.dumps(model_dict)), 
                                  temp=0.0, stop="Observation:", visualize=visualize, max_tokens=256, model=model_ckpts)[-1]['content']

                    item['args_correct'] = int("No." not in jud)
                    
            examples[ii] = item

        dataset[gt_api] = examples

    with open(file_name, "w", encoding='utf-8') as f:
        json.dump(dataset, f)
        
        
def eval_batch(file_name, key_list=None):
    if type(file_name) == str:
        with open(file_name, "r") as f:
            dataset_evaled = json.load(f)
    else:
        dataset_evaled = file_name
    
    correct = 0
    syntax_err, no_call = 0, 0
    total = 0
    api_match = 0
    non_err = 0
        
    for key, examples in dataset_evaled.items():
        if not (key_list is None or key in key_list):
            continue

        for item in examples:
            total += 1
            
            if item['no_call']:
                no_call += 1
                continue
            
            if item['err']:
                syntax_err += 1
                continue
                
            non_err += 1
            
            if item['api_match']:
                api_match += 1
                correct += item['args_correct']
                continue
    
    print("% wellformed:", round(100*(non_err/total), 3))
    print("% api match:", round(100*api_match/non_err, 3))
    print("% correct:", round(100*correct/total, 3))

In [5]:
eval_batch("saved_results/toolLLMv2.json")

% wellformed: 98.133
% api match: 49.049
% correct: 37.333


In [6]:
eval_batch("saved_results/llama-7Bf.json")
print("----")
eval_batch("saved_results/llama-7Bf_ICL.json")
print("----")
eval_batch("saved_results/llama-7Bf_FT.json")

% wellformed: 34.533
% api match: 40.154
% correct: 10.667
----
% wellformed: 58.267
% api match: 86.728
% correct: 41.733
----
% wellformed: 99.2
% api match: 94.892
% correct: 73.333


In [7]:
eval_batch("saved_results/llama-13Bf.json")
print("----")
eval_batch("saved_results/llama-13Bf_ICL.json")
print("----")
eval_batch("saved_results/llama-13Bf_FT.json")

% wellformed: 79.333
% api match: 53.613
% correct: 32.667
----
% wellformed: 87.467
% api match: 86.585
% correct: 62.933
----
% wellformed: 98.933
% api match: 95.148
% correct: 74.267


In [8]:
eval_batch("saved_results/mistral.json")
print("----")
eval_batch("saved_results/mistral_ICL.json")
print("----")
eval_batch("saved_results/mistral_FT.json")

% wellformed: 61.733
% api match: 69.33
% correct: 30.133
----
% wellformed: 69.867
% api match: 88.359
% correct: 47.867
----
% wellformed: 99.067
% api match: 95.828
% correct: 76.8


In [9]:
eval_batch("saved_results/gpt35.json")
print("---")
eval_batch("saved_results/gpt35_ICL.json")

% wellformed: 96.933
% api match: 77.579
% correct: 60.533
---
% wellformed: 97.6
% api match: 90.847
% correct: 75.6


In [10]:
eval_batch("saved_results/gpt4.json")
print("---")
eval_batch("saved_results/gpt4_ICL.json")

% wellformed: 96.133
% api match: 78.086
% correct: 60.8
---
% wellformed: 97.733
% api match: 92.769
% correct: 76.267


In [11]:
# ablations
eval_batch("saved_results/llama-7Bf_FT_no_exec.json")
print("--")
eval_batch("saved_results/llama-7Bf_FT_no_STM.json")
print("--")
eval_batch("saved_results/llama-7Bf_FT_no_LTM.json")
print("--")
eval_batch("saved_results/llama-7Bf_FT_no_reflection.json")

% wellformed: 89.867
% api match: 79.377
% correct: 50.533
--
% wellformed: 99.733
% api match: 70.588
% correct: 53.867
--
% wellformed: 98.667
% api match: 79.865
% correct: 59.733
--
% wellformed: 99.333
% api match: 81.745
% correct: 60.133


In [12]:
with open("tool_metadata/CL_batches.json") as f:
    batches = json.load(f)

In [13]:
for i in range(4):
    print("batch", i)
    eval_batch("saved_results/llama-7Bf_FT_flan.json", batches[i])
    print("--")
eval_batch("saved_results/llama-7Bf_FT_flan.json")

batch 0
% wellformed: 100.0
% api match: 92.222
% correct: 73.333
--
batch 1
% wellformed: 100.0
% api match: 98.974
% correct: 87.179
--
batch 2
% wellformed: 100.0
% api match: 92.778
% correct: 68.333
--
batch 3
% wellformed: 95.897
% api match: 96.791
% correct: 67.179
--
% wellformed: 98.933
% api match: 95.283
% correct: 74.133


In [14]:
for round_ in range(4):
    print("round:", round_)
    print("=====")
    for i in range(round_+1):
        print("batch", i)
        eval_batch("saved_results/CL_round_{}.json".format(round_), batches[i])
        print("--")
    print("total:")
    eval_batch("saved_results/CL_round_{}.json".format(round_))
    print("=============================")

round: 0
=====
batch 0
% wellformed: 99.444
% api match: 98.883
% correct: 80.556
--
total:
% wellformed: 99.444
% api match: 98.883
% correct: 80.556
round: 1
=====
batch 0
% wellformed: 97.778
% api match: 94.886
% correct: 76.111
--
batch 1
% wellformed: 100.0
% api match: 96.923
% correct: 84.103
--
total:
% wellformed: 98.933
% api match: 95.957
% correct: 80.267
round: 2
=====
batch 0
% wellformed: 100.0
% api match: 87.222
% correct: 70.556
--
batch 1
% wellformed: 100.0
% api match: 97.436
% correct: 84.103
--
batch 2
% wellformed: 100.0
% api match: 91.667
% correct: 65.556
--
total:
% wellformed: 100.0
% api match: 92.252
% correct: 73.694
round: 3
=====
batch 0
% wellformed: 100.0
% api match: 82.778
% correct: 65.0
--
batch 1
% wellformed: 100.0
% api match: 97.436
% correct: 88.718
--
batch 2
% wellformed: 100.0
% api match: 87.222
% correct: 66.111
--
batch 3
% wellformed: 92.308
% api match: 97.222
% correct: 70.256
--
total:
% wellformed: 98.0
% api match: 91.293
% corr

In [15]:
for round_ in range(1, 4):
    print("round:", round_)
    print("=====")
    for i in range(round_+1):
        print("batch", i)
        eval_batch("saved_results/CL_round_{}_no_replay.json".format(round_), batches[i])
        print("--")
    print("total:")
    eval_batch("saved_results/CL_round_{}_no_replay.json".format(round_))
    print("=============================")

round: 1
=====
batch 0
% wellformed: 98.889
% api match: 2.247
% correct: 1.667
--
batch 1
% wellformed: 100.0
% api match: 96.923
% correct: 87.692
--
total:
% wellformed: 99.467
% api match: 51.743
% correct: 46.4
round: 2
=====
batch 0
% wellformed: 97.778
% api match: 0.0
% correct: 0.0
--
batch 1
% wellformed: 99.487
% api match: 63.918
% correct: 56.923
--
batch 2
% wellformed: 100.0
% api match: 92.778
% correct: 68.889
--
total:
% wellformed: 99.099
% api match: 52.909
% correct: 42.342
round: 3
=====
batch 0
% wellformed: 100.0
% api match: 0.0
% correct: 0.0
--
batch 1
% wellformed: 99.487
% api match: 47.423
% correct: 38.462
--
batch 2
% wellformed: 100.0
% api match: 34.444
% correct: 25.0
--
batch 3
% wellformed: 96.923
% api match: 96.296
% correct: 71.795
--
total:
% wellformed: 99.067
% api match: 45.222
% correct: 34.667
