In [None]:
# # install dependencies
# !pip install datasets dashscope openai requests retrying numpy func_timeout bert_score transformers

In [None]:
from datasets import load_dataset
import json

# # load from local files in the repository
# qas = load_dataset("../JJQA","qa")["train"]
# songs = load_dataset("../JJQA","song")["train"]
# song_index=json.loads(load_dataset("hobeter/JJQA","song_index")["train"]["dic"][0])[0]

# load from huggingface
qas = load_dataset("hobeter/JJQA","qa")["train"]
songs = load_dataset("hobeter/JJQA","song")["train"]
song_index=json.loads(load_dataset("hobeter/JJQA","song_index")["train"]["dic"][0])[0]

In [None]:
# model&mode settings
mode="with_rf" #options: without_info with_whole_song with_rf; work when assistants_flag=False
model="gpt-3.5-turbo" #option: ernie-turbo chatglm2_6b_32k qwen-turbo baichuan2-7b-chat-v1 gpt-4 gpt-3.5-turbo gpt-4-1106-preview

assistant_flag=False # whther to apply Assistants API for retrieval; only available in openai platform; if True, mode config would be ignored
retrieval_file_path="../dataset/hf_song.json"# file for retrieval when applying Assistants API


# complete your keys here!!!
qianfan_api_key="your_qianfan_api_key"
qianfan_secret_key="your_qianfan_secret_key"

dashscope_key="your_dashscope_key"

openai_key="your_openai_key"

#seconds for waiting when calling openai's apis
wait_time=5

# proxy settings
proxy_flag=False 
proxy_url="http://127.0.0.1:7890" 

In [None]:
#manual prompts

system_prompt="你是一个基于林俊杰歌词的问答系统。"

prompt="不要考虑人称的变化，仅仅用简洁的短语回答以下相关问题，不要说明问题！不要直接重复整句歌词！\n"
info_prompt="请根据以上的歌词信息，"

rf_prompt="已知一些相关歌词：\n"
song_prompt="已知一首歌曲的歌词：\n"

assistant_prompt="你是基于歌词的问答系统，请根据知识库中的相关信息回答问题。仅仅用简洁的短语回答，不要结合问题以完整的句子回答！"

In [None]:
import dashscope
import openai
import requests
import json
from http import HTTPStatus
import retrying
from func_timeout import func_set_timeout
import httpx
import time

