In [None]:
!pip install --quiet langchain langchain_community langchain_aws gradio

In [None]:
import os
import sys

ROOT_PATH = os.path.abspath("../")
sys.path.append(ROOT_PATH)

In [None]:
import pandas as pd
import os
import json
import pandas as pd
from functools import wraps
from PIL import Image
from io import BytesIO

import boto3
import sagemaker
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from IPython.display import display, HTML

from common.aws.embedding import BedrockEmbedding
from common.aws.claude import BedrockClaude
from common.utils.images import encode_image_base64, encode_image_base64_from_file, display_image

In [None]:
with open("oss_policies_info.json", "r") as f:
    saved_data = json.load(f)

host = saved_data["opensearch_host"]
index_name = saved_data["opensearch_index_name"]
vector_store_name = saved_data["vector_store_name"]
encryption_policy = saved_data["encryption_policy"]
network_policy = saved_data["network_policy"]
access_policy = saved_data["access_policy"]

print(f"""OpenSearch Host: {host}\n \
        Index Name    : {index_name}""")

print(f"""Vector Store Name: {vector_store_name}\n \
        Encryption Policy: {encryption_policy}\n \
        Network Policy   : {network_policy} \n \
        Access Policy    : {access_policy}""")

In [None]:
## Initialize boto3 session ## 
boto3_session = boto3.session.Session(region_name='us-west-2')
print(f'The notebook will use aws services hosted in {boto3_session.region_name} region')

# initialize boto3 clients for required AWS services
sts_client = boto3_session.client('sts')
s3_client = boto3_session.client('s3')
opensearchservice_client = boto3_session.client('opensearchserverless')

service = 'aoss'
credentials = boto3.Session().get_credentials()
awsauth = AWSV4SignerAuth(credentials, boto3_session.region_name, service)

# initiailize a SageMaker role ARN 
sagemaker_role_arn = sagemaker.get_execution_role()

bedrock_embedding = BedrockEmbedding(region=boto3_session.region_name)

