In [None]:
import os
import numpy as np
import torch
import json
from tqdm import tqdm

In [None]:
with open("ft_datasets/tool_data_train_STE_full.json") as f:
    train_data = json.load(f)

with open("ft_datasets/tool_test.json") as f:
    test_data = json.load(f)
    
with open("ft_datasets/api2neighbors.json", "r", encoding='utf-8') as f:
    api2neighbors = json.load(f)
    
with open("ft_datasets/API_descriptions.json") as f:
    API_descriptions = json.load(f)

with open("ft_datasets/tool_test.json") as f:
    all_apis = list(test_data.keys())

train_items, train_queries = [], []
for item in train_data:
    query = item['query']
    if query not in train_queries:
        train_queries.append(query)
        train_items.append(item)

In [None]:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2')

train_query_embeddings = model.encode(train_queries, convert_to_tensor=True)

In [None]:
num_examples_retrieve = 8
for key in test_data:
    print(key)
    examples = test_data[key]
    for i in tqdm(range(len(examples))):
        item = examples[i]
        test_query_embedding = model.encode([item['query']], convert_to_tensor=True)
        cosine_scores = util.cos_sim(test_query_embedding, train_query_embeddings)[0]
        pairs = []
        for j in range(len(cosine_scores)):
            pairs.append({'index': j, 'score': cosine_scores[j]})
        pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
        item['demo'] = [train_items[var['index']] for var in pairs[:num_examples_retrieve]]

        examples[i] = item
    test_data[key] = examples

In [None]:
with open("ft_datasets/tool_data_test_with_demo.json", "w", encoding='utf-8') as f:
    json.dump(test_data, f)

In [None]:
for key in test_data:
    # oracle tool retriever
    tool_list = api2neighbors[key]
    api_descriptions = "\n\n".join(["API_name: {}\nDescription: {}".format(API_name, API_descriptions[API_name]) for API_name in tool_list])

    examples = test_data[key]
    for i in range(len(examples)):
        item = examples[i]

        prompt = prompt_template.format(api_descriptions=api_descriptions, api_names="\n".join(tool_list))

        prompt = prompt + "\n\nUser Query: " + item['query']

        dialog = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]

        item['dialog_history'] = dialog
        examples[i] = item
    test_data[key] = examples

with open("ft_datasets/tool_test_OTR.json", "w") as f:
    json.dump(test_data, f)

In [None]:
for key in test_data:
    tool_list = api2neighbors[key]
    api_descriptions = "\n\n".join(["API_name: {}\nDescription: {}".format(API_name, API_descriptions[API_name]) for API_name in tool_list])

    examples = test_data[key]
    for i in range(len(examples)):
        item = examples[i]

        prompt = prompt_template.format(api_descriptions=api_descriptions, api_names="\n".join(tool_list))

        # demonstration
        demo_examples = item['demo']

        prompt = prompt + "\n\nBelow are some examples:\n\n" + \
            "---\n".join(["User Query: {}\nAction: {}\nAction Input: {}\n".format(demo['query'], demo['action'], demo['action_input']) for demo in demo_examples]) + \
            "Now it's your turn.\n\nUser Query: " + item['query']

        dialog = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]

        item['dialog_history'] = dialog
        examples[i] = item
    test_data[key] = examples

with open("ft_datasets/tool_test_OTR_DR.json", "w") as f:
    json.dump(test_data, f)