Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import asyncio
import logging
import os
import typing
from concurrent.futures import ThreadPoolExecutor
from tokenizers import Tokenizer # type: ignore
import logging

import httpx

from cohere.types.detokenize_response import DetokenizeResponse
from cohere.types.tokenize_response import TokenizeResponse

from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
from . import EmbeddingType, EmbedInputType, EmbedRequestTruncate, EmbedResponse
from .base_client import OMIT, AsyncBaseCohere, BaseCohere
from .config import embed_batch_size
from .core import RequestOptions
from .environment import ClientEnvironment
from .manually_maintained.cache import CacheMixin
from .manually_maintained import tokenizers as local_tokenizers
from .manually_maintained.cache import CacheMixin
from .overrides import run_overrides
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils
from .utils import AsyncSdkUtils, SyncSdkUtils, async_wait, merge_embed_responses, wait
from tokenizers import Tokenizer # type: ignore

from cohere.types.detokenize_response import DetokenizeResponse
from cohere.types.tokenize_response import TokenizeResponse

logger = logging.getLogger(__name__)
run_overrides()
Expand Down Expand Up @@ -202,7 +201,7 @@ def embed(
request_options=request_options,
)

textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]

responses = [
Expand Down Expand Up @@ -394,7 +393,7 @@ async def embed(
request_options=request_options,
)

textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]

responses = typing.cast(
Expand Down Expand Up @@ -516,4 +515,7 @@ def _get_api_key_from_environment() -> typing.Optional[str]:
Retrieves the Cohere API key from specific environment variables.
CO_API_KEY is preferred (and documented) COHERE_API_KEY is accepted (but not documented).
"""
return os.getenv("CO_API_KEY", os.getenv("COHERE_API_KEY"))
api_key = os.environ.get("CO_API_KEY")
if api_key is not None:
return api_key
return os.environ.get("COHERE_API_KEY")