<a href="https://colab.research.google.com/github/bjdzliu/ai_lab/blob/main/langchain/OutputParser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install --quiet langchain langchain-openai

In [None]:
!pip3 install cohere openai typing-extensions


In [5]:
from google.colab import userdata
apikey=userdata.get('OPENAI_API_KEY')

In [6]:
import langchain
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain.output_parsers import PydanticOutputParser


temperature = 0
llm = ChatOpenAI(model="gpt-4",openai_api_key=apikey,temperature=temperature) # 默认是gpt-3.5-turbo


In [7]:
from langchain_core.pydantic_v1  import BaseModel, Field, validator
from typing import List, Dict

# 定义你的输出对象
"""
用class Date(BaseMode) 方式定义出来输出格式。
langchain用这个格式去做prompt
llm返回格式，格式的要求，在class Date(BaseMode) 做了定义
"""
class Date(BaseModel):
  year: int=Field(description="Year")
  month: int=Field(description="Month")
  day: int=Field(description="Day")
  era: str=Field(description="BC or AD")

  @validator('month')
  def valid_month(cls,field):
    if field <= 0 or field >12:
      raise ValueError("月份必须在1-12之间")
    return field

  @validator('day')
  def valid_day(cls,field):
    if field <=0 or field > 31:
      raise ValueError("日期必须在1-31日之间")
    return field

  @validator('day', pre=True, always=True)
  def valid_date(cls, day, values):
      year = values.get('year')
      month = values.get('month')

      if year is None or month is None:
        return day

      if month ==2:
        if cls.is_leap_year(year) and day > 29:
          raise  ValueError(">29")
        elif not cls.is_leap_year(year) and day > 28:
          raise ValueError(">28")
      elif month in [4,6,9,11] and day >30:
        raise ValueError(f"{month} 30")
      return day


  @staticmethod
  def is_leap_year(year):
      if year % 400 == 0 or (year % 4 == 0 and year % 100 != 0):
          return True
      return False



In [None]:
selfdate={"year": 2023, "month": 4, "day": 3, "era": "AD"}
selfdate_result=Date(**selfdate)
print(selfdate)

In [None]:
# 根据Pydantic对象的定义，构造一个OutputParser
parser = PydanticOutputParser(pydantic_object=Date)

template = """提取用户输入中的日期。
{format_instructions}
用户输入:
{query}"""

from langchain import PromptTemplate

prompt = PromptTemplate(
    template=template,
    input_variables=["query"],
    # 直接从OutputParser中获取输出描述，并对模板的变量预先赋值
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

print("====Format Instruction=====")
print(parser.get_format_instructions())


query = "2023年四月6日天气晴..."
model_input = prompt.format_prompt(query=query)

print("====Prompt=====")
print(model_input)

"""
model_input type is langchain_core.prompt_values.StringPromptValue
model_input.to_messages() type  is list

"""

output = llm(model_input.to_messages())

print("====模型原始输出=====")
print(output)

print("====Parse后的输出=====")
date = parser.parse(output.content)
print(date)