oss_client = OpenSearch(
    hosts=[{'host': host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=600
)

## 검색 결과 시각화를 위한 함수

In [None]:
def show_search_results(documents):
    data = []
    for _, doc in enumerate(documents):
        source = doc.get('_source', {})
        metadata = source.get('metadata', {})
        score = doc.get('_score', 0)

        img_res = s3_client.get_object(
            Bucket="amazon-berkeley-objects",
            Key=f"images/small/{metadata.get('image_url', '')}"
        )
                
        img = Image.open(BytesIO(img_res['Body'].read()))
        img_base64 = encode_image_base64(img) if img else ''

        data.append({
            'thumbnail': f'<img src="data:image/png;base64,{img_base64}" width="50" height="50">' if img_base64 else '',
            'item_name': metadata.get('item_name', ''),
            'item_id': metadata.get('item_id', ''),
            'image_url': metadata.get('image_url', ''),
            'description': source.get('description', ''),
            'score': score
        })
    
    df = pd.DataFrame(data)
    display(HTML(df.to_html(escape=False, index=False)))


def log_query_and_results(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        text = kwargs.get('text', None) if 'text' in kwargs else None
        image = kwargs.get('image', None) if 'image' in kwargs else None

        print('======= Query ========')
        print(text)
        if image:
            display_image(image)

        docs = func(*args, **kwargs)

        print('======= Results ========')
        show_search_results(docs)
        
        return docs
    
    return wrapper


## Vector 검색 테스트

In [None]:
""" 
Function for semantic search capability using knn on input query prompt.
"""
@log_query_and_results
def find_similar_items(oss_client,
                        bedrock_embedding: BedrockEmbedding,
                        index_name: str,
                        k: int = 5,
                        text: str = None,
                        image: str = None):
    """
    Main semantic search capability using knn on input query prompt.
    Args:
        k: number of top-k similar vectors to retrieve from OpenSearch index
        num_results: number of the top-k similar vectors to retrieve
        index_name: index name in OpenSearch
    """

    query_emb = bedrock_embedding.embedding_multimodal(text=text, image=image)

    body = {
        "size": k,
        "_source": {
            "exclude": ["image_vector"],
        },
        "query": {
            "knn": {
                "image_vector": {
                    "vector": query_emb,
                    "k": k,
                }
            }
        },
    }

    res = oss_client.search(index=index_name, body=body)
    return res["hits"]["hits"]

### 텍스트 검색

In [None]:
query_text = "leather sofa"

docs = find_similar_items(
    oss_client,
    bedrock_embedding,
    index_name=index_name,
    text=query_text,
)

In [None]:
query_text = "glass cup"

docs = find_similar_items(
    oss_client,
    bedrock_embedding,
    index_name=index_name,
    text=query_text,
)

### 이미지 검색

In [None]:
# base64 image
query_image = encode_image_base64_from_file(file_path="./sample/rug.jpg", format="JPEG")

docs = find_similar_items(
    oss_client,
    bedrock_embedding,
    index_name=index_name,
    image=query_image,
)

In [None]:
# base64 image
query_image = encode_image_base64_from_file(file_path="./sample/table.jpg", format="JPEG")

docs = find_similar_items(
    oss_client,
    bedrock_embedding,
    index_name=index_name,
    image=query_image,
)

### Multimodal LLM을 통해 이미지를 해석하고 텍스트로 이미지 검색

In [None]:
claude = BedrockClaude()
display_image(query_image)

query_text = """
You are an expert at analyzing images in great detail. Your task is to carefully examine the provided \
image and generate a detailed, accurate textual description capturing all of the important elements and \
context present in the image. Pay close attention to any numbers, data, or quantitative information visible, \
and be sure to include those numerical values along with their semantic meaning in your description. \
Thoroughly read and interpret the entire image before providing your detailed caption describing the \
image content in text format. Strive for a truthful and precise representation of what is depicted"""

gen_text = claude.invoke_llm_response(text=query_text, image=query_image)
print(gen_text)

In [None]:
docs = find_similar_items(
    oss_client,
    bedrock_embedding,
    index_name=index_name,
    text=gen_text,
)

In [None]:
import asyncio
import gradio as gr
import base64
import json
from langchain.prompts import PromptTemplate
from langchain.callbacks import AsyncIteratorCallbackHandler
from common.utils.images import encode_image_base64_from_file, encode_image_base64
from common.aws.claude import BedrockClaude


PRODUCT_KEYWORD_SUGGESTION_TEMPLATE = '''
You are an expert in suggesting product search terms. You suggest search terms in english for products that users want through conversations and questions with users.
The output should be formatted as a JSON instance. Just answer json without any other explanation. Format the response as a JSON object with a key: "keyword"

Here are the conversation between an Assistant and a user:
<conversations>
{conversations}
</conversations>

Here is a question from user:
<question>
{question}
</question>
'''

PRODUCT_SEARCH_TEMPLATE = """
You are an expert who finds the products users want. Based on the user's question in <question>, refer to <information> and recommend a product to the user.
If there is no accurate information, answer that the product does not exist.

Information related to what the user requested is here:
<information>
{information}
</information>

Here is a question from user:
<question>
{question}
</question>
"""

claude = BedrockClaude()


def add_message(history, message):
    for x in message["files"]:
        history.append({"role": "user", "content": {"path": x}})

    if message["text"] is not None:
        history.append({"role": "user", "content": message["text"]})
        
    return history, gr.MultimodalTextbox(value=None, interactive=False)

async def bot(history: list):
    chat = claude.get_chat_model()

    # extract a last message
    last_message = history[-1]["content"]
    last_message = "관련 상품 추천해줘" if len(last_message) == 0 else last_message

    # extract a last image
    image_path = None
    for msg in reversed(history):
        if msg["role"] == "assistant":
            break
        if msg["role"] == "user" and isinstance(msg["content"], tuple):
            image_path = msg["content"][0]
            break
    image = encode_image_base64_from_file(image_path) if image_path is not None else None

    # find search keyword
    search_keyword = ""
    try:
        res = claude.invoke_llm_response(
            text=PromptTemplate(
                template=PRODUCT_KEYWORD_SUGGESTION_TEMPLATE,
                input_variables=["question", "conversations"]).format(question=last_message, conversations=""),
            image=image
        )
        
        search_keyword = json.loads(res).get('keyword', '')
    except:
        pass

    print(search_keyword)

    docs = find_similar_items(
        oss_client,
        bedrock_embedding,
        index_name=index_name,
        text=search_keyword,
        image=image
    )

    # make a prompt
    text = PromptTemplate(
            template=PRODUCT_SEARCH_TEMPLATE,
            input_variables=["question", "conversations", "information"]
        ).format(question=last_message, conversations=history, information=docs)
    
    # invoke LLM
    prompt = claude.get_prompt(text=text, image=image)
    history.append({"role": "assistant", "content": ""})
    async for chunk in chat.astream(prompt):
        history[-1]["content"] += chunk.content
        yield history    

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages")

    chat_input = gr.MultimodalTextbox(
        interactive=True,
        file_count="single", # multiple
        placeholder="Enter message or upload file...",
        show_label=False,
    )

    chat_msg = chat_input.submit(
        add_message, [chatbot, chat_input], [chatbot, chat_input]
    )
    bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
    bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])


if __name__ == "__main__":
    demo.launch(share=True)