# Setup OpenAI

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#### remember modify the LIB before start running
LIB = "ehrapy"
LIB_ALIAS = "ehrapy"
#### Also notice that the prompt also contains `scanpy` which needs to be modified manually

mode_index = 'randomseed' # 'similarseed' or randomseed
# whether use similar shot example retriving mode, the similar shot example retriving is 
# to retrieve the similar queries that similar to the input query
# noted that this mode always retrieve 5 shot queries for the same API, as the query for same API is always similar
oracle_index = 'noncorrected' # 'noncorrected' or corrected
# whether use the corrected mode after retrieving API. The retrieved API is different from retrieved query, 
# retrieved API will provide an API list for gpt to select
# if under correct mode, we will put the ground truth API into the retrieved API list, 
# and delete the last one API under the retrieved API list
retrieved_index = 'retrieved' # nonretrieved or retrieved
# if using retrieved mode, then we will provide a filtered retrieved list
# otherwise, we will provide the whole API list for gpt to select


In [None]:
import gpt_interface, random, bz2, json, re, os
from tqdm import auto as tqdm
import logging
logging.basicConfig(level=logging.CRITICAL)  # turn off logging
mode = 'openai'
k_shot = 5
secrets = gpt_interface.setup_openai('../configs/secrets.json', mode=mode)
# load data
with open(f'../data/standard_process/{LIB}/API_inquiry_annotate.json', 'r') as f:
    data = json.load(f)

# load val ids
with open(f"../data/standard_process/{LIB}/API_instruction_testval_query_ids.json", 'r') as file:
    files_ids = json.load(file)

test = [dict(query=row['query'], gold=row['api_name']) for row in [i for i in data if i['query_id'] in files_ids['test']]]
val = [dict(query=row['query'], gold=row['api_name']) for row in [i for i in data if i['query_id'] in files_ids['val']]]
train_remain = [dict(query=row['query'], gold=row['api_name']) for row in [i for i in data if (i['query_id'] not in files_ids['val']) and (i['query_id'] not in files_ids['test'])]]

print('train:', len(train_remain), 'val:', len(val), 'test:', len(test))

# add K-shot 
shuffled = [dict(query=row['query'], gold=row['api_name']) for row in [i for i in data if i['query_id'] not in files_ids['val'] and i['query_id'] not in files_ids['test']]]
random.Random(0).shuffle(shuffled)
print(len(shuffled))
print(len(data))
assert len(data)==len(test)+len(val)+len(shuffled)
train = shuffled[:k_shot]
# all-apis
from utils import get_all_api_json
all_apis, all_apis_json = get_all_api_json(f"../data/standard_process/{LIB}/API_init.json", mode='full')
len(all_apis), len(all_apis_json)
# load API_init
with open(f'../data/standard_process/{LIB}/API_init.json', 'r') as f:
    API_init = json.load(f)


In [None]:
from sentence_transformers import SentenceTransformer, util
class ToolRetriever:
    def __init__(self, corpus_tsv_path = "", model_path="",shuffled_data=[]):
        self.build_retrieval_corpus(corpus_tsv_path, model_path,shuffled_data)
    def build_retrieval_corpus(self, corpus_tsv_path, model_path,shuffled_data):
        print("Building corpus...")
        self.corpus_tsv_path = corpus_tsv_path
        self.model_path = model_path
        documents_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
        corpus, self.corpus2tool = process_retrieval_document_query_version(documents_df)
        corpus_ids = list(corpus.keys())
        corpus = [corpus[cid] for cid in corpus_ids]
        self.corpus = corpus
        self.embedder = SentenceTransformer(self.model_path, device=device)
        self.corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)
        self.shuffled_data = shuffled_data
        self.shuffled_queries = [item['query'] for item in shuffled_data]
        self.shuffled_query_embeddings = self.embedder.encode(self.shuffled_queries, convert_to_tensor=True)
    def retrieving(self, query, top_k):
        query_embedding = self.embedder.encode(query, convert_to_tensor=True)
        hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=top_k, score_function=util.cos_sim) #170*
        retrieved_apis = [self.corpus2tool[self.corpus[hit['corpus_id']]] for hit in hits[0]]
        #scores = [hit['score'] for hit in hits[0]]
        return retrieved_apis[:top_k]
    def retrieve_similar_queries(self, query, shot_k=5):
        query_embedding = self.embedder.encode(query, convert_to_tensor=True)
        # filter class/composite API
        hits = util.semantic_search(query_embedding, self.shuffled_query_embeddings, top_k=shot_k+10, score_function=util.cos_sim)
        hits = [hit for hit in hits if self.shuffled_data[hit['corpus_id']]['gold'].startswith(LIB_ALIAS)]
        hits = [hit for hit in hits if API_init[self.shuffled_data[hit['corpus_id']]['gold']]['api_type'] not in ['class', 'unknown']]
        hits = hits[:shot_k]
        #similar_queries = [shuffled_data[hit['corpus_id']] for hit in hits[0]]
        similar_queries = ["\nInstruction: " + self.shuffled_data[hit['corpus_id']]['query'] + "\nFunction: [" + self.shuffled_data[hit['corpus_id']]['gold']+"]" for hit in hits[0]]
        return ''.join(similar_queries)

