# Pydough LLM Demo

This notebook showcases how an LLM can generate PyDough queries from natural language instructions. The goal is to demonstrate how AI can automate complex data analysis, making querying faster, more intuitive, and accessible without needing deep technical expertise.

Each example highlights different capabilities, including aggregations, filtering, ranking, and calculations across multiple collections.

---
## Setup & Basic Usage of the API

First, we import the created client.

In [1]:
from llm import LLMClient

We can declare global definitions that will be useful for any question that needs them.

In [2]:
definitions = [
    "Total Order Value is defined as the sum of extended_price * (1 - discount).",
    "Aggregate Revenue is defined as the sum of LineItem_ExtendedPrice minus the sum of LineItem_Discount.",
    "Average Revenue per Ship Date is defined as the sum of revenue divided by the count of distinct ship dates.",
    "Partial Revenue is defined as quantity * extended_price * (1 - discount).",
    "Profit is defined as revenue minus cost."
]

Then we initialize the client.

In [3]:
client = LLMClient(definitions=definitions)

One should use the `ask()` method to make a query to the model.

We get a `result` object with the next attributes: 

- `code`:  The PyDough query generated by the LLM.
- `full_explanation`: An explanation of how the model solved the query.
- `df`:The dataframe containing the query results.
- `exception`: Stores any errors encountered while executing the query.
- `original_question`: The natural language question input by the user.
- `sql`: The SQL equivalent of the generated PyDough query.
- `base_prompt`: The initial instruction given to the LLM to generate the query.
- `cheat_sheet`:  A reference guide and example queries to help the LLM structure responses.
- `knowledge_graph`: The metadata structure that informs the LLM about available collections and relationships.

---
## First query

Ask method asks a simple query (no discourse mode)

In [4]:
result = client.ask("Give me the name of all the suppliers from the United States")

