In [None]:
import json
import os

import dotenv

from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI

from tqdm import tqdm

dotenv.load_dotenv()

In [None]:
with open('data/Final_TestSet/Final_TestSet.json', 'r', encoding='utf-8') as f:
    dataset_init=json.load(f)
with open('data/Final_Example.json', 'r', encoding='utf-8') as f:
    preliminary_example=json.load(f)

for i in range(0, len(dataset_init)):
    # 检查数据集文件是否一致
    assert dataset_init[i]["ID"] == preliminary_example[i]["ID"] 
    assert dataset_init[i]["question"] == preliminary_example[i]["question"]
    
print("样本数量：",len(dataset_init))
# print("问题类型：",",".join(set([item["problem_type"] for item in dataset_init])))


FROM=0
TO=FROM+512
dataset=dataset_init[FROM:TO]

In [None]:
gpt4o=ChatOpenAI(
    api_key=os.getenv("API_KEY"),
    base_url=os.getenv("BASE_URL"),
    model="gpt-4o",
)

gpt4o.invoke("hello")
set_llm_cache(SQLiteCache(database_path=".langchain.db")) # 

## 预处理
### 1. 翻译

In [None]:
from tool.langchain_tool import translate_prompt

# 翻译所有问题，已经缓存，所以全量翻译
translation_runnable= translate_prompt | gpt4o | StrOutputParser()
translation_list = translation_runnable.batch([{"text":item["question"]} for item in dataset], config={"max_concurrency":1}, return_exceptions=True)
for i in range(0,len(dataset)):
    dataset[i]["translation"]=translation_list[i]

### 2. 抽取题目中的函数和类

In [None]:
from tool.langchain_tool import extract_runnable

# 从翻译中提取出函数和库，重复多次，保证成功
tmp_extract_dict={}
for i in tqdm(range(5)):
    tmp_extract_dict[i]=tmp_extract_list=extract_runnable.batch([{"text":str(i)+"\n"+item["translation"]} for item in dataset[:]], config={"max_concurrency":5}, return_exceptions=True)

for i in range(1,len(tmp_extract_dict)):
    for j in range(len(tmp_extract_dict[i])):
        tmp_extract_dict[0][j].extend(tmp_extract_dict[i][j])

def remove_duplicates(lst):
    seen = {}
    result = []
    for d in lst:
        # 将字典转换为字符串，这样就可以用作字典的键
        dict_str = str(d) #d["function_name"]
        if dict_str not in seen:
            seen[dict_str] = True
            result.append(d)
    return result
extract_list=[]
for item in tmp_extract_dict[0]:
    extract_list.append(remove_duplicates(item))

for i in range(len(extract_list)):
    # print(i+1,extract_list[i])
    dataset[i]["func_extract"]=extract_list[i]

# extract_list=extract_runnable.batch([{"text":item["translation"]} for item in dataset[:]], config={"max_concurrency":5}, return_exceptions=True)
# for i in range(len(extract_list)):
#     # print(i+1,extract_list[i])
#     dataset[i]["func_extract"]=extract_list[i]

### 3. 根据搜索到的类型和搜索到的文档，搜索出对应的函数文档

In [None]:
from tool.rag_tool import search_documents_by_help_function

for i,key_work in tqdm(enumerate(extract_list), total=len(extract_list)):
    key_work=key_work if type(key_work) is list else [key_work]
    tmp_set=set()
    for kw in key_work:
        # print(kw)
        doc=search_documents_by_help_function(
            kw["function_name"].split(".")[-1],
            kw["module_name"].lower().strip().split(".")[0]
        )
        tmp_set.add("<api doc>\n" + doc + "\n</api doc>")
    dataset[i]["rag_infos"]=tmp_set

### 4. 根据题目，搜索向量数据库的相关函数/类，然后获取相关函数/类的文档 

In [None]:

from tool.rag_tool import search_documents_in_mutil_keywords

