### Basic RAG retriever using Langchain and Anthropic model

In [1]:
import bs4
from langchain import hub
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_anthropic import ChatAnthropic
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase

import os
from dotenv import load_dotenv
load_dotenv()
os.environ["LANGCHAIN_API_KEY"]= os.getenv("LANGCHAIN_API_KEY")
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620",temperature=0)
os.environ["LANGCHAIN_PROJECT"] = "Trials"

USER_AGENT environment variable not set, consider setting it to identify your requests.


In [2]:
db = SQLDatabase.from_uri("sqlite:////Users/main/Desktop/database/llm.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT count(*) FROM restaurants LIMIT 10;")

sqlite
['restaurants', 'restaurants_madrid']


'[(1191,)]'

In [3]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many restaurants are in the restaurants_madrid table, just give me the query, nothing else"})
response

'SELECT COUNT(*) FROM restaurants_madrid'

In [4]:
db.run(response)

'[(2380,)]'

In [5]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [21]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, verbose=True)

In [22]:
agent_executor.invoke(
    "List the total amount of restaurantst?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: 
[0m[38;5;200m[1;3mrestaurants, restaurants_madrid[0m[32;1m[1;3mBased on the observation, we have two tables: "restaurants" and "restaurants_madrid". To get the total amount of restaurants, we should count the number of entries in the "restaurants" table, as it likely contains all restaurants.

Let's check the schema of the "restaurants" table to ensure we have the correct information.

Action: sql_db_schema
Action Input: restaurants
[0m[33;1m[1;3m
CREATE TABLE restaurants (
	"index" INTEGER, 
	business_status TEXT, 
	icon TEXT, 
	icon_background_color TEXT, 
	icon_mask_base_uri TEXT, 
	name TEXT, 
	place_id TEXT, 
	rating REAL, 
	reference TEXT, 
	scope TEXT, 
	user_ratings_total REAL, 
	vicinity TEXT, 
	"geometry.location.lat" REAL, 
	"geometry.location.lng" REAL, 
	"geometry.viewport.northeast.lat" REAL, 
	"geometry.viewport.northeast.lng" REAL, 
	"geometry.viewport.so

{'input': 'List the total amount of restaurantst?',
 'output': 'The total number of restaurants in the database is 1,191.'}

In [23]:
examples = [
    {
        "input": "List all operational places.",
        "query": "SELECT * FROM Places WHERE Status = 'OPERATIONAL';",
    },
    {
        "input": "Find all places with a rating of 4.5 or higher.",
        "query": "SELECT * FROM Places WHERE Rating >= 4.5;",
    },
    {
        "input": "Get the details of the place with the highest rating.",
        "query": "SELECT * FROM Places ORDER BY Rating DESC LIMIT 1;",
    },
    {
        "input": "List all places that are temporarily closed.",
        "query": "SELECT * FROM Places WHERE Status = 'CLOSED_TEMPORARILY';",
    },
    {
        "input": "Find the total number of places in Madrid.",
        "query": "SELECT COUNT(*) FROM Places WHERE Address LIKE '%Madrid%';",
    },
    {
        "input": "List all places located on Calle de Recoletos.",
        "query": "SELECT * FROM Places WHERE Address LIKE '%Calle de Recoletos%';",
    },
    {
        "input": "Find the place with the highest price range.",
        "query": "SELECT * FROM Places ORDER BY PriceRange DESC LIMIT 1;",
    },
    {
        "input": "How many places are there with a rating of 4.4 or lower?",
        "query": "SELECT COUNT(*) FROM Places WHERE Rating <= 4.4;",
    },
    {
        "input": "List all places with a price range between 600 and 2000.",
        "query": "SELECT * FROM Places WHERE PriceRange BETWEEN 600 AND 2000;",
    },
    {
        "input": "Find the average rating of all operational places.",
        "query": "SELECT AVG(Rating) FROM Places WHERE Status = 'OPERATIONAL';",
    },
]

In [28]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_community.embeddings import OllamaEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OllamaEmbeddings(model='nomic-embed-text'),
    Chroma,
    k=5,
    input_keys=["input"],
)

In [30]:
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

In [31]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [32]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)