In [None]:
# install required packages
!pip install dashvector dashscope
!pip install transformers_stream_generator python-dotenv

In [None]:
# prepare news corpus as knowledge source
!git clone https://ghproxy.com/https://github.com/shijiebei2009/CEC-Corpus.git

In [2]:
import dashscope
import os
from dashscope import TextEmbedding
from dashvector import Client, Doc

# [Note: get your DashScope API key here first: https://dashscope.console.aliyun.com/apiKey]
dashscope.api_key='YOUR-DASHSCOPE-API-KEY'

# 初始化 DashVector client
# [Note: get your DashVector API key here first: https://dashvector.console.aliyun.com/cn-hangzhou/api-key]
dashvector_client = Client(api_key='YOUR-DASHSCOPE-API-KEY')

# define collection name
collection_name = 'news_embeddings'

# delete if already exist
dashvector_client.delete(collection_name)

# create a collection with embedding size of 1536
rsp = dashvector_client.create(collection_name, 1536)
collection = dashvector_client.get(collection_name)


In [3]:
def prepare_data_from_dir(path, size):
    # prepare the data from a file folder in order to upsert to DashVector with a reasonable doc's size.
    batch_docs = []
    for file in os.listdir(path):
        with open(path + '/' + file, 'r', encoding='utf-8') as f:
            batch_docs.append(f.read())
            if len(batch_docs) == size:
                yield batch_docs[:]
                batch_docs.clear()

    if batch_docs:
        yield batch_docs

In [4]:
def prepare_data_from_file(path, size):
    # prepare the data from file in order to upsert to DashVector with a reasonable doc's size.
    batch_docs = []
    chunk_size = 12
    with open(path, 'r', encoding='utf-8') as f:
        doc = ''
        count = 0
        for line in f:
            if count < chunk_size and line.strip() != '':
                doc += line
                count += 1
            if count == chunk_size:
                batch_docs.append(doc)
                if len(batch_docs) == size:
                    yield batch_docs[:]
                    batch_docs.clear()
                doc = ''
                count = 0

    if batch_docs:
        yield batch_docs

In [5]:
def generate_embeddings(docs):
    # create embeddings via DashScope's TextEmbedding model API
    rsp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,
                             input=docs)
    embeddings = [record['embedding'] for record in rsp.output['embeddings']]
    return embeddings if isinstance(docs, list) else embeddings[0]

In [6]:
# create embeddings and insert them into DashVector.
# Note: this may take a while (up to 5 mins) to run.

id = 0
dir_name = 'CEC-Corpus/raw corpus/allSourceText'

# indexing the raw docs with index to DashVector
collection = dashvector_client.get(collection_name)

batch_size = 20

for news in list(prepare_data_from_dir(dir_name, batch_size)):
    ids = [id + i for i, _ in enumerate(news)]
    id += len(news)
    # generate embedding from raw docs
    vectors = generate_embeddings(news)
    # upsert and index
    ret = collection.upsert(
        [
            Doc(id=str(id), vector=vector, fields={"raw": doc})
            for id, doc, vector in zip(ids, news, vectors)
        ]
    )
    print(ret)


{"code": 0, "message": "Success", "requests_id": "caec6c1e-ce9f-4a70-b080-4c9d455aecc3"}
{"code": 0, "message": "Success", "requests_id": "94282933-35df-4e4a-9b31-f4e1e52909d3"}
{"code": 0, "message": "Success", "requests_id": "8c44b3b7-c742-4232-880f-bf708ef44ff4"}
{"code": 0, "message": "Success", "requests_id": "861912f1-6a64-48aa-9a16-9eb2bd383b78"}
{"code": 0, "message": "Success", "requests_id": "a8cab2c9-c8a0-4810-9205-7e5718e6500b"}
{"code": 0, "message": "Success", "requests_id": "7a6b30e8-2391-4505-bb7a-74f4257aeb55"}
{"code": 0, "message": "Success", "requests_id": "230f7af3-503c-4911-a9d9-7f889130eaa3"}
{"code": 0, "message": "Success", "requests_id": "e94b9b95-36b3-4ed3-8b94-a4582fc4543d"}
{"code": 0, "message": "Success", "requests_id": "22774357-9e21-4b36-8928-b9f3fb83fc50"}
{"code": 0, "message": "Success", "requests_id": "3c8d2b3e-6877-44db-8207-faffafabbe2b"}
{"code": 0, "message": "Success", "requests_id": "7718ffff-8d66-4b8a-9dd5-cde9bf758b38"}
{"code": 0, "message"

In [7]:
# check the collection status
collection = dashvector_client.get(collection_name)
rsp = collection.stats()
print(rsp)

{"code": 0, "message": "Success", "requests_id": "cbb508ac-0d40-464a-b6fe-1351e99f0637", "output": {"total_doc_count": 332, "index_completeness": 1.0, "partitions": {"default": {"total_doc_count": 332}}}}


In [8]:
def search_relevant_context(question, topk=1, client=dashvector_client):
    # query and recall the relevant information
    collection = client.get(collection_name)

    # recall the top k similarity results from DashVector
    rsp = collection.query(generate_embeddings(question), output_fields=['raw'],
                           topk=topk)
    return "".join([item.fields['raw'] for item in rsp.output])

In [9]:
import dashscope
import textwrap
from dashscope import Generation

# define a prompt template for the vectorDB-enhanced LLM generation
def answer_question(model_name, question, context):
    text = f'''请基于```内的内容回答问题。"
	```
	{context}
	```
	我的问题是：{question}。
    '''
    #ziya prompt
    #prompt=f'<human>:{text}\n<bot>:'
    
    #qwen prompt
    #ChatGLM prompt
    #baichuan prompt
    prompt = text
    
    response = Generation.call(
      model= model_name,
      prompt=prompt,
      #for ChatGLM,history is required
      #history=[]
    )
    #print(prompt)
    return response.output['text']

In [10]:
# test Q&A on plain LLM without vectorDB enhancement
model_name = 'qwen-7b-chat-v1'
question = '海南安定追尾事故，发生在哪里？原因是什么？人员伤亡情况如何？'
answer = answer_question(model_name, question, '')
print(f'问题: {question}\n' f'回答: {textwrap.fill(answer, width=50)}')

问题: 海南安定追尾事故，发生在哪里？原因是什么？人员伤亡情况如何？
回答: 很抱歉，我无法提供关于该事故的最新信息。请您查阅可靠的新闻来源以获取最新信息。


In [11]:
# test Q&A with knowledge enhancement through DashVector
context = search_relevant_context(question, topk=2)
answer = answer_question(model_name, question, context)
#for qwen,ziya,baichuan answer
print(f'问题: {question}\n' f'回答: {textwrap.fill(answer, width=50)}')
#for ChatGLM answer
#print(f'问题: {question}\n' f"""回答：{answer['response']}""")

问题: 海南安定追尾事故，发生在哪里？原因是什么？人员伤亡情况如何？
回答: 海南安定追尾事故发生在海南省定安县境内，环岛东线高速公路海口往三亚方向53公里处。原因是琼AB711
9小轿车驾驶人追尾所致。该事故造成小轿车人员5人当场死亡，其中一人为未成年人。
