In [1]:
import os
import dspy
import json
from typing import Literal
import uuid
import datetime
from typing import Optional
from langsmith import Client
from dotenv import load_dotenv
from app.services.dspy_optimizer import DspyOptimizer
from app.websocket.manager import connection_manager
from app.db.database import get_db, init_db
from app.models.item import ItemInDB
from app.models.node import NodeInDB
from bson import ObjectId

load_dotenv()

True

In [2]:
class LangSmithCallback:
    """Callback class to trace DSPy LLM calls with LangSmith"""

    def __init__(
        self, project_name: str = "dspy-tracing", client: Optional[Client] = None
    ):
        """
        Initialize the LangSmith callback

        Args:
            project_name: Name of the LangSmith project
            client: Optional LangSmith client instance
        """
        self.client = client or Client()
        self.project_name = project_name
        self.run_stack = []  # Stack to handle nested calls
        self.current_module_run = None
        self.current_lm_run = None

    def on_module_start(self, call_id, **kwargs):
        """Called when a DSPy module starts execution"""
        # Extract expected arguments from kwargs
        module = kwargs.get('module') or kwargs.get('instance')
        inputs = kwargs.get('inputs', {})
        
        run_id = str(uuid.uuid4())
        start_time = datetime.datetime.now(datetime.UTC)
        
        # Get module name
        module_name = "Unknown"
        if module:
            module_name = module.__class__.__name__ if hasattr(module, '__class__') else str(module)
        
        run_info = {
            "id": run_id,
            "start_time": start_time,
            "type": "module",
            "module_name": module_name,
        }
        
        self.run_stack.append(run_info)
        self.current_module_run = run_info
        
        # Create the run in LangSmith
        self.client.create_run(
            id=run_id,
            project_name=self.project_name,
            name=f"dspy_{module_name}",
            run_type="chain",
            inputs={"inputs": str(inputs), **kwargs},
            start_time=start_time,
            parent_run_id=self.run_stack[-2]["id"] if len(self.run_stack) > 1 else None,
        )

    def on_module_end(self, call_id, **kwargs):
        """Called when a DSPy module ends execution"""
        if not self.run_stack:
            return
            
        run_info = self.run_stack.pop()
        end_time = datetime.datetime.now(datetime.UTC)
        # Extract outputs and handle exceptions
        outputs = kwargs.get('outputs', kwargs.get('results', {}))
        exception = kwargs.get('exception')
        
        output_data = {
            "outputs": str(outputs) if outputs else None,
        }
        
        # Include exception info if present
        if exception:
            output_data["error"] = str(exception)
        
        # Update the run with outputs
        self.client.update_run(
            run_id=run_info["id"],
            outputs=output_data,
            end_time=end_time,
            error=str(exception) if exception else None,
        )
        
        # Update current module run
        self.current_module_run = self.run_stack[-1] if self.run_stack else None

    def on_lm_start(self, call_id, **kwargs):
        """Called when an LM call starts"""
        # Extract expected arguments from kwargs
        model_name = kwargs.get('model_name', kwargs.get('model', 'unknown'))
        prompt = kwargs.get('prompt', kwargs.get('messages', ''))
        temperature = kwargs.get('temperature')
        max_tokens = kwargs.get('max_tokens')
        
        run_id = str(uuid.uuid4())
        start_time = datetime.datetime.now(datetime.UTC)
        
        run_info = {
            "id": run_id,
            "start_time": start_time,
            "type": "lm",
            "model": model_name,
        }
        
        self.current_lm_run = run_info
        
        # Create run inputs
        inputs = {
            "prompt": str(prompt),
            "model": model_name,
        }
        
        if temperature is not None:
            inputs["temperature"] = temperature
        if max_tokens is not None:
            inputs["max_tokens"] = max_tokens
            
        # Add any additional kwargs
        for k, v in kwargs.items():
            if k not in ['call_id', 'model_name', 'prompt', 'temperature', 'max_tokens']:
                inputs[k] = v
        
        # Create the run in LangSmith
        self.client.create_run(
            id=run_id,
            project_name=self.project_name,
            name="dspy_lm_call",
            run_type="llm",
            inputs=inputs,
            start_time=start_time,
            parent_run_id=self.current_module_run["id"] if self.current_module_run else None,
        )

    def on_lm_end(self, call_id, **kwargs):
        """Called when an LM call ends"""
        if not self.current_lm_run:
            return
            
        end_time = datetime.datetime.now(datetime.UTC)
        
        # Extract response and usage from kwargs
        response = kwargs.get('response', kwargs.get('outputs', kwargs.get('completions', '')))
        usage = kwargs.get('usage')
        
        outputs = {
            "response": str(response) if response else None,
        }
        
        if usage:
            outputs["usage"] = usage
            
        # Add any additional kwargs
        for k, v in kwargs.items():
            if k not in ['call_id', 'response', 'outputs', 'completions', 'usage']:
                outputs[k] = v
        
        # Update the run with outputs
        self.client.update_run(
            run_id=self.current_lm_run["id"],
            outputs=outputs,
            end_time=end_time,
        )
        
        self.current_lm_run = None

    def on_adapter_format_start(self, call_id, **kwargs):
        """Called when adapter formatting starts"""
        # For now, we'll just log this as metadata
        pass

    def on_adapter_format_end(self, call_id, **kwargs):
        """Called when adapter formatting ends"""
        # For now, we'll just log this as metadata
        pass

    def on_adapter_parse_start(self, call_id, **kwargs):
        """Called when adapter parsing starts"""
        # For now, we'll just log this as metadata
        pass

    def on_adapter_parse_end(self, call_id, **kwargs):
        """Called when adapter parsing ends"""
        # For now, we'll just log this as metadata
        pass


