From c125cbbc540119b044ca30b5dfd1914c88c7d117 Mon Sep 17 00:00:00 2001 From: paxiaatucsdedu Date: Mon, 20 Oct 2025 14:52:55 -0700 Subject: [PATCH 1/4] Add total_tokens to AiMessage Add total_tokens to AiMessage in cohere provider and generic provider --- libs/oci/PR_DESCRIPTION.md | 48 +++++++++++++++++++ .../chat_models/oci_generative_ai.py | 10 ++++ 2 files changed, 58 insertions(+) create mode 100644 libs/oci/PR_DESCRIPTION.md diff --git a/libs/oci/PR_DESCRIPTION.md b/libs/oci/PR_DESCRIPTION.md new file mode 100644 index 0000000..6664b53 --- /dev/null +++ b/libs/oci/PR_DESCRIPTION.md @@ -0,0 +1,48 @@ +# Add Token Usage to AIMessage Response + +## Summary +Adds `total_tokens` to `AIMessage.additional_kwargs` for non-streaming chat responses, enabling users to track token consumption when using `ChatOCIGenAI`. + +## Problem +When using `ChatOCIGenAI.invoke()`, token usage information (prompt_tokens, completion_tokens, total_tokens) from the OCI Generative AI API was not accessible in the `AIMessage` response, even though the raw OCI API returns this data. + +## Solution +Extract token usage from the OCI API response and add `total_tokens` to `additional_kwargs` in non-streaming mode. + +### Changes Made +**File:** `langchain_oci/chat_models/oci_generative_ai.py` + +1. **CohereProvider.chat_generation_info()** (lines 246-248) + - Extract `usage.total_tokens` from `response.data.chat_response.usage` + - Add to `generation_info["total_tokens"]` + +2. **GenericProvider.chat_generation_info()** (lines 611-613) + - Same extraction for Meta/Llama models + +## Usage + +### Before +```python +response = chat.invoke("What is the capital of France?") +# No way to access token usage +``` + +### After +```python +response = chat.invoke("What is the capital of France?") +print(response.additional_kwargs.get('total_tokens')) # 26 +``` + +## Limitations +- **Streaming mode:** Token usage is NOT available when `is_stream=True` because the OCI Generative AI streaming API does not include usage statistics in stream events. +- **Non-streaming only:** Use `is_stream=False` to get token usage information. + +## Testing +Tested with: +- ✅ Cohere Command-R models (`cohere.command-r-plus-08-2024`) +- ✅ Meta Llama models (`meta.llama-3.3-70b-instruct`) +- ✅ Non-streaming mode (`is_stream=False`) +- ❌ Streaming mode (not supported by OCI API) + +## Backward Compatibility +✅ Fully backward compatible - existing code continues to work unchanged. diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index c4d8dc3..9362c60 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -242,6 +242,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "is_search_required": response.data.chat_response.is_search_required, "finish_reason": response.data.chat_response.finish_reason, } + + # Include token usage if available + if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: + generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens + # Include tool calls if available if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( @@ -602,6 +607,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "finish_reason": response.data.chat_response.choices[0].finish_reason, "time_created": str(response.data.chat_response.time_created), } + + # Include token usage if available + if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: + generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens + if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( self.chat_tool_calls(response) From 61c05d6f9a3d21d1346f6b96facda6cdd31af5dc Mon Sep 17 00:00:00 2001 From: paxiaatucsdedu Date: Mon, 20 Oct 2025 14:53:31 -0700 Subject: [PATCH 2/4] Revert "Add total_tokens to AiMessage" This reverts commit c125cbbc540119b044ca30b5dfd1914c88c7d117. --- libs/oci/PR_DESCRIPTION.md | 48 ------------------- .../chat_models/oci_generative_ai.py | 10 ---- 2 files changed, 58 deletions(-) delete mode 100644 libs/oci/PR_DESCRIPTION.md diff --git a/libs/oci/PR_DESCRIPTION.md b/libs/oci/PR_DESCRIPTION.md deleted file mode 100644 index 6664b53..0000000 --- a/libs/oci/PR_DESCRIPTION.md +++ /dev/null @@ -1,48 +0,0 @@ -# Add Token Usage to AIMessage Response - -## Summary -Adds `total_tokens` to `AIMessage.additional_kwargs` for non-streaming chat responses, enabling users to track token consumption when using `ChatOCIGenAI`. - -## Problem -When using `ChatOCIGenAI.invoke()`, token usage information (prompt_tokens, completion_tokens, total_tokens) from the OCI Generative AI API was not accessible in the `AIMessage` response, even though the raw OCI API returns this data. - -## Solution -Extract token usage from the OCI API response and add `total_tokens` to `additional_kwargs` in non-streaming mode. - -### Changes Made -**File:** `langchain_oci/chat_models/oci_generative_ai.py` - -1. **CohereProvider.chat_generation_info()** (lines 246-248) - - Extract `usage.total_tokens` from `response.data.chat_response.usage` - - Add to `generation_info["total_tokens"]` - -2. **GenericProvider.chat_generation_info()** (lines 611-613) - - Same extraction for Meta/Llama models - -## Usage - -### Before -```python -response = chat.invoke("What is the capital of France?") -# No way to access token usage -``` - -### After -```python -response = chat.invoke("What is the capital of France?") -print(response.additional_kwargs.get('total_tokens')) # 26 -``` - -## Limitations -- **Streaming mode:** Token usage is NOT available when `is_stream=True` because the OCI Generative AI streaming API does not include usage statistics in stream events. -- **Non-streaming only:** Use `is_stream=False` to get token usage information. - -## Testing -Tested with: -- ✅ Cohere Command-R models (`cohere.command-r-plus-08-2024`) -- ✅ Meta Llama models (`meta.llama-3.3-70b-instruct`) -- ✅ Non-streaming mode (`is_stream=False`) -- ❌ Streaming mode (not supported by OCI API) - -## Backward Compatibility -✅ Fully backward compatible - existing code continues to work unchanged. diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 9362c60..c4d8dc3 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -242,11 +242,6 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "is_search_required": response.data.chat_response.is_search_required, "finish_reason": response.data.chat_response.finish_reason, } - - # Include token usage if available - if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: - generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens - # Include tool calls if available if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( @@ -607,11 +602,6 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "finish_reason": response.data.chat_response.choices[0].finish_reason, "time_created": str(response.data.chat_response.time_created), } - - # Include token usage if available - if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: - generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens - if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( self.chat_tool_calls(response) From b96e40cce9e3c88df34cded453e46cd89fc5b9c7 Mon Sep 17 00:00:00 2001 From: paxiaatucsdedu Date: Mon, 20 Oct 2025 15:01:50 -0700 Subject: [PATCH 3/4] Add total_tokens to AIMessage Add total tokens to AIMessage in cohere provider and generic provider --- .../oci/langchain_oci/chat_models/oci_generative_ai.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index c4d8dc3..9362c60 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -242,6 +242,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "is_search_required": response.data.chat_response.is_search_required, "finish_reason": response.data.chat_response.finish_reason, } + + # Include token usage if available + if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: + generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens + # Include tool calls if available if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( @@ -602,6 +607,11 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]: "finish_reason": response.data.chat_response.choices[0].finish_reason, "time_created": str(response.data.chat_response.time_created), } + + # Include token usage if available + if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage: + generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens + if self.chat_tool_calls(response): generation_info["tool_calls"] = self.format_response_tool_calls( self.chat_tool_calls(response) From 01400bc844ac4d99d31d50f6c3d5c686f4b12c51 Mon Sep 17 00:00:00 2001 From: paxiaatucsdedu Date: Tue, 21 Oct 2025 10:01:38 -0700 Subject: [PATCH 4/4] Add unit test for total_tokens Add unit test for total_tokens --- .../chat_models/test_oci_generative_ai.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index 1a6649a..b7e19da 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -55,6 +55,11 @@ def mocked_response(*args): # type: ignore[no-untyped-def] "citations": None, "documents": None, "tool_calls": None, + "usage": MockResponseDict( + { + "total_tokens": 50, + } + ), } ), "model_id": "cohere.command-r-16k", @@ -116,6 +121,11 @@ def mocked_response(*args): # type: ignore[no-untyped-def] ) ], "time_created": "2025-08-14T10:00:01.100000+00:00", + "usage": MockResponseDict( + { + "total_tokens": 75, + } + ), } ), "model_id": "meta.llama-3.3-70b-instruct", @@ -141,6 +151,13 @@ def mocked_response(*args): # type: ignore[no-untyped-def] expected = "Assistant chat reply." actual = llm.invoke(messages, temperature=0.2) assert actual.content == expected + + # Test total_tokens in additional_kwargs + assert "total_tokens" in actual.additional_kwargs + if provider == "cohere": + assert actual.additional_kwargs["total_tokens"] == 50 + elif provider == "meta": + assert actual.additional_kwargs["total_tokens"] == 75 @pytest.mark.requires("oci")