In [1]:
from dotenv import load_dotenv
import dspy
import chromadb
import chromadb.utils.embedding_functions as embedding_functions
from dspy.retrieve.chromadb_rm import ChromadbRM
import os
import uuid
import pandas as pd
import sqlite3

load_dotenv("../.env")

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
turbo = dspy.OpenAI(model="gpt-3.5-turbo")

In [3]:
CHROMADB_DATA = "./chromadb_data"

openai_ef = embedding_functions.OpenAIEmbeddingFunction(
    api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)

chroma_client = chromadb.PersistentClient(
    path=CHROMADB_DATA,
)

# if collection does not exist create it
collection = chroma_client.get_or_create_collection(
    name="datasets_info", embedding_function=openai_ef
)

In [4]:
def split_text(text: str, chunks: int, overlap: int):
    return [
        text[i : i + chunks - overlap]
        for i in range(0, len(text), chunks - overlap)
    ]

In [5]:
with open("./data/schema.sql") as f:
    schema = f.read()

docs = split_text(schema, 1000, 300)
ids = [uuid.uuid4().hex for _ in range(len(docs))]

collection.add(
    ids=ids, documents=docs, metadatas=[{"name": "sql schema"}] * len(docs)
)

rm = ChromadbRM(
    collection_name="datasets_info",
    persist_directory=CHROMADB_DATA,
    embedding_function=openai_ef,
)

In [6]:
print(rm("Artists"))

[{'id': '10299b7e4a8942919323593d718ef4d9', 'score': 1.3780223629133594, 'long_text': 'BEGIN TRANSACTION;\nCREATE TABLE [Album]\n(\n    [AlbumId] INTEGER  NOT NULL,\n    [Title] NVARCHAR(160)  NOT NULL,\n    [ArtistId] INTEGER  NOT NULL,\n    CONSTRAINT [PK_Album] PRIMARY KEY  ([AlbumId]),\n    FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n);\nCREATE TABLE [Artist]\n(\n    [ArtistId] INTEGER  NOT NULL,\n    [Name] NVARCHAR(120),\n    CONSTRAINT [PK_Artist] PRIMARY KEY  ([ArtistId])\n);\nCREATE TABLE [Customer]\n(\n    [CustomerId] INTEGER  NOT NULL,\n    [FirstName] NVARCHAR(40)  NOT NULL,\n    [LastName] NVARCHAR(20)  NOT NULL,\n    [Company] NVARCHAR(80),\n    [Address] NVARCHAR(70),\n    [City] NVARCHAR(40),\n    [State] NVARCHAR(40),\n    [C', 'metadatas': {'name': 'sql schema'}}, {'id': '6b612a7509184b618cfc15d0a04aca35', 'score': 1.412078388951942, 'long_text': 'BEGIN TRANSACTION;\nCREATE TABLE [Album]\n(\n    [AlbumId] I

In [41]:
class TextToSQLAnswer(dspy.Signature):
    """Convert natural language text to SQL the using a database schema."""

    question: str = dspy.InputField()
    context: str = dspy.InputField()
    sql: str = dspy.OutputField(desc="sql string, no code fences or premeable")

In [44]:
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_sql = dspy.ChainOfThought(TextToSQLAnswer)

    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_sql(context=context, question=question)
        return dspy.Prediction(context=context, sql=prediction.sql)


with dspy.context(lm=turbo, rm=rm):
    my_question = "Who are the top 10 selling artists and how much did each of them sold?"
    answer = RAG().forward(my_question)
    print(answer.sql)

SELECT Artist.Name, SUM(InvoiceLine.UnitPrice * InvoiceLine.Quantity) AS TotalSales
FROM Artist
JOIN Album ON Artist.ArtistId = Album.ArtistId
JOIN Track ON Album.AlbumId = Track.AlbumId
JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId
JOIN Invoice ON InvoiceLine.InvoiceId = Invoice.InvoiceId
GROUP BY Artist.Name
ORDER BY TotalSales DESC
LIMIT 10;


In [45]:
conn = sqlite3.connect("./data/Chinook.db")
pd.read_sql(answer.sql, conn)

Unnamed: 0,Name,TotalSales
0,Iron Maiden,138.6
1,U2,105.93
2,Metallica,90.09
3,Led Zeppelin,86.13
4,Lost,81.59
5,The Office,49.75
6,Os Paralamas Do Sucesso,44.55
7,Deep Purple,43.56
8,Faith No More,41.58
9,Eric Clapton,39.6