def process_retrieval_document_query_version(documents_df):
    ir_corpus = {}
    corpus2tool = {}
    for row in documents_df.itertuples():
        doc = json.loads(row.document_content)
        ir_corpus[row.docid] = compress_api_str_from_list_query_version(doc)
        corpus2tool[compress_api_str_from_list_query_version(doc)] = doc['api_calling'][0].split('(')[0]
    return ir_corpus, corpus2tool

def compress_api_str_from_list_query_version(api):
    api_name = api['api_calling'][0].split('(')[0]
    api_desc_truncated = api['api_description'].split('\n')[0]
    req_params = json.dumps(api['required_parameters'])
    opt_params = json.dumps(api['optional_parameters'])
    return_schema = json.dumps(api['Returns'])
    compressed_str = f"{api_name}, {api_desc_truncated}, required_params: {req_params}, optional_params: {opt_params}, return_schema: {return_schema}"
    return compressed_str

import hashlib, random
from utils import correct_pred, get_sampled_shuffled, generate_seed

import pandas as pd
device = 'cuda:0'
retriever = ToolRetriever(corpus_tsv_path=f"../data/standard_process/{LIB}/retriever_train_data/corpus.tsv", model_path=f"../hugging_models/retriever_model_finetuned/{LIB}/assigned/",shuffled_data=shuffled)

# Query to API selection

## K-shot

Here, GPT does not see candidate list of APIs, it just tries to tell the correct function from memory.

In [None]:
from utils import get_generate_prompt
prompt = get_generate_prompt()
print(prompt)

def run_gpt(test, gpt_model, prompt, dout, mode, max_tokens=20,title=""):
    correct = []
    for ex in (pbar := tqdm.tqdm(test)):
        if (not ex['gold'].startswith(LIB_ALIAS)) or (ex['gold'] not in API_init):# do not test composite API
            print('filters composite API')
            continue
        elif (API_init[ex['gold']]['api_type'] in ['class', 'unknown']): # do not test class type API query
            print('filters class API')
            continue
        else:
            pass
        if mode_index == 'similarseed':
            similar_queries = retriever.retrieve_similar_queries(ex['query'],shot_k=5)
        elif mode_index == 'randomseed':
            #sampled_shuffled = random.sample(shuffled, 5)
            shot_k=5
            sampled_shuffled = get_sampled_shuffled(ex['gold'], shuffled, num_samples=shot_k+10)
            sampled_shuffled = [hit for hit in sampled_shuffled if hit['gold'].startswith(LIB_ALIAS)]
            sampled_shuffled = [hit for hit in sampled_shuffled if API_init[hit['gold']]['api_type'] not in ['class', 'unknown']]
            sampled_shuffled = sampled_shuffled[:shot_k]
            assert len(sampled_shuffled)==shot_k
            similar_queries = "".join(["\nInstruction: " + iii['query'] + "\nFunction: [" + iii['gold']+"]" for iii in sampled_shuffled])
        else:
            raise NotImplementedError
        # do not provide retrieved part
        p = gpt_interface.query_openai(prompt.format(lib_name=LIB_ALIAS, query=ex['query'],similar_queries=similar_queries), mode=mode, model=gpt_model, max_tokens=max_tokens)
        p = p.replace('[','').replace(']','')
        parts = p.split(',')
        result = []
        for part in parts:
            part = part.strip()
            part = correct_pred(part, LIB_ALIAS)
            result.append(part)
        p = result
        print('==>ask: ')
        print(prompt.format(lib_name=LIB_ALIAS, query=ex['query'],similar_queries=similar_queries))
        print(f'==>answer: {p}')
        ex['pred'] = p
        ex['correct'] = c = ex['pred'][0] == ex['gold']
        ex['prompt'] = prompt.format(lib_name=LIB_ALIAS, query=ex['query'],similar_queries=similar_queries)
        correct.append(c)
        pbar.set_description('correct: {}'.format(sum(correct)/len(correct)))
    with open(os.path.join(dout, '{}.json'.format(title)), 'wt') as f:
        title = os.path.join(dout, '{}.json'.format(title))
        print(f'save to {title}')
        json.dump(test, f, indent=2)

