In [1]:
from langchain.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI

from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, validator, root_validator
from typing import List

import openai
import os
openai.organization = os.getenv('OPENAI_ORG')
openai.api_key = os.getenv('OPENAI_API_KEY')

import sqlvalidator

from rich import print
import ast

In [2]:
model_name = "text-davinci-003"
temperature = 0.0
model = OpenAI(model_name=model_name, temperature=temperature)

In [3]:
class CustomObj(BaseModel):
    
    sql: str = Field(description="sql")

    @root_validator()
    def verify_sql(cls, values):
        sql = values.get("sql")
        try:
            sqlvalidator.parse(sql)
        except:
            raise ValueError("The sql code is invalid.")
        return values

In [4]:
# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=CustomObj)
format_instructions = parser.get_format_instructions()
format_instructions

'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"sql": {"title": "Sql", "description": "sql", "type": "string"}}, "required": ["sql"]}\n```'

In [5]:
query_prefix = '''
Generate a valid SQL query for the following natural language instruction:

Problem Description:
{description}
'''

query = query_prefix.format(description='Select the name of the employee who has the highest salary.')
query

'\nGenerate a valid SQL query for the following natural language instruction:\n\nProblem Description:\nSelect the name of the employee who has the highest salary.\n'

In [6]:
prompt = PromptTemplate(
    template="Answer the user query.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": format_instructions},
)
prompt

PromptTemplate(input_variables=['query'], output_parser=None, partial_variables={'format_instructions': 'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"sql": {"title": "Sql", "description": "sql", "type": "string"}}, "required": ["sql"]}\n```'}, template='Answer the user query.\n{format_instructions}\n{query}\n', template_format='f-string', validate_template=True)

In [7]:
input_prompt = prompt.format_prompt(query=query).to_string()
input_prompt

'Answer the user query.\nThe output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"sql": {"title": "Sql", "description": "sql", "type": "string"}}, "required": ["sql"]}\n```\n\nGenerate a valid SQL query for the following natural language instruction:\n\nProblem Description:\nSelect the name of the employee who has the highest salary.\n\n'

In [8]:
output = model(input_prompt)
custom_obj = parser.parse(output)

In [9]:
print(custom_obj.sql)