In [0]:
%pip install langgraph-checkpoint-postgres langchain-core

In [0]:
dbutils.library.restartPython()

In [0]:
# Databricks notebook source
import os
import json
from urllib.parse import quote_plus
from typing import Any

from langgraph.checkpoint.postgres import PostgresSaver
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType

client = "edg"
env_name = "test"
SCOPE = "databricks-secrets"

os.environ["DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME"] = (
    f"{env_name}-connector-{client}-data-service"
)

# This never changes because lakebase is hosted in prod
POSTGRES_HOST = (
    "instance-0b95e886-17ee-4296-9752-6cdf30c0739b.database.azuredatabricks.net"
)


def create_pg_url(user: str, password: str, host: str, db: str) -> str:
    return (
        f"postgresql://{quote_plus(user)}:{quote_plus(password)}@{host}:5432/{db}"
        "?sslmode=require"
    )


def _to_json(x: Any) -> str:
    return json.dumps(x, default=str, ensure_ascii=False)


def _split_ai_content(content):
    """
    Returns: (visible_text, thinking_text, has_tool_call_parts)
    content may be:
      - str
      - list of {"type": "...", ...}
    """
    if content is None:
        return ("", "", False)

    if isinstance(content, str):
        return (content, "", False)

    visible = []
    thinking = []
    has_tool_parts = False

    if isinstance(content, list):
        for part in content:
            if not isinstance(part, dict):
                continue
            ptype = part.get("type")
            if ptype == "text":
                visible.append(part.get("text", ""))
            elif ptype == "reasoning":
                thinking.append(part.get("text") or _to_json(part))
            elif ptype in ("function_call", "tool_call"):
                has_tool_parts = True

    return (
        "".join(visible).strip(),
        "\n".join([t for t in thinking if t]).strip(),
        has_tool_parts,
    )


def decode_checkpoint_tuple(tup):
    """Turn one CheckpointTuple into your flattened event rows."""
    cp = tup.checkpoint or {}
    meta = tup.metadata or {}
    cfg = (tup.config or {}).get("configurable", {}) or {}

    thread_id = cfg.get("thread_id")
    checkpoint_id = cp.get("id")
    checkpoint_ts = cp.get("ts")
    step = meta.get("step")
    thread_user = meta.get("username") or cfg.get("user_id") or None

    messages = (cp.get("channel_values") or {}).get("messages") or []

    events = []
    seq = 0

    for i, msg in enumerate(messages):
        seq += 1
        msg_id = (
            getattr(msg, "id", None)
            or f"{thread_id}:{checkpoint_id}:{i}:{type(msg).__name__}"
        )

        visible_text = ""
        thinking_text = ""
        tool_name = None
        tool_call_id = None
        tool_args_json = None

        parent_ai_message_id = None

        if isinstance(msg, HumanMessage):
            event_type = "human"
            role = "human"
            name = getattr(msg, "name", None) or thread_user
            visible_text = (msg.content or "").strip()
            model = None
            input_tokens = output_tokens = total_tokens = None

        elif isinstance(msg, AIMessage):
            event_type = "ai"
            role = "ai"
            name = getattr(msg, "name", None) or "assistant"

            rm = getattr(msg, "response_metadata", {}) or {}
            um = getattr(msg, "usage_metadata", {}) or {}
            model = rm.get("model_name") or rm.get("model")

            input_tokens = um.get("input_tokens")
            output_tokens = um.get("output_tokens")
            total_tokens = um.get("total_tokens")

            visible_text, thinking_text, _ = _split_ai_content(
                getattr(msg, "content", None)
            )

        elif isinstance(msg, ToolMessage):
            event_type = "tool_result"
            role = "tool"
            name = getattr(msg, "name", None)
            tool_name = getattr(msg, "name", None)
            tool_call_id = getattr(msg, "tool_call_id", None)
            model = None
            input_tokens = output_tokens = total_tokens = None

        else:
            event_type = "other"
            role = "other"
            name = getattr(msg, "name", None)
            model = None
            input_tokens = output_tokens = total_tokens = None

        events.append(
            {
                "thread_id": thread_id,
                "thread_user": thread_user,
                "checkpoint_id": checkpoint_id,
                "checkpoint_ts": checkpoint_ts,
                "step": step,
                "seq": seq,
                "event_type": event_type,
                "role": role,
                "message_type": type(msg).__name__,
                "message_id": msg_id,
                "name": name,
                "visible_text": visible_text,
                "thinking_text": thinking_text,
                "tool_name": tool_name,
                "tool_call_id": tool_call_id,
                "tool_args_json": tool_args_json,
                "model": model,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "total_tokens": total_tokens,
                "content_json": _to_json(getattr(msg, "content", None)),
                "parent_ai_message_id": parent_ai_message_id,
            }
        )

        # Add tool_call rows from AIMessage.tool_calls
        if isinstance(msg, AIMessage):
            for tc in getattr(msg, "tool_calls", None) or []:
                seq += 1
                events.append(
                    {
                        "thread_id": thread_id,
                        "thread_user": thread_user,
                        "checkpoint_id": checkpoint_id,
                        "checkpoint_ts": checkpoint_ts,
                        "step": step,
                        "seq": seq,
                        "event_type": "tool_call",
                        "role": "tool",
                        "message_type": "ToolCall",
                        "message_id": tc.get("id")
                        or f"{thread_id}:{msg_id}:toolcall:{seq}",
                        "name": tc.get("name"),
                        "visible_text": "",
                        "thinking_text": "",
                        "tool_name": tc.get("name"),
                        "tool_call_id": tc.get("id"),
                        "tool_args_json": _to_json(tc.get("args")),
                        "model": model,
                        "input_tokens": input_tokens,
                        "output_tokens": output_tokens,
                        "total_tokens": total_tokens,
                        "content_json": "null",
                        "parent_ai_message_id": msg_id,
                    }
                )

    return events


