In [23]:

import asyncio
from dataclasses import asdict
from dbtsl.asyncio import AsyncSemanticLayerClient
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models import ChatOpenAI
from langchain.schema import Document
from langchain_groq import ChatGroq
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import os
import nest_asyncio
nest_asyncio.apply() 

In [47]:
class RetrievedMetric(BaseModel):
    """A metric retrieved from the semantic layer"""
    id: str
    name: str
    description: str
    type: str

class RetrievedDimension(BaseModel):
    """A dimension retrieved from the semantic layer"""
    id: str
    name: str
    description: str
    metric_id: str

class RetrievalResult(BaseModel):
    """Results from the retriever to be passed to query construction"""
    metrics: List[RetrievedMetric] = Field(
        description="Metrics found relevant to the query"
    )
    dimensions: List[RetrievedDimension] = Field(
        description="Dimensions found relevant to the query"
    )

class SemanticLayerRetriever:
    def __init__(
        self,
        environment_id: int,
        auth_token: str,
        host: str = "semantic-layer.cloud.getdbt.com",
        persist_directory: str = "./chroma_db",
        llm: Optional[BaseChatModel] = None,
        embeddings: Optional[OpenAIEmbeddings] = None
    ):
        self.client = AsyncSemanticLayerClient(
            environment_id=environment_id,
            auth_token=auth_token,
            host=host
        )
        self.persist_directory = persist_directory
        self.embeddings = embeddings or OpenAIEmbeddings()
        self.llm = llm or ChatOpenAI(temperature=0)
        
        # Initialize stores
        self.metric_store = Chroma(
            collection_name="metrics",
            embedding_function=self.embeddings,
            persist_directory=f"{persist_directory}/metrics"
        )
        
        self.dimension_store = Chroma(
            collection_name="dimensions",
            embedding_function=self.embeddings,
            persist_directory=f"{persist_directory}/dimensions"
        )
    
    async def refresh_stores(self):
        """Fetch latest metrics and rebuild stores"""
        async with self.client.session():
            metrics = await self.client.metrics()
            
            # Clear existing stores
            self.metric_store.delete_collection()
            self.dimension_store.delete_collection()
            
            # Rebuild stores
            metric_docs = []
            dimension_docs = []
            
            for metric in metrics:
                metric_docs.append(Document(
                    page_content=f"{metric.name}: {metric.description}",
                    metadata={
                        "id": metric.name,
                        "name": metric.name,
                        "description": metric.description,
                        "type": metric.type,
                    }
                ))
                
                for dim in metric.dimensions:
                    dimension_docs.append(Document(
                        page_content=f"{dim.name}: {dim.description}",
                        metadata={
                            "id": f"{metric.name}_{dim.name}",
                            "name": dim.name,
                            "description": dim.description or "",
                            "metric_id": metric.name,
                        }
                    ))
            
            # Recreate collections with new documents
            self.metric_store = Chroma(
                collection_name="metrics",
                embedding_function=self.embeddings,
                persist_directory=f"{self.persist_directory}/metrics"
            )
            self.dimension_store = Chroma(
                collection_name="dimensions",
                embedding_function=self.embeddings,
                persist_directory=f"{self.persist_directory}/dimensions"
            )
            
            if metric_docs:
                self.metric_store.add_documents(metric_docs)
            if dimension_docs:
                self.dimension_store.add_documents(dimension_docs)
    
    async def retrieve(
        self, 
        query: str, 
        k_metrics: int = 3,
        k_dimensions: int = 8
    ) -> RetrievalResult:
        """
        Retrieve relevant metrics and dimensions from the semantic layer.
        Returns structured data that can be used by a query construction prompt.
        """
        # Get relevant metrics
        metric_docs = self.metric_store.similarity_search(query, k=k_metrics)
        metric_ids = [doc.metadata["id"] for doc in metric_docs]
        
        # Get relevant dimensions, filtered by retrieved metrics
        dimension_docs = self.dimension_store.similarity_search(
            query,
            k=k_dimensions,
            filter={"metric_id": {"$in": metric_ids}}
        )
        
        return RetrievalResult(
            metrics=[
                RetrievedMetric(
                    id=doc.metadata["id"],
                    name=doc.metadata["name"],
                    description=doc.metadata["description"],
                    type=doc.metadata["type"],
                ) for doc in metric_docs
            ],
            dimensions=[
                RetrievedDimension(
                    id=doc.metadata["id"],
                    name=doc.metadata["name"],
                    description=doc.metadata["description"] or "",
                    metric_id=doc.metadata["metric_id"],
                ) for doc in dimension_docs 
            ],
        )


In [54]:
# Initial setup and refresh (run this when you start or when metrics change)
async def initialize_stores():
    retriever = SemanticLayerRetriever(
        environment_id=218762,
        auth_token=os.getenv("DBT_CLOUD_SERVICE_TOKEN"),
        llm=ChatGroq(model_name="llama-3.3-70b-specdec", temperature=0)
    )
    await retriever.refresh_stores()
    return retriever

# Query function (use this for each query)
async def get_semantic_layer_results(retriever: SemanticLayerRetriever, query: str):
    return await retriever.retrieve(query)

In [55]:
retriever = await initialize_stores()



In [57]:
await get_semantic_layer_results(retriever, "What is total ARR by sales director in 2023?")

RetrievalResult(metrics=[RetrievedMetric(id='total_revenue', name='total_revenue', description='The total revenue for the business.', type='SIMPLE'), RetrievedMetric(id='cumulative_revenue_total', name='cumulative_revenue_total', description='The cumulative revenue for the business', type='CUMULATIVE'), RetrievedMetric(id='total_profit', name='total_profit', description='The total profit for the business', type='DERIVED')], dimensions=[RetrievedDimension(id='cumulative_revenue_total_customer__customer_balance_segment', name='customer__customer_balance_segment', description='Bucketing customers by their account balance', metric_id='cumulative_revenue_total'), RetrievedDimension(id='total_profit_customer__customer_balance_segment', name='customer__customer_balance_segment', description='Bucketing customers by their account balance', metric_id='total_profit'), RetrievedDimension(id='total_revenue_customer__customer_balance_segment', name='customer__customer_balance_segment', description='

In [60]:
async with retriever.client.session():
    metrics = await retriever.client.metrics()

In [63]:
metrics[0]

Metric(name='total_revenue', description='The total revenue for the business.', type=<MetricType.SIMPLE: 'SIMPLE'>, dimensions=[Dimension(name='customer__customer_balance_segment', qualified_name='customer__customer_balance_segment', description='Bucketing customers by their account balance', type=<DimensionType.CATEGORICAL: 'CATEGORICAL'>, label=None, is_partition=False, expr="case\n  when account_balance < 0 then 'Bad Debt'\n  when account_balance < 2500 then 'Low'\n  when account_balance < 7500 then 'Medium'\n  else 'High'\nend", queryable_granularities=[]), Dimension(name='customer__customer_id', qualified_name='customer__customer_id', description='The ID of the customer', type=<DimensionType.CATEGORICAL: 'CATEGORICAL'>, label=None, is_partition=False, expr='customer_key', queryable_granularities=[]), Dimension(name='customer__customer_market_segment', qualified_name='customer__customer_market_segment', description='The market segment the customer belongs to', type=<DimensionType.C