diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index c4d8dc3..9362c60 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -242,6 +242,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "is_search_required": response.data.chat_response.is_search_required, "finish_reason": response.data.chat_response.finish_reason, } + + # Include token usage if available + if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: + generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens + # Include tool calls if available if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( @@ -602,6 +607,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "finish_reason": response.data.chat_response.choices[0].finish_reason, "time_created": str(response.data.chat_response.time_created), } + + # Include token usage if available + if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: + generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens + if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( self.chat_tool_calls(response) diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index 1a6649a..b7e19da 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -55,6 +55,11 @@ def mocked_response(*args): # type: ignore[no-untyped-def] "citations": None, "documents": None, "tool_calls": None, + "usage": MockResponseDict( + { + "total_tokens": 50, + } + ), } ), "model_id": "cohere.command-r-16k", @@ -116,6 +121,11 @@ def mocked_response(*args): # type: ignore[no-untyped-def] ) ], "time_created": "2025-08-14T10:00:01.100000+00:00", + "usage": MockResponseDict( + { + "total_tokens": 75, + } + ), } ), "model_id": "meta.llama-3.3-70b-instruct", @@ -141,6 +151,13 @@ def mocked_response(*args): # type: ignore[no-untyped-def] expected = "Assistant chat reply." actual = llm.invoke(messages, temperature=0.2) assert actual.content == expected + + # Test total_tokens in additional_kwargs + assert "total_tokens" in actual.additional_kwargs + if provider == "cohere": + assert actual.additional_kwargs["total_tokens"] == 50 + elif provider == "meta": + assert actual.additional_kwargs["total_tokens"] == 75 @pytest.mark.requires("oci")