# for LLM calling
class LLM():
    def __init__(self) -> None:
        self.model=model
        self.platfrom=None
        if(model in ["ernie-turbo","chatglm2_6b_32k"]):
            self.platform="qianfan"
            self.qianfan_api_key=qianfan_api_key
            self.qianfan_secret_key=qianfan_secret_key
            if(model=="ernie-turbo"):
                self.url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + self.get_qianfan_access_token()
            elif(model=="chatglm2_6b_32k"):
                self.url="https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/chatglm2_6b_32k?access_token=" + self.get_qianfan_access_token()
            else:
                assert 1==0,f"model {model} not supported"
        elif(model in ["qwen-turbo","baichuan2-7b-chat-v1"]):
            self.platform="dashscope"
            self.dashScope_key=dashscope_key
            dashscope.api_key=self.dashScope_key
        elif(model in ["gpt-4","gpt-3.5-turbo","gpt-4-1106-preview"]):
            self.platform="openai"
            self.openai_key=openai_key

            if(proxy_flag==True):
                self.client = openai.OpenAI(
                    api_key=self.openai_key,
                    http_client=httpx.Client(
                        proxies=proxy_url,
                    ),
                )
            else:
                self.client = openai.OpenAI(
                    api_key=self.openai_key,
                )

            if(assistant_flag==True):
                self.file = self.client.files.create(
                    file=open(retrieval_file_path, "rb"),
                    purpose='assistants'
                )
                self.assistant = self.client.beta.assistants.create(
                    name="JJQA-bot",
                    instructions=assistant_prompt,
                    tools=[{"type": "retrieval"}],
                    model=self.model,
                    file_ids=[self.file.id],
                )
                self.assistant_logs=""

        else:
            assert 1==0,f"model {model} not supported"
    
    def get_qianfan_access_token(self):
        url = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="+self.qianfan_api_key+"&client_secret="+self.qianfan_secret_key
        
        payload = json.dumps("")
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }
        
        response = requests.request("POST", url, headers=headers, data=payload)
        return response.json().get("access_token")
    
    def assistant_postprocess(self,message):
        assert len(message.content)==1,f"message content length error"
        message_content=message.content[0].text
        annotations = message_content.annotations
        citations = []

        tmp=message_content.value
        if(tmp[-1]=="."):
            tmp=tmp[:-1]+"。"

        res_s=tmp
        log_s=tmp
        
        # Iterate over the annotations and add footnotes
        for index, annotation in enumerate(annotations):
            # Replace the text with a footnote
            res_s=res_s.replace(annotation.text,'')
            log_s=log_s.replace(annotation.text, f' [{index}]')

            # Gather citations based on annotation attributes
            if (file_citation := getattr(annotation, 'file_citation', None)):
                cited_file = self.client.files.retrieve(file_citation.file_id)
                citations.append(f'[{index}] {file_citation.quote} from {cited_file.filename}')
            elif (file_path := getattr(annotation, 'file_path', None)):
                cited_file = self.client.files.retrieve(file_path.file_id)
                citations.append(f'[{index}] Click <here> to download {cited_file.filename}')
                # Note: File download functionality not implemented above for brevity

        # Add footnotes to the end of the message before displaying to user
        log_s += '\n' + '\n'.join(citations)

        res_s=res_s.strip()
        log_s=log_s.strip()
        if(len(citations)!=0):
            assert res_s!=log_s,f"assistant postprocess error"
        return res_s,log_s,len(annotations)
    
    @retrying.retry(stop_max_attempt_number=5,wait_fixed=1*1000,)# retry when a error occurs
    def call(self,input_t):
        if(self.platform=="qianfan"):

            payload=None
            if(self.model in ["ernie-turbo"]):
                payload = json.dumps({
                    "messages": [
                        {
                            "role": "user",
                            "content": input_t
                        }
                    ],
                    # "system": system_prompt, #availavle for ernie
                })
            else:
                payload = json.dumps({
                    "messages": [
                        {
                            "role": "user",
                            "content": input_t
                        }
                    ],
                    # "system": system_prompt, #availavle for ernie
                })

            headers = {
                'Content-Type': 'application/json'
            }
            response = requests.request("POST", self.url, headers=headers, data=payload)
            response_dic=json.loads(response.text)

            if("error_code" in list(response_dic.keys())):
                assert 1==0,f"error_code:{response_dic['error_code']},error_msg:{response_dic['error_msg']}"

            return response_dic["result"].strip()
        
        elif(self.platform=="dashscope"):

            messages = [
                # {'role': 'system', 'content': system_prompt},
                {'role': 'user', 'content': input_t},
            ]
            response = dashscope.Generation.call(
                model=self.model,
                messages=messages,
                result_format='message',  # set the result to be "message" format.
            )
            if response.status_code != HTTPStatus.OK:
                if("inappropriate" in response.message):  # calling fails due to the safety system of dashscope
                    return "inappropriate error"
                else:
                    assert 1==0,'Request id: %s, Status code: %s, error code: %s, error message: %s' % (response.request_id, response.status_code,response.code, response.message)

            return response.output.choices[0]['message']['content'].strip()
        
        elif(self.platform=="openai"):

            @func_set_timeout(wait_time)
            def openai_chat(model,input_t):
                return self.client.chat.completions.create(
                    model=model,
                    messages=[
                        # {"role": "system", "content": system_prompt},
                        {"role": "user", "content": input_t}
                    ],
                )
            
            @func_set_timeout(wait_time*12*5)
            def openai_assistant_check(thread_id,run_id):
                while(1):
                    run = self.client.beta.threads.runs.retrieve(
                        thread_id=thread_id,
                        run_id=run_id,
                    )
                    if(run.status=="completed"):
                        return
                    else:
                        # print(run.status)
                        time.sleep(1.0)
            
            if(assistant_flag==False):
                try:
                    response = openai_chat(self.model,input_t)
                    return response.choices[0].message.content.strip()
                except:
                    print("timeout")
                    assert 1==0,f"timeout"
            
            else:
                thread = self.client.beta.threads.create()
                message = self.client.beta.threads.messages.create(
                    thread_id=thread.id,
                    role="user",
                    content=input_t,
                )

                run = self.client.beta.threads.runs.create(
                    thread_id=thread.id,
                    assistant_id=self.assistant.id,
                    instructions=prompt,
                )

                try:
                    openai_assistant_check(thread.id,run.id)
                    messages = self.client.beta.threads.messages.list(thread_id=thread.id)
                    self.client.beta.threads.delete(thread.id)
                    assert len(messages.data)==2,f"message.data length error"
                    message=messages.data[0]
                    res_s,log_s,annos_l=self.assistant_postprocess(message)
                    tmp_s="q:\n"+input_t+"\nannos:"+str(annos_l)+"\nlogs:\n"+log_s+"\na:\n"+res_s+"\n\n"
                    # print(tmp_s)
                    self.assistant_logs=self.assistant_logs+tmp_s

                    with open(f"as_{model}_{retrieval_file_path.split('/')[-1].split('.')[0].strip()}_logs.txt","w") as f:
                        f.write(self.assistant_logs)

                    return res_s
                except:
                    print("error")
                    self.client.beta.threads.delete(thread.id)
                    assert 1==0,f"error"

    def free(self):
        if(assistant_flag==True):
            file_deletion_status = self.client.beta.assistants.files.delete(
                assistant_id=self.assistant.id,
                file_id=self.file.id
            )
            self.client.beta.assistants.delete(self.assistant.id)

