In [None]:
import torch
torch.cuda.is_available = lambda: False
if torch.cuda.is_available():
  device = torch.device(0)
else:
  device = torch.device('cpu')

In [None]:
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import LangchainEmbedding

# Load embedding
def load_embedding(embedding_source:str = "huggingface"):
  if embedding_source == "openai":
    return OpenAIEmbeddings()
  else:
    llama_model_path = "../models/all-mpnet-base-v2"
    embed_model = HuggingFaceEmbeddings(model_name=llama_model_path)
    return embed_model

In [None]:
embedding = load_embedding()
query_result = embedding.embed_query("上海海事大学")
len(query_result)

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from typing import List
from llama_index.readers.schema.base import Document

def load_documents():
  loader = DirectoryLoader("./data/", "**/*.txt")
  documents = loader.load()
  text_splitter = CharacterTextSplitter(        
   chunk_size = 1000,
   chunk_overlap  = 20,
  )
  texts = text_splitter.split_documents(documents)
  return texts

In [None]:
docs = load_documents()
print(len(docs))

In [None]:
from langchain.vectorstores import FAISS, Chroma

index = FAISS.from_documents(docs, embedding)

def get_similiar_docs(query, k=3, score=False):
  if score:
    similar_docs = index.similarity_search_with_score(query, k=k)
  else:
    similar_docs = index.similarity_search(query, k=k)
  
  # print(similar_docs)
  return similar_docs


In [None]:
similar_docs = get_similiar_docs("领导干部离沪外出请假报告相关的规章制度有哪些？", score=True)
similar_docs

In [None]:
from langchain import PromptTemplate
# load prompt
with open("prompts/question_prompt.txt", "r") as f:
	template_quest = f.read()
with open("prompts/chat_reduce_prompt.txt", "r") as f:
	chat_reduce_template = f.read()
with open("prompts/combine_prompt.txt", "r") as f:
	template = f.read()
with open("prompts/chat_combine_prompt.txt", "r") as f:
	chat_combine_template = f.read()
	
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template,
													template_format="jinja2")

q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest, template_format="jinja2")

In [None]:
from langchain.llms import OpenAI
from langchain.chains.question_answering import load_qa_chain
import os
from utils.model import load_model, load_moss_moon
from utils.customllm import CustomLLM

os.environ['OPENAI_API_KEY'] = "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"
# os.environ['HTTPS_PROXY']="http://10.81.38.5:8443"
# model_name = "text-davinci-003"
model_name = "gpt-3.5-turbo"
# model_name = "gpt-4"

# llm = OpenAI(model_name=model_name)

base_model = "../models/llama-7b-hf"
lora_model_path = "../models/chinese-alpaca-lora-7b"

model, tokenizer = load_model(base_model, lora_model_path)
#model, tokenizer = load_moss_moon()
llm = CustomLLM(model, tokenizer, device)



In [None]:
# chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=c_prompt, question_prompt=q_prompt)
chain = load_qa_chain(llm, chain_type="stuff")

In [None]:
def get_answer(query):
  similar_docs = get_similiar_docs(query)
  answer = chain.run(input_documents=similar_docs, question=query)
  return answer

In [None]:
print(get_answer("上海高级国际航运学院是哪一年成立的？"))

In [None]:
print(get_answer("上海海事大学有多少毕业生？"))

In [None]:
print(get_answer("上海海事大学有几个博士点？"))

In [None]:
print(get_answer("上海海事大学有多少个硕士点？"))

In [None]:
print(get_answer("上海海事大学有马克思主义学院吗？"))

In [None]:
print(get_answer("通知公告的主管部门是？"))

In [None]:
print(get_answer("离沪外出请假报告相关的规章制度有哪些？"))

In [None]:
print(get_answer("信息化专项申报的联系方式是什么号码？"))

In [None]:
print(get_answer("2023年4月19日有什么活动？"))