Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion langfuse/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CallbackHandler(BaseCallbackHandler):
version: Optional[str] = None
session_id: Optional[str] = None
user_id: Optional[str] = None
trace_name: Optional[str] = None
_task_manager: TaskManager

def __init__(
Expand All @@ -56,6 +57,7 @@ def __init__(
] = None,
session_id: Optional[str] = None,
user_id: Optional[str] = None,
trace_name: Optional[str] = None,
release: Optional[str] = None,
version: Optional[str] = None,
threads: Optional[int] = None,
Expand Down Expand Up @@ -130,6 +132,7 @@ def __init__(
self.runs = {}
self.session_id = session_id
self.user_id = user_id
self.trace_name = trace_name
self._task_manager = self.langfuse.task_manager

else:
Expand Down Expand Up @@ -281,7 +284,7 @@ def __generate_trace_and_parent(
if self.trace is None and self.langfuse is not None:
trace = self.langfuse.trace(
id=str(run_id),
name=class_name,
name=self.trace_name if self.trace_name is not None else class_name,
metadata=self.__join_tags_and_metadata(tags, metadata),
version=self.version,
session_id=self.session_id,
Expand Down
7 changes: 6 additions & 1 deletion tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def test_next_span_id_from_trace_simple_chain():

def test_callback_simple_chain():
api = get_api()
handler = CallbackHandler(debug=False)
handler = CallbackHandler(
debug=False, trace_name="test-trace-name", session_id="100", user_id="200"
)

llm = ChatOpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"))
template = """You are a playwright. Given the title of play, it is your job to write a synopsis for that title.
Expand All @@ -328,6 +330,9 @@ def test_callback_simple_chain():
)[0]
assert trace.input == root_observation.input
assert trace.output == root_observation.output
assert trace.name == "test-trace-name"
assert trace.session_id == "100"
assert trace.user_id == "200"

for observation in trace.observations:
if observation.type == "GENERATION":
Expand Down