In [1]:
# from dotenv import load_dotenv, find_dotenv
# load_dotenv(find_dotenv())
import os

# os.environ["MLFLOW_TRACKING_URI"] = "sqlite:///../mlflow.db"
os.environ["MLFLOW_TRACKING_URI"] = "file:../mlruns"

In [2]:
from __future__ import annotations

from typing import Iterable, Optional, Dict, Any, Tuple
import time
import pandas as pd
import mlflow


def load_user_traces_df(
    user_id: str,
    *,
    start_ms: Optional[int] = None,
    end_ms: Optional[int] = None,
    order_by: Optional[list[str]] = None,
    extract_fields: Optional[list[str]] = None,
) -> pd.DataFrame:
    """
    Query MLflow traces for a given user and return a DataFrame (one row per trace).

    Parameters
    ----------
    user_id : str
        The value of tag `mlflow.trace.user` to filter.
    start_ms : Optional[int]
        Filter lower bound for `timestamp_ms` (inclusive).
    end_ms : Optional[int]
        Filter upper bound for `timestamp_ms` (inclusive).
    order_by : Optional[list[str]]
        Ordering keys, e.g., ["timestamp_ms ASC"].
    extract_fields : Optional[list[str]]
        Optional span fields to expand into columns
        (works with return_type='pandas' only).

    Returns
    -------
    pd.DataFrame
        Columns differ by MLflow version:
        - MLflow 2.x: ['trace_id','trace','timestamp_ms','status','execution_time_ms',
                       'request','response','request_metadata','spans','tags', ...]
        - MLflow 3.x: ['trace_id','trace','client_request_id','state','request_time',
                       'execution_duration','inputs','outputs','trace_metadata','tags', ...]
    """
    filters = [f"tag.mlflow.trace.user = '{user_id}'"]
    if start_ms is not None:
        filters.append(f"timestamp_ms >= {start_ms}")
    if end_ms is not None:
        filters.append(f"timestamp_ms <= {end_ms}")
    filter_string = " AND ".join(filters)

    df = mlflow.search_traces(  # returns DataFrame by default
        filter_string=filter_string,
        order_by=order_by or ["timestamp_ms ASC"],
        extract_fields=extract_fields,
    )
    # Normalize session_id from tags dict
    def _get_session(tags: Any) -> Optional[str]:
        if isinstance(tags, dict):
            return tags.get("mlflow.trace.session")
        return None

    df = df.copy()
    df["session_id"] = df["tags"].apply(_get_session)
    return df


def group_traces_by_session(df: pd.DataFrame) -> Dict[Optional[str], pd.DataFrame]:
    """
    Group the traces DataFrame by session_id (can be None if tag missing).
    """
    return {sid: g.sort_values("timestamp_ms") for sid, g in df.groupby("session_id", dropna=False)}


def preview_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Return a slim view per MLflow version: time/status + request/response (2.x) or inputs/outputs (3.x).
    """
    cols = ["trace_id", "session_id"]
    # Version-aware columns
    if "timestamp_ms" in df.columns:
        cols.append("timestamp_ms")
    if "status" in df.columns:
        cols.append("status")
    if "state" in df.columns:
        cols.append("state")
    if "request" in df.columns:
        cols.append("request")
    if "response" in df.columns:
        cols.append("response")
    if "inputs" in df.columns:
        cols.append("inputs")
    if "outputs" in df.columns:
        cols.append("outputs")
    existing = [c for c in cols if c in df.columns]
    return df[existing]

In [3]:
now = int(time.time() * 1000)
df = load_user_traces_df("daviddwlee84", start_ms=now - 7 * 24 * 3600 * 1000)  # 最近7天
by_session = group_traces_by_session(df)

for sid, g in by_session.items():
    print(f"\n=== Session: {sid} ({len(g)} traces) ===")
    print(preview_columns(g).head(10))

In [4]:
df

Unnamed: 0,trace_id,trace,client_request_id,state,request_time,execution_duration,request,response,trace_metadata,tags,spans,assessments,session_id


In [5]:
from __future__ import annotations

from typing import Dict, List, Optional, Iterable
import mlflow


def load_user_traces_as_list(
    user_id: str,
    *,
    start_ms: Optional[int] = None,
    end_ms: Optional[int] = None,
    order_by: Optional[list[str]] = None,
) -> List["mlflow.entities.Trace"]:
    """
    Return traces as a list of Trace objects (MLflow >= 2.21.1 supports return_type='list').
    """
    filters = [f"tag.mlflow.trace.user = '{user_id}'"]
    if start_ms is not None:
        filters.append(f"timestamp_ms >= {start_ms}")
    if end_ms is not None:
        filters.append(f"timestamp_ms <= {end_ms}")
    filter_string = " AND ".join(filters)

    traces = mlflow.search_traces(
        filter_string=filter_string,
        order_by=order_by or ["timestamp_ms ASC"],
        return_type="list",  # requires MLflow 2.21.1+
    )
    return traces


def session_key_from_trace(t: "mlflow.entities.Trace") -> Optional[str]:
    """
    Get session id from Trace tags (works for both 2.x/3.x schemas).
    """
    # MLflow 2.x: tags on trace.info.tags ; 3.x: tags on trace.info.tags / trace.tags (kept via compatibility)
    info = getattr(t, "info", None)
    if info and getattr(info, "tags", None):
        return info.tags.get("mlflow.trace.session")
    # fallback
    return getattr(t, "tags", {}).get("mlflow.trace.session")  # type: ignore[call-arg]


def group_trace_list_by_session(
    traces: List["mlflow.entities.Trace"],
) -> Dict[Optional[str], List["mlflow.entities.Trace"]]:
    out: Dict[Optional[str], List["mlflow.entities.Trace"]] = {}
    for tr in traces:
        sid = session_key_from_trace(tr)
        out.setdefault(sid, []).append(tr)
    return out

In [6]:
traces = load_user_traces_as_list("daviddwlee84")
by_session = group_trace_list_by_session(traces)

for sid, items in by_session.items():
    print(f"\n--- Session {sid} ---")
    for t in items:
        # 2.x: t.data.request / t.data.response ; 3.x: t.data.inputs / t.data.outputs
        data = getattr(t, "data", None)
        req = getattr(data, "request", None) or getattr(data, "inputs", None)
        resp = getattr(data, "response", None) or getattr(data, "outputs", None)
        ts = getattr(getattr(t, "info", None), "timestamp_ms", None) or getattr(
            t, "timestamp_ms", None
        )
        print(f"[{ts}] inputs={str(req)[:80]} ... -> outputs={str(resp)[:80]} ...")