In [1]:
import gc
import torch
import time

from pydantic import BaseModel, Field
from typing import Literal, List, Dict, Any, Tuple
from IPython.display import Image, display

from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_classic.retrievers.document_compressors import CrossEncoderReranker
from langgraph.graph import END, StateGraph

In [2]:
loader = PyPDFLoader('./data/투자설명서.pdf')

In [3]:
embeddings = HuggingFaceEmbeddings(model='BAAI/bge-m3', model_kwargs={'device':'cuda'}, encode_kwargs={'batch_size':8})

In [4]:
docs = loader.load()
full_text = '\n\n'.join(doc.page_content for doc in docs)
text_splitter = SemanticChunker(embeddings=embeddings)
docs = text_splitter.create_documents([full_text])
for doc in docs:
    doc.metadata['source'] = '투자설명서.pdf'
print(len(docs))

243


In [5]:
gc.collect()
torch.cuda.empty_cache()
time.sleep(3)

In [6]:
faiss_store = FAISS.from_documents(docs, embedding=embeddings)
persist_dir = './data/faiss_index_dense'
faiss_store.save_local(persist_dir)

In [7]:
vectorstore = FAISS.load_local(persist_dir, embeddings=embeddings, allow_dangerous_deserialization=True)
vectorstore

<langchain_community.vectorstores.faiss.FAISS at 0x267107a1e50>

In [8]:
gc.collect()
torch.cuda.empty_cache()
time.sleep(3)

In [9]:
llm_eval = ChatOpenAI(model='gpt-5-nano', temperature=0)
llm_gen = ChatOpenAI(model='gpt-4o-mini', temperature=0)

In [10]:
cross_encoder = HuggingFaceCrossEncoder(model_name='Dongjin-kr/ko-reranker', model_kwargs={'device':'cuda'})
reranker = CrossEncoderReranker(model=cross_encoder, top_n=3)
base_retriever = vectorstore.as_retriever()

In [11]:
compression_retriever = ContextualCompressionRetriever(
    base_compressor=reranker,
    base_retriever=base_retriever
)

In [12]:
def retrieve(state):
    print('---Retrieve---')
    
    question = state['question']
    documents = compression_retriever.invoke(question)

    return {'documents':documents, 'question':question, 'retry_count':0}

In [13]:
def grade_documents(state):
    print('---Check Relevance---')

    question = state['question']
    documents = state['documents']

    class Grade(BaseModel):
        binary_score: str = Field(description='문서가 질문과 관련이 있으면 "yes", 아니면 "no"')
    
    structured_llm_grader = llm_eval.with_structured_output(Grade)

    system = '''당신은 제공된 연관 문서가 주어진 질문과 관련이 있는지, 그리고 질문에 답하는 데 유용한 정보를 제공하는지 판단하는 것입니다.
    철저하게 검증하여 문서가 질문의 키워드나 의미를 포함하고 있다면 "yes"를, 아니라면 "no"를 출력하세요.'''
    grade_prompt = ChatPromptTemplate.from_messages(
        [('system', system), ('human', '질문: {question}\n\n문서: {document}')]
    )
    retrieval_grader = grade_prompt | structured_llm_grader

    filtered_docs = []
    for doc in documents:
        score = retrieval_grader.invoke({'question':question, 'document':doc.page_content})
        if score.binary_score == 'yes':
            print(f' -- 문서 채택: (관련성 있음)')
            filtered_docs.append(doc)
        else:
            print(f' -- 문서 기각: (관련성 없음)')
    
    return {'documents':filtered_docs, 'question':question}

In [14]:
def transform_query(state):
    print('---Transform Query---')

    question = state['question']
    documents = state['documents']
    retry_count = state.get('retry_count', 0) + 1

    system = '''당신은 사용자의 질문을 검색에 더 최적화된 형태로 다듬는 전문가입니다.
    원래 질문의 의도를 유지하면서, 더 좋은 문서를 찾을 수 있도록 질문을 수정하세요.'''
    retry_prompt = ChatPromptTemplate.from_messages(
        [('system', system), ('human', '원본 질문: {question}')]
    )
    question_rewriter = retry_prompt | llm_eval | StrOutputParser()

    better_question = question_rewriter.invoke({'question':question})

    print(f' -- 수정된 질문: {better_question}')

    return {'documents':documents, 'question':better_question, 'retry_count':retry_count}

In [15]:
def generate(state):
    print('---Generate---')

    question = state['question']
    documents = state['documents']

    prompt = ChatPromptTemplate.from_template(
        '''다음 문서들을 바탕으로 질문에 답변하세요.
        문서: {context}
        질문: {question}
        답변:'''
    )

    context = '\n\n'.join(doc.page_content for doc in documents)
    rag_chain = prompt | llm_gen | StrOutputParser()

    generation = rag_chain.invoke({'context':context, 'question':question})

    return {'documents':documents, 'question':question, 'generation':generation}