In [1]:
import os
import json
import requests
import datetime
import pytz
import torch
from utils import DataTool
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

os.environ['CUDA_VISIBLE_DEVICES'] = '0' 
os.environ['http_proxy'] = 'http://proxygate2.ctripcorp.com:8080'
os.environ['https_proxy'] = 'http://proxygate2.ctripcorp.com:8080'


weekdays = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
poi_keywors_list = ["价格","门票","地址","位置","优惠","政策","票价","开放时间"]

def get_retrieval(query, url="http://general.retrieval.ctripcorp.com/general_retrieval", districts="阳朔", req_id=1, topk=5, source="lvp"):
    flag = False
    for keyword in poi_keywors_list:
        if keyword in query:
            flag = True
            break
    if flag:
        topk = 2
        source="poipolicy"
    
    data={
        "reqId": req_id,
        "query": query,
        # "districts": "上海",
        # "pioids":["9262827"],
        "source": source,
        # "source": "lvp",
        # "publish_time_min": "2020-01-01",
        "return_fields": ["title", "content", "publish_time"],
        "topK": 10
    }
    
    eastern = pytz.timezone('Asia/Shanghai')
    current_time = datetime.datetime.now(eastern)
    weekday_num = current_time.weekday()
    weekday = weekdays[weekday_num]
    
    resp = requests.post(url, json=data)
    response = json.loads(resp.text)['results']
    return f'API查询到的关于{query}的可能相关的信息如下：' + '\n'.join([str(item['source']) for item in response[:topk]]) + \
           f'\n当前的时间是：{current_time.strftime("%Y-%m-%d %H:%M")} ({weekday})'


model_path = '/data/share_user/zzd/ckpt/rlhf_baichuan2/0126-f3/sft_hf'
def init_model():
    print("init actor model ...")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map='auto'
    ).eval()
    model.generation_config = GenerationConfig.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
    tokenizer.padding_side = 'left'

    return model, tokenizer

model, tokenizer = init_model()
datatool = DataTool(tokenizer, model.generation_config.max_new_tokens)

init actor model ...


Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


In [2]:
query = '拙政园门票多少钱呀'
requery = '拙政园门票价格'

text = get_retrieval(requery)
print(text)
print('\n\n')

prompt = f'你是一个优秀的旅行助手，可以根据用户问题给出合理的建议。用户的问题是：{query}，请结合以下查询到的参考信息进行回答:\n'
messages = [{'role':'user','content': prompt+text}]
input_ids = datatool.build_chat_input(model, tokenizer, messages, model.generation_config.max_new_tokens)
response = datatool.chat(model, input_ids, stream=False, device=model.device)
print(response)

API查询到的关于拙政园门票价格的可能相关的信息如下：{'publish_time': '', 'title': '拙政园', 'content': '拙政园 地址: 苏州市姑苏区东北街178号\n拙政园：\n景点门票，门票\n成人票，年龄19周岁（含）~59周岁（含）（按生日计算年龄，以出行日期为准），价格：70起\n成人票，年龄18周岁（含）~59周岁（含）（按生日计算年龄，以出行日期为准），价格：70起\n优待票，6周岁（不含6周岁）~18周岁（含18周岁）未成年人、全日制大学本科及以下学历在校学生凭有效身份证件实行半票，价格：35起\n优待票，儿童：6周岁（不含）-18周岁（不含）；老人：60周岁（含）-70周岁（不含），凭有效证件；学生：本科及以下全日制在校学生（不含函授、成人教育及短期培训生、交流生和研究生），须凭本人有效学生证和有效证件，价格：35起\n老人票，年龄60周岁（含）~69周岁（含）（按生日计算年龄，以出行日期为准），价格：35起\n#!\n免费人群预约票\n优待票，身高1.4米（含）以下或年龄6周岁（含）以下的儿童、年龄70周岁（含）以上的老人、现役军人、退役军人、残疾人等优抚对象、苏州通•转转卡用户（预约成功后，凭短信二维码及本人有效证件验证入园），优待票， 儿童：身高在1.4米（含）以下或年龄在6周岁（含）以下；老人：年龄在70周岁（含）以上，凭有效证件，免票人群适用范围见预约须知中的适用条件优待票(免费人群)适用人群：符合国家、江苏省、苏州市园林景区门票免费游览优惠政策的游客（如：70周岁（含）以上的老年人、6周岁以下或1.4米以下儿童、现役军人、残疾人等等优抚对象）和苏州园林年卡用户苏州通•转转卡用户入园方式：在线预约成功，持有效证件现场核验通过后，凭二维码或身份证验证入园优待票(外宾免费人群) 适用人群：港澳游客、外国宾客：身高在1.4米（含）以下或6周岁（含）以下的儿童、70周岁（含）以上的老年人、苏州园林年卡用户、苏州通•转转卡用户入园方式：在线预约成功，持有效证件现场核验通过后，凭二维码或身份证验证入园（符合免票人群的预约，适用人群详见预约须知中的适用条件）\n'}
{'publish_time': '', 'title': '拙政问雅', 'content': '拙政问雅 地址: 江苏省苏州市姑苏区东北街178号拙政园内

RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

In [None]:
def get_retrieval(query, url="http://general.retrieval.ctripcorp.com/general_retrieval", districts="阳朔", req_id=1, topk=5, source="lvp"):
    flag = False
    for keyword in poi_keywors_list:
        if keyword in query:
            flag = True
            break
    if flag:
        topk = 2
        source="poipolicy"
    
    data={
        "reqId": req_id,
        "query": query,
        # "districts": "上海",
        # "pioids":["9262827"],
        "source": source,
        # "source": "lvp",
        # "publish_time_min": "2020-01-01",
        "return_fields": ["title", "content", "publish_time"],
        "topK": 10
    }
    
    eastern = pytz.timezone('Asia/Shanghai')
    current_time = datetime.datetime.now(eastern)
    weekday_num = current_time.weekday()
    weekday = weekdays[weekday_num]
    
    resp = requests.post(url, json=data)
    response = json.loads(resp.text)['results']
    return f'API查询到的关于{query}的可能相关的信息如下：' + '\n'.join([str(item['source']) for item in response[:topk]]) + \
           f'\n当前的时间是：{current_time.strftime("%Y-%m-%d %H:%M")} ({weekday})'