import os
folder_name = "{}/{}-shot-generate".format(LIB,k_shot)
print(f'makedir for {folder_name}')
os.makedirs(folder_name, exist_ok=True)

In [None]:
title = f'gpt-3.5-turbo-0125-trainsample'
run_gpt(val, 'gpt-3.5-turbo-0125', prompt, folder_name, mode,title=title)
title = f'gpt-3.5-turbo-0125-test'
run_gpt(test, 'gpt-3.5-turbo-0125', prompt, folder_name, mode,title=title)

# Classification
Here, GPT sees the list of available APIs and tries to pick out the correct one

In [None]:
def run_gpt_new(test, gpt_model, prompt, dout, mode, max_tokens=20,top_k=3,title=""):
    correct = []
    for ex in (pbar := tqdm.tqdm(test)):
        # ignore compositeAPI and classAPI query
        if (not ex['gold'].startswith(LIB_ALIAS)) or (ex['gold'] not in API_init):# do not test composite API
            print('filters composite API')
            continue
        elif (API_init[ex['gold']]['api_type'] in ['class', 'unknown']): # do not test class type API query
            print('filters class API')
            continue
        else:
            pass
        if mode_index == 'similarseed':
            similar_queries = retriever.retrieve_similar_queries(ex['query'],shot_k=5)
        elif mode_index == 'randomseed':
            shot_k=5
            #sampled_shuffled = random.sample(shuffled, shot_k+10)
            sampled_shuffled = get_sampled_shuffled(ex['gold'], shuffled, num_samples=shot_k+10)
            # filter composite/class API
            #sampled_shuffled = util.semantic_search(query_embedding, self.shuffled_query_embeddings, top_k=shot_k+10, score_function=util.cos_sim)
            sampled_shuffled = [hit for hit in sampled_shuffled if hit['gold'].startswith(LIB_ALIAS)]
            sampled_shuffled = [hit for hit in sampled_shuffled if API_init[hit['gold']]['api_type'] not in ['class', 'unknown']]
            sampled_shuffled = sampled_shuffled[:shot_k]
            similar_queries = ""
            for iii in sampled_shuffled:
                tmp_retrieved_api_list = retriever.retrieving(iii['query'], top_k=top_k+10)
                tmp_retrieved_api_list = [i for i in tmp_retrieved_api_list if (i.startswith(LIB_ALIAS)) and (i in API_init) and (API_init[i]['api_type'] not in ['class', 'unknown'])]
                # for the k-shot incontext example, require the groundtruth API must be in the sub retrieved_api_list, otherwise GPT won't understand the tasks
                # noted that this is not a cheating step, as we just did it for incontext examples.
                tmp_retrieved_api_list = tmp_retrieved_api_list[:top_k]
                if iii['gold'] not in tmp_retrieved_api_list:
                    tmp_retrieved_api_list = tmp_retrieved_api_list[:top_k-1]+[iii['gold']]
                random.shuffle(tmp_retrieved_api_list)
                function_candidates = ""
                for idx, api in enumerate(tmp_retrieved_api_list):
                    if idx<top_k:
                        function_candidates += f"{idx}:" + api + ", description: "+all_apis_json[api].replace('\n',' ')+"\n"
                similar_queries += "function candidates:\n" + function_candidates+"Instruction: " + iii['query'] + "\nFunction: [" + iii['gold']+"]"+ "\n---\n"
        else:
            raise NotImplementedError
        if retrieved_index=='retrieved':
            #retrieved_api_list = retriever.retrieving(ex['query'], top_k=top_k)
            retrieved_api_list = retriever.retrieving(ex['query'], top_k=top_k+10)
            retrieved_api_list = [i for i in retrieved_api_list if (i.startswith(LIB_ALIAS)) and (i in API_init) and (API_init[i]['api_type'] not in ['class', 'unknown'])]
            retrieved_api_list = retrieved_api_list[:top_k]
            assert len(retrieved_api_list)==top_k
            assert all(i.startswith(LIB_ALIAS) and (API_init[i]['api_type'] not in ['class', 'unknown']) and (i in API_init) for i in retrieved_api_list)
            if oracle_index=='corrected':
                if ex['gold'] not in retrieved_api_list:
                    retrieved_api_list = [ex['gold']] + retrieved_api_list[:-1]
                assert ex['gold'] in retrieved_api_list
            elif oracle_index=='noncorrected':
                pass
            else:
                raise NotImplementedError
            retrieved_apis = ""
            for idx, api in enumerate(retrieved_api_list):
                retrieved_apis+=f"{idx}:" + api+", description: "+all_apis_json[api].replace('\n',' ')+"\n"
        elif retrieved_index=='nonretrieved':
            retrieved_apis = ""
            for idx, api in enumerate(all_apis_json):
                if (not api.startswith(LIB_ALIAS)) or (API_init[api]['api_type'] in ['class', 'unknown']):
                    continue
                retrieved_apis+=f"{idx}:" + api+", description: "+all_apis_json[api].replace('\n',' ')+"\n"
        else:
            raise NotImplemented
        #print(ex['gold'],retrieved_api_list)
        if ex['gold'] in retrieved_api_list:# retriever correct
            p = gpt_interface.query_openai(prompt.format(query=ex['query'],retrieved_apis=retrieved_apis,similar_queries=similar_queries), mode=mode, model=gpt_model, max_tokens=max_tokens)
            p = p.replace('[','').replace(']','')
            parts = p.split(',')
            result = []
            for part in parts:
                part = part.strip()
                part = correct_pred(part, LIB)
                result.append(part)
            p = result
            print('==>ask: ')
            print(prompt.format(query=ex['query'],retrieved_apis=retrieved_apis,similar_queries=similar_queries))
            print(f'==>answer: {p}')
            ex['pred'] = p
            ex['correct'] = c = ex['gold']==ex['pred'][0]
        else:# retriever wrong
            ex['pred'] = None
            ex['correct'] = c = False
        ex['retrieved_apis'] = retrieved_api_list
        ex['prompt'] = prompt.format(query=ex['query'],retrieved_apis=retrieved_apis,similar_queries=similar_queries)
        correct.append(c)
        pbar.set_description('correct: {}'.format(sum(correct)/len(correct)))
    with open(os.path.join(dout, '{}.json'.format(title)), 'wt') as f:
        json.dump(test, f, indent=2)

