In [None]:
import os

In [None]:
from typing import TypedDict, Annotated, List

In [None]:
# from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.pydantic_v1 import BaseModel

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import StateGraph, END

from langchain import hub

In [None]:
from tavily import TavilyClient

In [None]:
from IPython.display import Markdown, display

In [None]:
LANGUAGE = 'english'
TAVILY_MAX_RESULTS = 5
MAX_QUERIES = 5
TAVILY_API_KEY = os.environ["TAVILY_API_KEY"]

In [None]:
tavily = TavilyClient(api_key=TAVILY_API_KEY)


In [None]:
class AgentState(TypedDict):
    task: str
    plan: str
    draft: str
    translation: str
    target_language: str
    queries: List[str]
    critique: str
    content: List[str]
    revision_number: int
    max_revisions: int

In [None]:
memory = SqliteSaver.from_conn_string(":memory:")


In [None]:
model = ChatOpenAI(model='gpt-4o-mini', temperature=0, verbose=True)

In [None]:
translation_prompt = hub.pull("haha918301/translator") 

In [None]:
translation_prompt.template

In [None]:
PLAN_PROMPT = """You are an expert writer specialised on unbiased information veryfication in the context of dezinformation, conspiracy theories and propaganda. \
You are tasked with writing a high level outline of an expository article. \
Write such an outline for the user provided topic. Give an outline of the expository article along with any relevant notes \
or instructions for the sections."""

RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following expository article. Generate a list of search queries that will gather \
any relevant information. Only generate {max_queries} queries max.""".format(max_queries=MAX_QUERIES)

TRANSLATION_PROMPT = translation_prompt

DRAFT_PROMPT = """You are an expository article assistant tasked with writing excellent 5-paragraph expository article.\
Generate the best unbiased expository article possible for the user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize only the information below, don't add any new information, make up or gues anything, or change the topic. : 

------

{content}"""

REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""



RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""

In [None]:
def get_plan(state):

    messages = [
        SystemMessage(content=PLAN_PROMPT), 
        HumanMessage(content=state['task'])
    ]
    response = model.invoke(messages)
    plan = response.content
    
    return plan

def plan_node(state: AgentState):
    
    plan = get_plan(state)
    
    return {'plan': plan}

In [None]:
def get_queries(state: AgentState):
    
    class Queries(BaseModel):
        queries: List[str]
        
    messages = [
        SystemMessage(content=RESEARCH_PLAN_PROMPT), 
        HumanMessage(content=state['task'])
    ]
    queries = model.with_structured_output(Queries).invoke(messages)
    queries = queries.queries

    return queries

def get_content(state, queries):

    content = state['content'] or set()
    for q in queries:
        response = tavily.search(query=q, max_results=TAVILY_MAX_RESULTS)
        for r in response['results']:
            content.add(r['content'])
    
    return content

def research_plan_node(state: AgentState):
    
    queries = get_queries(state)
    content = get_content(state, queries)

    return {"content": content, "queries": queries}

In [None]:
def get_draft(state: AgentState):
    
    content = "<<<<<<>>>>>>".join(state['content'] or [])
    user_message = f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}"
    
    draft_promp = DRAFT_PROMPT.format(content=content)
    messages = [
        SystemMessage(content=draft_promp),
        HumanMessage(content=user_message)
    ]
    
    response = model.invoke(messages)
    draft = response.content
    return draft

def draft_node(state: AgentState):
    
    draft = get_draft(state)
    rev_num = state.get("revision_number", 1) + 1

    return {"draft": draft, "revision_number": rev_num}



In [None]:
def get_translation(state: AgentState):
    translation_promp = TRANSLATION_PROMPT.format(input_language=LANGUAGE,
                                             output_language=state['target_language'],
                                             text=state['draft']
                                             )
    messages = [
        HumanMessage(content=translation_promp)
    ]
    response = model.invoke(messages)
    translation = response.content
    
    return translation

def translate_node(state: AgentState):

    translation = get_translation(state)
    
    return {"translation": translation}

In [None]:
builder = StateGraph(AgentState)
builder.add_node("planner_node", plan_node)
builder.add_node("research_plan_node", research_plan_node)
builder.add_node("draft_node", draft_node)
builder.add_node("translate_node", translate_node)


builder.set_entry_point("planner_node")


In [None]:
builder.add_edge("planner_node", "research_plan_node")
builder.add_edge("research_plan_node", "draft_node")
builder.add_edge("draft_node", "translate_node")


In [None]:
graph = builder.compile(checkpointer=memory)

In [None]:
from IPython.display import Image

# Image(graph.get_graph().draw_png())

In [None]:
thread_num = 0

In [None]:
thread_num += 1
task = "TODO"

thread = {"configurable": {"thread_id": str(thread_num)}}
params = {
    'task': task,
    'target_language': "czech",
    "max_revisions": 2,
    "revision_number": 1,   
}

for s in graph.stream(params, thread, debug=True):
    print(s)

In [None]:
graph.get_state(thread).values['queries']

In [None]:
graph.get_state(thread).values['content']


In [None]:
display(Markdown(graph.get_state(thread).values['draft']))

In [None]:
display(Markdown(graph.get_state(thread).values['translation']))