Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: add parameter boto3_session for AWS DynamoDB cross account use cases #10326

Merged
merged 3 commits into from
Sep 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions libs/langchain/langchain/memory/chat_message_histories/dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

from langchain.schema import (
BaseChatMessageHistory,
Expand All @@ -11,6 +13,9 @@
messages_to_dict,
)

if TYPE_CHECKING:
from boto3.session import Session

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -42,21 +47,34 @@ def __init__(
endpoint_url: Optional[str] = None,
primary_key_name: str = "SessionId",
key: Optional[Dict[str, str]] = None,
boto3_session: Optional[Session] = None,
):
import boto3

if endpoint_url:
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
if boto3_session:
client = boto3_session.resource("dynamodb")
else:
client = boto3.resource("dynamodb")
try:
import boto3
except ImportError as e:
raise ImportError(
"Unable to import boto3, please install with `pip install boto3`."
) from e
if endpoint_url:
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
else:
client = boto3.resource("dynamodb")
self.table = client.Table(table_name)
self.session_id = session_id
self.key: Dict = key or {primary_key_name: session_id}

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from DynamoDB"""
from botocore.exceptions import ClientError
try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e

response = None
try:
Expand All @@ -77,7 +95,12 @@ def messages(self) -> List[BaseMessage]: # type: ignore

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB"""
from botocore.exceptions import ClientError
try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e

messages = messages_to_dict(self.messages)
_message = _message_to_dict(message)
Expand All @@ -90,7 +113,12 @@ def add_message(self, message: BaseMessage) -> None:

def clear(self) -> None:
"""Clear session memory from DynamoDB"""
from botocore.exceptions import ClientError
try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e

try:
self.table.delete_item(self.key)
Expand Down