# AI Database Agent for YugabyteDB

Learn new tech by building a simple AI-powered SQL agent for YugabyteDB.

This agent utilizes LangChain to create a flow that takes user questions in plain English, then uses an LLM (Large Language Model) to generate a SQL request. It executes the request on your database and then uses the LLM again to respond as a human would, or to convert the response into a JSON object for your downstream APIs.

## Prerequisites

* [Docker](https://www.docker.com)
* Python and pip.
* [OpenAI API key](https://platform.openai.com).

## Set Up Environment

Use pip to install required modules:

In [None]:
! pip install psycopg2 langchain langchain_openai langchain_experimental

Start a YugabyteDB node in Docker:

In [None]:
! rm -r ~/yb_docker_data
! mkdir ~/yb_docker_data

! docker network create yb-network
! docker run -d --name yugabytedb-node1 --net yb-network \
    -p 15433:15433 -p 5433:5433 \
    -v ~/yb_docker_data/node1:/home/yugabyte/yb_data --restart unless-stopped \
    yugabytedb/yugabyte:latest \
    bin/yugabyted start --base_dir=/home/yugabyte/yb_data --background=false

Load a sample dataset:

In [None]:
# Copy the schema and data files to the container
! docker cp ./schema.sql yugabytedb-node1:/home
! docker cp ./movie_data.sql yugabytedb-node1:/home
! docker cp ./user_data.sql yugabytedb-node1:/home

# Load the dataset into the database
! docker exec -it yugabytedb-node1 bin/ysqlsh -h yugabytedb-node1 -f /home/schema.sql
! docker exec -it yugabytedb-node1 bin/ysqlsh -h yugabytedb-node1 -f /home/movie_data.sql
! docker exec -it yugabytedb-node1 bin/ysqlsh -h yugabytedb-node1 -f /home/user_data.sql

## Provide OpenAI API Key

Provide your OpenAI API key by setting it as the `OPENAI_API_KEY` environment variable and run the code snippet below. If the variable is not set, you'll be prompted to enter the key:

In [None]:
import os
from getpass import getpass

openai_key = os.getenv('OPENAI_API_KEY')

if (openai_key == None):
    openai_key = getpass('Provide your OpenAI API key: ')

if (not openai_key):
    raise Exception('No OpenAI API key provided. Please set the OPENAI_API_KEY environment variable or provide it when prompted.')

print('OpenAI API key set.')

## Train SQL Agent for YugabyteDB

Prepare a system prompt that defines SQL agent's behavior and clarifes the task.

In [83]:
def prepare_agent_prompt(input_text):
    agent_prompt = f"""
    Query the database using PostgreSQL syntax.

    It is not necessary to search on all columns, only those necessary for a query. 
        
    Generate a PostgreSQL query using the input: {input_text}. 
        
    Responds like a human would.
    """

    return agent_prompt

Initialize LangChain's OpenAI and SQL agents:

In [None]:
import psycopg2
from langchain.sql_database import SQLDatabase
from langchain_openai import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

# Initialize the OpenAI's agent
openai = OpenAI(
    api_key=openai_key,
    temperature=0, # the model's creativity. 0 = deterministic output with minimal creativity. 1 = very diverse and creative.
    max_tokens=-1 # the maximum number of tokens to generate in the completion. -1 returns as many tokens as possible given the prompt and the models maximal context size
    )

# Initialize LangChain's database agent
database = SQLDatabase.from_uri(
    "postgresql+psycopg2://yugabyte:yugabyte@localhost:5433/yugabyte", 
    include_tables=["movie", "user_account", "user_library"]);

# Initialize LangChain's database chain agent
db_chain = SQLDatabaseChain.from_llm(openai, db=database, verbose=True, use_query_checker=True, return_intermediate_steps=True)

Experiment with the SQL agent by running the code snippet below and asking the following questions (one at a time): 

1st set:
* How many movies are in the database?
* What are the three most popular movies?
* Find the name and rank of the movie with the highest rating.

2nd set:
* How many sci-fi movies are in the database?
* Find five movies with the highest revenue in the action genre. Return the revenue in the dollar format.

3rd set:
* What is the most popular genre?
* I want to know three most popular genres of movies.

4th set:
* List five studios that have produced the most movies.
* What movies have users added the most to their library?

5th set:
* Find studios which movies are added the most to the user's library.

In [None]:
user_prompt = "I want to know three most popular genres of movies."

agent_prompt = prepare_agent_prompt(user_prompt)

try:
    result = db_chain.invoke(agent_prompt)

    print(f"Answer: {result['result']}")
except (Exception, psycopg2.Error) as error:
    print(error)