In [1]:
!pip install -q -e ../python

In [4]:
import json

from explainprompt.logger import Logger, MemoryLogHandler
from explainprompt.model import (
    Prompt, 
    Section, 
    ModelResponse)

from explainprompt.jupyter import ExplainPromptWidget

from langchain import PromptTemplate, OpenAI, LLMChain
from dotenv import load_dotenv

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
    Run,
    TracerSession,
)
from typing import Any, Dict, Optional, Union

from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Union
from uuid import UUID

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, TracerSession
from langchain.load.dump import dumpd
from langchain.schema.messages import BaseMessage

load_dotenv()

True

In [53]:
class ExplainPromptTracer(BaseTracer):

    def __init__(
        self,
        example_id: Optional[Union[UUID, str]] = None,
        tags: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the LangChain tracer."""
        super().__init__(**kwargs)
        self.session: Optional[TracerSession] = None
        self.example_id = (
            UUID(example_id) if isinstance(example_id, str) else example_id
        )
        self.tags = tags or []
        self.logger = MemoryLogHandler()

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        tags: Optional[List[str]] = None,
        parent_run_id: Optional[UUID] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> None:
        """Start a trace for an LLM run."""
        parent_run_id_ = str(parent_run_id) if parent_run_id else None
        execution_order = self._get_execution_order(parent_run_id_)
        start_time = datetime.utcnow()
        if metadata:
            kwargs.update({"metadata": metadata})
        chat_model_run = Run(
            id=run_id,
            parent_run_id=parent_run_id,
            serialized=serialized,
            inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
            extra=kwargs,
            events=[{"name": "start", "time": start_time}],
            start_time=start_time,
            execution_order=execution_order,
            child_execution_order=execution_order,
            run_type="llm",
            tags=tags,
        )
        self._start_trace(chat_model_run)
        self._on_chat_model_start(chat_model_run)

    def _persist_run(self, run: Run) -> None:
        """The Langchain Tracer uses Post/Patch rather than persist."""
        return

    def _on_llm_start(self, run: Run) -> None:
        """Persist an LLM run."""
        for prompt in run.inputs['prompts']:
            self.logger.log_prompt(Prompt(label="Prompt", sections=[
                Section(content=prompt)
            ]))

    def _on_chat_model_start(self, run: Run) -> None:
        """Persist an LLM run."""
        return

    def _on_llm_end(self, run: Run) -> None:
        """Process the LLM Run."""
        for generation in run.outputs['generations']:
            self.logger.log_response(ModelResponse(sections=[
                Section(content=generation[0]['text'])
            ]))

    def _on_llm_error(self, run: Run) -> None:
        """Process the LLM Run upon error."""
        return

    def _on_chain_start(self, run: Run) -> None:
        """Process the Chain Run upon start."""
        return

    def _on_chain_end(self, run: Run) -> None:
        """Process the Chain Run."""
        return

    def _on_chain_error(self, run: Run) -> None:
        """Process the Chain Run upon error."""
        return

    def _on_tool_start(self, run: Run) -> None:
        """Process the Tool Run upon start."""
        return

    def _on_tool_end(self, run: Run) -> None:
        """Process the Tool Run."""
        return

    def _on_tool_error(self, run: Run) -> None:
        """Process the Tool Run upon error."""
        return

    def _on_retriever_start(self, run: Run) -> None:
        """Process the Retriever Run upon start."""
        return

    def _on_retriever_end(self, run: Run) -> None:
        """Process the Retriever Run."""
        return

    def _on_retriever_error(self, run: Run) -> None:
        """Process the Retriever Run upon error."""
        return


In [54]:
tracer = ExplainPromptTracer()

prompt_template = "What is a good name for a company that makes {product}?"

llm = OpenAI(temperature=0)
llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate.from_template(prompt_template)
)

llm_chain.run(
    "colorful socks",
    callbacks=[tracer]
)

messages = list(map(lambda x: x.model_dump(exclude_none=True), tracer.logger._messages))

print(json.dumps(messages, indent=2))

[
  {
    "type": "prompt",
    "label": "Prompt",
    "sections": [
      {
        "content": "What is a good name for a company that makes colorful socks?"
      }
    ]
  },
  {
    "type": "response",
    "sections": [
      {
        "content": "\n\nSocktastic!"
      }
    ]
  }
]


In [50]:
widget = ExplainPromptWidget()
widget.data = json.dumps({"trajectory":messages})
widget.theme = 'dark'
widget

ExplainPromptWidget(data='{"trajectory": [{"type": "response", "sections": [{"content": "colorful socks"}]}, {…