# Demo-app에 Streamlit 애플리케이션 배포

#### 유사 이미지 검색에는`test_image_1.png` / `test_image_2.png` 파일을 사용해보세요

In [None]:
%%writefile image_search_lib.py
import os
import boto3
import json
import base64
from langchain.vectorstores import FAISS
from io import BytesIO


#Bedrock을 호출하여 이미지, 텍스트 또는 둘 다에서 벡터를 가져옵니다.
def get_multimodal_vector(input_image_base64=None, input_text=None):
    session = boto3.Session(
        profile_name=os.environ.get("BWB_PROFILE_NAME")
    ) 

    bedrock = session.client(
        service_name='bedrock-runtime', 
        region_name=os.environ.get("BWB_REGION_NAME"),
        endpoint_url=os.environ.get("BWB_ENDPOINT_URL")
    )
    
    request_body = {}
    
    if input_text:
        request_body["inputText"] = input_text
        
    if input_image_base64:
        request_body["inputImage"] = input_image_base64
    
    body = json.dumps(request_body)
    
    response = bedrock.invoke_model(
    	body=body, 
    	modelId="amazon.titan-embed-image-v1", 
    	accept="application/json", 
    	contentType="application/json"
    )
    
    response_body = json.loads(response.get('body').read())
    embedding = response_body.get("embedding")
    
    return embedding

#파일에서 벡터를 생성합니다
def get_vector_from_file(file_path):
    with open(file_path, "rb") as image_file:
        input_image_base64 = base64.b64encode(image_file.read()).decode('utf8')
    
    vector = get_multimodal_vector(input_image_base64 = input_image_base64)
    
    return vector


#디렉터리에서 (경로, 벡터) 튜플 목록을 생성합니다
def get_image_vectors_from_directory(path):
    items = []

    base_path = os.path.dirname(__file__)  
    full_path = os.path.join(base_path, path)  
    
    for file in os.listdir(full_path):
        file_path = os.path.join(full_path, file)
        if os.path.isfile(file_path):  
            vector = get_vector_from_file(file_path)
            items.append((file_path, vector))
    
    return items

#애플리케이션에서 사용할 인메모리 벡터 저장소를 생성하고 반환합니다
def get_index(): 
    # Pass the relative path to 'images' directory correctly
    image_vectors = get_image_vectors_from_directory("images")  # 'images' is directly inside the directory where this script is located
    
    text_embeddings = [("", item[1]) for item in image_vectors]
    metadatas = [{"image_path": item[0]} for item in image_vectors]
    
    index = FAISS.from_embeddings(
        text_embeddings=text_embeddings,
        embedding=None,
        metadatas=metadatas
    )
    
    return index

#파일 바이트에서 base64로 인코딩된 문자열 가져오기
def get_base64_from_bytes(image_bytes):
    
    image_io = BytesIO(image_bytes)
    
    image_base64 = base64.b64encode(image_io.getvalue()).decode("utf-8")
    
    return image_base64

#제공된 검색어 및/또는 검색 이미지를 기반으로 이미지 목록을 가져옵니다
def get_similarity_search_results(index, search_term=None, search_image=None):
    
    search_image_base64 = (get_base64_from_bytes(search_image) if search_image else None)

    search_vector = get_multimodal_vector(input_text=search_term, input_image_base64=search_image_base64)
    
    results = index.similarity_search_by_vector(embedding=search_vector)
    
    results_images = []
    
    for res in results: #리스트에 이미지 로드
        
        with open(res.metadata['image_path'], "rb") as f:
            img = BytesIO(f.read())
        
        results_images.append(img)
    
    
    return results_images


In [None]:
%%writefile ../demo-app.py
import streamlit as st 
import Embedding.image_search_lib as glib 


st.set_page_config(page_title="Image Search", layout="wide")
st.title("Image Search") 


if 'vector_index' not in st.session_state:
    with st.spinner("Indexing images..."): 
        st.session_state.vector_index = glib.get_index() 


search_images_tab, find_similar_images_tab = st.tabs(["Image search", "Find similar images"])

with search_images_tab:
    search_col_1, search_col_2 = st.columns(2)

    with search_col_1:
        input_text = st.text_input("Search for:") 
        search_button = st.button("Search", type="primary")

    with search_col_2:
        if search_button: 
            st.subheader("Results")
            with st.spinner("Searching..."):
                response_content = glib.get_similarity_search_results(index=st.session_state.vector_index, search_term=input_text)
                
                for res in response_content:
                    st.image(res, width=250)


with find_similar_images_tab:
    find_col_1, find_col_2 = st.columns(2)

    with find_col_1:
    
        uploaded_file = st.file_uploader("Select an image", type=['png', 'jpg'])
        
        if uploaded_file:
            uploaded_image_preview = uploaded_file.getvalue()
            st.image(uploaded_image_preview)
    
        find_button = st.button("Find", type="primary") #기본 버튼 표시

    with find_col_2:
        if find_button: 
            st.subheader("Results")
            with st.spinner("Finding..."): 
                response_content = glib.get_similarity_search_results(index=st.session_state.vector_index, search_image=uploaded_file.getvalue())
                
                for res in response_content:
                    st.image(res, width=250)
    