diff --git a/py/noxfile.py b/py/noxfile.py index 6edec440..6087325a 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -64,6 +64,7 @@ def _pinned_python_version(): "agno", "agentscope", "anthropic", + "cohere", "dspy", "openai", "openai-agents", @@ -103,6 +104,7 @@ def _pinned_python_version(): AUTOEVALS_VERSIONS = (LATEST, "0.0.129") GENAI_VERSIONS = (LATEST,) +COHERE_VERSIONS = (LATEST, "5.10.0") DSPY_VERSIONS = (LATEST,) GOOGLE_ADK_VERSIONS = (LATEST, "1.14.1") LANGCHAIN_VERSIONS = (LATEST, "0.3.28") @@ -306,6 +308,15 @@ def test_litellm(session, version): _run_core_tests(session) +@nox.session() +@nox.parametrize("version", COHERE_VERSIONS, ids=COHERE_VERSIONS) +def test_cohere(session, version): + _install_test_deps(session) + _install(session, "cohere", version) + _run_tests(session, f"{INTEGRATION_DIR}/cohere/test_cohere.py") + _run_core_tests(session) + + @nox.session() @nox.parametrize("version", DSPY_VERSIONS, ids=DSPY_VERSIONS) def test_dspy(session, version): diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index e34d7625..04399e93 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -13,6 +13,7 @@ AgnoIntegration, AnthropicIntegration, ClaudeAgentSDKIntegration, + CohereIntegration, DSPyIntegration, GoogleGenAIIntegration, LangChainIntegration, @@ -53,6 +54,7 @@ def auto_instrument( agno: bool = True, agentscope: bool = True, claude_agent_sdk: bool = True, + cohere: bool = True, dspy: bool = True, adk: bool = True, langchain: bool = True, @@ -78,6 +80,7 @@ def auto_instrument( agno: Enable Agno instrumentation (default: True) agentscope: Enable AgentScope instrumentation (default: True) claude_agent_sdk: Enable Claude Agent SDK instrumentation (default: True) + cohere: Enable Cohere instrumentation (default: True) dspy: Enable DSPy instrumentation (default: True) adk: Enable Google ADK instrumentation (default: True) langchain: Enable LangChain instrumentation (default: True) @@ -149,6 +152,8 @@ def auto_instrument( results["agentscope"] = _instrument_integration(AgentScopeIntegration) if claude_agent_sdk: results["claude_agent_sdk"] = _instrument_integration(ClaudeAgentSDKIntegration) + if cohere: + results["cohere"] = _instrument_integration(CohereIntegration) if dspy: results["dspy"] = _instrument_integration(DSPyIntegration) if adk: diff --git a/py/src/braintrust/conftest.py b/py/src/braintrust/conftest.py index ee94881c..1383eaf5 100644 --- a/py/src/braintrust/conftest.py +++ b/py/src/braintrust/conftest.py @@ -153,6 +153,8 @@ def setup_braintrust(): os.environ.setdefault("OPENAI_API_KEY", "sk-test-dummy-api-key-for-vcr-tests") os.environ.setdefault("ANTHROPIC_API_KEY", "sk-ant-test-dummy-api-key-for-vcr-tests") os.environ.setdefault("MISTRAL_API_KEY", "mistral-test-dummy-api-key-for-vcr-tests") + os.environ.setdefault("CO_API_KEY", os.getenv("COHERE_API_KEY", "cohere-test-dummy-api-key-for-vcr-tests")) + os.environ.setdefault("COHERE_API_KEY", os.getenv("CO_API_KEY", "cohere-test-dummy-api-key-for-vcr-tests")) @pytest.fixture(autouse=True) diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index d17a0146..462c212e 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -3,6 +3,7 @@ from .agno import AgnoIntegration from .anthropic import AnthropicIntegration from .claude_agent_sdk import ClaudeAgentSDKIntegration +from .cohere import CohereIntegration from .dspy import DSPyIntegration from .google_genai import GoogleGenAIIntegration from .langchain import LangChainIntegration @@ -20,6 +21,7 @@ "AgnoIntegration", "AnthropicIntegration", "ClaudeAgentSDKIntegration", + "CohereIntegration", "DSPyIntegration", "GoogleGenAIIntegration", "LiteLLMIntegration", diff --git a/py/src/braintrust/integrations/auto_test_scripts/test_auto_cohere.py b/py/src/braintrust/integrations/auto_test_scripts/test_auto_cohere.py new file mode 100644 index 00000000..88316c5a --- /dev/null +++ b/py/src/braintrust/integrations/auto_test_scripts/test_auto_cohere.py @@ -0,0 +1,36 @@ +"""Test auto_instrument for Cohere.""" + +import os +from pathlib import Path + +from braintrust.auto import auto_instrument +from braintrust.wrappers.test_utils import autoinstrument_test_context +from cohere import ClientV2 + + +results = auto_instrument() +assert results.get("cohere") == True + +results2 = auto_instrument() +assert results2.get("cohere") == True + +COHERE_CASSETTES_DIR = Path(__file__).resolve().parent.parent / "cohere" / "cassettes" + +with autoinstrument_test_context("test_auto_cohere", cassettes_dir=COHERE_CASSETTES_DIR) as memory_logger: + client = ClientV2(api_key=os.environ.get("CO_API_KEY")) + response = client.chat( + model="command-a-03-2025", + messages=[{"role": "user", "content": "What is 2+2? Reply with just the number."}], + max_tokens=10, + ) + assert response.message.content[0].text == "4" + + spans = memory_logger.pop() + assert len(spans) == 1, f"Expected 1 span, got {len(spans)}" + span = spans[0] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["metadata"]["model"] == "command-a-03-2025" + assert span["output"]["content"][0]["text"] == "4" + +print("SUCCESS") diff --git a/py/src/braintrust/integrations/cohere/__init__.py b/py/src/braintrust/integrations/cohere/__init__.py new file mode 100644 index 00000000..a9f5ea4c --- /dev/null +++ b/py/src/braintrust/integrations/cohere/__init__.py @@ -0,0 +1,10 @@ +"""Braintrust integration for the Cohere Python SDK.""" + +from .integration import CohereIntegration +from .tracing import wrap_cohere + + +__all__ = [ + "CohereIntegration", + "wrap_cohere", +] diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_auto_cohere.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_auto_cohere.yaml new file mode 100644 index 00000000..768835ba --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_auto_cohere.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"What + is 2+2? Reply with just the number."}],"max_tokens":10,"stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '142' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: '{"id":"f76fd244-eba2-4270-8c13-5d02e4c727a6","message":{"role":"assistant","content":[{"type":"text","text":"4"}]},"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":13,"output_tokens":1},"tokens":{"input_tokens":508,"output_tokens":4},"cached_tokens":0}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '271' + content-type: + - application/json + date: + - Tue, 07 Apr 2026 22:59:12 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '2618' + num_tokens: + - '14' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - fafcd67dcebe77d4b9dcda7c83bd75fe + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '211' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '14' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_cohere_integration_setup_creates_spans.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_cohere_integration_setup_creates_spans.yaml new file mode 100644 index 00000000..95517091 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_cohere_integration_setup_creates_spans.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + body: '{"model":"command-r-plus-08-2024","messages":[{"role":"user","content":"What + is 4+4? Reply with just the number."}],"max_tokens":10,"stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '147' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: '{"id":"824f231b-f146-428c-a1e5-a1af328ccf67","message":{"role":"assistant","content":[{"type":"text","text":"8"}]},"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":13,"output_tokens":1},"tokens":{"input_tokens":214,"output_tokens":1},"cached_tokens":208}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '273' + content-type: + - application/json + date: + - Tue, 07 Apr 2026 22:59:11 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '1218' + num_tokens: + - '14' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - ce3bf0635eb69af4cff4ac45e07ec311 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '129' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '15' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_client_v2_chat_sync.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_client_v2_chat_sync.yaml new file mode 100644 index 00000000..19026bb3 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_client_v2_chat_sync.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + body: '{"model":"command-r-plus-08-2024","messages":[{"role":"user","content":"What + is 5+5? Reply with just the number."}],"max_tokens":10,"stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '147' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: '{"id":"f1599186-b136-4697-9133-0c3de526f296","message":{"role":"assistant","content":[{"type":"text","text":"10"}]},"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":13,"output_tokens":2},"tokens":{"input_tokens":214,"output_tokens":2},"cached_tokens":144}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '274' + content-type: + - application/json + date: + - Tue, 07 Apr 2026 22:59:10 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '1218' + num_tokens: + - '15' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - 129505e7d514afa6f3968c7647a1862e + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '158' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '16' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_nested_v2_chat_sync.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_nested_v2_chat_sync.yaml new file mode 100644 index 00000000..8212ae0c --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_nested_v2_chat_sync.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + body: '{"model":"command-r-plus-08-2024","messages":[{"role":"user","content":"What + is 2+2? Reply with just the number."}],"max_tokens":10,"stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '147' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: '{"id":"6adb7836-72c3-4a6c-b554-a5121cbcc6fe","message":{"role":"assistant","content":[{"type":"text","text":"4"}]},"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":13,"output_tokens":1},"tokens":{"input_tokens":214,"output_tokens":1},"cached_tokens":208}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '273' + content-type: + - application/json + date: + - Tue, 07 Apr 2026 22:59:09 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '1218' + num_tokens: + - '14' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - 4ab549420f7ca67b4dc3abbf5ff702bd + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '148' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '17' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_chat_stream_async.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_chat_stream_async.yaml new file mode 100644 index 00000000..cdcfd795 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_chat_stream_async.yaml @@ -0,0 +1,106 @@ +interactions: +- request: + body: '{"model":"command-r-plus-08-2024","messages":[{"role":"user","content":"What + is 8+8? Reply with just the number."}],"max_tokens":10,"stream":true}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '146' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: 'event: message-start + + data: {"id":"b2fed6fe-ce1d-4982-9ea0-016fab7cd960","type":"message-start","delta":{"message":{"role":"assistant","content":[],"tool_plan":"","tool_calls":[],"citations":[]}}} + + + event: content-start + + data: {"type":"content-start","index":0,"delta":{"message":{"content":{"type":"text","text":""}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"1"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"6"}}}} + + + event: content-end + + data: {"type":"content-end","index":0} + + + event: message-end + + data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":13,"output_tokens":2},"tokens":{"input_tokens":214,"output_tokens":2},"cached_tokens":192}}} + + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-type: + - text/event-stream + date: + - Tue, 07 Apr 2026 22:59:10 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + pragma: + - no-cache + server: + - envoy + vary: + - Origin + x-accel-expires: + - '0' + x-debug-trace-id: + - f79b417fbea3860786b2d54195c2f820 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '100' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '16' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_embed.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_embed.yaml new file mode 100644 index 00000000..adf23771 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_embed.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + body: '{"texts":["braintrust tracing"],"model":"embed-english-v3.0","input_type":"search_query","embedding_types":["float"]}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '117' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/embed + response: + body: + string: '{"id":"d3ef8ebc-7152-4741-a922-eaae55536a6a","texts":["braintrust tracing"],"embeddings":{"float":[[-0.009063721,-0.034820557,-0.022445679,-0.03353882,-0.055541992,0.006187439,-0.040649414,0.009559631,-0.033721924,0.044128418,0.015777588,0.014099121,0.012794495,-0.025863647,-0.006362915,-0.012413025,0.022842407,-0.03842163,0.017791748,-0.0030059814,0.036315918,-0.02508545,0.056243896,-0.022720337,0.029800415,0.014503479,-0.03414917,0.0009975433,-0.010238647,0.00012850761,0.013084412,0.008277893,0.010391235,0.011116028,-0.014968872,-0.026977539,-0.031433105,0.009490967,0.0079574585,-0.014839172,0.018371582,-0.008476257,0.04824829,-0.0043945312,-0.005432129,0.004310608,0.02355957,-0.038604736,0.01977539,0.022903442,-0.014511108,0.029083252,-0.014228821,0.011802673,-0.00045514107,0.0030231476,-0.08874512,-0.005619049,-0.014511108,0.08294678,0.0793457,0.006259918,-0.009239197,0.010360718,0.00083494186,-0.010940552,0.015113831,0.019897461,-0.034301758,-0.046813965,0.018569946,-0.0143585205,0.027862549,-0.015007019,-0.04006958,-0.014450073,0.012809753,0.0014944077,0.023803711,-0.0178833,-0.008468628,-0.045288086,-0.023773193,0.034088135,-0.021438599,0.03829956,-0.005897522,-0.0065956116,0.052734375,0.049621582,0.002450943,-0.010894775,0.07623291,0.088012695,0.012084961,-0.0057640076,-0.04168701,0.0018901825,0.019638062,-0.008903503,-0.018173218,0.0072364807,-0.08959961,-0.00014150143,-0.0129776,0.004585266,-0.028945923,0.0138549805,0.032714844,0.0368042,-0.0126571655,-0.012649536,-0.02330017,-0.057739258,-0.0009841919,0.042266846,-0.018157959,0.014572144,0.027145386,-0.0287323,-0.023010254,0.0037899017,-0.024337769,-0.007575989,0.052246094,-0.007949829,-0.008338928,0.028030396,0.046936035,-0.0028533936,-0.024414062,-0.0037899017,0.030853271,0.028656006,0.010299683,-0.01084137,-0.0058288574,-0.0078048706,-0.02079773,0.041107178,0.05291748,-0.025466919,0.023162842,0.034973145,0.046936035,0.009399414,0.05319214,0.02017212,-0.03955078,-0.032836914,0.05215454,0.02848816,0.019943237,0.0037784576,-0.02671814,-0.037628174,-0.055664062,-0.010116577,0.05291748,-0.009880066,0.027069092,0.020507812,-0.0012254715,0.028167725,-0.005973816,-0.037841797,0.04348755,0.019729614,-0.0033740997,0.00015854836,-0.005996704,-0.017730713,0.0019102097,0.0017080307,-0.010620117,0.026931763,0.023391724,0.030685425,0.028396606,-0.03616333,-0.032836914,-0.059509277,0.044128418,-0.025604248,-0.020828247,0.023422241,0.016998291,0.06732178,0.017181396,0.008338928,-0.00737381,-0.011001587,-0.016418457,0.0066986084,-0.0034694672,-0.017333984,0.04626465,0.039093018,0.020050049,-0.040496826,0.05734253,-0.0014915466,-0.022705078,-0.019119263,-0.062805176,-0.05105591,-0.0022945404,-0.007896423,0.05001831,0.0031795502,-0.008033752,-0.007293701,-0.015899658,0.011039734,-0.020370483,0.062438965,-0.015670776,-0.017837524,-0.0035057068,-0.0064353943,-0.007534027,0.022064209,-0.041168213,0.008094788,-0.016967773,0.06768799,-0.044158936,0.0473938,-0.014533997,0.022277832,0.05807495,0.025726318,-0.030258179,-0.011352539,-0.035614014,0.012374878,-0.0046081543,0.05532837,0.008255005,-0.009811401,0.021438599,0.034484863,-0.02420044,-0.0362854,0.00258255,-0.035308838,0.031341553,-0.03869629,-0.025756836,-0.011253357,0.021606445,-0.0052833557,0.0385437,0.022476196,-0.014511108,0.02923584,-0.053100586,-0.03213501,0.020263672,-0.026626587,0.002319336,-0.030334473,0.0036182404,-0.031463623,-0.035614014,-0.046875,-0.023834229,-0.012214661,-0.019134521,-0.023117065,0.046142578,-0.004886627,-0.012176514,0.018447876,0.06317139,-0.055511475,-0.033935547,-0.009300232,-0.00038838387,-0.009414673,0.004184723,0.039489746,-0.023040771,-0.023208618,0.018157959,0.048431396,-0.011161804,-0.014533997,-0.08148193,0.013771057,0.050628662,0.04840088,-0.0032634735,-0.03262329,0.004634857,0.017181396,-0.033294678,-0.017166138,0.02166748,-0.028625488,0.0435791,0.07122803,-0.0012235641,-0.00016987324,0.0025310516,0.05215454,0.0390625,-0.011833191,-0.016616821,-0.022338867,0.010559082,-0.043640137,-0.010864258,-0.046295166,-0.05508423,-0.03414917,-0.08129883,-0.039031982,-0.113464355,-0.030151367,-0.031311035,0.011451721,-0.0011997223,0.016662598,-0.0076828003,0.004737854,0.002128601,0.009689331,-0.041290283,0.030975342,0.03050232,0.01576233,0.0010585785,0.027252197,-0.011062622,0.017730713,0.02670288,-0.014587402,-0.0026779175,0.014915466,0.047698975,0.014091492,0.011581421,0.00016391277,-0.083618164,-0.024490356,-0.028411865,0.0034980774,0.012184143,-0.06604004,0.06317139,-0.0064048767,-0.013702393,0.01234436,0.028381348,-0.06829834,-0.019104004,0.02468872,-0.026489258,-0.03591919,0.049682617,0.022247314,0.015945435,-0.0418396,-0.047027588,-0.05404663,0.0129852295,-0.006843567,0.08343506,0.076660156,-0.0047416687,0.043914795,-0.02027893,0.002948761,0.053955078,-0.06100464,0.033203125,-0.02923584,-0.041290283,-0.0052948,-0.02217102,0.013328552,-0.012329102,0.020721436,0.023971558,0.043945312,0.043151855,0.026947021,0.018112183,0.00040197372,-0.032928467,-0.028793335,0.020095825,-0.020324707,0.025741577,0.013511658,-0.008102417,0.00070142746,-0.014709473,-0.021499634,0.043701172,-0.009849548,-0.04559326,-0.007217407,0.006416321,-0.1262207,0.064819336,-0.03540039,0.025253296,0.04043579,-0.00390625,0.0007133484,-0.010475159,0.0047836304,0.006904602,0.005874634,0.0340271,0.0028972626,0.004108429,-0.040161133,0.02671814,0.003364563,0.013580322,0.0016202927,0.029388428,0.025848389,-0.011726379,-0.034362793,0.02168274,0.032318115,0.015075684,-0.0068473816,-0.0030593872,-0.01939392,-0.0524292,-0.009338379,-0.024505615,-0.013633728,0.009239197,-0.049438477,-0.07965088,-0.012886047,-0.014984131,-0.01335907,0.024978638,-0.025436401,-0.018447876,-0.0680542,-0.009757996,0.014045715,-0.019805908,-0.019363403,-0.00039172173,0.046905518,-0.019378662,0.13220215,0.03967285,0.028518677,0.04095459,0.00080013275,0.01423645,0.0014295578,-0.027923584,0.0074272156,0.014335632,0.023162842,0.0049476624,0.013900757,-0.01309967,-0.011955261,0.018508911,-0.03543091,0.04928589,0.010406494,0.020553589,0.026519775,-0.00472641,0.006313324,-0.057769775,-0.0069999695,0.013633728,0.0037078857,0.011299133,-0.015419006,-0.006450653,0.007965088,-0.041992188,0.0060310364,-0.013618469,-0.026763916,-0.04751587,-0.029220581,0.0021038055,-0.026229858,-0.028259277,0.010025024,0.004180908,0.0069465637,-0.03050232,-0.09777832,0.032104492,0.009414673,-0.0463562,0.06213379,-0.05822754,0.02444458,-0.03466797,-0.032165527,-0.012489319,-0.11395264,-0.0435791,-0.059936523,0.05239868,-0.009994507,0.006439209,0.0289917,0.03277588,-0.05142212,0.04269409,0.0008764267,0.013900757,0.069885254,-0.05392456,0.018035889,0.061676025,-0.022750854,-0.069885254,0.031402588,0.010437012,-0.010665894,0.016906738,-0.0021400452,-0.0054855347,0.008476257,0.004283905,0.021133423,-0.01084137,-0.068359375,-0.0010480881,0.03756714,0.024032593,-0.027252197,-0.04586792,0.04168701,-0.0035743713,-0.014472961,0.008132935,0.019851685,0.005378723,-0.00995636,-0.009239197,0.0047340393,-0.023834229,-0.041137695,-0.08428955,0.09539795,0.011154175,0.021362305,-0.016601562,0.044158936,0.009857178,-0.018600464,-0.0025634766,-0.011405945,-0.0051651,0.0058403015,0.027359009,-0.02394104,0.014282227,-0.014862061,-0.008811951,-0.014633179,0.030410767,0.003194809,0.015838623,-0.015930176,-0.032104492,0.014541626,-0.037902832,0.017532349,0.017501831,0.030731201,-0.027511597,0.02217102,0.008773804,0.011688232,0.021362305,0.017059326,0.024551392,0.0072135925,-0.0055999756,-0.017929077,-0.014472961,-0.017578125,-0.0025730133,-0.016021729,0.006591797,0.04196167,0.083618164,-0.012435913,0.0065956116,0.03933716,0.03829956,-0.028259277,-0.062805176,0.014472961,-0.017868042,0.023803711,0.005470276,0.007369995,-0.01461792,-0.03439331,0.0231781,0.0062179565,0.0029697418,0.007865906,0.008346558,0.015396118,-0.023620605,-0.016479492,0.029022217,-0.0074310303,0.033721924,0.018829346,-0.008277893,-0.040039062,0.041015625,0.005290985,-0.010437012,-0.000957489,-0.025543213,0.0011520386,-0.07299805,-0.08496094,-0.025054932,0.06414795,-0.023529053,-0.02746582,0.024887085,-0.0079422,-0.0037631989,0.003376007,-0.060821533,-0.027435303,0.008125305,-0.02368164,-0.032348633,-0.013587952,-0.02053833,0.029724121,0.042297363,0.008460999,-0.054595947,-0.007347107,0.017608643,-0.0032196045,-0.019973755,-0.0018987656,0.019363403,-0.024978638,-0.016799927,0.038970947,-0.030090332,0.005886078,-0.019073486,0.0067100525,0.06665039,0.040924072,-0.0010881424,0.014961243,0.001162529,0.01802063,0.024124146,-0.0055274963,0.026473999,-0.028945923,-0.022415161,0.00044822693,-0.03225708,-0.00440979,0.016235352,0.035003662,0.024795532,-0.00365448,-0.030288696,0.036346436,-0.024475098,-0.0015201569,0.041870117,0.0076942444,-0.0073165894,0.0028800964,0.021911621,-0.008216858,-0.0000074505806,0.037384033,-0.019119263,-0.021377563,-0.064208984,-0.034423828,-0.014564514,-0.0025730133,0.028839111,-0.003255844,-0.015029907,0.056488037,0.01663208,-0.010513306,0.002046585,0.0143966675,0.0077590942,-0.009460449,0.052947998,0.0050697327,-0.0056419373,-0.02607727,-0.016983032,-0.08050537,-0.10876465,0.079833984,0.01890564,-0.0056610107,0.01802063,-0.008811951,0.028442383,-0.02267456,-0.00044608116,0.02923584,0.012382507,0.0030536652,0.0028324127,0.06222534,0.014404297,-0.015464783,-0.0059661865,0.0012321472,-0.0044403076,0.014060974,-0.009513855,-0.044677734,-0.024276733,-0.030899048,0.031173706,-0.021484375,-0.0035820007,-0.0064430237,0.0024604797,0.03488159,-0.024475098,-0.025344849,0.03164673,-0.010520935,-0.027038574,0.019760132,0.057678223,0.043029785,-0.0073013306,-0.07891846,-0.048980713,0.0017671585,-0.030426025,-0.025680542,-0.014404297,0.054107666,0.009559631,0.066589355,0.07269287,0.013786316,0.026626587,-0.005558014,-0.025756836,0.01020813,0.0033416748,0.042053223,-0.0019359589,-0.0036315918,0.017929077,0.019424438,-0.023986816,-0.037963867,-0.015197754,0.028717041,0.05117798,0.02885437,-0.03741455,-0.015350342,0.049316406,0.03262329,0.010375977,0.084472656,0.03729248,-0.0018062592,0.03265381,0.005748749,0.032684326,0.019424438,-0.0068359375,0.006252289,0.014602661,-0.017486572,-0.016220093,-0.0021781921,-0.041625977,0.018234253,-0.023086548,0.031311035,-0.00207901,0.025939941,0.034301758,-0.020019531,0.035858154,0.024459839,-0.037261963,0.0036373138,-0.113464355,0.04550171,0.09399414,0.00434494,-0.0022888184,-0.0023975372,0.00069379807,0.03881836,0.010734558,0.0231781,0.011917114,-0.004753113,0.048065186,-0.0033664703,-0.04071045,0.040100098,0.04623413,-0.009056091,0.015823364,0.015655518,0.03225708,0.018844604,0.022842407,-0.054229736,0.0134887695,0.023239136,0.0012168884,0.016662598,-0.016220093,-0.0001872778,-0.033081055,-0.009269714,0.017166138,-0.025543213,0.00058317184,0.053955078,0.020004272,-0.041900635,0.040252686,0.024597168,-0.018569946,-0.026916504,-0.003118515,-0.03378296,-0.036956787,0.0007176399,-0.0036811829,0.0057411194,-0.010726929,-0.0050582886,-0.008331299,-0.024002075,-0.005935669,-0.014045715,-0.02670288,-0.05441284,0.00868988,0.03253174,-0.019470215,0.012077332,-0.015716553,-0.023498535,0.046722412,-0.00223732,0.014701843,0.046569824,0.02130127,0.0059509277,0.039245605,-0.038238525,0.002483368,-0.000320673,-0.0039863586,0.024597168,-0.033325195,0.0038280487,-0.0063056946,-0.038848877,0.026031494,-0.0055656433,0.013366699,0.02255249,-0.027236938,0.016662598,-0.026382446,0.016571045,-0.045410156,-0.0032367706,-0.009536743,0.028411865,-0.020843506,0.0005159378,-0.020065308,0.0014324188,-0.001033783,-0.018493652,0.015701294,0.043151855,0.0014209747,0.059936523,-0.010025024,-0.0619812,-0.022003174,0.01725769,0.027862549,-0.026245117,-0.046875,-0.0019245148,0.014213562,0.025146484,0.00017046928,0.007987976,0.0004911423,-0.023239136,-0.019332886,0.0115356445,0.024932861,-0.010871887,0.019638062,-0.048065186,0.03125,-0.026992798,-0.034423828,0.020874023,0.03781128,-0.021972656,-0.018920898,-0.016494751,-0.058166504,-0.043060303,-0.022399902,0.0071754456,0.021484375,-0.01625061,0.0060768127,0.016693115,0.025009155,-0.03717041,-0.017196655,-0.0284729,0.042266846,0.0390625,-0.04055786,-0.025756836,0.015975952,0.01737976,0.005683899,0.014930725,-0.014312744,-0.033172607,0.007320404,0.0043945312,-0.007408142,0.021087646,-0.058563232,-0.05758667,0.024673462,-0.04815674,0.0047683716,-0.091552734,-0.018310547,0.022079468,0.010032654,-0.030838013,0.01928711,-0.0055618286,-0.015357971,0.04107666,-0.025848389,-0.015174866,0.055389404,-0.020843506,-0.0084991455,-0.0027122498,-0.018325806,-0.027389526,-0.021972656,-0.015052795,-0.04067993,0.009254456,-0.014015198,-0.020248413,0.0093688965,0.021850586,0.00982666,0.02331543,-0.010917664,-0.03414917,-0.0050239563,-0.036102295,0.032104492,-0.08239746,-0.025863647,-0.035614014,-0.019088745,0.011932373,-0.004131317,0.035369873,-0.017211914,-0.009117126,-0.0013742447,-0.006378174,0.017227173,-0.0096206665,0.04055786,0.005619049,-0.0052452087,-0.062805176,-0.015434265,0.0077056885]]},"meta":{"api_version":{"version":"2"},"billed_units":{"input_tokens":3}},"response_type":"embeddings_by_type"}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '12970' + content-type: + - application/json + date: + - Tue, 07 Apr 2026 22:59:10 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '18' + num_tokens: + - '3' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - 697b2b0a29abbda3a97ef5521ce33db2 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '391' + x-trial-endpoint-call-limit: + - '100' + x-trial-endpoint-call-remaining: + - '99' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_rerank.yaml b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_rerank.yaml new file mode 100644 index 00000000..d8f41103 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/test_wrap_cohere_v2_rerank.yaml @@ -0,0 +1,76 @@ +interactions: +- request: + body: '{"model":"rerank-v3.5","query":"What is the capital of the United States?","documents":["Carson + City is the capital city of Nevada.","Washington, D.C. is the capital of the + United States."],"top_n":2}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '200' + Host: + - api.cohere.com + User-Agent: + - cohere/5.21.1 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.13.3 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 5.21.1 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/rerank + response: + body: + string: '{"id":"d59e58e1-d871-435f-89df-1d9ed2b16580","results":[{"index":1,"relevance_score":0.9065293},{"index":0,"relevance_score":0.15399423}],"meta":{"api_version":{"version":"2"},"billed_units":{"search_units":1}}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '211' + content-type: + - application/json + date: + - Tue, 07 Apr 2026 22:59:11 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - da2cf47610759fcc304c2ae152bda9c1 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '60' + x-trial-endpoint-call-limit: + - '10' + x-trial-endpoint-call-remaining: + - '9' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/integration.py b/py/src/braintrust/integrations/cohere/integration.py new file mode 100644 index 00000000..7ffd34cf --- /dev/null +++ b/py/src/braintrust/integrations/cohere/integration.py @@ -0,0 +1,18 @@ +"""Cohere integration orchestration.""" + +from braintrust.integrations.base import BaseIntegration + +from .patchers import V2ChatPatcher, V2EmbedPatcher, V2RerankPatcher + + +class CohereIntegration(BaseIntegration): + """Braintrust instrumentation for the Cohere Python SDK.""" + + name = "cohere" + import_names = ("cohere",) + min_version = "5.10.0" + patchers = ( + V2ChatPatcher, + V2EmbedPatcher, + V2RerankPatcher, + ) diff --git a/py/src/braintrust/integrations/cohere/patchers.py b/py/src/braintrust/integrations/cohere/patchers.py new file mode 100644 index 00000000..2ec217d4 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/patchers.py @@ -0,0 +1,90 @@ +"""Cohere patchers.""" + +from braintrust.integrations.base import CompositeFunctionWrapperPatcher, FunctionWrapperPatcher + +from .tracing import ( + _v2_chat_async_wrapper, + _v2_chat_stream_async_wrapper, + _v2_chat_stream_wrapper, + _v2_chat_wrapper, + _v2_embed_async_wrapper, + _v2_embed_wrapper, + _v2_rerank_async_wrapper, + _v2_rerank_wrapper, +) + + +class _V2ChatPatcher(FunctionWrapperPatcher): + name = "cohere.v2.chat" + target_module = "cohere.v2.client" + target_path = "V2Client.chat" + wrapper = _v2_chat_wrapper + + +class _V2ChatAsyncPatcher(FunctionWrapperPatcher): + name = "cohere.v2.chat_async" + target_module = "cohere.v2.client" + target_path = "AsyncV2Client.chat" + wrapper = _v2_chat_async_wrapper + + +class _V2ChatStreamPatcher(FunctionWrapperPatcher): + name = "cohere.v2.chat_stream" + target_module = "cohere.v2.client" + target_path = "V2Client.chat_stream" + wrapper = _v2_chat_stream_wrapper + + +class _V2ChatStreamAsyncPatcher(FunctionWrapperPatcher): + name = "cohere.v2.chat_stream_async" + target_module = "cohere.v2.client" + target_path = "AsyncV2Client.chat_stream" + wrapper = _v2_chat_stream_async_wrapper + + +class V2ChatPatcher(CompositeFunctionWrapperPatcher): + name = "cohere.v2_chat" + sub_patchers = ( + _V2ChatPatcher, + _V2ChatAsyncPatcher, + _V2ChatStreamPatcher, + _V2ChatStreamAsyncPatcher, + ) + + +class _V2EmbedPatcher(FunctionWrapperPatcher): + name = "cohere.v2.embed" + target_module = "cohere.v2.client" + target_path = "V2Client.embed" + wrapper = _v2_embed_wrapper + + +class _V2EmbedAsyncPatcher(FunctionWrapperPatcher): + name = "cohere.v2.embed_async" + target_module = "cohere.v2.client" + target_path = "AsyncV2Client.embed" + wrapper = _v2_embed_async_wrapper + + +class V2EmbedPatcher(CompositeFunctionWrapperPatcher): + name = "cohere.v2_embed" + sub_patchers = (_V2EmbedPatcher, _V2EmbedAsyncPatcher) + + +class _V2RerankPatcher(FunctionWrapperPatcher): + name = "cohere.v2.rerank" + target_module = "cohere.v2.client" + target_path = "V2Client.rerank" + wrapper = _v2_rerank_wrapper + + +class _V2RerankAsyncPatcher(FunctionWrapperPatcher): + name = "cohere.v2.rerank_async" + target_module = "cohere.v2.client" + target_path = "AsyncV2Client.rerank" + wrapper = _v2_rerank_async_wrapper + + +class V2RerankPatcher(CompositeFunctionWrapperPatcher): + name = "cohere.v2_rerank" + sub_patchers = (_V2RerankPatcher, _V2RerankAsyncPatcher) diff --git a/py/src/braintrust/integrations/cohere/test_cohere.py b/py/src/braintrust/integrations/cohere/test_cohere.py new file mode 100644 index 00000000..3447f26e --- /dev/null +++ b/py/src/braintrust/integrations/cohere/test_cohere.py @@ -0,0 +1,299 @@ +import inspect +import os +import time +from pathlib import Path + +import pytest +from braintrust import logger +from braintrust.integrations.cohere import CohereIntegration, wrap_cohere +from braintrust.integrations.cohere.tracing import _v2_chat_async_wrapper, _v2_chat_wrapper +from braintrust.test_helpers import init_test_logger +from braintrust.wrappers.test_utils import assert_metrics_are_valid, verify_autoinstrument_script + + +pytest.importorskip("cohere") +from cohere import AsyncClientV2, Client, ClientV2 +from cohere.v2.client import V2Client + + +PROJECT_NAME = "test-cohere-sdk" +CHAT_MODEL = "command-r-plus-08-2024" +EMBED_MODEL = "embed-english-v3.0" +RERANK_MODEL = "rerank-v3.5" + + +@pytest.fixture(scope="module") +def vcr_cassette_dir(): + return str(Path(__file__).resolve().parent / "cassettes") + + +@pytest.fixture +def memory_logger(): + init_test_logger(PROJECT_NAME) + with logger._internal_with_memory_background_logger() as bgl: + yield bgl + + +def _get_client(): + return Client(api_key=os.environ.get("CO_API_KEY")) + + +def _get_client_v2(): + return ClientV2(api_key=os.environ.get("CO_API_KEY")) + + +async def _get_async_client_v2(): + return AsyncClientV2(api_key=os.environ.get("CO_API_KEY")) + + +@pytest.mark.vcr +def test_wrap_cohere_nested_v2_chat_sync(memory_logger): + assert not memory_logger.pop() + + client = wrap_cohere(_get_client()) + start = time.time() + response = client.v2.chat( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "What is 2+2? Reply with just the number."}], + max_tokens=10, + ) + end = time.time() + + assert response.message.content[0].text == "4" + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["input"]["messages"] == [{"role": "user", "content": "What is 2+2? Reply with just the number."}] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["metadata"]["model"] == CHAT_MODEL + assert span["output"]["content"][0]["text"] == "4" + assert_metrics_are_valid(span["metrics"], start, end) + + +@pytest.mark.vcr +def test_wrap_cohere_client_v2_chat_sync(memory_logger): + assert not memory_logger.pop() + + client = wrap_cohere(_get_client_v2()) + start = time.time() + response = client.chat( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "What is 5+5? Reply with just the number."}], + max_tokens=10, + ) + end = time.time() + + assert response.message.content[0].text == "10" + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["metadata"]["model"] == CHAT_MODEL + assert span["output"]["content"][0]["text"] == "10" + assert_metrics_are_valid(span["metrics"], start, end) + + +@pytest.mark.vcr +@pytest.mark.asyncio +async def test_wrap_cohere_v2_chat_stream_async(memory_logger): + assert not memory_logger.pop() + + client = wrap_cohere(await _get_async_client_v2()) + start = time.time() + stream = client.chat_stream( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "What is 8+8? Reply with just the number."}], + max_tokens=10, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk) + end = time.time() + + assert chunks + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["metadata"]["model"] == CHAT_MODEL + assert span["metadata"]["stream"] == True + assert span["metrics"]["time_to_first_token"] >= 0 + assert span["output"]["content"][0]["text"] == "16" + assert_metrics_are_valid(span["metrics"], start, end) + + +@pytest.mark.vcr +def test_wrap_cohere_v2_embed(memory_logger): + assert not memory_logger.pop() + + client = wrap_cohere(_get_client_v2()) + start = time.time() + response = client.embed( + model=EMBED_MODEL, + texts=["braintrust tracing"], + input_type="search_query", + embedding_types=["float"], + ) + end = time.time() + + assert response.embeddings.float_ + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["metadata"]["model"] == EMBED_MODEL + assert span["output"]["embeddings_count"] == 1 + assert span["output"]["embedding_length"] == len(response.embeddings.float_[0]) + assert span["output"]["embedding_types"] == ["float"] + assert span["metrics"]["prompt_tokens"] > 0 + assert span["metrics"]["tokens"] > 0 + assert start <= span["metrics"]["start"] <= span["metrics"]["end"] <= end + + +@pytest.mark.vcr +def test_wrap_cohere_v2_rerank(memory_logger): + assert not memory_logger.pop() + + client = wrap_cohere(_get_client_v2()) + start = time.time() + response = client.rerank( + model=RERANK_MODEL, + query="What is the capital of the United States?", + documents=[ + "Carson City is the capital city of Nevada.", + "Washington, D.C. is the capital of the United States.", + ], + top_n=2, + ) + end = time.time() + + assert response.results[0].index == 1 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["metadata"]["model"] == RERANK_MODEL + assert span["output"][0]["index"] == 1 + assert span["metrics"]["billed_search_units"] > 0 + assert start <= span["metrics"]["start"] <= span["metrics"]["end"] <= end + + +@pytest.mark.vcr +def test_cohere_integration_setup_creates_spans(memory_logger, monkeypatch): + assert not memory_logger.pop() + + original_chat = inspect.getattr_static(V2Client, "chat") + original_embed = inspect.getattr_static(V2Client, "embed") + original_rerank = inspect.getattr_static(V2Client, "rerank") + + assert CohereIntegration.setup() + client = _get_client_v2() + start = time.time() + response = client.chat( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "What is 4+4? Reply with just the number."}], + max_tokens=10, + ) + end = time.time() + + monkeypatch.setattr(V2Client, "chat", original_chat) + monkeypatch.setattr(V2Client, "embed", original_embed) + monkeypatch.setattr(V2Client, "rerank", original_rerank) + + assert response.message.content[0].text == "8" + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert span["output"]["content"][0]["text"] == "8" + assert_metrics_are_valid(span["metrics"], start, end) + + +def test_cohere_integration_setup_is_idempotent(monkeypatch): + first_chat = inspect.getattr_static(V2Client, "chat") + first_embed = inspect.getattr_static(V2Client, "embed") + first_rerank = inspect.getattr_static(V2Client, "rerank") + + assert CohereIntegration.setup() + patched_chat = inspect.getattr_static(V2Client, "chat") + patched_embed = inspect.getattr_static(V2Client, "embed") + patched_rerank = inspect.getattr_static(V2Client, "rerank") + + assert CohereIntegration.setup() + assert inspect.getattr_static(V2Client, "chat") is patched_chat + assert inspect.getattr_static(V2Client, "embed") is patched_embed + assert inspect.getattr_static(V2Client, "rerank") is patched_rerank + + monkeypatch.setattr(V2Client, "chat", first_chat) + monkeypatch.setattr(V2Client, "embed", first_embed) + monkeypatch.setattr(V2Client, "rerank", first_rerank) + + +def test_v2_chat_wrapper_logs_errors(memory_logger): + assert not memory_logger.pop() + + def fail(*args, **kwargs): + raise RuntimeError("sync boom") + + with pytest.raises(RuntimeError, match="sync boom"): + _v2_chat_wrapper( + fail, + None, + (), + { + "model": CHAT_MODEL, + "messages": [{"role": "user", "content": "hello"}], + }, + ) + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["input"]["messages"] == [{"role": "user", "content": "hello"}] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert "sync boom" in span["error"] + + +@pytest.mark.asyncio +async def test_v2_chat_async_wrapper_logs_errors(memory_logger): + assert not memory_logger.pop() + + async def fail(*args, **kwargs): + raise RuntimeError("async boom") + + with pytest.raises(RuntimeError, match="async boom"): + await _v2_chat_async_wrapper( + fail, + None, + (), + { + "model": CHAT_MODEL, + "messages": [{"role": "user", "content": "hello"}], + }, + ) + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["input"]["messages"] == [{"role": "user", "content": "hello"}] + assert span["metadata"]["provider"] == "cohere" + assert span["metadata"]["api_version"] == "2" + assert "async boom" in span["error"] + + +class TestAutoInstrumentCohere: + def test_auto_instrument_cohere(self): + verify_autoinstrument_script("test_auto_cohere.py") diff --git a/py/src/braintrust/integrations/cohere/tracing.py b/py/src/braintrust/integrations/cohere/tracing.py new file mode 100644 index 00000000..6b37d36f --- /dev/null +++ b/py/src/braintrust/integrations/cohere/tracing.py @@ -0,0 +1,680 @@ +"""Cohere-specific tracing helpers.""" + +import time +from collections.abc import AsyncIterator, Iterator +from numbers import Real +from typing import Any + +from braintrust.bt_json import bt_safe_deep_copy +from braintrust.logger import start_span +from braintrust.span_types import SpanTypeAttribute + + +_V2_CHAT_METADATA_KEYS = ( + "model", + "documents", + "tools", + "citation_options", + "response_format", + "safety_mode", + "max_tokens", + "stop_sequences", + "temperature", + "seed", + "frequency_penalty", + "presence_penalty", + "k", + "p", +) +_V2_EMBED_METADATA_KEYS = ( + "model", + "input_type", + "embedding_types", + "truncate", + "images", +) +_V2_RERANK_METADATA_KEYS = ( + "model", + "top_n", + "max_tokens_per_doc", +) + + +def sanitize_cohere_logged_value(value: Any) -> Any: + if hasattr(value, "model_dump"): + try: + value = value.model_dump(mode="json", by_alias=True) + except TypeError: + value = value.model_dump() + + safe = bt_safe_deep_copy(value) + + if callable(safe): + return "[Function]" + if isinstance(safe, list): + return [sanitize_cohere_logged_value(item) for item in safe] + if isinstance(safe, tuple): + return [sanitize_cohere_logged_value(item) for item in safe] + if isinstance(safe, dict): + sanitized = {} + for key, entry in safe.items(): + if entry is None: + continue + sanitized[key] = sanitize_cohere_logged_value(entry) + return sanitized + return safe + + +def _is_supported_metric_value(value: Any) -> bool: + return isinstance(value, Real) and not isinstance(value, bool) + + +def _timing_metrics(start_time: float, first_token_time: float | None = None) -> dict[str, float]: + end_time = time.time() + metrics = { + "start": start_time, + "end": end_time, + "duration": end_time - start_time, + } + if first_token_time is not None: + metrics["time_to_first_token"] = first_token_time - start_time + return metrics + + +def _normalize_usage_tokens(data: dict[str, Any], metrics: dict[str, float]) -> None: + input_tokens = data.get("input_tokens") + output_tokens = data.get("output_tokens") + + if _is_supported_metric_value(input_tokens): + metrics["prompt_tokens"] = float(input_tokens) + if _is_supported_metric_value(output_tokens): + metrics["completion_tokens"] = float(output_tokens) + + +def _metrics_from_usage_like(usage_or_meta: Any) -> dict[str, float]: + data = sanitize_cohere_logged_value(usage_or_meta) + if not isinstance(data, dict): + return {} + + metrics: dict[str, float] = {} + + tokens = data.get("tokens") + if isinstance(tokens, dict): + _normalize_usage_tokens(tokens, metrics) + + billed_units = data.get("billed_units") + if isinstance(billed_units, dict): + if "prompt_tokens" not in metrics or "completion_tokens" not in metrics: + _normalize_usage_tokens(billed_units, metrics) + for key, value in billed_units.items(): + if _is_supported_metric_value(value): + metrics[f"billed_{key}"] = float(value) + + cached_tokens = data.get("cached_tokens") + if _is_supported_metric_value(cached_tokens): + metrics["cached_tokens"] = float(cached_tokens) + + if "tokens" not in metrics: + if "prompt_tokens" in metrics and "completion_tokens" in metrics: + metrics["tokens"] = metrics["prompt_tokens"] + metrics["completion_tokens"] + elif "prompt_tokens" in metrics: + metrics["tokens"] = metrics["prompt_tokens"] + elif "completion_tokens" in metrics: + metrics["tokens"] = metrics["completion_tokens"] + + return metrics + + +def _merge_metrics(start_time: float, usage_or_meta: Any, first_token_time: float | None = None) -> dict[str, float]: + return { + **_timing_metrics(start_time, first_token_time), + **_metrics_from_usage_like(usage_or_meta), + } + + +def _build_metadata( + kwargs: dict[str, Any], + keys: tuple[str, ...], + *, + stream: bool | None = None, +) -> dict[str, Any]: + metadata = { + "provider": "cohere", + "api_version": "2", + } + + for key in keys: + value = kwargs.get(key) + if value is None: + continue + metadata[key] = sanitize_cohere_logged_value(value) + + if stream is not None: + metadata["stream"] = stream + + return metadata + + +def _chat_input(kwargs: dict[str, Any]) -> dict[str, Any]: + span_input = { + "messages": sanitize_cohere_logged_value(kwargs.get("messages")), + } + for key in ("documents", "tools"): + value = kwargs.get(key) + if value is not None: + span_input[key] = sanitize_cohere_logged_value(value) + return span_input + + +def _embed_input(kwargs: dict[str, Any]) -> Any: + span_input = {} + for key in ("texts", "images", "inputs"): + value = kwargs.get(key) + if value is not None: + span_input[key] = sanitize_cohere_logged_value(value) + if len(span_input) == 1: + return next(iter(span_input.values())) + return span_input + + +def _rerank_input(kwargs: dict[str, Any]) -> dict[str, Any]: + return { + "query": kwargs.get("query"), + "documents": sanitize_cohere_logged_value(kwargs.get("documents")), + } + + +def _start_span(name: str, span_input: Any, metadata: dict[str, Any]): + return start_span( + name=name, + type=SpanTypeAttribute.LLM, + input=sanitize_cohere_logged_value(span_input), + metadata=metadata, + ) + + +def _response_metadata(response: Any, *, finish_reason: Any | None = None) -> dict[str, Any]: + data = sanitize_cohere_logged_value(response) + if not isinstance(data, dict): + return {} + + metadata = {} + if data.get("id") is not None: + metadata["id"] = data["id"] + if data.get("finish_reason") is not None: + metadata["finish_reason"] = data["finish_reason"] + elif finish_reason is not None: + metadata["finish_reason"] = finish_reason + + meta = data.get("meta") + if isinstance(meta, dict): + warnings = meta.get("warnings") + if warnings: + metadata["warnings"] = warnings + + return metadata + + +def _chat_output(response: Any) -> Any: + data = sanitize_cohere_logged_value(response) + if not isinstance(data, dict): + return data + return data.get("message") or data + + +def _embed_output(response: Any) -> dict[str, Any]: + data = sanitize_cohere_logged_value(response) + if not isinstance(data, dict): + return {"embeddings_count": 0, "embedding_length": None} + + output = { + "embeddings_count": 0, + "embedding_length": None, + } + embeddings = data.get("embeddings") + if isinstance(embeddings, dict): + embedding_types = [] + for key, value in embeddings.items(): + if not value: + continue + embedding_types.append(key.rstrip("_")) + if output["embeddings_count"] == 0 and isinstance(value, list): + output["embeddings_count"] = len(value) + first = value[0] if value else None + if isinstance(first, list): + output["embedding_length"] = len(first) + if embedding_types: + output["embedding_types"] = embedding_types + return output + + +def _rerank_output(response: Any) -> Any: + data = sanitize_cohere_logged_value(response) + if not isinstance(data, dict): + return data + return data.get("results") or data + + +def _log_and_end( + span: Any, + *, + output: Any = None, + metrics: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, +): + event = {} + if output is not None: + event["output"] = output + if metrics: + event["metrics"] = metrics + if metadata: + event["metadata"] = metadata + if event: + span.log(**event) + span.end() + + +def _log_error_and_end(span: Any, error: Exception): + span.log(error=error) + span.end() + + +def _call_with_error_logging(span: Any, wrapped: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + try: + return wrapped(*args, **kwargs) + except Exception as error: + _log_error_and_end(span, error) + raise + + +async def _call_async_with_error_logging( + span: Any, + wrapped: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + try: + return await wrapped(*args, **kwargs) + except Exception as error: + _log_error_and_end(span, error) + raise + + +def _finalize_response( + span: Any, + *, + output: Any, + response: Any, + request_metadata: dict[str, Any], + start_time: float, + usage_or_meta: Any, +): + _log_and_end( + span, + output=output, + metrics=_merge_metrics(start_time, usage_or_meta), + metadata={ + **request_metadata, + **_response_metadata(response), + }, + ) + + +def _append_text_content(content_by_index: dict[int, dict[str, Any]], index: int, delta: dict[str, Any]) -> None: + item = content_by_index.setdefault(index, {"type": "text", "text": ""}) + if delta.get("type") is not None: + item["type"] = delta["type"] + text = delta.get("text") + if isinstance(text, str): + item["text"] = f"{item.get('text', '')}{text}" + + +def _merge_tool_call(target: dict[str, Any], delta: dict[str, Any]) -> None: + for key in ("id", "type"): + value = delta.get(key) + if value is not None: + target[key] = value + + function = delta.get("function") + if not isinstance(function, dict): + return + + target_function = target.setdefault("function", {"name": "", "arguments": ""}) + name = function.get("name") + if isinstance(name, str): + target_function["name"] = f"{target_function.get('name', '')}{name}" + arguments = function.get("arguments") + if isinstance(arguments, str): + target_function["arguments"] = f"{target_function.get('arguments', '')}{arguments}" + + +def _aggregate_chat_stream(chunks: list[Any]) -> tuple[dict[str, Any], Any, dict[str, Any]]: + message = {"role": "assistant", "content": []} + content_by_index: dict[int, dict[str, Any]] = {} + tool_calls: dict[int, dict[str, Any]] = {} + response_id = None + finish_reason = None + usage = None + + for chunk in chunks: + data = sanitize_cohere_logged_value(chunk) + if not isinstance(data, dict): + continue + + chunk_type = data.get("type") + if chunk_type == "message-start": + response_id = data.get("id") or response_id + delta_message = (data.get("delta") or {}).get("message") or {} + role = delta_message.get("role") + if role is not None: + message["role"] = role + elif chunk_type == "content-start": + index = int(data.get("index", 0) or 0) + content = ((data.get("delta") or {}).get("message") or {}).get("content") or {} + if isinstance(content, dict): + _append_text_content(content_by_index, index, content) + elif chunk_type == "content-delta": + index = int(data.get("index", 0) or 0) + content = ((data.get("delta") or {}).get("message") or {}).get("content") or {} + if isinstance(content, dict): + _append_text_content(content_by_index, index, content) + elif chunk_type == "tool-plan-delta": + delta_message = (data.get("delta") or {}).get("message") or {} + tool_plan = delta_message.get("tool_plan") + if isinstance(tool_plan, str): + message["tool_plan"] = f"{message.get('tool_plan', '')}{tool_plan}" + elif chunk_type in ("tool-call-start", "tool-call-delta"): + index = int(data.get("index", 0) or 0) + tool_call = tool_calls.setdefault(index, {"function": {"name": "", "arguments": ""}}) + delta_tool_call = ((data.get("delta") or {}).get("message") or {}).get("tool_calls") or {} + if isinstance(delta_tool_call, dict): + _merge_tool_call(tool_call, delta_tool_call) + elif chunk_type == "message-end": + delta = data.get("delta") or {} + finish_reason = delta.get("finish_reason") or finish_reason + usage = delta.get("usage") or usage + + if content_by_index: + message["content"] = [content_by_index[index] for index in sorted(content_by_index)] + if tool_calls: + message["tool_calls"] = [tool_calls[index] for index in sorted(tool_calls)] + + metadata = {} + if response_id is not None: + metadata["id"] = response_id + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + return message, usage, metadata + + +class _TracedSyncChatStream: + def __init__(self, stream: Any, span: Any, metadata: dict[str, Any], start_time: float): + self._stream = stream + self._span = span + self._metadata = metadata + self._start_time = start_time + self._first_token_time = None + self._items: list[Any] = [] + self._closed = False + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + try: + item = next(self._stream) + except StopIteration: + self._finalize() + raise + except Exception as error: + self._finalize(error=error) + raise + + if self._first_token_time is None and getattr(item, "type", None) in ( + "content-delta", + "tool-plan-delta", + "tool-call-start", + "tool-call-delta", + ): + self._first_token_time = time.time() + self._items.append(item) + return item + + def _finalize(self, *, error: Exception | None = None): + if self._closed: + return + self._closed = True + + if error is not None: + _log_error_and_end(self._span, error) + return + + output, usage, response_metadata = _aggregate_chat_stream(self._items) + _log_and_end( + self._span, + output=output, + metrics=_merge_metrics(self._start_time, usage, self._first_token_time), + metadata={**self._metadata, **response_metadata}, + ) + + +class _TracedAsyncChatStream: + def __init__(self, stream: Any, span: Any, metadata: dict[str, Any], start_time: float): + self._stream = stream + self._span = span + self._metadata = metadata + self._start_time = start_time + self._first_token_time = None + self._items: list[Any] = [] + self._closed = False + + def __aiter__(self) -> AsyncIterator[Any]: + return self + + async def __anext__(self) -> Any: + try: + item = await self._stream.__anext__() + except StopAsyncIteration: + self._finalize() + raise + except Exception as error: + self._finalize(error=error) + raise + + if self._first_token_time is None and getattr(item, "type", None) in ( + "content-delta", + "tool-plan-delta", + "tool-call-start", + "tool-call-delta", + ): + self._first_token_time = time.time() + self._items.append(item) + return item + + def _finalize(self, *, error: Exception | None = None): + if self._closed: + return + self._closed = True + + if error is not None: + _log_error_and_end(self._span, error) + return + + output, usage, response_metadata = _aggregate_chat_stream(self._items) + _log_and_end( + self._span, + output=output, + metrics=_merge_metrics(self._start_time, usage, self._first_token_time), + metadata={**self._metadata, **response_metadata}, + ) + + +def _v2_chat_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_CHAT_METADATA_KEYS) + span = _start_span("cohere.chat", _chat_input(kwargs), request_metadata) + start_time = time.time() + result = _call_with_error_logging(span, wrapped, args, kwargs) + _finalize_response( + span, + output=_chat_output(result), + response=result, + request_metadata=request_metadata, + start_time=start_time, + usage_or_meta=getattr(result, "usage", None), + ) + return result + + +async def _v2_chat_async_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_CHAT_METADATA_KEYS) + span = _start_span("cohere.chat", _chat_input(kwargs), request_metadata) + start_time = time.time() + result = await _call_async_with_error_logging(span, wrapped, args, kwargs) + _finalize_response( + span, + output=_chat_output(result), + response=result, + request_metadata=request_metadata, + start_time=start_time, + usage_or_meta=getattr(result, "usage", None), + ) + return result + + +def _v2_chat_stream_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_CHAT_METADATA_KEYS, stream=True) + span = _start_span("cohere.chat", _chat_input(kwargs), request_metadata) + start_time = time.time() + result = _call_with_error_logging(span, wrapped, args, kwargs) + return _TracedSyncChatStream(result, span, request_metadata, start_time) + + +def _v2_chat_stream_async_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_CHAT_METADATA_KEYS, stream=True) + span = _start_span("cohere.chat", _chat_input(kwargs), request_metadata) + start_time = time.time() + try: + result = wrapped(*args, **kwargs) + except Exception as error: + _log_error_and_end(span, error) + raise + return _TracedAsyncChatStream(result, span, request_metadata, start_time) + + +def _v2_embed_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_EMBED_METADATA_KEYS) + span = _start_span("cohere.embed", _embed_input(kwargs), request_metadata) + start_time = time.time() + result = _call_with_error_logging(span, wrapped, args, kwargs) + _finalize_response( + span, + output=_embed_output(result), + response=result, + request_metadata=request_metadata, + start_time=start_time, + usage_or_meta=getattr(result, "meta", None), + ) + return result + + +async def _v2_embed_async_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_EMBED_METADATA_KEYS) + span = _start_span("cohere.embed", _embed_input(kwargs), request_metadata) + start_time = time.time() + result = await _call_async_with_error_logging(span, wrapped, args, kwargs) + _finalize_response( + span, + output=_embed_output(result), + response=result, + request_metadata=request_metadata, + start_time=start_time, + usage_or_meta=getattr(result, "meta", None), + ) + return result + + +def _v2_rerank_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_RERANK_METADATA_KEYS) + span = _start_span("cohere.rerank", _rerank_input(kwargs), request_metadata) + start_time = time.time() + result = _call_with_error_logging(span, wrapped, args, kwargs) + _finalize_response( + span, + output=_rerank_output(result), + response=result, + request_metadata=request_metadata, + start_time=start_time, + usage_or_meta=getattr(result, "meta", None), + ) + return result + + +async def _v2_rerank_async_wrapper(wrapped, instance, args, kwargs): + request_metadata = _build_metadata(kwargs, _V2_RERANK_METADATA_KEYS) + span = _start_span("cohere.rerank", _rerank_input(kwargs), request_metadata) + start_time = time.time() + result = await _call_async_with_error_logging(span, wrapped, args, kwargs) + _finalize_response( + span, + output=_rerank_output(result), + response=result, + request_metadata=request_metadata, + start_time=start_time, + usage_or_meta=getattr(result, "meta", None), + ) + return result + + +def _is_v2_client_instance(client: Any) -> bool: + return any( + base.__name__ in {"ClientV2", "AsyncClientV2", "V2Client", "AsyncV2Client"} for base in type(client).__mro__ + ) + + +def _is_async_v2_client_instance(client: Any) -> bool: + return any(base.__name__ in {"AsyncClientV2", "AsyncV2Client"} for base in type(client).__mro__) + + +def _wrap_v2_target(target: Any) -> None: + from .patchers import ( + _V2ChatAsyncPatcher, + _V2ChatPatcher, + _V2ChatStreamAsyncPatcher, + _V2ChatStreamPatcher, + _V2EmbedAsyncPatcher, + _V2EmbedPatcher, + _V2RerankAsyncPatcher, + _V2RerankPatcher, + ) + + if _is_async_v2_client_instance(target): + patchers = ( + _V2ChatAsyncPatcher, + _V2ChatStreamAsyncPatcher, + _V2EmbedAsyncPatcher, + _V2RerankAsyncPatcher, + ) + else: + patchers = ( + _V2ChatPatcher, + _V2ChatStreamPatcher, + _V2EmbedPatcher, + _V2RerankPatcher, + ) + + for patcher in patchers: + patcher.wrap_target(target) + + +def wrap_cohere(client: Any) -> Any: + """Wrap a single Cohere client or Cohere V2 client instance for tracing.""" + if _is_v2_client_instance(client): + _wrap_v2_target(client) + return client + + nested_v2 = getattr(client, "v2", None) + if nested_v2 is not None: + _wrap_v2_target(nested_v2) + + return client