Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions examples/transformers/rag/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@

from langchain.embeddings.base import Embeddings

from mindnlp.sentence import SentenceTransformer
from sentence_transformers import SentenceTransformer


class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_model, mirror='huggingface'):
def __init__(self, embed_model):
self.embed_model = embed_model
self.embedding_model = SentenceTransformer(model_name_or_path=self.embed_model, mirror=mirror)
self.embedding_model = SentenceTransformer(model_name_or_path=self.embed_model)

def encode_texts(self, texts: List[str]) -> List[List[float]]:
texts = [t.replace("\n", " ") for t in texts]
embeddings = self.embedding_model.encode(texts)
for i, embedding in enumerate(embeddings):
embeddings[i] = embedding.tolist()
return embeddings

def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = self.embedding_model.encode_texts(texts)
embeddings = self.encode_texts(texts)
return embeddings

def embed_query(self, text: str) -> List[float]:
embeddings = self.embedding_model.encode_texts([text])
embeddings = self.encode_texts([text])
return embeddings[0]
172 changes: 172 additions & 0 deletions examples/transformers/rag/newchat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import argparse
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import CharacterTextSplitter

import mindnlp
from embedding import EmbeddingsFunAdapter
from text import TextLoader
from threading import Thread

import mindspore
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

def load_knowledge_base(file_name):
print(f"正在加载知识库文件: {file_name}")
loader = TextLoader(file_name)
texts = loader.load()
text_splitter = CharacterTextSplitter(separator='\n', chunk_size=256, chunk_overlap=0)
split_docs = text_splitter.split_text(texts)
print(f"文档已切分为 {len(split_docs)} 个片段")

embeddings = EmbeddingsFunAdapter("Qwen/Qwen3-Embedding-0.6B")
faiss = FAISS.from_texts(split_docs, embeddings)
print("FAISS 向量数据库构建完成。")
return faiss


def load_model_and_tokenizer():
print("正在加载模型")
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', use_fast=False, mirror='modelscope', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', ms_dtype=mindspore.bfloat16,mirror='modelscope', device_map=0)


print("模型加载完成。")
return tokenizer, model


def retrieve_knowledge(faiss, query):
docs = faiss.similarity_search(query, k=1)
return docs[0].page_content

def generate_answer(tokenizer, model, query, knowledge=None):
if knowledge:
input_text = knowledge + "\n\n" + query
else:
input_text = query

messages = [
{"role": "user", "content": input_text}
]

# 使用 tokenizer.apply_chat_template 构建输入
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
print(f"⚠️ apply_chat_template 失败,使用手动拼接: {e}")
prompt = f"<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n"

# Tokenize
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=8192).to(model.device)

# 创建 streamer
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True, # 跳过输入部分
skip_special_tokens=True # 不输出特殊 token
)

# 启动生成线程
def generate():
model.generate(
**inputs,
streamer=streamer,
max_new_tokens=512,
temperature=0.001,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)

thread = Thread(target=generate)
thread.start()

# 实时输出生成的文本
print("回答: ", end="", flush=True)
generated_text = ""
for new_text in streamer:
print(new_text, end="", flush=True)
generated_text += new_text
print() # 换行

return generated_text.strip()



def rag_pipeline(faiss, tokenizer, model, query, use_rag=True):
if use_rag:
knowledge = retrieve_knowledge(faiss, query)
answer = generate_answer(tokenizer, model, query, knowledge)
return answer, knowledge
else:
answer = generate_answer(tokenizer, model, query, "")
return answer, ""


def main():
parser = argparse.ArgumentParser(description="RAG Demo - Command Line Version")
parser.add_argument("filename", help="知识库文本文件路径")
args = parser.parse_args()

# 加载知识库和模型
faiss_db = load_knowledge_base(args.filename)
tokenizer, model = load_model_and_tokenizer()

print("\n" + "="*60)
print("RAG系统已准备就绪!")
print("输入 'quit' 或 'exit' 退出程序。")
print("="*60)

while True:
try:
# 获取用户输入
query = input("\n请输入您的问题: ").strip()
if query.lower() in ['quit', 'exit', 'bye']:
print("再见!")
break
if not query:
print("问题不能为空,请重新输入。")
continue

# 是否启用 RAG
use_rag_input = input("是否启用检索增强 (RAG)? [Y/n]: ").strip().lower()
use_rag = use_rag_input not in ['n', 'no', 'N', 'NO']

# RAG 流程
if use_rag:
print("正在检索知识库...")
knowledge = retrieve_knowledge(faiss_db, query)
print(f"检索到的知识:\n{knowledge}")
# print("生成中: ", end="", flush=True)
answer = generate_answer(tokenizer, model, query, knowledge)
else:
print("直接生成回答(无检索)...")
# print("生成中: ", end="", flush=True)
answer = generate_answer(tokenizer, model, query)

except KeyboardInterrupt:
print("\n\n程序被用户中断,再见!")
break


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/transformers/rag/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#### Install dependencies

```
pip install mindnlp langchain langchain-community faiss-cpu
pip install -r requirements.txt
```

### Download knowledge file
Expand All @@ -16,5 +16,5 @@ wget https://raw.githubusercontent.com/limchiahooi/nlp-chinese/master/%E8%A5%BF%
### Run RAG Demo

```
streamlit run startup.py xiyouji.txt
python newchat.py xiyouji.txt
```
7 changes: 7 additions & 0 deletions examples/transformers/rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
protobuf==3.20.3
streamlit
langchain
langchain-community
faiss-cpu
transformers==4.55.4
sentence_transformers