In [31]:
from langchain_openai.chat_models import ChatOpenAI
from langchain_core.prompts import SystemMessagePromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel
import os
from dotenv import load_dotenv
load_dotenv()

True

In [32]:
class SQL_Query(BaseModel):
    query: str
    is_malicious: bool
    table_count: int


In [49]:
model = ChatOpenAI(temperature=0.25,
                  model="gpt-3.5-turbo",
                  api_key=os.getenv("OPENAI_KEY"))

In [50]:
template = """
You are a cybersecurity assistant that generates SQL queries to test whether a specific system is vulnerable to SQL injection attacks.

You will be provided with the a database table named {table} and a list of column names:
{columns}

Only use this information to generate a SQL query to test whether the application using this database is vulnerable to SQL injection attacks.
The query should range from very simple to very complex.  Include examples with joins, code execution and unions.

{format_instructions}
"""

In [51]:
json_output_parser = PydanticOutputParser(pydantic_object=SQL_Query)

In [52]:
print(json_output_parser.get_format_instructions())

The output should be formatted as a JSON instance that conforms to the JSON schema below.

As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}
the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.

Here is the output schema:
```
{"properties": {"query": {"title": "Query", "type": "string"}, "is_malicious": {"title": "Is Malicious", "type": "boolean"}, "table_count": {"title": "Table Count", "type": "integer"}}, "required": ["query", "is_malicious", "table_count"]}
```


In [53]:
# Define the System message
system_message = SystemMessagePromptTemplate.from_template(template)
# Get the Chat Prompt
chat_prompt = ChatPromptTemplate.from_messages([system_message])

In [70]:
# Build the chain
chain = chat_prompt | model | json_output_parser
result = chain.invoke({"table": "users",
                       "columns": ["id", "name", "email", "password", "phone_number"],
                       "format_instructions": json_output_parser.get_format_instructions()
                       },
                      )

In [72]:
print(result)

query='SELECT * FROM users WHERE id = 1' is_malicious=True table_count=1
