Skip to content

Commit

Permalink
BedrockEncoder: Fix for aurelio-labs#293
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtismassey committed May 21, 2024
1 parent a8544d8 commit 5a748d0
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions semantic_router/encoders/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tiktoken
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger


class BedrockEncoder(BaseEncoder):
Expand Down Expand Up @@ -69,12 +70,14 @@ def __init__(
"""

super().__init__(name=name, score_threshold=score_threshold)
self.access_key_id = self.get_env_variable("access_key_id", access_key_id)
self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id)
self.secret_access_key = self.get_env_variable(
"secret_access_key", secret_access_key
"AWS_SECRET_ACCESS_KEY", secret_access_key
)
self.session_token = self.get_env_variable("AWS_SESSION_TOKEN", session_token)
self.region = self.get_env_variable("AWS_REGION", region, default="us-west-1")
self.region = self.get_env_variable(
"AWS_DEFAULT_REGION", region, default="us-west-1"
)

self.input_type = input_type

Expand Down Expand Up @@ -116,22 +119,24 @@ def _initialize_client(
"`pip install boto3`"
)

access_key_id = access_key_id or os.getenv("access_key_id")
aws_secret_key = secret_access_key or os.getenv("secret_access_key")
region = region or os.getenv("AWS_REGION", "us-west-2")
access_key_id = access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_key = secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
region = region or os.getenv("AWS_DEFAULT_REGION", "us-west-2")

if access_key_id is None:
raise ValueError("AWS access key ID cannot be 'None'.")

if aws_secret_key is None:
raise ValueError("AWS secret access key cannot be 'None'.")

session = boto3.Session(
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
try:
bedrock_client = boto3.client(
bedrock_client = session.client(
"bedrock-runtime",
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
region_name=region,
)
except Exception as err:
Expand All @@ -155,6 +160,8 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
ValueError: If the Bedrock Platform client is not initialized or if the
API call fails.
"""
from botocore.exceptions import ClientError

if self.client is None:
raise ValueError("Bedrock client is not initialised.")
try:
Expand Down Expand Up @@ -224,6 +231,21 @@ def chunk_strings(strings, MAX_WORDS=20):
else:
raise ValueError("Unknown model name")
return embeddings
except ClientError as error:
if error.response["Error"]["Code"] == "ExpiredTokenException":
logger.warning("Session token has expired. Retrying initialisation.")
try:
self.session_token = os.getenv("AWS_SESSION_TOKEN")
self.client = self._initialize_client(
self.access_key_id,
self.secret_access_key,
self.session_token,
self.region,
)
except Exception as e:
raise ValueError(
f"Bedrock client failed to reinitialise. Error: {e}"
) from e
except Exception as e:
raise ValueError(f"Bedrock call failed. Error: {e}") from e

Expand All @@ -246,5 +268,7 @@ def get_env_variable(var_name, provided_value, default=None):
return provided_value
value = os.getenv(var_name, default)
if value is None:
if var_name == "AWS_SESSION_TOKEN":
return None
raise ValueError(f"No {var_name} provided")
return value

0 comments on commit 5a748d0

Please sign in to comment.