In [None]:
await init_db()
db = get_db()

request = {
    "taxonomy_id": "68913c1c1c58af74cb5d53b8",  # customer feedback
    "node_id": "68913c501c58af74cb5d53b9",  # Product Quality Issues
    "current_user": "688a2f73d286129d6167cc93", # bonsense
}

nodes_collection = db[f"nodes_{request['taxonomy_id']}"]
node = await nodes_collection.find_one({"_id": ObjectId(request["node_id"])})
if not node:
    raise ValueError("Node not found")

node = NodeInDB(**node)

items = node.items
if items is None:
    raise ValueError("Node has no items")

item_ids_to_optimize = [
    ObjectId(item.item_id) for item in items if item.is_verified
]
print(f"item_ids_to_optimize: {len(item_ids_to_optimize)}")
if len(item_ids_to_optimize) < 1:
    raise ValueError(
        f"Need at least 10 verified items and found {len(item_ids_to_optimize)}"
    )

user_items_collection = db[f"items_{str(request['current_user'])}"]
items_to_optimize = await user_items_collection.find(
    {"_id": {"$in": item_ids_to_optimize}}
).to_list(length=None)
items_to_optimize = [ItemInDB(**item) for item in items_to_optimize]

trainset = []
for item in items_to_optimize:
    trainset.append(
        dspy.Example(review=item.content, category=node.label).with_inputs("review")
    )

sibling_nodes = await nodes_collection.find(
    {"parent_node_id": node.parent_node_id}
).to_list(length=None)
sibling_nodes = [NodeInDB(**sibling_node) for sibling_node in sibling_nodes]
categories_labels = [sibling_node.label for sibling_node in sibling_nodes] + [
    node.label
]
print(f"{len(categories_labels)} categories_labels: {categories_labels}")

Database indexes created
Connected to MongoDB
item_ids_to_optimize: 7
trainset: 7


In [None]:
dspy_optimizer = DspyOptimizer(
    lm=dspy.LM("openai/gpt-4.1-nano-2025-04-14"),
    connection_manager=connection_manager,
    user_id=str(request["current_user"]),
    categories=categories_labels,
    trainset=trainset,
    callbacks=[LangSmithCallback(project_name="dspy-test")],
)


In [None]:
module_id = await dspy_optimizer.compile()
print(module_id)

In [None]:
# module_id = ""
await dspy_optimizer.predict(
    review="I am a happy person",
    compiled_module_id=module_id,
)

Prediction(
    category='Service & Delivery Experience'
)