In [None]:
import json
import os

import dotenv

from langchain_community.cache import SQLiteCache
from langchain_core.globals import set_llm_cache
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI

from tqdm import tqdm

dotenv.load_dotenv()

In [None]:
with open('data/Final_TestSet/Final_TestSet.json', 'r', encoding='utf-8') as f:
    dataset_init=json.load(f)
with open('data/Final_Example.json', 'r', encoding='utf-8') as f:
    preliminary_example=json.load(f)

for i in range(0, len(dataset_init)):
    # 检查数据集文件是否一致
    assert dataset_init[i]["ID"] == preliminary_example[i]["ID"] 
    assert dataset_init[i]["question"] == preliminary_example[i]["question"]
    
print("样本数量：",len(dataset_init))
# print("问题类型：",",".join(set([item["problem_type"] for item in dataset_init])))


FROM=0
TO=FROM+100
dataset=dataset_init[FROM:TO]

In [None]:
gpt4o=ChatOpenAI(
    api_key=os.getenv("WLAI_API_KEY"),
    base_url=os.getenv("WLAI_BASE_URL"),
    model="gpt-4o",
)

gpt4o.invoke("hello")
set_llm_cache(SQLiteCache(database_path=".langchain.db")) # 

## 预处理


### 修改文件位置
1. 构建prompt
2. 修改题目中文件名位置

In [None]:
from gpt4o import *
for i in range(0, len(dataset)):
    content = d_template[dataset[i]["problem_type"]].format(dataset[i]["question"])
    filenames = extract_filenames(content)
    for filename in filenames:
        content = content.replace(filename, add_path(filename, data_path / 'Final_TestSet/data'))
    dataset[i]["content"]=content

### rag添加文档信息
1.  翻译
2. 从翻译提取函数和库
3. 查询

如何衡量：暂时不做




In [None]:
from tool.model import translate_prompt

# 翻译所有问题，已经缓存，所以全量翻译
translation_runnable= translate_prompt | gpt4o | StrOutputParser()
translation_list = translation_runnable.batch([{"text":item["question"]} for item in dataset], config={"max_concurrency":1}, return_exceptions=True)
for i in range(0,len(dataset)):
    dataset[i]["translation"]=translation_list[i]

In [None]:
from tool.model import extract_runnable

# 从翻译中提取出函数和库
extract_list=extract_runnable.batch([{"text":item["translation"]} for item in dataset[:]], config={"max_concurrency":5}, return_exceptions=True)
for i in range(len(extract_list)):
    # print(i+1,extract_list[i])
    dataset[i]["func_extract"]=extract_list[i]

In [None]:
from tool.rag_tool import search_documents_by_help_function

for i,key_work in tqdm(enumerate(extract_list), total=len(extract_list)):
    key_work=key_work if type(key_work) is list else [key_work]
    tmp_list=[]
    for kw in key_work:
        doc=search_documents_by_help_function(kw["function_name"],kw["module_name"])
        tmp_list.append("<api doc>\n"+doc+"\n</api doc>")
    dataset[i]["rag_infos"]=tmp_list

for i in range(len(dataset)):
    dataset[i]["content"]=dataset[i]["content"]+"\n\n"+"\n".join(dataset[i]["rag_infos"])

### 添加目标

In [None]:
from tqdm import tqdm
from tool.model import cal_prompt, draw_prompt, tof_prompt

def get_goals(text:str, problem_type:str):
    types=[]
    goals=[]
    if problem_type.startswith("multi"):
        types.extend(problem_type[6:-1].split(", "))
    else:
        types.append(problem_type)

    for t in types:
        if t=="calculations":
            prompt=cal_prompt
        elif t=="True/False":
            prompt=tof_prompt
        elif t=="draw":
            prompt=draw_prompt
        else:
            raise Exception("unknown problem type")
    
        runnable=prompt|gpt4o|StrOutputParser()
        goal=runnable.invoke({"question":text})
        goals.append(goal)
    return goals
        
        
for i in tqdm(range(len(dataset))):
    if dataset[i]["problem_type"].startswith("multi"):
        goals=get_goals(dataset[i]["question"], dataset[i]["problem_type"])
        # print(i+1,goals)
        dataset[i]["goals"]=goals
        dataset[i]["content"]=dataset[i]["content"]+"\n\n"+"\nwe need to answer following question：\n"+"\n".join(goals)
    

## 运行
### 运行agent

In [None]:
from autogen import Cache

