In [1]:
import chromadb
import duckdb
import pandas as pd
import plotly.express as px
from sentence_transformers import SentenceTransformer
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
conn = duckdb.connect(database=":memory:")
conn.execute(
    """
    CREATE TABLE sales (date DATE, product TEXT, revenue FLOAT, quantity INT);
"""
)

sample_data = [
    ("2024-01-01", "Product A", 1000, 50),
    ("2024-01-02", "Product B", 1500, 75),
    ("2024-01-03", "Product A", 1200, 60),
    ("2024-01-04", "Product C", 800, 40),
]
conn.executemany("INSERT INTO sales VALUES (?, ?, ?, ?)", sample_data)

<duckdb.duckdb.DuckDBPyConnection at 0x7fd21c2fea70>

In [3]:
chroma_client = chromadb.PersistentClient(path="chroma_db")
collection = chroma_client.get_or_create_collection(name="metadata")

In [4]:
# Data dictionary

metadata = [
    {"column": "date", "description": "Date of the sale"},
    {"column": "product", "description": "Product name"},
    {"column": "revenue", "description": "Revenue generated from sales"},
    {"column": "quantity", "description": "Quantity of products sold"},
]

embedder = SentenceTransformer("all-MiniLM-L6-v2")
for item in metadata:
    embedding = embedder.encode(item["description"]).tolist()
    collection.add(
        ids=[item["column"]], embeddings=[embedding], metadatas=[item]
    )

Add of existing embedding ID: date
Add of existing embedding ID: product
Add of existing embedding ID: revenue
Add of existing embedding ID: quantity
Insert of existing embedding ID: date
Add of existing embedding ID: date
Insert of existing embedding ID: product
Add of existing embedding ID: product
Insert of existing embedding ID: revenue
Add of existing embedding ID: revenue
Insert of existing embedding ID: quantity
Add of existing embedding ID: quantity


In [5]:
user_question = "Show me total revenue per product"
query_embedding = embedder.encode(user_question).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=2)

relevant_columns = [res["column"] for res in results["metadatas"][0]]
print("Relevant columns detected:", relevant_columns)

Relevant columns detected: ['revenue', 'quantity']


In [6]:
aggregation_model = pipeline(
    "zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
)
chart_type_model = pipeline(
    "zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
)


def determine_aggregation(question):
    labels = ["SUM", "AVG", "COUNT"]
    prediction = aggregation_model(question, candidate_labels=labels)
    return prediction["labels"][0] if prediction["scores"][0] > 0.5 else None


def determine_chart_type(question):
    labels = ["bar chart", "line chart", "scatter plot"]
    prediction = chart_type_model(question, candidate_labels=labels)
    return (
        prediction["labels"][0]
        if prediction["scores"][0] > 0.5
        else "bar chart"
    )


aggregation = determine_aggregation(user_question)
chart_type = determine_chart_type(user_question)

print(aggregation, chart_type)

Device set to use cpu
Device set to use cpu


SUM scatter plot


In [7]:
if aggregation and "revenue" in relevant_columns:
    df = conn.execute(
        f"SELECT product, {aggregation}(revenue) as revenue FROM sales GROUP BY product"
    ).fetchdf()
    if chart_type == "bar chart":
        fig = px.bar(
            df, x="product", y="revenue", title="Total Revenue Per Product"
        )
    elif chart_type == "line chart":
        fig = px.line(
            df, x="product", y="revenue", title="Total Revenue Per Product"
        )
    else:
        fig = px.scatter(
            df, x="product", y="revenue", title="Total Revenue Per Product"
        )
else:
    df = conn.execute(
        f"SELECT {', '.join(relevant_columns)} FROM sales"
    ).fetchdf()
    fig = px.line(df, x="date", y="revenue", title="Revenue Trends Over Time")

# Step 5: Generate visualization
fig.show()