def iter_latest_checkpoints_per_thread(saver, limit=None):
    """Iterate newest checkpoint per thread_id."""
    seen = set()
    for tup in saver.list(None, limit=limit):
        cfg = (tup.config or {}).get("configurable", {}) or {}
        tid = cfg.get("thread_id")
        if not tid or tid in seen:
            continue
        seen.add(tid)
        yield tup


def run(
    client: str,
    env_name: str,
    spark: SparkSession | None = None,
    *,
    limit: int | None = None,
) -> str:
    """
    Extract latest checkpoints per thread, flatten them into event rows, and write to:
      {ENVIRONMENT_NAME}_{client}.tracking.asa_agent

    Returns the fully-qualified destination table name.
    """
    # if not passed, get/create a session (works in jobs + local)
    if spark is None:
        spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()

    postgres_user = f"alex.feng@databricks.com"
    postgres_db = f"databricks_postgres"


    postgres_pwd = dbutils.secrets.get(scope="alex-feng", key="postgres-test")

    postgres_url = create_pg_url(
        user=postgres_user,
        password=postgres_pwd,
        host=POSTGRES_HOST,
        db=postgres_db,
    )

    print("Starting New Run")
    print("env_name", env_name, "client", client)
    print("postgres host", POSTGRES_HOST)
    print("postgres db", postgres_db)
    print("postgres user", postgres_user)
    print("limit", limit)

    with PostgresSaver.from_conn_string(postgres_url) as saver:
        saver.setup()

        events = []
        for tup in iter_latest_checkpoints_per_thread(saver, limit=limit):
            print("*** processing new conversation thread ***")
            events.extend(decode_checkpoint_tuple(tup))

    print("Completed PostgresSaver")

    cols = [
        "thread_id",
        "thread_user",
        "checkpoint_id",
        "checkpoint_ts",
        "step",
        "seq",
        "event_type",
        "role",
        "message_type",
        "message_id",
        "name",
        "visible_text",
        "thinking_text",
        "tool_name",
        "tool_call_id",
        "tool_args_json",
        "model",
        "input_tokens",
        "output_tokens",
        "total_tokens",
        "content_json",
        "parent_ai_message_id",
    ]

    schema = StructType(
        [
            StructField("thread_id", StringType(), True),
            StructField("thread_user", StringType(), True),
            StructField("checkpoint_id", StringType(), True),
            StructField("checkpoint_ts", StringType(), True),
            StructField("step", IntegerType(), True),
            StructField("seq", IntegerType(), True),
            StructField("event_type", StringType(), True),
            StructField("role", StringType(), True),
            StructField("message_type", StringType(), True),
            StructField("message_id", StringType(), True),
            StructField("name", StringType(), True),
            StructField("visible_text", StringType(), True),
            StructField("thinking_text", StringType(), True),
            StructField("tool_name", StringType(), True),
            StructField("tool_call_id", StringType(), True),
            StructField("tool_args_json", StringType(), True),
            StructField("model", StringType(), True),
            StructField("input_tokens", LongType(), True),
            StructField("output_tokens", LongType(), True),
            StructField("total_tokens", LongType(), True),
            StructField("content_json", StringType(), True),
            StructField("parent_ai_message_id", StringType(), True),
        ]
    )

    sdf = spark.createDataFrame(events, schema=schema).select(*cols)

    dest_catalog = "alex_feng"
    dest_table = "alex_feng.goodwork_test.dest_table"

    print("Writing Table Now")
    (
        sdf.write.format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .saveAsTable(dest_table)
    )

    return dest_table


# COMMAND ----------
if __name__ == "__main__":
    print(f"Starting Environment {env_name} (client={client})")
    run(client=client, env_name=env_name, spark=spark)
 

In [0]:
import time

t0 = time.time()
scanned = 0
unique = 0
seen = set()


with PostgresSaver.from_conn_string(postgres_url) as saver:
    saver.setup()

    for tup in saver.list(None, limit=None):
        scanned += 1
        cfg = (tup.config or {}).get("configurable", {}) or {}
        tid = cfg.get("thread_id")

        if tid and tid not in seen:
            seen.add(tid)
            unique += 1

        if scanned % 5000 == 0:
            print(
                "scanned",
                scanned,
                "unique_threads",
                unique,
                "elapsed_s",
                round(time.time() - t0, 1),
            )

print(
    "FINAL scanned",
    scanned,
    "unique_threads",
    unique,
    "elapsed_s",
    round(time.time() - t0, 1),
)
 