# [0] 필요한 모듈 설치
 - 이 작업은 초기에 한번만 수행되면 됩니다

In [None]:
!pip install --quiet opensearch-py requests-aws4auth boto3 botocore awscli s3fs sagemaker

In [None]:
## restart the Kernel to load installed libraries 
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
HTML("<h3>Kernel Restart complete</h3>")

---

In [None]:
import os
import sys

ROOT_PATH = os.path.abspath("../")
sys.path.append(ROOT_PATH)

In [None]:
import pandas as pd
import re
import secrets
import json
import time
import base64
import numpy as np
from PIL import Image
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor, as_completed

import boto3
import sagemaker
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, helpers
from opensearchpy.exceptions import NotFoundError

from common.dataset.dataloader import DataLoader, LanguageTag
from common.aws.embedding import BedrockEmbedding
from common.utils.images import encode_image_base64, display_image

In [None]:
## Initialize boto3 session ## 
boto3_session = boto3.session.Session(region_name='us-west-2')
print(f'The notebook will use aws services hosted in {boto3_session.region_name} region')

# initialize boto3 clients for required AWS services
sts_client = boto3_session.client('sts')
s3_client = boto3_session.client('s3')
opensearchservice_client = boto3_session.client('opensearchserverless')

# initiailize a SageMaker role ARN 
sagemaker_role_arn = sagemaker.get_execution_role()

bedrock_embedding = BedrockEmbedding(region=boto3_session.region_name)

# [1] 리소스 준비

## [1-1] OpenSearchServerless 생성에 필요한 policy 생성

In [None]:
## Create Encryption, Network and access policies for OpenSearch serverless 
def create_policies_in_oss(opensearchservice_client, role_arn):
    # suffix = secrets.randbelow(800)
    suffix = 5
    vector_store_name = f'mm-image-collection-{suffix}'
    encryption_policy_name = f"mm-sample-sp-{suffix}"
    network_policy_name = f"mm-sample-np-{suffix}"
    access_policy_name = f'mm-sample-ap-{suffix}'
    identity = boto3.client('sts').get_caller_identity()['Arn']

    try:
        encryption_policy = opensearchservice_client.create_security_policy(
            name=encryption_policy_name,
            policy=json.dumps(
                {
                    'Rules': [{'Resource': ['collection/' + vector_store_name],
                               'ResourceType': 'collection'}],
                    'AWSOwnedKey': True
                }),
            type='encryption'
        )
    except Exception as ex:
        print(ex)

    try:
        network_policy = opensearchservice_client.create_security_policy(
            name=network_policy_name,
            policy=json.dumps(
                [
                    {'Rules': [{'Resource': ['collection/' + vector_store_name],
                                'ResourceType': 'collection'}],
                     'AllowFromPublic': True}
                ]),
            type='network'
        )
    except Exception as ex:
        print(ex)

    try:
        access_policy = opensearchservice_client.create_access_policy(
            name=access_policy_name,
            policy=json.dumps(
                [
                    {
                        'Rules': [
                            {
                                'Resource': ['collection/' + vector_store_name],
                                'Permission': [
                                    'aoss:CreateCollectionItems',
                                    'aoss:DeleteCollectionItems',
                                    'aoss:UpdateCollectionItems',
                                    'aoss:DescribeCollectionItems'],
                                'ResourceType': 'collection'
                            },
                            {
                                'Resource': ['index/' + vector_store_name + '/*'],
                                'Permission': [
                                    'aoss:CreateIndex',
                                    'aoss:DeleteIndex',
                                    'aoss:UpdateIndex',
                                    'aoss:DescribeIndex',
                                    'aoss:ReadDocument',
                                    'aoss:WriteDocument'],
                                'ResourceType': 'index'
                            }],
                        'Principal': [identity, role_arn],
                        'Description': 'Easy data policy'}
                ]),
            type='data'
        )
    except Exception as ex:
        print(ex)

    return vector_store_name, encryption_policy_name, network_policy_name, access_policy_name

In [None]:
vector_store_name, \
encryption_policy, \
network_policy, \
access_policy = create_policies_in_oss(opensearchservice_client=opensearchservice_client, 
                                       role_arn=sagemaker_role_arn)

print(f"""Vector Store Name: {vector_store_name}\n \
        Encryption Policy: {encryption_policy}\n \
        Network Policy   : {network_policy} \n \
        Access Policy    : {access_policy}""")

## [1-2] OpenSearchServerless에 Collection 생성
 - 이 작업은 약 10분 소요됩니다.

In [None]:
# create vector store collection
try:
    vs_collection = opensearchservice_client.create_collection(name=vector_store_name, type='VECTORSEARCH')
    vs_collection_id = vs_collection['createCollectionDetail']['id']
    print(f"Created collection in OpenSearch -> {vs_collection}\n")
except:
    response = opensearchservice_client.list_collections()
    for collection in response['collectionSummaries']:
        if collection['name'] == vector_store_name:
            print(f"Collection '{vector_store_name}' already exists.")
            vs_collection = collection
            vs_collection_id = collection.get('id', 'unknown')

In [None]:
%%time 

def wait_for_collection_creation(opensearchservice_client, collection_name, timeout=600, interval=60):
    start_time = time.time()

    while (time.time() - start_time) < timeout:
        try:
            # Fetch the list of collections
            response = opensearchservice_client.list_collections()

            # Check if the collection exists
            for collection in response['collectionSummaries']:
                if collection['name'] == collection_name:
                    status = collection['status']
                    if status == 'ACTIVE':
                        print(f"Collection '{collection_name}' is active.")
                        return True
                    else:
                        print(f"Collection '{collection_name}' status: {status}. Waiting...")

        except Exception as e:
            print(f"Collection '{collection_name}' not found: {e}")

        time.sleep(interval)

    print(f"Timeout reached: Collection '{collection_name}' was not created in {timeout} seconds.")
    return False