In [None]:
import tqdm

baseline_res={}

llm=LLM()

td=tqdm.tqdm(enumerate(qas),total=len(qas))

#run the baseline
for ind,qa in td:
    q=qa["q"]
    a=qa["a"]+"。"
    rf=qa["rf"].strip().split(" ")
    rf=[int(_) for _ in rf]
    song=songs[song_index[qa["song_id"]]]
    lyric=song["lyric"]
    if(assistant_flag==True):
        input_t=q+"？"
        res=llm.call(input_t)
    elif (mode=="without_info"):
        input_t=prompt+q+"？"
        res=llm.call(input_t)
    elif(mode=="with_rf"):
        input_t=rf_prompt
        lyrics=lyric.strip().split("\n")
        for line in rf:
            input_t+=lyrics[line]+"\n"
        input_t=input_t+"\n\n"+info_prompt+prompt+q+"？"
        res=llm.call(input_t)
    elif(mode=="with_whole_song"):
        input_t=song_prompt+"\n"+lyric+"\n\n"+info_prompt+prompt+q+"？"
        res=llm.call(input_t)
    else:
        assert 1==0,f"mode {mode} not supported"
        
    baseline_res[qa["id"]]={
        "q":q+"？",
        "pred":res.strip(),
        "label":a
    }

llm.free()

In [None]:
import os

file_name=""
if(assistant_flag==False):
    file_name=f"{model}_{mode}_dic.json"
else:
    file_name=f"as_{model}_{retrieval_file_path.split('/')[-1].split('.')[0].strip()}_dic.json"

while(os.path.exists(file_name)):
    file_name="new_"+file_name
    
#save results
with open(file_name,"w",encoding='utf-8') as f:
    json.dump(baseline_res,f,indent=2,ensure_ascii=False)

print(file_name)