In [12]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


In [13]:
import os
from dotenv import load_dotenv

load_dotenv()

os.environ["COHERE_API_KEY"]=os.getenv("COHERE_API_KEY")

In [14]:
from langchain_cohere import ChatCohere

llm = ChatCohere(model="command-r-plus")

In [None]:
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

In [None]:
system = """Return the names of any SQL tables that are relevant to the user question.
The tables are:

Music
Business
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)

category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

IMPLEMENTATION of SCHEMA RETREIVAL APPROACH


In [15]:
from langchain_chroma import Chroma
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_core.documents import Document
from langchain.chains.sql_database.query import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_core.prompts import PromptTemplate
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings




def setup_vector_store(db):
    """Store table schemas in ChromaDB with semantic search capabilities"""
    # Get all table DDLs
    ddls = [
        Document(
            page_content=db.get_table_info([table_name]),
            metadata={"table_name": table_name}
        )
        for table_name in db.get_usable_table_names()
    ]

    # Initialize embeddings with Cohere model
    embeddings = CohereEmbeddings(model="embed-english-v3.0")

    
    # Create vector store
    return Chroma.from_documents(
        documents=ddls,
        embedding=embeddings,
        collection_name="schema_documents",
        persist_directory="./chroma_db"
    )

def create_semantic_sql_chain(llm, db):
    """Create SQL generation chain with semantic schema retrieval"""
    # Setup vector store
    vector_store = setup_vector_store(db)
    
    # Create retriever for relevant schemas
    retriever = vector_store.as_retriever(search_kwargs={"k": 3})
    
    # Formatting function for retrieved documents
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    template = '''Given an input question, first create a syntactically correct SQL query to run, then look at the results of the query and return the answer.
    Use the following format:

    Question: "Question here"
    SQLQuery: "SQL Query to run"
    SQLResult: "Result of the SQLQuery"
    Answer: "Final answer here"

    DO NOT include any markdown formatting, code blocks, or explanation text such as ```sql in start and ``` in end of the SQL query.
    ONLY return the SQL query itself, nothing else.
    The query should be executable and return the results asked in the question.
    You may return up to {top_k} results.

    Only use the following tables:

    {table_info}.

    Question: {input}'''
    prompt = PromptTemplate.from_template(template)
    # Base SQL query chain
    query_chain = create_sql_query_chain(llm, db,prompt)
    
    # Full chain with semantic retrieval
    return RunnablePassthrough.assign(
        table_info=itemgetter("question") | retriever | format_docs
    ) | query_chain

# Example usage:
chain = create_semantic_sql_chain(llm, db)
result = chain.invoke({"question": "Which artist has the most tracks in the Rock genre, and how many tracks do they have? Include only artists with more than 100 rock tracks."})
print(result)
print(db.run(result))

Failed to multipart ingest runs: langsmith.utils.LangSmithError: Failed to POST https://api.smith.langchain.com/runs/multipart in LangSmith API. HTTPError('403 Client Error: Forbidden for url: https://api.smith.langchain.com/runs/multipart', '{"detail":"Forbidden"}')
Failed to send compressed multipart ingest: langsmith.utils.LangSmithError: Failed to POST https://api.smith.langchain.com/runs/multipart in LangSmith API. HTTPError('403 Client Error: Forbidden for url: https://api.smith.langchain.com/runs/multipart', '{"detail":"Forbidden"}')


SELECT a.Name, COUNT(t.TrackId) AS TrackCount
FROM Artist a
JOIN Album al ON a.ArtistId = al.ArtistId
JOIN Track t ON al.AlbumId = t.AlbumId
JOIN Genre g ON t.GenreId = g.GenreId
WHERE g.Name = 'Rock'
GROUP BY a.Name
HAVING COUNT(t.TrackId) > 100;
[('Led Zeppelin', 114), ('U2', 112)]


Failed to send compressed multipart ingest: langsmith.utils.LangSmithError: Failed to POST https://api.smith.langchain.com/runs/multipart in LangSmith API. HTTPError('403 Client Error: Forbidden for url: https://api.smith.langchain.com/runs/multipart', '{"detail":"Forbidden"}')


Approaching with better chunking and embedding strategies 

In [30]:
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import CohereEmbeddings
from langchain_chroma import Chroma

In [31]:
# Load SQL file
loader = TextLoader("chinook.sql", encoding="utf-8")
documents = loader.load()


In [18]:
# Custom SQL text splitter
sql_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=64,
    separators=[
        "\n\nCREATE TABLE",  # Split at table definitions
        "\n\nALTER TABLE",   # Split at constraints
        "\n\nINSERT INTO",   # Split data inserts
        ";\n",               # Split at SQL statement ends
        "\n",                # Split at newlines
        " "                  # Split at spaces as last resort
    ],
    is_separator_regex=False  # Treat separators as literal strings
)

chunks = sql_splitter.split_documents(documents)

In [19]:
# Merge broken table definitions
combined_chunks = []
current_table = None

for chunk in chunks:
    content = chunk.page_content
    if "CREATE TABLE" in content:
        if current_table:
            combined_chunks.append(current_table)
        current_table = chunk
    elif current_table:
        current_table.page_content += "\n" + content
    else:
        combined_chunks.append(chunk)

if current_table:
    combined_chunks.append(current_table)

chunks = combined_chunks

In [20]:
from langchain_cohere import CohereEmbeddings
import re
import os

embedding_model = CohereEmbeddings(
    model="embed-english-v3.0"
)

from langchain_chroma import Chroma



In [21]:
from langchain_core.documents import Document
from langchain_community.vectorstores.utils import filter_complex_metadata

def process_chunks(db):
    ddls = []
    relationships = {}
    
    for table_name in db.get_usable_table_names():
        ddl = db.get_table_info([table_name])
        
        # Create document with initial metadata
        chunk = Document(
            page_content=ddl,
            metadata={
                "object_type": "table_definition",
                "table_name": table_name,
                "related_tables": ", ".join(re.findall(r'REFERENCES\s+"?(\w+)"?', ddl)),
                "is_central_table": int(len(re.findall(r'REFERENCES\s+"?(\w+)"?', ddl)) > 2)
            }
        )
        
        # Filter metadata values to Chroma-compatible types
        chunk.metadata = {k: v for k, v in chunk.metadata.items() 
                         if isinstance(v, (str, int, float, bool))}
        
        ddls.append(chunk)
        relationships[table_name] = chunk.metadata["related_tables"].split(", ")
    
    return ddls, relationships

# Create vector store with pre-filtered documents
chunks, rel_graph = process_chunks(db)
vector_store = Chroma.from_documents(
    documents=chunks,
    embedding=embedding_model,
    collection_name="cohere_schema",
    persist_directory="./chroma_cohere",
    collection_metadata={"hnsw:space": "cosine"}
)

In [22]:
class CohereRelationshipRetriever:
    def __init__(self, vector_store, relationship_graph):
        # Initialize with Chroma-compatible parameters
        self.base_retriever = vector_store.as_retriever(
            search_kwargs={"k": 10}  # Adjust k as needed
        )
        self.relationship_graph = relationship_graph
        
    def find_related_chunks(self, table_name):
        return [c for c in chunks if c.metadata["table_name"] in self.relationship_graph.get(table_name, [])]
    
    def get_relevant_documents(self, query):
        # Get base results with scores using similarity_search_with_score
        docs_scores = self.base_retriever.vectorstore.similarity_search_with_score(query, k=15)
        
        # Filter by score manually
        filtered_docs = [doc for doc, score in docs_scores if score < 0.28]
        
        # Expand relationships
        expanded = []
        seen = set()
        
        for doc in filtered_docs:
            tbl = doc.metadata.get("table_name")
            if tbl and tbl not in seen:
                expanded.append(doc)
                seen.add(tbl)
                # Add 1st-degree relationships
                expanded += [d for d in self.find_related_chunks(tbl) 
                            if d.metadata["table_name"] not in seen]
                
        return sorted(expanded, 
                    key=lambda x: -len(x.metadata["related_tables"]))[:7]
retriever = CohereRelationshipRetriever(vector_store, rel_graph)


In [27]:
results = retriever.get_relevant_documents(
    "Find customers who purchased tracks from more than 3 genres, "
    "along with their total spending and favorite artist"
)



In [29]:
print(results)

[]