RAW RESPONSE TEXT:
 candidates=[Candidate(content=Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=None, inline_data=None, text='Okay, let\'s break this down step by step to generate the PyDough query.\n\n1.  **Identify the Goal:** We need to find the names of suppliers located in the "UNITED STATES".\n\n2.  **Identify the Starting Collection:** The core information we need is about suppliers, so the `suppliers` collection is the logical starting point.\n\n3.  **Identify the Filtering Condition:** We only want suppliers whose associated nation has the name "UNITED STATES".\n    *   Looking at the `suppliers` collection schema, each supplier has a `nation` relationship that links to the corresponding `nations` record.\n    *   The `nations` collection has a `name` field.\n    *   Therefore, the condition is `nation.name == "UNITED STATES"`.\n\n4.  **Identify the Required Output:

After that, we can consult all the necessary atributes from the result.

Initially, I want to check the **base instruction** that guided the LLM.

In [5]:
print(result.exception) 

None


We can also ask for the for the **pydough code** directly.

In [6]:
print(result.code)

# Step 1: Start with the 'suppliers' collection.
# Step 2: Filter suppliers WHERE the name of their associated nation is "UNITED STATES".
# Step 3: Select the 'name' of the filtered suppliers.
us_suppliers = suppliers.WHERE(nation.name == "UNITED STATES").CALCULATE(
    supplier_name=name
)


And if we want to compare, we can get the **equivalent SQL query** created by Pydough.

In [7]:
print(result.sql)

SELECT
  name AS supplier_name
FROM (
  SELECT
    s_name AS name,
    s_nationkey AS nation_key
  FROM main.SUPPLIER
)
INNER JOIN (
  SELECT
    key
  FROM (
    SELECT
      n_name AS name,
      n_nationkey AS key
    FROM main.NATION
  )
  WHERE
    name = 'UNITED STATES'
)
  ON nation_key = key


If we want to visually check, analyze or edit the resulting **dataframe**, we also can. There is a dedicated section for this later. on the document

In [None]:
result.df

We can also get an **explanation** of how the model solved the query.

In the future, we will offer an English-only explanation to learn more about what the query is about, but for now, you can view a combined PyDough and English explanation generated by the LLM.


In [None]:
print(result.full_explanation)

We can also check the original natural language **question** that was asked.

In [None]:
print(result.original_question)

We also have a reference guide or **cheat_sheet** with example queries to help structure responses.

In [None]:
print(result.cheat_sheet)

Now, we will try a new example, this one has an **exception**.

---

## Query correction (early preview)

To try to correct a response with execution problems in PyDough, we can use the `correct()` method. 

In [None]:
result = client.ask("For each of the 5 largest part sizes, find the part of that size with the highest retail price")

print(result.full_explanation)

If one calls the dataframe and gets an error, no response, or a empty dataframe, it is possible that there is a PyDough exception. We can check this by running:

In [None]:
print(result.exception)

You can try to fix the error using the `correct` method. We are going to declare a new variable to obtain the corrected result. 

In [None]:
corrected_result = client.correct(result)

To see how the model try to solve the issue, you can print the full explanation of the `corrected_result`.

In [None]:
print(corrected_result.full_explanation)

---
## Post-processing dataframe

Once the DataFrame is generated, it can be analysed or manipulated just like any other dataframe, using **any** Python package for analysis, visualization, or transformation. 

This is a planned future feature, and one of the core differentiators of PyDough!

As a first example, in this section we manipulate a resulting dataframe with pandas and matplotlib to create a histogram and a boxplot of the different orders prices.

In [None]:
result = client.ask("Give me all the order prices, name the column total_price.")

result.df

In [None]:
import matplotlib.pyplot as plt

df= result.df

plt.figure(figsize=(10, 6))
plt.hist(df["total_price"], bins=20, edgecolor="black", alpha=0.7)

# Configurar etiquetas y título
plt.xlabel("Total Price")
plt.ylabel("Frequency")
plt.title("Histogram of Order Prices in TPCH")

# Mostrar el gráfico
plt.show()

In [None]:
# Create a box plot for order prices
plt.figure(figsize=(8, 6))
plt.boxplot(df["total_price"], vert=False, patch_artist=True)

# Configure labels and title
plt.xlabel("Total Price")
plt.title("Box Plot of Order Prices in TPCH")

# Show the plot
plt.show()

---
## Query Quality and Verification

In the future we will be able to take the PyDough code, the user question, the knowledge graph and the generated data frame to perform semantic tests that will help assure the quality of the results. Some ideas that are under evaluation: 

1. John's graph reconstruction / deconstruction / verification ideas
2. LLM based rules that allow us to confirm certain properties of the Data Frame
   a. If the user requires a top 5 result, then the resulting DF must have 5 rows.
3. Program slicing ideas:
   b. If each cell can be traced to a transformation, look at the transformation and match it against the query (this is connected to John's ideas)
4. Emsemble ideas many variations prompts and LLMs could help us compare multiple results.
5. Talk about why Greg is flabbergasted. 

Here we show a simple demonstration of taking the resulting PyDough code, and then generating structured data from it so that you can compare the variables it used, against the original variables in our Knowledge Graph.

In [None]:
import ast
import json
from itertools import combinations
import networkx as nx
import matplotlib.pyplot as plt

class DSLGraphBuilder:
    class VariableNameExtractor(ast.NodeVisitor):
        def __init__(self):
            self.names = set()
        def visit_Name(self, node):
            # Ignore uppercase identifiers (like COUNT or YEAR)
            if not node.id.isupper():
                self.names.add(node.id)
            self.generic_visit(node)
    
    @staticmethod
    def extract_calculate_rhs_names(code):
        """
        Parse a DSL snippet and extract variable names from the right-hand side 
        of CALCULATE keyword assignments.
        """
        tree = ast.parse(code)
        names = set()
        for node in ast.walk(tree):
            if (
                isinstance(node, ast.Call)
                and isinstance(node.func, ast.Attribute)
                and node.func.attr == "CALCULATE"
            ):
                for kw in node.keywords:
                    extractor = DSLGraphBuilder.VariableNameExtractor()
                    extractor.visit(kw.value)
                    names.update(extractor.names)
        return names
    
    @classmethod
    def build_graph_from_snippets(cls, snippets):
        """
        Build a co-occurrence graph where each variable (from CALCULATE calls) 
        is a node, and an edge connects any two variables that appear together in a snippet.
        """
        all_nodes = set()
        snippet_vars = []  # list of sets of variable names per snippet
        for code in snippets:
            vars_in_snippet = cls.extract_calculate_rhs_names(code)
            snippet_vars.append(vars_in_snippet)
            all_nodes.update(vars_in_snippet)
        
        nodes = [{"id": var, "label": var} for var in sorted(all_nodes)]
        edges = []
        added_edges = set()
        for vars_in_snippet in snippet_vars:
            for a, b in combinations(sorted(vars_in_snippet), 2):
                key = tuple(sorted((a, b)))
                if key not in added_edges:
                    added_edges.add(key)
                    edges.append({"source": key[0], "target": key[1]})
        return {"nodes": nodes, "edges": edges}
    
    @staticmethod
    def visualize_graph(graph):
        """
        Visualize the graph using NetworkX and Matplotlib.
        """
        G = nx.Graph()
        for node in graph["nodes"]:
            G.add_node(node["id"])
        for edge in graph["edges"]:
            G.add_edge(edge["source"], edge["target"])
        
        # Use spring layout for a visually appealing graph
        pos = nx.spring_layout(G)
        
        plt.figure(figsize=(8, 6))
        nx.draw_networkx_nodes(G, pos, node_size=500, node_color="lightblue", edgecolors="black")
        nx.draw_networkx_edges(G, pos)
        nx.draw_networkx_labels(G, pos, font_size=10)
        
        plt.title("Co-occurrence Graph of Variables", fontsize=14)
        plt.axis("off")
        plt.tight_layout()
        plt.show()

Now, this can be used to analyse the PyDough code that was first generated.

In [None]:
# Using the DSLGraphBuilder defined in a previous cell
builder = DSLGraphBuilder()

# Define your DSL code snippets to analyse

snippet1 = '''# Find all suppliers from the United States
us_suppliers = suppliers.WHERE(nation.name == "UNITED STATES").CALCULATE(
    supplier_name=name
)
'''
snippet2 = '''# Identify orders in 1995
orders_1995 = orders.WHERE(YEAR(order_date) == 1995)
european_customers = customers.CALCULATE(
    customer_name=name,
    account_balance=acctbal,
    num_orders_1995=COUNT(orders.WHERE(YEAR(order_date) == 1995))
).WHERE(
    (nation.region.name == "EUROPE") &
    (acctbal > 700) &
    (num_orders_1995 > 0)
)
'''

snippets = [snippet1, snippet2]

# Build the graph structure from your snippets
graph = builder.build_graph_from_snippets(snippets)

# Print out the graph JSON structure if desired
import json
print(json.dumps(graph, indent=4))

# Visualize the graph using matplotlib and networkx
builder.visualize_graph(graph)

---
That generates a graphical view of the variables, and the co-occurence, and a json with that information, which can be compared against the original json knowledge graph for verification.

NB: This is only a first exploratory step, but as we have better ideas on verification, we will be able to add and extend this segment. Which is why it's an entire workstream as planned in our roadmap.

Now, on to more test cases!

---

## Test Cases

Let's now look at a few more test cases and see how PyDough and PyDough LLM does.

### Customer Segmentation

### 1. Find the names of all customers and the number of orders placed in 1995 in Europe.

- Demonstrates simple filtering, counting, and sorting while being business-relevant for regional market analysis. 
- Adds a second filtering layer by including account balance and order activity, making it more dynamic.

In [8]:
query = "Find the names of all customers and the number of their orders placed in 1995 in Europe."

result = client.ask(query)

print(result.full_explanation)
result.df.head()

RAW RESPONSE TEXT:
 candidates=[Candidate(content=Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=None, inline_data=None, text='Okay, let\'s break down the process step-by-step to find the names of European customers and the count of their orders placed in 1995 using PyDough.\n\n1.  **Identify the Starting Collection:** We are interested in information about *customers*, so we\'ll start with the `customers` collection.\n\n2.  **Identify the Filtering Condition (Customer Location):** We only want customers from "EUROPE".\n    *   Customers are linked to `nations`.\n    *   Nations are linked to `regions`.\n    *   Regions have a `name`.\n    *   So, the condition is: `nation.region.name == "EUROPE"`.\n\n3.  **Identify the Required Output Fields:**\n    *   Customer Name: This is the `name` field of the `customers` collection.\n    *   Number of 1995 Orders: This requires calcul

AttributeError: 'NoneType' object has no attribute 'head'

In [9]:
print(result.exception)

Traceback (most recent call last):
  File "/home/adriel/TextoToPydough/text2pydough/workbench/AAraya/chattest/llm.py", line 203, in get_pydough_sql
    exec(transformed_source, {}, local_env)
  File "<string>", line 3, in <module>
  File "/home/adriel/anaconda3/envs/aisuite_deepseek/lib/python3.12/site-packages/pydough/unqualified/unqualified_node.py", line 138, in __call__
    raise PyDoughUnqualifiedException(
pydough.unqualified.errors.PyDoughUnqualifiedException: PyDough nodes orders.WHERE((orderdate.YEAR == 1995)).COUNT is not callable. Did you mean to use a function?

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/adriel/TextoToPydough/text2pydough/workbench/AAraya/chattest/llm.py", line 308, in ask
    pydough_sql = self.get_pydough_sql(extracted_code)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adriel/TextoToPydough/text2pydough/workbench/AAraya/chattest/llm.py", line 209, in get_pyd

**Follow up**: Now, give me the ones who have an account balance greater than $700 and placed at least one order in that same year. Sorted in descending order by the number of orders.

In [None]:
result2 = client.discourse(result, """Now, give me the ones who have an account balance greater than $700 and placed at least one order in that same year. 
Sorted in descending order by the number of orders.""")

print(result2.full_explanation)
result2.df.head()

### 2. List customers who ordered in 1996 but not in 1997, with a total spent of over 1000$?

- Showcases PyDough’s HAS() and HASNOT() functions, helping analyze customer retention and spending patterns. 
- Also incorporates a time-based calculation.

In [None]:
query = "List customers who ordered in 1996 but not in 1997 with a total spent of over $1000."

result = client.ask(query)

print(result.full_explanation)
result.df.head()

**Follow up**: Now, include the number of months since the last order and sort by total spent, highest first.

In [None]:
result2 = client.discourse(result,
"Now, include the number of months since the last order and sort by total spent, highest first.")

print(result2.full_explanation)
result2.df.head()

### Sales Performance

### 3. Find the region with the highest total order value in 1996.

- The total order value is defined as potential revenue, defined as the sum of extended_price * (1 - discount)
- It introduces precise calculations within the query, ensuring revenue insights.

In [None]:
client.add_definition("Revenue is defined as the sum of quantity * extended_price * (1 - discount) * (1 + tax).")

In [None]:
query = "Find the region with the highest revenue in 1996."

result = client.ask(query)

print(result.full_explanation)
result.df.head()

**Follow up**: Can you compare it now year over year **in that region**?

In [None]:
result2 = client.discourse(result, "Can you compare it now year over year in that region?")

print(result2.full_explanation)
result2.df.head()

### Product Trends

### 4. Which 10 customers purchased the highest quantity of products during 1998?

- Highlights ranking queries, customer segmentation, and purchasing trends. 

In [None]:
query = "Which 10 customers purchased the highest quantity of products during 1998?."

result = client.ask(query)

print(result.full_explanation)
result.df

**Follow up**: Now return the sum of only the products that have "green" on the product name.

In [None]:
result2 = client.discourse(result, """Now take into account only the products that have "green" in their name""")

print(result2.full_explanation)
result2.df