In [None]:
import os

os.chdir("../../../")

import asyncio
from textwrap import dedent
from typing import List, Literal, Union

from langchain.tools import BaseTool
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
from langchain_core.prompts import PromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, MessagesPlaceholder
from openai import OpenAI
from pydantic import BaseModel, Field

from src.initialization import credential_init

credential_init()

## Generate

In [None]:
from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
from langchain.memory import ChatMessageHistory
from langchain_google_genai import ChatGoogleGenerativeAI

chat_history = ChatMessageHistory()

system_message = SystemMessage(content=dedent("""\
              You are an essay assistant tasked with writing excellent 5-paragraph essays.
              Generate the best essay possible for the user's request.
              If the user provides critique, respond with a revised version of your previous attempts.
              """))


chat_prompt_template = ChatPromptTemplate.from_messages(
    [
        system_message,
        MessagesPlaceholder(variable_name="messages"),
    ]
)

"""
model: gemini-2.5-flash-lite
"""

model = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=6,
    disable_streaming=False
    # other params...
)

generate_pipeline = chat_prompt_template|model|StrOutputParser()

In [None]:
# ChatGoogleGenerativeAI?

In [None]:
chat_history.add_user_message("生成一個投放在Tiktok上的冰淇淋廣告劇本。目標群眾為8-15歲的小孩")

In [None]:
essay = generate_pipeline.invoke({"messages": chat_history.messages})


In [None]:
essay

## Reflect

In [None]:
system_message = SystemMessage(content=dedent("""\
                        你是一個資深的廣告投放諮詢，擅長於在社群軟體投放食品類廣告。你會根據送來的劇本給予建議並提出改善的方法。
                        """)
                              )

chat_prompt_template = ChatPromptTemplate.from_messages(
    [
        system_message,
        MessagesPlaceholder(variable_name="messages"),
    ]
)

reflect_pipeline = chat_prompt_template|model|StrOutputParser()

將之前生成的文章加入chat_history

In [None]:
chat_history = ChatMessageHistory()

chat_history.add_user_message(essay)

生成反饋

In [None]:
chat_history.messages

In [None]:
reflection = reflect_pipeline.invoke({"messages": chat_history.messages})

print(reflection)

將生成的反饋加入到Chat_history

In [None]:
chat_history.add_user_message(reflection)

然後回到 generate_pipeline 並且重複整個過程 

## Langgraph Workflow

## 💬 Chat History 結構 (ChatMessageHistory Structure)

在 **Generation** 階段後，`chat_history` 的結構如下：

- **HumanMessage**: 使用者指令，例如：  
  > 生成一個投放在 TikTok 上的冰淇淋廣告劇本。目標群眾為 8–15 歲的小孩  
- **AIMessage**: 模型生成的回覆內容 (`<生成的內容>`)

---

## 🔄 Reflection 階段 — 為什麼要交換角色

在 **Reflection Agent** 中，AI 會對自己剛才生成的內容進行「反思 (reflection)」。  
此時，我們希望模型 **以「使用者」的角度重新審視自己剛才的輸出**。

因此，在反思過程中，我們需要將 `chat_history` 中的訊息角色進行對調：

| 原本類型 | 在 Reflection 中變為 |
|-----------|-----------------------|
| `AIMessage` | `HumanMessage` |
| `HumanMessage` | `AIMessage` |

這樣做的目的，是讓模型把自己先前生成的回答 (`AIMessage`) 視為「使用者輸入」，  
並根據這個內容進行反思或修正。

---

## 🧩 範例

**Generation 後:**
```python
[
    HumanMessage(content="生成一個投放在 TikTok 上的冰淇淋廣告劇本。目標群眾為 8–15 歲的小孩"),
    AIMessage(content="<生成的內容>")
]

在使用reflection_pipeline時，我們要讓輸入變為


[
    HumanMessage(content="<生成的內容>"),
]

並將輸出定調為HumanMessage，方便在Generation時直接使用


In [None]:
import time

from typing import Annotated, List, Sequence
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import InMemorySaver
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage
from langchain.memory import ChatMessageHistory


MAX_ITERATION = 2

class State(TypedDict):
    messages: Annotated[list, add_messages]


async def generation_node(state: State) -> State:
    
    result = await generate_pipeline.ainvoke({"messages": state['messages']})
    
    return {"messages": AIMessage(content=result)}


async def reflection_node(state: State) -> State:
        
    cls_map = {"ai": HumanMessage, "human": AIMessage}
    
    messages = [cls_map[msg.type](content=msg.content) for msg in state["messages"][1:]]
    
    result = await reflect_pipeline.ainvoke({"messages": messages})
    
    return {"messages": HumanMessage(content=result)}


def should_continue(state: State):
    if len(state["messages"]) > MAX_ITERATION:
        # End after 3 iterations
        return END
    return "reflection_node"


workflow = StateGraph(State)

workflow.add_node("generation_node", generation_node)
workflow.add_node("reflection_node", reflection_node)

workflow.add_edge(START, "generation_node")
workflow.add_edge("reflection_node", "generation_node")
workflow.add_conditional_edges("generation_node", should_continue, [END, "reflection_node"])

memory = InMemorySaver()
app = workflow.compile(checkpointer=memory)

config = {"configurable": {"thread_id": "1"}}

In [None]:
from IPython.display import Image, display

display(Image(app.get_graph(xray=True).draw_mermaid_png()))

In [None]:
# result = app.invoke({"messages": [HumanMessage(content="生成一個投放在Tiktok上的冰淇淋廣告劇本。目標群眾為8-15歲的小孩")]}, config)

In [None]:
async for event in app.astream(
    {
        "messages": [
            HumanMessage(
                content="生成一個投放在Tiktok上的冰淇淋廣告劇本。目標群眾為8-15歲的小孩"
            )
        ],
    },
    config,
):
    print(event)
    print("---")

取得歷史紀錄

In [None]:
state = app.get_state(config)

In [None]:
state.values['messages'][1]

Dummy Examples:

In [None]:
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import InMemorySaver
from typing import Annotated
from typing_extensions import TypedDict
from operator import add

class State(TypedDict):
    foo: str
    bar: Annotated[list[str], add]

def node_a(state: State):
    return {"foo": "a", "bar": ["a"]}

def node_b(state: State):
    return {"foo": "b", "bar": ["b"]}


workflow = StateGraph(State)
workflow.add_node(node_a)
workflow.add_node(node_b)
workflow.add_edge(START, "node_a")
workflow.add_edge("node_a", "node_b")
workflow.add_edge("node_b", END)

"""
Checkpoints:

The state of a thread at a particular point in time is called a checkpoint. 
Checkpoint is a snapshot of the graph state saved at each superstep and is represented by StateSnapshot object.
"""

checkpointer = InMemorySaver()
graph = workflow.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}
graph.invoke({"foo": ""}, config)

In [None]:
state = graph.get_state(config)
print(state)

In [None]:
graph.invoke({"foo": "c"}, config)