# Database Explorer

This notebook can help you explore a database using langchain. It should be able to connect to any SQL Alchemy supported database, read the structure, allow you to see some of the underlying data and ask questions.

## Enter a DB URL and your OpenAi api_key

In [3]:
# db_url is a Sql Alchemy compliant connection string
# see https://docs.sqlalchemy.org/en/20/core/engines.html for more details
db_url = "<enter a connection string>"

# or you can set an env variable with the api_key
api_key = "<enter an api key>"

## Select the table you wish to view

In [4]:
import pandas as pd
from sqlalchemy import create_engine, MetaData, text
from ipywidgets import widgets

engine = create_engine(db_url)
metadata = MetaData()
metadata.reflect(bind=engine)

tabs = []
names = []
for tbl in metadata.sorted_tables:
    cols = []
    for col in tbl.columns:
        cols.append(col.name)

    with engine.begin() as connection:
        cursor = connection.execute(text(f"""select * from "{tbl.name}" limit 5"""))
        anz = cursor.fetchall()
        df = pd.DataFrame(anz, columns=cols)
        # display(df)

        names.append(tbl.name)
        tabs.append(df)

df_by_name = dict(zip(names, tabs))
table_names = widgets.Dropdown(options=names)
output_table = widgets.Output()


def table_name_eventhandler(change):
    output_table.clear_output()
    with output_table:
        display(df_by_name[change.new])


table_names.observe(table_name_eventhandler, names="value")

display(table_names)

Dropdown(options=('city', 'country', 'countrylanguage'), value='city')

In [5]:
display(output_table)

Output()

## Ask your question

In [6]:
question = "which country has the most cities?"

In [7]:
from langchain import OpenAI, SQLDatabase
from langchain.chains import SQLDatabaseSequentialChain

db = SQLDatabase.from_uri(db_url)

llm = OpenAI(temperature=0, openai_api_key=api_key)
db_chain = SQLDatabaseSequentialChain.from_llm(
    llm=llm, database=db, verbose=False, return_intermediate_steps=True
)
result = db_chain(question)

In [8]:
import json
import ast
import sqlparse

sql_script = result["intermediate_steps"][0]


def get_query_columns(sql):
    stmt = sqlparse.parse(sql)[0]
    columns = []
    column_identifiers = []

    # get column_identifieres
    in_select = False
    for token in stmt.tokens:
        if isinstance(token, sqlparse.sql.Comment):
            continue
        if str(token).lower() == "select":
            in_select = True
        elif in_select and token.ttype is None:
            if isinstance(token, sqlparse.sql.Function):
                column_identifiers.append(token)
            else:
                for identifier in token.get_identifiers():
                    column_identifiers.append(identifier)
            break

    # get column names
    for column_identifier in column_identifiers:
        columns.append(column_identifier.get_name())

    return columns


cols = get_query_columns(sql_script)

data = ast.literal_eval(result["intermediate_steps"][1])

print(f"Natural Language Query:      {result['query']}")
print(f"Sql Query:                  {sql_script}")

display(pd.DataFrame(data, columns=cols))
print(result["result"])

Natural Language Query:      which country has the most cities?
Sql Query:                   SELECT countrycode, COUNT(*) AS num_cities FROM city GROUP BY countrycode ORDER BY num_cities DESC LIMIT 5;


Unnamed: 0,countrycode,num_cities
0,CHN,363
1,IND,341
2,USA,274
3,BRA,250
4,JPN,248


 China has the most cities with 363.


## Here's another prompt to construct the SQL, please iterate to get the exact answer you expect.

In [9]:
from langchain.prompts import PromptTemplate
from langchain import OpenAI

# llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
    input_variables=["sql_script"],
    template="""given the sql "{sql_script}" write a specific prompt,""",
)

the_prompt = prompt.format(sql_script=sql_script)


llm = OpenAI(temperature=0.0, openai_api_key=api_key)

print(llm(the_prompt))



Which five countries have the most cities?