wait_for_collection_creation(opensearchservice_client, vector_store_name)

In [None]:
# 생성된 Collection의 host
host = vs_collection_id + '.' + boto3_session.region_name + '.aoss.amazonaws.com'
print(f"OpenSearch Host: {host}")

## [1-3] OpenSearch Index 생성
 - 이 셀이 여러 번 실행되면 index가 이미 존재한다는 오류가 발생합니다.

In [None]:
service = 'aoss'
credentials = boto3.Session().get_credentials()
awsauth = AWSV4SignerAuth(credentials, boto3_session.region_name, service)

# index 생성을 위한 스키마
# 유사한 K개의 검색을 위해 KNN을 사용합니다
# 이미지 dimension은 1024로 정의합니다
index_name = f"fireup-image-mm-index"
index_body = {
   "settings": {
      "index.knn": "true"
   },
   "mappings": {
      "properties": {
         "image_vector": {
            "type": "knn_vector",
            "dimension": 1024
         },
         "metadata": {
             "properties": {
               "item_name": {"type": "text"},
               "item_id" : {"type": "text"},
               "img_path": {"type": "text"},
               "description": {"type": "text"},
             }
         }
      }
   }
}

oss_client = OpenSearch(
    hosts=[{'host': host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=600
)

try:
    response = oss_client.indices.create(index_name, body=index_body)
    print(f"Response received for the create index -> {response}")
except Exception as e:
    print(f"Encountered error while creating index={index_name}, exception={e}")

In [None]:
%%time

def wait_for_index_creation(oss_client, index_name, timeout=600, interval=10):
    start_time = time.time()
    
    while (time.time() - start_time) < timeout:
        try:
            if oss_client.indices.exists(index=index_name):
                print(f"Index '{index_name}' has been created.")
                return True
        except Exception as e:
            print(e)

        print(f"Index '{index_name}' not created yet. Waiting for {interval} seconds...")
        time.sleep(interval)
    
    print(f"Timeout reached: Index '{index_name}' was not created in {timeout} seconds.")
    return False

# index가 정상적으로 생성됨을 확인하세요
wait_for_index_creation(oss_client, index_name)

---

# [2] 데이터셋 준비

- [Amazon Berkeley Objects Dataset](https://amazon-berkeley-objects.s3.amazonaws.com/index.html)에서 이미지와 해당 메타데이터를 추출하여 데이터셋을 준비합니다.
- `image_name`이 영문으로 된 데이터만 추출합니다. 이는 0번째 데이터셋 기준으로 1,655개 입니다.
- 데이터를 로드하는 데 약 10초 소요됩니다.

## [2-1] 데이터셋 로드

In [None]:
%%time

loader = DataLoader(index=0, language=LanguageTag.ENG)
loader.dataset

In [None]:
# 데이터셋에 포함되어 있는 ID를 랜덤 pick
item_id = loader.get_random_id()
print(f"Select a random item ID in dataset: {item_id}")

# Load된 아이템의 샘플은 다음과 같습니다.
item, img = loader.get_item(item_id=item_id)
base64img = encode_image_base64(img)

display_image(base64img)
print(json.dumps(item, indent=4, ensure_ascii=False))

## [2-2] Embedding 데이터 삽입

- 이 작업은 약 10분 소요됩니다.

In [None]:
def insert_document_in_aoss(index_name, item, image):
    document = {
        "image_vector": image,
        "metadata": {
            "item_name" : item["item_name"],
            "item_id" : item["item_id"],
            "image_url": item['img_path'],
            "description":   item["bullet_point"],
        }
    }
    oss_client.index(
        index = index_name,
        body=document,
    )

In [None]:
%%time

all_id_list = loader.get_id_list()


def insert_documents(item_id):
    item, img = loader.get_item(item_id=item_id)

    if item is None or img is None:
        print(f"Skipping item_id {item_id} due to missing data.")
        return
        
    # embedding image
    multimodal_embeddings_img = bedrock_embedding.embedding_multimodal(
        image=encode_image_base64(img)
    )

    insert_document_in_aoss(index_name, item, multimodal_embeddings_img)

# 병렬 처리를 위한 ThreadPoolExecutor 사용
with ThreadPoolExecutor(max_workers=5) as executor:
    # executor.map(insert_documents, all_id_list)
    futures = [executor.submit(insert_documents, id) for id in all_id_list]
    for future in as_completed(futures):
        try:
            future.result()
        except Exception as e:
            print(f"Error in future: {e}")

In [None]:
def get_document_count(index_name):
    try:
        response = oss_client.count(index=index_name)
        doc_count = response['count']
        print(f"Total number of documents in index '{index_name}': {doc_count}")
        return doc_count
    except Exception as e:
        print(f"Error fetching document count: {e}")
        return None

document_count = get_document_count(index_name)
print(document_count)

### 사용하는 리소스에 대한 정보 저장
 - 이후 생성한 리소스를 정리하기 위해 리소스들의 정보를 저장합니다.

In [None]:
data_to_save = {
    "opensearch_host": host,
    "opensearch_index_name": index_name,
    "vector_store_name": vector_store_name,
    "encryption_policy": encryption_policy,
    "network_policy": network_policy,
    "access_policy": access_policy
}

with open("oss_policies_info.json", "w") as f:
    json.dump(data_to_save, f)