# Writer Process

---

## 1. Set API Keys

In [1]:
import os

# Load API Keys
from config.secret_keys import OPENAI_API_KEY, TAVILY_API_KEY, USER_AGENT

os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
os.environ["USER_AGENT"] = USER_AGENT

## 2. Import Modules

In [23]:
# Construct Vector DB / Create retriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

# Retrieval Grader
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

# Writer
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# define graph state
from typing import List
from typing_extensions import TypedDict

## 3. Construct Vector DB 

In [12]:
urls = [
    'https://www.mk.co.kr/news/stock/11209083', # title : 돌아온 외국인에 코스피 모처럼 ‘활짝’…코스닥 700선 탈환
    'https://www.mk.co.kr/news/stock/11209254', # title : 힘 못받는 증시에 밸류업 ETF 두 달째 마이너스 수익률
    'https://www.mk.co.kr/news/stock/11209229', # title : 서학개미 한 달간 1조원 샀는데···테슬라 400달러 붕괴
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=300, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)

vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name = 'rag-chroma',
    embedding=OpenAIEmbeddings(),
)

retriever = vectorstore.as_retriever()

## 4. Define Agents

### 4-1. Retrieval Grader

In [14]:
class GradeDocuments(BaseModel):
    """
    Binary score for relevance check on retrieved documents.
    """

    relevance_score : str = Field(
        description="Document are relevant to the question, 'yes' or 'no'"
    )

In [None]:
llm = ChatOpenAI(
    model = "gpt-4o-mini",
    temperature=0
)
structured_llm = llm.with_structured_output(GradeDocuments)

system_prompt = """
    You are a grader assessing relevance of a retrieved document to a user question. \n 

    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n

    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n

    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
"""

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "Retreived documents : \n\n {document} \n\n User question : {question}")
    ]
)

retrieval_grader = grade_prompt | structured_llm

### 4-2. Writer

In [16]:
write_prompt = hub.pull("rlm/rag-prompt")

llm = ChatOpenAI(
    model = "gpt-4o-mini",
    temperature=0
)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

writer = write_prompt | llm | StrOutputParser()



### 4-3. Hallucination Grader

In [18]:
class GradeHallucination(BaseModel):
    """
    Binary Score for hallunination present in generation answer.
    """

    hallucination_score : str = Field(
        description="Answer is grounded in the facts, 'yes' or 'no'"
    )

In [21]:
llm = ChatOpenAI(
    model = "gpt-4o-mini",
    temperature=0
)
structured_llm = llm.with_structured_output(GradeHallucination)

system_prompt = """
    You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n 

    Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.
"""
hallucination_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "Set of facts : \n\n {documents} \n\n LLM generation : {generation}")
    ]
)

hallucination_grader = hallucination_prompt | structured_llm

### 4-4. Answer Grader

In [19]:
class AnswerGrader(BaseModel):
    """
    Binary Score to assess answer address question.
    """

    answer_score : str = Field(
        description="Answer address the question, 'yes' or 'no'"
    )

In [20]:
llm = ChatOpenAI(
    model='gpt-4o-mini',
    temperature=0
)
structured_llm = llm.with_structured_output(AnswerGrader)

system_prompt = """
    You are a grader assessing whether an answer addresses / resolves a question \n 
    
    Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.
"""
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "User question : \n\n {question} \n\n LLM generation : {generation}")
    ]
)

answer_grader = answer_prompt | structured_llm

### 4-5. Query rewriter

In [22]:
llm = ChatOpenAI(
    model = "gpt-4o-mini",
    temperature=0
)

system_prompt = """
    You a question re-writer that converts an input question to a better version that is optimized for vectorstore retrieval. \n
    
    Look at the input and try to reason about the underlying semantic intent / meaning.
"""
rewrite_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        (
            "human",
            "Here is the initial question : \n\n {question} \n Formulation an improved question"
        )
    ]
)

query_rewriter = rewrite_prompt | llm

## 5. Construct Graph

### 5-1. Define Graph State

In [24]:
class State(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question : str
    generation : str
    documents : List[str]

### 5-2. Define Nodes

In [31]:
# Retriever
def retriever(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("[Graph Log] RETRIEVE ...")
    question = state["question"]

    documents = retriever.get_relevant_documents(question)
    
    return {
        "documents" : documents,
        "question" : question
    }

# writer
def writer(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """

    print("[Graph Log] WRITE ...")

    question = state["question"]
    documents = state["documents"]

    generation = writer.invoke({"context" : documents, "question" : question})

    return {
        "documents" : documents,
        "question" : question,
        "generation" : generation
    }

def 