In [1]:
import os
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = 'college-information-llm'

-----------------Documents Retriever-------------------------

In [2]:
from knowledgebase import TXTKnowledgeBase
SEARCH_DOCS_NUM=2

kb=TXTKnowledgeBase(txt_source_folder_path='lxbd')
vector=kb.return_retriever_from_persistant_vector_db()

documents_retriever = vector.as_retriever(search_kwargs={'k':SEARCH_DOCS_NUM})

------------------Data Source---------------------------

In [3]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

In [4]:
### Router

from typing import Literal

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field



# Data model
class RouteQuery(BaseModel):
    """基于用户的查询词条选择最相关的资料来源"""

    datasource: Literal["vectorstore","database"] = Field(
        ...,
        description="基于用户的问题选择vectrostore或者database.",
    )


# LLM with function call

structured_llm_router = llm.with_structured_output(RouteQuery)

# Prompt
system = """你是一位选择路径的专家，你需要基于用户的提问选择是使用vectorstore还是database.
vectorstore包含了关于总体的在美国留学相关的资料，比如美国大学排名，美国留学申请，美国转学等等.
database包含了特定某一所大学的相关资料，比如这所大学的排名、录取率、学费、生活费、专业设置、犯罪率等等.
如果用户的问题是美国留学相关但是不针对某一所大学的问题，请选择vectorstore，如果是针对某一所美国大学的问题，请选择database."""
route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

question_router = route_prompt | structured_llm_router

"""
print(
    question_router.invoke({"question": "怎么钓鱼？"})
)
print(
    question_router.invoke({"question": "美国大学排名？"})
)
print(
    question_router.invoke({"question": "哈佛学费"})
)
"""

'\nprint(\n    question_router.invoke({"question": "怎么钓鱼？"})\n)\nprint(\n    question_router.invoke({"question": "美国大学排名？"})\n)\nprint(\n    question_router.invoke({"question": "哈佛学费"})\n)\n'

--------------Retriever Generator--------------------

In [5]:
### Generate

from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = ChatPromptTemplate.from_messages([
    ('system',"""你是一位回答问题的助手，你只能使用下面美国续航教育留学宝典的资料回答问题，如果你不知道问题的答案，请回答'不知道'，取回的资料如下
    \n\n{context}\n\n
    如果在回答中遇到机构或者需要留学相关的服务，尽量在回答最后提示可以寻求美国续航教育的服务"""),
    ('human',"{question}")
])


# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# Chain
rag_chain = prompt | llm | StrOutputParser()

"""
# Test Run
question = "美国大学申请流程"
docs = retriever.invoke(question)
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)
"""

'\n# Test Run\nquestion = "美国大学申请流程"\ndocs = retriever.invoke(question)\ngeneration = rag_chain.invoke({"context": docs, "question": question})\nprint(generation)\n'

---------------------database------------------

In [6]:
from colleges import CollegesData
college_data=CollegesData()
college_vector=college_data.return_colleges_vector_from_db()
college_retriever=college_vector.as_retriever()
prompt = ChatPromptTemplate.from_messages([
    ('system',"你是一位了解美国高等院校的专家，你需要根据用户的问题提取出一所美国高等院校的全名，包括中文名和英文名，输出格式为'中文全名（英文全名）'"),
    ('human',"{question}")
])
college_name_chain=prompt | llm | StrOutputParser()
#college_name_chain.invoke({'question':'普林斯顿排名'})

In [7]:
from langchain_core.output_parsers import JsonOutputParser
class College_Info(BaseModel):
    cname:str=Field(description='学校中文全名')
    ename:str=Field(description='学校英文全名')
    postid:str=Field(description='学校的postid')
    unitid:str=Field(description='学校的unitid')
    data_type:str=Field(description='数据种类，可以是排名、录取率、录取人数、专业数、学费、犯罪率这几种中的一种或多种，如果不在这几种类型中使用排名作为默认选项')

college_info_structured_output=llm.with_structured_output(College_Info)
prompt = ChatPromptTemplate.from_messages([
    ("system", "基于下面学校信息内容及用户的问题，按照格式输出学校信息回答"),
    ("human", "用户问题如下：{question}，学校信息内容如下：{context}"),
])
college_info_chain=prompt | college_info_structured_output

"""
college_info_chain.invoke({
    'question':'普林斯顿排名',
    'context':college_retriever.invoke('普林斯顿大学')
})
"""

"\ncollege_info_chain.invoke({\n    'question':'普林斯顿排名',\n    'context':college_retriever.invoke('普林斯顿大学')\n})\n"

In [8]:
import pandas as pd
def plot_college_data(postid=8413,unitid=186131):
    college_df=pd.read_json('https://www.forwardpathway.com/d3v7/dataphp/school_database/ranking_admin_20231213.php?name='+str(postid))
    college_df.plot(x='year',y='rank')
    return college_df

In [9]:
#####Graph state
from typing_extensions import TypedDict
from typing import List


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
        college_name: College name
        data_type: College data type
        plot_type: College data plot type, can be line plot, bar plot, scatter plot or tree plot
    """

    question: str
    generation: str
    documents: List[str]
    college_info:College_Info
    data_type:str
    plot_type:str
    data:pd.DataFrame

------------------Graph Flow------------------------------

In [16]:
import matplotlib.pyplot as plt
def retrieve(state):
    print("---RETRIEVE---")
    question=state["question"]
    documents=documents_retriever.invoke(question)
    return {"documents":documents,"question":question}

def generate(state):
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.stream({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def route_question(state):
    print("---ROUTE QUESTION---")
    question = state["question"]
    source = question_router.invoke({"question": question})
    if source.datasource == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"
    elif source.datasource=="database":
        print("---ROUTE QUESTION TO DATABASE----")
        return "database"

def get_college_info(state):
    print("---COLLEGE NAME---")
    question=state["question"]
    college_name=college_name_chain.invoke({'question':question})
    college_info=college_info_chain.invoke({
        'question':question,
        'context':college_retriever.invoke(college_name)
    })
    return {'college_info':college_info,'question':question}

def college_data_plot(state):
    question=state['question']
    college_info=state['college_info']
    college_df=pd.read_json('https://www.forwardpathway.com/d3v7/dataphp/school_database/ranking_admin_20231213.php?name='+str(college_info.postid))
    return {'college_info':college_info,'question':question,'data':college_df}

----------------------Build Graph--------------------------------

In [17]:
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

workflow.add_node('retrieve',retrieve)
workflow.add_node('database',get_college_info)
workflow.add_node('generate',generate)
workflow.add_node('college_data_plot',college_data_plot)

workflow.set_conditional_entry_point(
    route_question,
    {
        "vectorstore": "retrieve",
        "database":"database"
    },
)

workflow.add_edge("retrieve","generate")
workflow.add_edge("generate",END)
workflow.add_edge("database","college_data_plot")
workflow.add_edge("college_data_plot",END)
app = workflow.compile()

In [18]:
from pprint import pprint

# Run
inputs = {
    "question": "哈佛大学排名"
}

for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("-------------------------------")
    if 'generation' in value:
        pprint(value["generation"])



---ROUTE QUESTION---
---ROUTE QUESTION TO DATABASE----
---COLLEGE NAME---
"Node 'database':"
'-------------------------------'
"Node 'college_data_plot':"
'-------------------------------'


In [52]:
inputs = {
    "question": "美国大学排名"
}
response=app.invoke(inputs)

---ROUTE QUESTION---
---ROUTE QUESTION TO RAG---
---RETRIEVE---
---GENERATE---