def run(item: dict,cache_seed=2):
    content = item["content"]
    item["content"]=content

    # Use DiskCache as cache
    with Cache.disk(cache_path_root="./autogen_cache",cache_seed=cache_seed) as cache:
        chat_result = code_executor_agent.initiate_chat(
            code_writer_agent,
            message=content,
            summary_method='reflection_with_llm',
            summary_args=dict(summary_prompt='only return the code output'),
            cache=cache,
            # silent=True,
        )
    # code = extract_python_code(chat_result.chat_history[-3]['content'])[-1]
    code=""
    for i in range(len(chat_result.chat_history)-1, 0, -1):
        l=extract_python_code(chat_result.chat_history[i]['content'])
        if len(l)>0:
            code=l[-1]
            break
    
    answer = chat_result.summary
    if isinstance(answer, dict):
        answer = answer['content']
    item["code"]=code
    item["answer"]=answer
    # item['chat_history']=chat_result.chat_history
    return item

for item in tqdm(dataset[92:93]):
    run(item)


### voting

In [None]:
# for item in [dataset[i-1] for i in [3,6,9,10,12,14,19,21,23,24,28,30,37,40]]:
#     temp_answer=[]
#     for seed in tqdm(range(1,17)):
#         item=run(item,seed)
#         code,answer=item["code"],item["answer"]
#         temp_answer.append(answer)
#     prompt=f"""从下面的不同人表达中，直接返回大部分人想表达的内容，不附带其他信息：\n"""+"\n".join(temp_answer)
#     print(item["ID"],(gpt4o|StrOutputParser()).invoke(prompt))
        
        

## 存储

In [None]:
with open('data/SMP_240913_check_1.json', 'w', encoding='utf-8') as f:
    s = json.dumps(dataset, indent=4, ensure_ascii=False)
    f.write(s)

----

In [None]:
raise Exception("stop")

In [None]:
with open('data/SMP_240905_check_1.json', 'r', encoding='utf-8') as f:
    tmp_dataset=json.load(f)

In [None]:
for i in range(len(dataset)):
    print(dataset[i]["problem_type"])

In [None]:
tmp_id=50
i=tmp_id-1
print(tmp_dataset[i]["ID"], tmp_dataset[i]["problem_type"],"\n---\n", tmp_dataset[i]["translation"],"\n---\n", tmp_dataset[i]['answer'],"\n---\n",tmp_dataset[i]["code"],"\n---\n",tmp_dataset[i]["question"])

In [None]:
from tool.rag_tool import search_documents

def remove_empty_values(d):
    """
    递归删除字典中的所有空值（包括空字符串、空列表、空字典、None等）
    """
    if not isinstance(d, dict):
        return d
    
    # 使用字典推导式递归遍历字典
    return {k: remove_empty_values(v) for k, v in d.items() if v not in ('', None, [], {}, set(), ())}


for i,key_work in tqdm(enumerate(extract_list), total=len(extract_list)):
    infos=""
    INFO_LIMIT=3000

    for item in key_work:
        if item["function_name"] != "":
            module,function = item['module_name'],item['function_name']
            api_docs=search_documents(function,module,dataset[i]["question"])
            for doc in api_docs[:2]:
                if len(infos)<INFO_LIMIT:
                    if not doc.startswith("no"):
                        doc=json.dumps(remove_empty_values(json.loads(doc)))
                    infos=infos + "\n\n"+doc
        # 没有抽取，尝试用整个问题查询
        else:
            api_docs=search_documents(method_description=dataset[i]["question"])
            for doc in api_docs:
                if len(infos)<INFO_LIMIT:
                    if not doc.startswith("no"):
                        doc=json.dumps(remove_empty_values(json.loads(doc)))
                    infos=infos + "\n\n"+doc
    dataset[i]["rag_infos"]=infos
for i in range(len(dataset)):
    dataset[i]["content"]=dataset[i]["content"]+"\n\n"+dataset[i]["rag_infos"]

In [None]:
for i in range(len(dataset)):
    rag_infos=dataset[i]["rag_infos"]
    print()
    print(i+1,extract_list[i],[round(t[1]+t[2],2)  for t in rag_infos])
        

In [None]:
from tool.rag_tool import search_documents_by_help_function

for i in range(0,len(dataset)):
    for item in extract_list[i]:
        fn=item["function_name"].split(".")[-1]
        mo=str(item["module_name"]).lower().strip().split(".")[0]
        print(i, fn, mo,end=" ")
        doc=search_documents_by_help_function(fn, mo)
        if doc:
            print(len(doc))
        else:
            print(None)
            