Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,15 @@

import httpx
from httpx import URL, SyncByteStream, ByteStream
from tokenizers import Tokenizer # type: ignore

from . import GenerateStreamedResponse, Generation, \
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
ApiMetaBilledUnits
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore


try:
import boto3 # type: ignore
from botocore.auth import SigV4Auth # type: ignore
from botocore.awsrequest import AWSRequest # type: ignore
AWS_DEPS_AVAILABLE = True
except ImportError:
AWS_DEPS_AVAILABLE = False

class AwsClient(Client):
def __init__(
self,
Expand All @@ -33,8 +25,6 @@ def __init__(
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
if not AWS_DEPS_AVAILABLE:
raise ImportError("AWS dependencies not available. Please install boto3 and botocore.")
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
Expand Down Expand Up @@ -183,14 +173,14 @@ def map_request_to_bedrock(
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
session = boto3.Session(
session = lazy_boto3().Session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
credentials = session.get_credentials()
signer = SigV4Auth(credentials, service, session.region_name)
signer = lazy_botocore().auth.SigV4Auth(credentials, service, session.region_name)

def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
Expand Down Expand Up @@ -220,7 +210,7 @@ def _event_hook(request: httpx.Request) -> None:
request._content = new_body
headers["content-length"] = str(len(new_body))

aws_request = AWSRequest(
aws_request = lazy_botocore().awsrequest.AWSRequest(
method=request.method,
url=url,
headers=headers,
Expand Down
68 changes: 28 additions & 40 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,7 @@
from .summary import Summary
from .mode import Mode
import typing

# Try to import sagemaker and related modules
try:
import sagemaker as sage
from sagemaker.s3 import S3Downloader, S3Uploader, parse_s3_url
import boto3
from botocore.exceptions import (
ClientError, EndpointConnectionError, ParamValidationError)
AWS_DEPS_AVAILABLE = True
except ImportError:
AWS_DEPS_AVAILABLE = False
from ..lazy_aws_deps import lazy_boto3, lazy_botocore, lazy_sagemaker

class Client:
def __init__(
Expand All @@ -37,21 +27,19 @@ def __init__(
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
if not AWS_DEPS_AVAILABLE:
raise CohereError("AWS dependencies not available. Please install boto3 and sagemaker.")
self._client = boto3.client("sagemaker-runtime", region_name=aws_region)
self._service_client = boto3.client("sagemaker", region_name=aws_region)
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = sage.Session(sagemaker_client=self._service_client)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
self.mode = Mode.SAGEMAKER



def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
self._service_client.describe_endpoint(EndpointName=endpoint_name)
except ClientError:
except lazy_botocore().ClientError:
return False
return True

Expand Down Expand Up @@ -87,7 +75,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:
# Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz
s3_tar_models = [
s3_path
for s3_path in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
for s3_path in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
if (
s3_path.endswith(".tar.gz") # only .tar.gz files
and (s3_path.split("/")[-1] != "models.tar.gz") # exclude the .tar.gz file we are creating
Expand All @@ -109,7 +97,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:
# Download and extract all fine-tuned models
for s3_tar_model in s3_tar_models:
print(f"Adding fine-tuned model: {s3_tar_model}")
S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess)
lazy_sagemaker().s3.S3Downloader.download(s3_tar_model, local_tar_models_dir, sagemaker_session=self._sess)
with tarfile.open(os.path.join(local_tar_models_dir, s3_tar_model.split("/")[-1])) as tar:
tar.extractall(local_models_dir)

Expand All @@ -120,10 +108,10 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:

# Upload the new tarfile containing all models to s3
# Very important to remove the trailing slash from s3_models_dir otherwise it just doesn't upload
model_tar_s3 = S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess)
model_tar_s3 = lazy_sagemaker().s3.S3Uploader.upload(model_tar, s3_models_dir[:-1], sagemaker_session=self._sess)

# sanity check
assert s3_models_dir + "models.tar.gz" in S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)
assert s3_models_dir + "models.tar.gz" in lazy_sagemaker().s3.S3Downloader.list(s3_models_dir, sagemaker_session=self._sess)

return model_tar_s3

Expand Down Expand Up @@ -180,17 +168,17 @@ def create_endpoint(
# Otherwise it might block deployment
try:
self._service_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
except ClientError:
except lazy_botocore().ClientError:
pass

if role is None:
try:
role = sage.get_execution_role()
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"

model = sage.ModelPackage(
model = lazy_sagemaker().ModelPackage(
role=role,
model_data=model_data,
sagemaker_session=self._sess, # makes sure the right region is used
Expand All @@ -204,7 +192,7 @@ def create_endpoint(
endpoint_name=endpoint_name,
**validation_params
)
except ParamValidationError:
except lazy_botocore().ParamValidationError:
# For at least some versions of python 3.6, SageMaker SDK does not support the validation_params
model.deploy(n_instances, instance_type, endpoint_name=endpoint_name)
self.connect_to_endpoint(endpoint_name)
Expand Down Expand Up @@ -366,7 +354,7 @@ def _sagemaker_chat(self, json_params: Dict[str, Any], variant: str) :
else:
result = self._client.invoke_endpoint(**params)
return Chat.from_dict(json.loads(result['Body'].read().decode()))
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -398,7 +386,7 @@ def _bedrock_chat(self, json_params: Dict[str, Any], model_id: str) :
result = self._client.invoke_model(**params)
return Chat.from_dict(
json.loads(result['body'].read().decode()))
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -473,7 +461,7 @@ def _sagemaker_generations(self, json_params: Dict[str, Any], variant: str) :
result = self._client.invoke_endpoint(**params)
return Generations(
json.loads(result['Body'].read().decode())['generations'])
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand All @@ -498,7 +486,7 @@ def _bedrock_generations(self, json_params: Dict[str, Any], model_id: str) :
result = self._client.invoke_model(**params)
return Generations(
json.loads(result['body'].read().decode())['generations'])
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -546,7 +534,7 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand All @@ -567,7 +555,7 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
try:
result = self._client.invoke_model(**params)
response = json.loads(result['body'].read().decode())
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -631,7 +619,7 @@ def rerank(self,
reranking = Reranking(response)
for rank in reranking.results:
rank.document = parsed_docs[rank.index]
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand All @@ -658,7 +646,7 @@ def classify(self, input: List[str], name: str) -> Classifications:
try:
result = self._client.invoke_endpoint(**params)
response = json.loads(result["Body"].read().decode())
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down Expand Up @@ -705,13 +693,13 @@ def create_finetune(

if role is None:
try:
role = sage.get_execution_role()
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"

training_parameters.update({"name": name})
estimator = sage.algorithm.AlgorithmEstimator(
estimator = lazy_sagemaker().algorithm.AlgorithmEstimator(
algorithm_arn=arn,
role=role,
instance_count=1,
Expand All @@ -734,7 +722,7 @@ def create_finetune(

current_filepath = f"{s3_models_dir}{job_name}/output/model.tar.gz"

s3_resource = boto3.resource("s3")
s3_resource = lazy_boto3().resource("s3")

# Copy new model to root of output_model_dir
bucket, old_key = parse_s3_url(current_filepath)
Expand Down Expand Up @@ -774,14 +762,14 @@ def export_finetune(

if role is None:
try:
role = sage.get_execution_role()
role = lazy_sagemaker().get_execution_role()
except ValueError:
print("Using default role: 'ServiceRoleSagemaker'.")
role = "ServiceRoleSagemaker"

export_parameters = {"name": name}

estimator = sage.algorithm.AlgorithmEstimator(
estimator = lazy_sagemaker().algorithm.AlgorithmEstimator(
algorithm_arn=arn,
role=role,
instance_count=1,
Expand All @@ -800,7 +788,7 @@ def export_finetune(
job_name = estimator.latest_training_job.name
current_filepath = f"{s3_output_dir}{job_name}/output/model.tar.gz"

s3_resource = boto3.resource("s3")
s3_resource = lazy_boto3().resource("s3")

# Copy the exported TensorRT-LLM engine to the root of s3_output_dir
bucket, old_key = parse_s3_url(current_filepath)
Expand Down Expand Up @@ -940,7 +928,7 @@ def summarize(
result = self._client.invoke_endpoint(**params)
response = json.loads(result['Body'].read().decode())
summary = Summary(response)
except EndpointConnectionError as e:
except lazy_botocore().EndpointConnectionError as e:
raise CohereError(str(e))
except Exception as e:
# TODO should be client error - distinct type from CohereError?
Expand Down
23 changes: 23 additions & 0 deletions src/cohere/manually_maintained/lazy_aws_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@


def lazy_sagemaker():
try:
import sagemaker as sage # type: ignore
return sage
except ImportError:
raise ImportError("Sagemaker not available. Please install sagemaker.")

def lazy_boto3():
try:
import boto3 # type: ignore
return boto3
except ImportError:
raise ImportError("Boto3 not available. Please install lazy_boto3().")

def lazy_botocore():
try:
import botocore # type: ignore
return botocore
except ImportError:
raise ImportError("Botocore not available. Please install botocore.")

Loading