import os
folder_name = "{}/{}-shot-classify".format(LIB,k_shot)
os.makedirs(folder_name, exist_ok=True)

In [None]:
from utils import get_retrieved_prompt, get_nonretrieved_prompt

if retrieved_index == 'retrieved':
    print('get_retrieved_prompt')
    prompt = get_retrieved_prompt()
    top_k = 3
    title = f'gpt-3.5-turbo-0125-topk-{top_k}-trainsample'
    run_gpt_new(val, 'gpt-3.5-turbo-0125', prompt, folder_name, mode,top_k=top_k,title=title)
    title = f'gpt-3.5-turbo-0125-topk-{top_k}-test'
    run_gpt_new(test, 'gpt-3.5-turbo-0125', prompt, folder_name, mode,top_k=top_k,title=title)
elif retrieved_index == 'nonretrieved':
    print('get_nonretrieved_prompt')
    prompt = get_nonretrieved_prompt()
    top_k = 3
    title = f'gpt-3.5-turbo-0125-topk-{top_k}-{retrieved_index}-trainsample'
    run_gpt_new(val, 'gpt-3.5-turbo-0125', prompt, folder_name, mode,top_k=top_k,title=title)
    title = f'gpt-3.5-turbo-0125-topk-{top_k}-{retrieved_index}-test'
    run_gpt_new(test, 'gpt-3.5-turbo-0125', prompt, folder_name, mode,top_k=top_k,title=title)

