<a href="https://colab.research.google.com/github/choki0715/lecture/blob/master/rag_openai_poc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI  # OpenAI로 변경
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
import streamlit as st
import tempfile

# OpenAI API 키 설정
os.environ['OPENAI_API_KEY'] = ""  # 실제 API 키로 변경

###########################################################################
# streamlit UI
# webpage 규격 및 디자인
st.set_page_config(layout="wide")
st.markdown("<h1 style='text-align: center; color: gray;'>인구감소 대응 정책을 위한 RAG System</h1>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center; color: gray;'> Chat with PDFs that you upload </h2>", unsafe_allow_html=True)
st.markdown("<h5 style='text-align: center; color: gray;'>OpenAI GPT 기반 (빠름)</h5>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: center; color: gray;'> AIDENTIFY Inc.</h3>", unsafe_allow_html=True)

# streamlit UI
###########################################################################

def document_data(query, chat_history, vectorstore):
    memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)

    # OpenAI LLM 사용 (빠름)
    llm = ChatOpenAI(
        model="gpt-3.5-turbo",  # 또는 "gpt-4o-mini" (더 저렴)
        temperature=0.7
    )

    qna = ConversationalRetrievalChain.from_llm(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(),
        memory=memory
    )

    return qna({"question": query, "chat_history": chat_history})

if __name__ == '__main__':
    # OpenAI 임베딩 사용 (매우 빠름)
    embeddings = OpenAIEmbeddings(
        model="text-embedding-3-small"  # 빠르고 저렴
    )

    uploaded_files = st.file_uploader("Choose a PDF file", type="pdf", accept_multiple_files=True)

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=100,
        separators=["\n\n", "\n", " ", ""]
    )

    temp_dir = tempfile.TemporaryDirectory()
    docs = []

    for i, uploaded_file in enumerate(uploaded_files):
        st.write("filename:", i, uploaded_file.name)
        temp_filepath = os.path.join(temp_dir.name, uploaded_file.name)
        with open(temp_filepath, "wb") as f:
            f.write(uploaded_file.getvalue())

        loader = PyPDFLoader(temp_filepath)
        docs.extend(loader.load())

    if docs:
        with st.spinner("🔨 벡터 데이터베이스를 구축하는 중..."):
            split_text = text_splitter.split_documents(documents=docs)
            vectorstore = FAISS.from_documents(split_text, embedding=embeddings)

        st.success("✅ PDF 업로드 및 벡터 데이터베이스 구축 완료!")

        st.header(':blue[질문을 입력해 주세요]', divider='rainbow')
        prompt = st.chat_input("Enter your questions here")

        if "user_prompt_history" not in st.session_state:
            st.session_state["user_prompt_history"] = []
        if "chat_answers_history" not in st.session_state:
            st.session_state["chat_answers_history"] = []
        if "chat_history" not in st.session_state:
            st.session_state["chat_history"] = []

        if prompt:
            with st.spinner("🤔 답변을 생성하는 중..."):
                output = document_data(
                    query=prompt,
                    chat_history=st.session_state["chat_history"],
                    vectorstore=vectorstore
                )

                st.session_state["chat_answers_history"].append(output['answer'])
                st.session_state["user_prompt_history"].append(prompt)
                st.session_state["chat_history"].append((prompt, output['answer']))

        if st.session_state["chat_answers_history"]:
            for i, j in zip(st.session_state["chat_answers_history"],
                          st.session_state["user_prompt_history"]):
                message1 = st.chat_message("user")
                message1.write(j)
                message2 = st.chat_message("assistant")
                message2.write(i)