# [Beta] Text-to-SQL with PGVector

This notebook demo shows how to perform text-to-SQL with pgvector. This allows us to jointly do both semantic search and structured querying, *all* within SQL!

This hypothetically enables more expressive queries than semantic search + metadata filters.

**NOTE**: This is a beta feature, interfaces might change. But in the meantime hope you find it useful! 

## Setup Data

### Load Documents

Load in the Lyft 2021 10k document.

In [None]:
from llama_hub.file.pdf.base import PDFReader

In [None]:
reader = PDFReader()

In [None]:
docs = reader.load_data("../data/10k/lyft_2021.pdf")

In [None]:
from llama_index.node_parser import SimpleNodeParser

node_parser = SimpleNodeParser.from_defaults()
nodes = node_parser.get_nodes_from_documents(docs)

In [None]:
print(nodes[8].get_content(metadata_mode="all"))

### Insert data into Postgres + PGVector

Make sure you have all the necessary dependencies installed! 

In [None]:
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet

In [None]:
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column

#### Establish Connection

In [None]:
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
    conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
    conn.commit()

#### Define Table Schema 

Define as Python class. Note we store the page_label, embedding, and text.

In [None]:
Base = declarative_base()


class SECTextChunk(Base):
    __tablename__ = "sec_text_chunk"

    id = mapped_column(Integer, primary_key=True)
    page_label = mapped_column(Integer)
    file_name = mapped_column(String)
    text = mapped_column(String)
    embedding = mapped_column(Vector(384))

In [None]:
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)

#### Generate embedding for each Node with a sentence_transformers model

In [None]:
# get embeddings for each row
from llama_index.embeddings import HuggingFaceEmbedding

embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")

for node in nodes:
    text_embedding = embed_model.get_text_embedding(node.get_content())
    node.embedding = text_embedding

#### Insert into Database

In [None]:
# insert into database
for node in nodes:
    row_dict = {
        "text": node.get_content(),
        "embedding": node.embedding,
        **node.metadata,
    }
    stmt = insert(SECTextChunk).values(**row_dict)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()

## Define PGVectorSQLQueryEngine

Now that we've loaded the data into the database, we're ready to setup our query engine.

### Define Prompt

We create a modified version of our default text-to-SQL prompt to inject awareness of the pgvector syntax.
We also prompt it with some few-shot examples of how to use the syntax (<-->). 

**NOTE**: This is included by default in the `PGVectorSQLQueryEngine`, we included it here mostly for visibility!

In [None]:
from llama_index.prompts import PromptTemplate

text_to_sql_tmpl = """\
Given an input question, first create a syntactically correct {dialect} \
query to run, then look at the results of the query and return the answer. \
You can order the results by a relevant column to return the most \
interesting examples in the database.

Pay attention to use only the column names that you can see in the schema \
description. Be careful to not query for columns that do not exist. \
Pay attention to which column is in which table. Also, qualify column names \
with the table name when needed. 

IMPORTANT NOTE: you can use specialized pgvector syntax (`<-->`) to do nearest \
neighbors/semantic search to a given vector from an embeddings column in the table. \
The embeddings value for a given row typically represents the semantic meaning of that row. \
The vector represents an embedding representation \
of the question, given below. Do NOT fill in the vector values directly, but rather specify a \
`[query_vector]` placeholder. For instance, some select statement examples below \
(the name of the embeddings column is `embedding`):
SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;
SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;
SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;

You are required to use the following format, \
each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}


Question: {query_str}
SQLQuery: \
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)

### Setup LLM, Embedding Model, and Misc.

Besides LLM and embedding model, note we also add annotations on the table itself. This better helps the LLM 
understand the column schema (e.g. by telling it what the embedding column represents) to better do 
either tabular querying or semantic search.

In [None]:
from llama_index import ServiceContext, SQLDatabase
from llama_index.llms import OpenAI
from llama_index.query_engine import PGVectorSQLQueryEngine

sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])

llm = OpenAI(model="gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)

table_desc = """\
This table represents text chunks from an SEC filing. Each row contains the following columns:

id: id of row
page_label: page number 
file_name: top-level file name
text: all text chunk is here
embedding: the embeddings representing the text chunk

For most queries you should perform semantic search against the `embedding` column values, since \
that encodes the meaning of the text.

"""

context_query_kwargs = {"sec_text_chunk": table_desc}

  self._metadata.reflect(


### Define Query Engine

In [None]:
query_engine = PGVectorSQLQueryEngine(
    sql_database=sql_database,
    text_to_sql_prompt=text_to_sql_prompt,
    service_context=service_context,
    context_query_kwargs=context_query_kwargs,
)

## Run Some Queries

Now we're ready to run some queries

In [None]:
response = query_engine.query(
    "Can you tell me about the risk factors described in page 6?",
)

  for column in self._inspector.get_columns(table_name):


In [None]:
print(str(response))

The text on page 6 discusses the impact of the COVID-19 pandemic on the business. It mentions that the pandemic has affected communities in the United States, Canada, and globally. It has also led to significant disruptions in the business, including a decrease in the number of riders and drivers, reduced hours of operation, and increased costs. The text also discusses the company's transportation network, which offers riders seamless, personalized, and on-demand access to a variety of mobility options.


In [None]:
print(response.metadata["sql_query"])

In [None]:
response = query_engine.query(
    "Tell me more about Lyft's real estate operating leases",
)

In [None]:
print(str(response))

Lyft's lease arrangements include vehicle rental agreements that are accounted for as operating leases. These leases do not meet any specific criteria that would categorize them otherwise. The company's leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease or the useful life of the assets.


In [None]:
print(response.metadata["sql_query"][:300])

SELECT * FROM sec_text_chunk WHERE text LIKE '%Lyft%' AND text LIKE '%real estate%' AND text LIKE '%operating leases%' ORDER BY embedding <-> '[-0.06691089272499084, -0.41431307792663574, 0.2750679850578308, 0.19374045729637146, 0.08942480385303497, -0.16577985882759094, 0.399348646402359, 0.3634052


In [None]:
# looked at returned result
print(response.metadata["result"])

[(157, 93, 'lyft_2021.pdf', "Leases that do not meet any of the above criteria are accounted for as operating leases.Lessor\nThe\n Company's lease arrangements include vehicle re ... (4356 characters truncated) ...  realized. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.", '[0.16887704,-0.22762142,0.040292107,0.2951868,0.034039058,-0.092776,0.23275128,0.12367551,0.17209437,-0.08910224,0.30044347,0.1590553,0.21984532,-0.1 ... (4111 characters truncated) ... 0.24707487,0.10685501,0.42726353,-0.16156487,-0.2705381,-0.15468368,0.100748956,-0.19910589,-0.06634029,-0.7986131,-0.14139938,0.55980897,0.31352338]')]


In [None]:
# structured query
response = query_engine.query(
    "Tell me about the max page number in this table",
)

In [None]:
print(str(response))

The maximum page number in this table is 238.


In [None]:
print(response.metadata["sql_query"][:300])

SELECT MAX(page_label) FROM sec_text_chunk;