In [None]:
top_k = 3
title = f'gpt-4-0125-preview-topk-{top_k}-trainsample'
run_gpt_new(val, 'gpt-4-0125-preview', prompt, folder_name, mode,top_k=top_k,title=title)
title = f'gpt-4-0125-preview-topk-{top_k}-test'
run_gpt_new(test, 'gpt-4-0125-preview', prompt, folder_name, mode, top_k=top_k,title=title)

### ambiguous pair

In [39]:
import pandas as pd
import json, re, os, glob
import matplotlib.pyplot as plt
import numpy as np
from utils import extract_and_print_adjusted, plot_figure, is_pair_in_merged_pairs, find_similar_two_pairs, correct_entries
results = []
import glob
def get_json_from_local(LIB, LIB_ALIAS):
    merged_pairs = find_similar_two_pairs(f'../data/standard_process/{LIB}/API_init.json')
    results = []
    all_apis_from_pairs = set(api for pair in merged_pairs for api in pair)
    collect_json = {}
    collect_json_all = {}
    for fname in glob.glob("*/*/*.json"):
        if f'{LIB}/' not in fname: # filter corresponding LIBs
            continue
        else:
            pass
        with open(fname, 'r') as file:
            res = json.load(file)
        # filter compositeAPI and classAPI queries
        res = [i for i in res if 'correct' in i]
        """res_tmp = []
        for item in res:
            if item['pred']:
                # correct it
                item['pred'][0] = item['pred'][0].replace('scanpy_subset.', 'scanpy.').replace('sc.', 'scanpy.').replace('ep.', 'ehrapy.').replace('sq.', 'squidpy.').replace('snap.', 'snapatac2.')
            res_tmp.append(item)
        res = res_tmp"""
        res = correct_entries(res, LIB_ALIAS)
        #for item in res:
        #    if item['pred'][0]:
        #        if item['pred'][0].startswith('scanpy_subset.'):
        #            print(item['pred'][0])
        original_correct = [ex['correct'] for ex in res]
        original_c = [i for i in original_correct if i]
        original_accuracy = sum(original_correct) / len(original_correct) if res else 0
        if 'generate' in fname:
            retrieved_in_gold = []
            retrieved_in_gold_correct = []
            retrieved_in_gold_accuracy = '-'
        else:
            retrieved_in_gold = [ex for ex in res if (ex['gold'] in ex['retrieved_apis'])]
            retrieved_in_gold_correct = [ex['correct'] for ex in retrieved_in_gold]
            retrieved_in_gold_accuracy = sum(retrieved_in_gold_correct) / len(retrieved_in_gold_correct) if retrieved_in_gold_correct else 0
        collect_json_all[fname] = []
        for item in retrieved_in_gold:
            collect_json_all[fname].append({'query': item['query'], 'gold':item['gold'], 'pred':item['pred'][0]})
        if 'generate' in fname:
            retriever_accuracy = '-'
        else:
            retriever_accuracy = len(retrieved_in_gold)/len(res) if res else 0
            assert abs(retriever_accuracy*retrieved_in_gold_accuracy-original_accuracy)<0.01
        #filtered_res = [item for item in res if (item['pred'][0] is not None) and (not is_pair_in_merged_pairs(item['gold'], item['pred'][0][0], merged_pairs))]
        filtered_res = []
        for item in res:
            if item['pred'] is None:
                assert not item['correct']
                filtered_res.append(item)
            elif not is_pair_in_merged_pairs(item['gold'], item['pred'][0], merged_pairs):
                filtered_res.append(item)
        #
        incorrect_filtered_res = [item for item in filtered_res if not item['correct']]
        #print(len(retrieved_in_gold), len(incorrect_filtered_res))
        collect_json[fname] = []
        for item in incorrect_filtered_res:
            collect_json[fname].append({'query': item['query'], 'gold':item['gold'], 'pred':item['pred']})
        #
        filtered_correct = [ex['correct'] for ex in filtered_res]
        filtered_c = [i for i in filtered_correct if i]
        filtered_accuracy = sum(filtered_correct) / len(filtered_res) if filtered_res else 0
        parent_dir = os.path.dirname(fname)
        match = re.search('-topk-(\d+)', os.path.basename(fname))
        top_k = int(match.group(1)) if match else '-'
        pred_count = [len(i['pred']) for i in res if i['pred'] is not None]
        #
        if os.path.basename(fname).replace('.json', '').startswith('gpt-4'):
            model_name = "gpt-4"
        else:
            model_name = "gpt-3.5"
        if os.path.basename(fname).replace('.json', '').endswith('trainsample'):
            test_val = 'synthetic_val'
        elif os.path.basename(fname).replace('.json', '').endswith('trainsub2'):
            #test_val = 'synthetic_train_sub2'
            #continue
            pass
        elif os.path.basename(fname).replace('.json', '').endswith('new_val'):
            test_val = 'new_val'
        else:
            test_val = 'human annotate'
        if 'nonretrieved' in os.path.basename(fname):
            retrieval_status = "nonretrieved"
        else:
            retrieval_status = "retrieved"
        results.append(dict(
            #task=parent_dir,
            lib = parent_dir.split('/')[0],
            model_name=model_name,
            accuracy=original_accuracy,
            total=len(res),
            #retrieved_in_gold_accuracy = retrieved_in_gold_accuracy,
            retrieved_in_gold_total = len(retrieved_in_gold),
            filtered_accuracy=filtered_accuracy,
            filtered_c = len(filtered_c),
            filtered_total = len(filtered_res),
            #retriever_accuracy=retriever_accuracy,
            top_k=top_k,
            pred_mean = sum(pred_count)/len(pred_count),
            test_val=test_val,
            retrieval_status=retrieval_status
        ))
    results = pd.DataFrame(results)
    results = results.sort_values(by=['lib', 'model_name', 'top_k','test_val', 'retrieval_status'])
    result_whole = results.copy()
    #results
    df = results.copy()
    df['has_dash'] = df['top_k'] == "-"
    model_name_order_original = ["gpt-3.5", "gpt-4"]
    df['model_name'] = pd.Categorical(df['model_name'], categories=model_name_order_original, ordered=True)
    df_sorted_again = df.sort_values(by=['has_dash', 'model_name', 'retrieval_status', 'test_val'], ascending=[False, True,True, False])
    df_sorted_again.drop('has_dash', axis=1, inplace=True)
    return df_sorted_again


In [None]:
LIB = "scanpy"
LIB_ALIAS = "scanpy"
df = get_json_from_local(LIB, LIB_ALIAS)
cluster_data = extract_and_print_adjusted(df)
plot_figure(cluster_data, LIB)

In [None]:
LIB = "squidpy"
LIB_ALIAS = "squidpy"
df = get_json_from_local(LIB, LIB_ALIAS)
cluster_data = extract_and_print_adjusted(df)
plot_figure(cluster_data, LIB)

In [None]:
LIB = "ehrapy"
LIB_ALIAS = "ehrapy"
df = get_json_from_local(LIB, LIB_ALIAS)
cluster_data = extract_and_print_adjusted(df)
plot_figure(cluster_data, LIB)

In [None]:
LIB = "snapatac2"
LIB_ALIAS = "snapatac2"
df = get_json_from_local(LIB, LIB_ALIAS)
cluster_data = extract_and_print_adjusted(df)
plot_figure(cluster_data, LIB)

In [None]:
LIB = "scanpy_subset"
LIB_ALIAS = "scanpy"
df = get_json_from_local(LIB, LIB_ALIAS)
cluster_data = extract_and_print_adjusted(df)
plot_figure(cluster_data, LIB)