for item in tqdm(dataset[:]):
    question=item["question"]
    tmp_l=[]
    for doc_json,_,_ in search_documents_in_mutil_keywords([], question,10):
        function_name=""
        for key in doc_json:
            if str(key).startswith("Field List > Methods > "):
                function_name=key[22:].strip()
        if function_name!="":
            class_name=doc_json["Section_id"] if "Section_id" in doc_json else doc_json["Section ID"]
        else:
            function_name=doc_json["Section_id"] if "Section_id" in doc_json else doc_json["Section ID"]
            class_name=""
        package_name=doc_json["module"]
        
        help_doc=search_documents_by_help_function(function_name,package_name,contain_key=class_name)
        if len(help_doc)>15000:
            init_len=len(help_doc)
            help_doc=(gpt4o|StrOutputParser()).invoke(f"Below is the documentation generated by the help() function. Extract the main information and reduce the word count to 500 words\n{help_doc}")
            after_len=len(help_doc)
            print(f"{package_name}.{function_name} {init_len} -> {after_len}")
        tmp_l.append(f"function:{function_name}, class:{class_name}, package:{package_name}, doc:'{repr(help_doc)[1:-1] }'")
    item["func_bk"]=tmp_l
    

### 5. 构建把所有信息构建prompt

In [None]:
from tool.autogen_tool import *
for i in range(0, len(dataset)):
    # 顺序 rag的函数文档-抽取的函数文档-题目
    content=""
    
    content += "\n\nThe following functions can be used optionally:\n"+"\n".join(dataset[i]["func_bk"])
    
    content += "\n\nThe following function must be used:\n"+"\n".join(dataset[i]["rag_infos"])
    
    question = d_template[dataset[i]["problem_type"]].format(dataset[i]["question"])
    # filenames = extract_filenames(question)
    # for filename in filenames:
    #     question = question.replace(filename, add_path(filename, data_path / 'Final_TestSet/data'))
    content += "\n\n\n"+question+"\n\n"

    dataset[i]["content"]=content


### 添加目标
对于mutil类，这里给这种类型添加额外的目标说明，保证其输出

In [None]:
from tqdm import tqdm
from tool.langchain_tool import cal_prompt, draw_prompt, tof_prompt

def get_goals(text:str, problem_type:str):
    types=[]
    goals=[]
    if problem_type.startswith("multi"):
        types.extend(problem_type[6:-1].split(", "))
    else:
        types.append(problem_type)

    for t in types:
        if t=="calculations":
            prompt=cal_prompt
        elif t=="True/False":
            prompt=tof_prompt
        elif t=="draw":
            prompt=draw_prompt
        else:
            raise Exception("unknown problem type")
    
        runnable=prompt|gpt4o|StrOutputParser()
        goal=runnable.invoke({"question":text})
        goals.append(goal)
    return goals
        
        
for i in tqdm(range(len(dataset))):
    if dataset[i]["problem_type"].startswith("multi"):
        goals=get_goals(dataset[i]["question"], dataset[i]["problem_type"])
        # print(i+1,goals)
        dataset[i]["goals"]=goals
        dataset[i]["content"]=dataset[i]["content"]+"\n\n"+"\nwe need to answer following question：\n"+"\n".join(goals)
    

In [None]:
id_and_content=[{"ID":i["ID"], "content":i["content"]} for i in dataset]
with open('data/id_and_content.json', 'w', encoding='utf-8') as f:
    json.dump(id_and_content, f)

In [None]:
raise Exception("stop")

----
## 以下为测试代码

## 运行
### 运行agent

In [None]:
from autogen import Cache

from tool.autogen_tool import *

def run(item: dict,cache_seed=1):
    content = item["content"]
    item["content"]=content

    # Use DiskCache as cache
    with Cache.disk(cache_path_root="./autogen_cache",cache_seed=cache_seed) as cache:
        chat_result = code_executor_agent.initiate_chat(
            code_writer_agent,
            message=content,
            summary_method='reflection_with_llm',
            summary_args=dict(summary_prompt='only return the code output, if no output just return "done!"'),
            cache=cache,
            # silent=True,
        )
    # code = extract_python_code(chat_result.chat_history[-3]['content'])[-1]
    code=""
    for i in range(len(chat_result.chat_history)-1, 0, -1):
        l=extract_python_code(chat_result.chat_history[i]['content'])
        if len(l)>0:
            code=l[-1]
            break
    
    answer = chat_result.summary
    if isinstance(answer, dict):
        answer = answer['content']
    item["code"]=code
    item["answer"]=answer
    # item['chat_history']=chat_result.chat_history
    return item


for item in tqdm(dataset[487:488]):
    run(item)


## 存储

In [None]:
for i in dataset:
    if type(i['rag_infos']) is set:
        i['rag_infos']=list(i['rag_infos'])

In [None]:
with open('data/SMP_240917_check_1.json', 'w', encoding='utf-8') as f:
    s = json.dumps(dataset, indent=4, ensure_ascii=False)
    f.write(s)