In [1]:
import os
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "college-information-llm"
import pandas as pd
import csv
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field

In [2]:
llm = ChatOpenAI(model='gpt-4o')

In [3]:
from knowledgebase import TXTKnowledgeBase

In [4]:
kb=TXTKnowledgeBase(txt_source_folder_path='lxbd')

In [5]:
#kb.initiate_documents()

In [6]:
vector=kb.return_retriever_from_persistant_vector_db()
retriever = vector.as_retriever(search_kwargs={'k':4})

In [7]:
from langchain.tools.retriever import create_retriever_tool

In [8]:
retriever_tool=create_retriever_tool(
    retriever,
    name='search_international_students_related_information',
    description='搜索并返回关于在美国留学相关的信息',
)

another tools

In [64]:
from langchain_core.tools import tool
from langchain.chains import create_history_aware_retriever
from langchain_core.prompts import MessagesPlaceholder
from colleges import CollegesData

cd=CollegesData()

college_vector=cd.return_colleges_vector_from_db()
college_retriever=college_vector.as_retriever()

@tool
def college_data(msg:str):
    """搜索美国大学数据相关的内容"""
    prompt_retriever = ChatPromptTemplate.from_messages([
        MessagesPlaceholder(variable_name="chat_history"),
        ('user',"{input}"),
        ("user", '基于聊天内容与用户输入，生成一个可以用于查询内容的学校全名，包含中文名与英文名，只回答学校全名，除此以外不回答任何内容')
    ])
    college_retriever_chain = create_history_aware_retriever(llm, college_retriever, prompt_retriever)

    class College_Info(BaseModel):
        cname:str=Field(description='学校中文全名')
        ename:str=Field(description='学校英文全名')
        postid:str=Field(description='学校的postid')
        unitid:str=Field(description='学校的unitid')
        data_type:str=Field(description='数据种类，可以是排名、录取率、录取人数、专业数、学费、犯罪率这几种中的一种或多种')
        
    parser=JsonOutputParser(pydantic_object=College_Info)
    
    prompt_document = ChatPromptTemplate.from_messages([
    ("system", "基于下面内容及用户输入，根据format_instructions输出，数据种类必须完全符合类型中的一种或几种\n\n{context}\n{format_instructions}"),
    MessagesPlaceholder(variable_name="chat_history"),
    ("user", "{input}"),
])
    document_chain=create_stuff_documents_chain(llm,prompt_document)
    college_chain=create_retrieval_chain(college_retriever_chain,document_chain)
    response=college_chain.invoke({
        'chat_history':[{'role':'ai','content':'how can I help you?'}],
        'input':msg,
        'format_instructions':parser.get_format_instructions()
    })
    print (extract_json(response['answer']))
    return response

In [65]:
import json
import re
from typing import List
def extract_json(msg)->List[dict]:
    pattern = r"```json(.*?)```"
    matches = re.findall(pattern, msg, re.DOTALL)
    try:
        return [json.loads(match.strip()) for match in matches]
    except Exception:
        raise ValueError(f"Failed to parse: {message}")

In [66]:
tools=[retriever_tool,college_data]

In [67]:
#from langchain import hub

#prompt = hub.pull("hwchase17/openai-tools-agent")
#prompt
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "不要改变输出内容格式"),
        ("placeholder", "{chat_history}"),
        ("human", "{input}"),
        ("placeholder", "{agent_scratchpad}"),
    ]
)

In [86]:
from langchain.agents import AgentExecutor, create_openai_tools_agent

agent = create_openai_tools_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, return_intermediate_steps=True)

In [None]:
result = agent_executor.invoke({"input": "普林斯顿大学留学"})

In [None]:
result

In [2]:
colleges=pd.read_csv('.\\colleges\\colleges.csv')

In [15]:
embedding=OpenAIEmbeddings(model='text-embedding-3-small')

In [35]:
row_txt_format='中文名：{c}，英文名：{e}，类型：{t}，postid：{p}，unitid：{u}，所在州：{s}'
college_types={1:'综合大学',2:'文理学院',3:'社区大学'}
txts=[]
for index,row in colleges.iterrows():
    row_txt=row_txt_format.format(c=row['cname'],e=row['name'],t=college_types[row['type']],p=row['postid'],u=row['unitid'],s=row['state'])
    txts.append(row_txt)

In [36]:
vectors=FAISS.from_texts(txts,embedding)

In [41]:
vectors.save_local(folder_path='vector',index_name='colleges-data-vector')

In [37]:
retriever=vectors.as_retriever()

In [40]:
retriever.invoke('princeton')

[Document(page_content='中文名：普林斯顿大学，英文名：Princeton University，类型：综合大学，postid：8413，unitid：186131，所在州：NJ'),
 Document(page_content='中文名：普林西庇亚学院，英文名：Principia College，类型：文理学院，postid：56413，unitid：148016，所在州：IL'),
 Document(page_content='中文名：迪堡大学，英文名：DePauw University，类型：文理学院，postid：37722，unitid：150400，所在州：IN'),
 Document(page_content='中文名：布朗大学，英文名：Brown University，类型：综合大学，postid：9427，unitid：217156，所在州：RI')]