# Interacting with SQL Databases Using Langchain's SQL Agents

## Libraries and Settings

In [None]:
# Libraries
import os
import json
import pandas as pd
from sqlalchemy import create_engine, text
import matplotlib.pyplot as plt

from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType

# Read OpenAI API key
try:
    with open('/workspace/credentials.json') as f:
        credentials = json.load(f)
    api_key = credentials['openai']['api_key']
except:
    print("Please provide your OpenAI API key in the credentials.json file.")

# Settings
import warnings
warnings.filterwarnings("ignore")

# Current working directory
print(os.getcwd())

## Read apartment data to data frame

In [None]:
# Read apartment data
df = pd.read_csv('/workspace/apartments_data_prepared.csv', sep=',')
df.head(5)

## Write data to database

In [None]:
# Create connection
engine = create_engine('postgresql://pguser:geheim@db:5432/postgres')

# Write data to table
df.to_sql('apartment_table', engine, if_exists='replace')

# Dispose the engine
engine.dispose()

## List tables in the database

In [None]:
# Create a connection
engine = create_engine('postgresql://pguser:geheim@db:5432/postgres')

# Open a connection
with engine.connect() as connection:

    # Execute the query
    result = connection.execute(text("""SELECT table_name
                                        FROM information_schema.tables
                                        WHERE table_schema = 'public'"""))
    
    # Fetch and print the results
    for row in result:
        print(row[0])

# Dispose the engine
engine.dispose()

## Make standard SQL query to select data

In [None]:
# Create a connection
engine = create_engine('postgresql://pguser:geheim@db:5432/postgres')

# Write data to table
df_sub = pd.read_sql_query('''SELECT
                              address_raw,
                              rooms,
                              area,
                              price
                              FROM apartment_table
                              WHERE price >= 1000''', 
                          con=engine)

# Dispose the engine
engine.dispose()

# Show the data
df_sub.head()

## Count number of apartments in table

In [None]:
# Create a connection
engine = create_engine('postgresql://pguser:geheim@db:5432/postgres')

# Count number of apartments
count_query = '''SELECT COUNT(*) as apartment_count
                 FROM apartment_table'''

# Execute the query and fetch the result
result = pd.read_sql_query(count_query, con=engine)

# Dispose the engine
engine.dispose()

# Show the count
print("Number of apartments:", result['apartment_count'].iloc[0])

## Calculate average price of selected apartments

In [None]:
# Create a connection
engine = create_engine('postgresql://pguser:geheim@db:5432/postgres')

# Query to get the average price of apartments with 3.5 rooms and >= 100m² living area
query = '''SELECT AVG(price) as average_price
           FROM apartment_table
           WHERE rooms = 3.5 
           AND area >= 100'''

# Execute the query and fetch the result
result = pd.read_sql_query(query, con=engine)

# Dispose the engine
engine.dispose()

# Show the average price
print("Average price of apartments with 3.5 rooms and >= 100m² living area:", result['average_price'].iloc[0])

## Plot apartment prices

In [None]:
# Plot Histogram
fig = plt.figure( figsize=(7,4))
plt.xticks(fontsize=14, rotation=0)
plt.yticks(fontsize=14, rotation=0)
n, bins, patches = plt.hist(x=df_sub['price'], 
                            bins=20, 
                            color='#5DADE2',
                            alpha=1.00, 
                            rwidth=0.95
                   )
plt.grid(True)
plt.ticklabel_format(style='plain')
plt.grid(axis='y', alpha=0.75)

# Set labels
plt.xlabel('price', fontsize=10, labelpad=10)
plt.ylabel('Frequency', fontsize=10, labelpad=10)
plt.title('Histogram of apartment prices', fontsize=12, pad=10)

# Set fontsize of tick labels
plt.xticks(fontsize = 10)
plt.yticks(fontsize = 10)

plt.show()

## Use an SQL Database Agent with LangChain to query the database

### Instantiate the LLM

In [None]:
# Initialize the OpenAI language model
llm = OpenAI(temperature=0, verbose=True, openai_api_key=api_key)

### Create the SQL agent executor

In [None]:
# Create the SQL database connection
db = SQLDatabase.from_uri('postgresql://pguser:geheim@db:5432/postgres')

# Create the SQL agent executor
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION)

### Let the agent query the database (1st example)

In [None]:
# Define the question
agent_executor.run("""Describe the table apartment_table""")

### Let the agent query the database (2nd example)

In [None]:
# Define the question
agent_executor.run("""How many unique apartments are in the table apartment_table?""")

### Let the agent query the database (3nd example)

In [None]:
agent_executor.run(
    """What is the average price of apartments with 3.5 rooms 
    and >= 100m2 living area in the table apartment_table?""")

### Jupyter notebook --footer info-- (please always provide this at the end of each notebook)

In [None]:
import os
import platform
import socket
from platform import python_version
from datetime import datetime

print('-----------------------------------')
print(os.name.upper())
print(platform.system(), '|', platform.release())
print('Datetime:', datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print('Python Version:', python_version())
print('-----------------------------------')