## LangGraph를 이용해서 email search agent를 강화학습으로 훈련시키기

### Email Search Agent with LangGraph

이 노트북에서는 LangGraph와 ART를 함께 사용하여 ART•E 에이전트를 처음부터 직접 학습시켜 볼 것입니다! 

이 구현은 LangGraph의 agent framework와 ART의 training 기능을 통합하는 방법을 보여줍니다.

Qwen2.5 - 7B base model로 시작하여, LangGraph의 ReAct agent pattern을 사용하여, 우리는 이것을 훈련 시켜 모델이 이메일을 검색하고 질문에 대답을 하게 합니다.

우리는 agentic environment를 구축할 것이고, LangGraph를 이용하여 rollout을 정의하고, training loop를 실행합니다.

우리는 또한 agnet의 answers들의 품질을 판단하기 위한 RULER를 어떻게 사용할 것인지 배울 것입니다.


### RULER

RULER는 agent의 answers 품질을 평가하고, agent가 최상의 완성을 더 잘 생산하도록 훈련시키는 robust한 기술입니다.

RULER에 대해서 더 알아보고 싶다면 다음 링크에 접속하세요:
[RULER](https://art.openpipe.ai/fundamentals/ruler)

In [1]:
#@title 💿 Installation

# Portions adapted from Unsloth Notebooks (https://github.com/unslothai/notebooks)
# Copyright (c) Unsloth contributors.
# License: GNU LGPL v3.0.
# Modifications by OpenPipe:
# - switched to uv
# - changed vllm/triton pinning logic
# - added litellm/protobuf pins
# See /licenses/LGPL-3.0.txt and /licenses/GPL-3.0.txt for full text.

# %%capture
import os

if "COLAB_" not in "".join(os.environ.keys()):
    !uv pip install openpipe-art[backend,langgraph]==0.4.11 langchain-core langgraph langchain_openai tenacity datasets --prerelease allow --no-cache-dir
else:
    try:
        import numpy

        get_numpy = f"numpy=={numpy.__version__}"
    except:
        get_numpy = "numpy"
    try:
        import subprocess

        is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except:
        is_t4 = False
    get_vllm, get_triton = (
        ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm", "triton")
    )
    !uv pip install --upgrade \
        openpipe-art[backend,langgraph]==0.4.11 langchain-core langgraph langchain_openai tenacity datasets protobuf==5.29.5 {get_vllm} {get_numpy} --prerelease allow --no-cache-dir
    !uv pip install -qqq {get_triton}

[2mUsing Python 3.11.13 environment at: /home/khw/miniconda3/envs/openpipe[0m
[2mAudited [1m6 packages[0m [2min 82ms[0m[0m


<a name="Environment-Variables"></a>

### Environment Variables

**OpenAI (used for RULER judge model)**

우리의 RULER 보상 함수는 third-party models을 쿼리하여 에이전트의 performance 품질을 판단합니다. LiteLLM에서 지원하는 모든 모델이 작동합니다. 이 예제에서는 OpenAI의 o4-mini 모델을 사용하므로 `OPENAI_API_KEY` 환경 변수를 설정해야 합니다.



**Weights & Biases (optional)**

나중에 노트북에서 Weights & Biases에 메트릭을 자동으로 기록하고 Weave에 채팅 완료를 자동으로 기록하는 모델을 만들 것입니다. 이를 위해서는 Weights & Biases API 키를 환경 변수로 제공해야 합니다.


In [None]:
import os

from dotenv import load_dotenv

load_dotenv()

# Required
# os.environ["OPENAI_API_KEY"] = "OPENAI_API_KEY"

if not os.environ.get("OPENAI_API_KEY"):
    raise ValueError(
        "OPENAI_API_KEY is required for RULER functionality when using openai/o4-mini."
    )

# Optional
# os.environ["WANDB_API_KEY"] = "YOUR_API_KEY"

if not os.environ.get("WANDB_API_KEY"):
    print("WANDB_API_KEY is not set. We'll skip logging metrics to Weights & Biases.")

<a name="Environment"></a>

### Email Search Environment

ART는 에이전트가 환경과 상호 작용하여 학습할 수 있도록 합니다. 이 예시에서는 **LangGraph의 도구 통합**을 사용하여 에이전트가 __이메일을 검색하고 관련 질문에 답변할 수 있는 환경__ 을 만들어 보겠습니다.

이 agent가 접근 할 수 있는 3가지 도구는 다음과 같습니다:
1. `search_inbox` - Search for emails by keywords
2. `read_email` - Read a specific email by message ID
3. `return_final_answer` - Return the final answer with source email IDs

In [3]:
#@title Email Search Code

import os
import random
import sqlite3
from dataclasses import asdict, dataclass
from datetime import datetime
from textwrap import dedent
from typing import List, Literal, Optional

from datasets import Dataset, Features, Sequence, Value, load_dataset
from pydantic import BaseModel, Field
from tqdm import tqdm


# Email and Scenario data models
class Email(BaseModel):
    message_id: str
    date: str  # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
    subject: Optional[str] = None
    from_address: Optional[str] = None
    to_addresses: List[str] = []  # Populated from recipients table
    cc_addresses: List[str] = []  # Populated from recipients table
    bcc_addresses: List[str] = []  # Populated from recipients table
    body: Optional[str] = None
    file_name: Optional[str] = None


class Scenario(BaseModel):
    id: int
    question: str
    answer: str
    message_ids: List[str]  # message_ids (strings) of referenced emails
    how_realistic: float
    inbox_address: str
    query_date: str
    split: Literal["train", "test"]


@dataclass
class SearchResult:
    message_id: str
    snippet: str


class FinalAnswer(BaseModel):
    answer: str
    source_ids: list[str]


# Database configuration
DB_PATH = "./enron_emails.db"
EMAIL_DATASET_REPO_ID = "corbt/enron-emails"        # 허깅페이스에 있는 데이터셋 레포 아이디 
SCENARIO_DATASET_REPO_ID = "corbt/enron_emails_sample_questions"

# Global database connection
db_conn = None

# 허깅페이스 데이터셋을 로드해서 sqlite3로 db를 만들고 sqlite3 db를 접근하며 옳은 대답을 해주는 LLM Agent를 만들것임
def create_email_database():
    """Create the email database from Hugging Face dataset"""
    print("Creating email database from Hugging Face dataset...")
    print(
        "This will download and process the full Enron email dataset - this may take several minutes..."
    )

    # Database schema
    SQL_CREATE_TABLES = """
    DROP TABLE IF EXISTS recipients;
    DROP TABLE IF EXISTS emails_fts;
    DROP TABLE IF EXISTS emails;

    CREATE TABLE emails (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        message_id TEXT UNIQUE,
        subject TEXT,
        from_address TEXT,
        date TEXT,
        body TEXT,
        file_name TEXT
    );

    CREATE TABLE recipients (
        email_id TEXT,
        recipient_address TEXT,
        recipient_type TEXT
    );
    """

    SQL_CREATE_INDEXES_TRIGGERS = """
    CREATE INDEX idx_emails_from ON emails(from_address);
    CREATE INDEX idx_emails_date ON emails(date);
    CREATE INDEX idx_emails_message_id ON emails(message_id);
    CREATE INDEX idx_recipients_address ON recipients(recipient_address);
    CREATE INDEX idx_recipients_type ON recipients(recipient_type);
    CREATE INDEX idx_recipients_email_id ON recipients(email_id);
    CREATE INDEX idx_recipients_address_email ON recipients(recipient_address, email_id);

    CREATE VIRTUAL TABLE emails_fts USING fts5(
        subject,
        body,
        content='emails',
        content_rowid='id'
    );

    CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
        INSERT INTO emails_fts (rowid, subject, body)
        VALUES (new.id, new.subject, new.body);
    END;

    CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
        DELETE FROM emails_fts WHERE rowid=old.id;
    END;

    CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
        UPDATE emails_fts SET subject=new.subject, body=new.body WHERE rowid=old.id;
    END;
    """

    # Create database
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()
    cursor.executescript(SQL_CREATE_TABLES)
    conn.commit()

    # Load dataset
    print("Loading full email dataset...")
    expected_features = Features(
        {
            "message_id": Value("string"),
            "subject": Value("string"),
            "from": Value("string"),
            "to": Sequence(Value("string")),
            "cc": Sequence(Value("string")),
            "bcc": Sequence(Value("string")),
            "date": Value("timestamp[us]"),
            "body": Value("string"),
            "file_name": Value("string"),
        }
    )

    dataset = load_dataset(
        EMAIL_DATASET_REPO_ID, features=expected_features, split="train"
    )
    print(f"Dataset contains {len(dataset)} total emails")

    # Populate database with ALL emails (not limited to 1000)
    print("Populating database with all emails...")
    conn.execute("PRAGMA synchronous = OFF;")
    conn.execute("PRAGMA journal_mode = MEMORY;")
    conn.execute("BEGIN TRANSACTION;")

    record_count = 0
    skipped_count = 0
    duplicate_count = 0
    processed_emails = set()  # Track (subject, body, from) tuples for deduplication

    for email_data in tqdm(dataset, desc="Inserting emails"):
        message_id = email_data["message_id"]
        subject = email_data["subject"]
        from_address = email_data["from"]
        date_obj: datetime = email_data["date"]
        body = email_data["body"]
        file_name = email_data["file_name"]
        to_list = [str(addr) for addr in email_data["to"] if addr]
        cc_list = [str(addr) for addr in email_data["cc"] if addr]
        bcc_list = [str(addr) for addr in email_data["bcc"] if addr]

        # Apply the same filters as the original project
        total_recipients = len(to_list) + len(cc_list) + len(bcc_list)

        # Filter out very long emails and those with too many recipients
        if len(body) > 5000:
            skipped_count += 1
            continue

        if total_recipients > 30:
            skipped_count += 1
            continue

        # Deduplication check (same as original project)
        email_key = (subject, body, from_address)
        if email_key in processed_emails:
            duplicate_count += 1
            continue
        else:
            processed_emails.add(email_key)

        date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")

        cursor.execute(
            """
            INSERT INTO emails (message_id, subject, from_address, date, body, file_name)
            VALUES (?, ?, ?, ?, ?, ?)
        """,
            (message_id, subject, from_address, date_str, body, file_name),
        )

        # Insert recipients
        recipient_data = []
        for addr in to_list:
            recipient_data.append((message_id, addr, "to"))
        for addr in cc_list:
            recipient_data.append((message_id, addr, "cc"))
        for addr in bcc_list:
            recipient_data.append((message_id, addr, "bcc"))

        if recipient_data:
            cursor.executemany(
                """
                INSERT INTO recipients (email_id, recipient_address, recipient_type)
                VALUES (?, ?, ?)
            """,
                recipient_data,
            )

        record_count += 1

    conn.commit()

    # Create indexes and triggers
    print("Creating indexes and FTS...")
    cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
    cursor.execute('INSERT INTO emails_fts(emails_fts) VALUES("rebuild")')
    conn.commit()

    print(f"Successfully created database with {record_count} emails.")
    print(f"Skipped {skipped_count} emails due to length/recipient limits.")
    print(f"Skipped {duplicate_count} duplicate emails.")
    return conn


def get_db_connection():
    """Get database connection"""
    if os.path.exists(DB_PATH):
        print(f"Loading existing database from {DB_PATH}")
        db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
    else:
        db_conn = create_email_database()
    return db_conn


def search_emails(
    inbox: str,
    keywords: List[str],
    from_addr: Optional[str] = None,
    to_addr: Optional[str] = None,
    sent_after: Optional[str] = None,
    sent_before: Optional[str] = None,
    max_results: int = 10,
) -> List[SearchResult]:
    """Search the email database based on keywords and filters"""
    conn = get_db_connection()
    cursor = conn.cursor()

    where_clauses: List[str] = []
    params: List[str | int] = []

    if not keywords:
        raise ValueError("No keywords provided for search.")

    if max_results > 10:
        raise ValueError("max_results must be less than or equal to 10.")

    # FTS5 default is AND, so just join keywords. Escape quotes for safety.
    fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
    where_clauses.append("fts.emails_fts MATCH ?")
    params.append(fts_query)

    # Inbox filter
    where_clauses.append("""
        (e.from_address = ? OR EXISTS (
            SELECT 1 FROM recipients r_inbox
            WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.message_id
        ))
    """)
    params.extend([inbox, inbox])

    if from_addr:
        where_clauses.append("e.from_address = ?")
        params.append(from_addr)

    if to_addr:
        where_clauses.append("""
            EXISTS (
                SELECT 1 FROM recipients r_to
                WHERE r_to.recipient_address = ? AND r_to.email_id = e.message_id
            )
        """)
        params.append(to_addr)

    if sent_after:
        where_clauses.append("e.date >= ?")
        params.append(f"{sent_after} 00:00:00")

    if sent_before:
        where_clauses.append("e.date < ?")
        params.append(f"{sent_before} 00:00:00")

    sql = f"""
        SELECT
            e.message_id,
            snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
        FROM
            emails e JOIN emails_fts fts ON e.id = fts.rowid
        WHERE
            {" AND ".join(where_clauses)}
        ORDER BY
            e.date DESC
        LIMIT ?;
    """
    params.append(max_results)

    cursor.execute(sql, params)
    results = cursor.fetchall()

    return [SearchResult(message_id=row[0], snippet=row[1]) for row in results]


def read_email(message_id: str) -> Optional[Email]:
    """Retrieve a single email by its message_id"""
    conn = get_db_connection()
    cursor = conn.cursor()

    # Get email details
    cursor.execute(
        "SELECT message_id, date, subject, from_address, body, file_name FROM emails WHERE message_id = ?",
        (message_id,),
    )
    email_row = cursor.fetchone()

    if not email_row:
        return None

    msg_id, date, subject, from_addr, body, file_name = email_row

    # Get recipients
    cursor.execute(
        "SELECT recipient_address, recipient_type FROM recipients WHERE email_id = ?",
        (message_id,),
    )
    recipient_rows = cursor.fetchall()

    to_addresses = []
    cc_addresses = []
    bcc_addresses = []

    for addr, type_val in recipient_rows:
        if type_val.lower() == "to":
            to_addresses.append(addr)
        elif type_val.lower() == "cc":
            cc_addresses.append(addr)
        elif type_val.lower() == "bcc":
            bcc_addresses.append(addr)

    return Email(
        message_id=msg_id,
        date=date,
        subject=subject,
        from_address=from_addr,
        to_addresses=to_addresses,
        cc_addresses=cc_addresses,
        bcc_addresses=bcc_addresses,
        body=body,
        file_name=file_name,
    )


def load_training_scenarios(
    split: Literal["train", "test"] = "train",
    limit: Optional[int] = None,
    max_messages: Optional[int] = 1,
    shuffle: bool = False,
    seed: Optional[int] = None,
) -> List[Scenario]:
    """Load training scenarios from Hugging Face dataset"""
    print(f"Loading {split} scenarios from Hugging Face...")
    dataset: Dataset = load_dataset(SCENARIO_DATASET_REPO_ID, split=split)

    if max_messages is not None:
        dataset = dataset.filter(lambda x: len(x["message_ids"]) <= max_messages)

    if shuffle or (seed is not None):
        if seed is not None:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()

    # Convert each row to a Scenario object
    scenarios = [Scenario(**row, split=split) for row in dataset]

    if max_messages is not None:
        scenarios = [s for s in scenarios if len(s.message_ids) <= max_messages]

    if shuffle:
        if seed is not None:
            rng = random.Random(seed)
            rng.shuffle(scenarios)
        else:
            random.shuffle(scenarios)

    if limit is not None:
        scenarios = scenarios[:limit]

    print(f"Loaded {len(scenarios)} scenarios.")
    return scenarios


# Load training scenarios
training_scenarios = load_training_scenarios(
    split="train", limit=50, max_messages=1, shuffle=True, seed=42
)

print("Email search environment created with full Enron dataset!")
print(
    f"Database contains the complete email dataset, loaded {len(training_scenarios)} training scenarios."
)

# print first scenario
print("\nSample scenario")
print("id:", training_scenarios[0].id)
print("question:", training_scenarios[0].question)
print("answer:", training_scenarios[0].answer)
print("message_ids:", training_scenarios[0].message_ids)
print("how_realistic:", training_scenarios[0].how_realistic)
print("inbox_address:", training_scenarios[0].inbox_address)
print("query_date:", training_scenarios[0].query_date)
print("split:", training_scenarios[0].split)

  from .autonotebook import tqdm as notebook_tqdm


Loading train scenarios from Hugging Face...
Loaded 50 scenarios.
Email search environment created with full Enron dataset!
Database contains the complete email dataset, loaded 50 training scenarios.

Sample scenario
id: 3296
question: Who can I contact for Power Operations when Sally is in London?
answer: Stacey White (x31870) and Leslie Reeves (x37962).
message_ids: ['<6033065.1075856098960.JavaMail.evans@thyme>']
how_realistic: 0.699999988079071
inbox_address: louise.kitchen@enron.com
query_date: 2001-01-25
split: train


### Creating a Model

이제 환경 규칙을 정의했으므로, 이메일을 효과적으로 검색하는 방법을 학습하는 모델을 만들 수 있습니다. 

이 예제에서는 Qwen 2.5 7B 모델을 사용하겠습니다.



In [None]:
import art
from art.local import LocalBackend

random.seed(42)

# Declare the model
model = art.TrainableModel(
    name="email-agent-langgraph-001",
    project="email-search-agent-langgraph",
    base_model="Qwen/Qwen2.5-7B-Instruct",
)

# To run on a T4, we need to override some config defaults.
model._internal_config = art.dev.InternalModelConfig(
    init_args=art.dev.InitArgs(
        max_seq_length=8192,
    ),
    engine_args=art.dev.EngineArgs(
        enforce_eager=True,
        gpu_memory_utilization=0.8,
    ),
)

# Initialize the server
backend = LocalBackend(
    # Normally we don't want to run the server in-process, but for the output
    # to show up properly on Google Colab we'll enable this.
    in_process=True,
    path="./.art",
)

# Register the model with the local Backend (sets up logging, inference, and training)
await model.register(backend)

<a name="Rollout"></a>

### Defining a Rollout with LangGraph


롤아웃은 에이전트가 작업을 수행하는 단일 에피소드를 의미합니다.

이 예시에서는 LangGraph의 ReAct 에이전트를 사용하여 롤아웃을 처리합니다.

롤아웃 함수는 에이전트에게 **이메일 검색** 시나리오를 제시하고, LangGraph 에이전트는 **사용 가능한 도구**를 사용하여 _이메일을 검색_ 하고 질문에 답합니다.

agent가 최종 답변을 제공하면 답변이 맞는지 여부에 따라 `correct` metric 가 계산됩니다.



In [5]:
import uuid

import weave
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from litellm import acompletion
from tenacity import retry, stop_after_attempt
from art.langgraph import init_chat_model

import art

if os.getenv("WANDB_API_KEY", ""):
    weave.init(model.project, settings={"print_call_link": False})

MAX_TURNS = 20

class CorrectnessJudgeResponse(BaseModel):
    reasoning: str = Field(description="Explanation of the reasoning process.")
    accept: bool = Field(description="Whether the AI answer should be accepted.")


@retry(stop=stop_after_attempt(3))
async def judge_correctness(
    scenario: Scenario, answer: str
) -> CorrectnessJudgeResponse:
    system_prompt = dedent(
        """
        You are given a question, the reference answer (labelled **Reference answer**), and an answer generated by an AI assistant (labelled **AI answer**).

        Your task is to decide whether the AI answer is correct and should be accepted. You should accept the answer if it contains the relevant information from the reference answer. You should not accept the answer if it is missing information relevant to the question, or if it contradicts the reference answer.
        """
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                f"Question: {scenario.question}\n"
                f"Reference answer: {scenario.answer}\n"
                f"AI answer: {answer}"
            ),
        },
    ]

    response = await acompletion(
        model="openai/gpt-4.1",
        messages=messages,
        response_format=CorrectnessJudgeResponse,
    )

    first_choice = response.choices[0]
    raw_content = first_choice.message.content or "{}"

    try:
        return CorrectnessJudgeResponse.model_validate_json(raw_content)
    except Exception as e:
        return CorrectnessJudgeResponse(
            reasoning=f"Parse error: {e}\nRaw: {raw_content}", accept=False
        )


class ProjectTrajectory(art.Trajectory):
    final_answer: FinalAnswer | None = None


class EmailScenario(BaseModel):
    step: int
    scenario: Scenario


@weave.op
async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory:
    scenario = email_scenario.scenario

    traj = ProjectTrajectory(
        reward=0.0,
        messages_and_choices=[],
        metadata={
            "scenario_id": scenario.id,
            "step": email_scenario.step,
        },
    )

    system_prompt = dedent(
        f"""
        You are an email search agent. You are given a user query and a list of tools you can use to search the user's email. Use the tools to search the user's emails and find the answer to the user's query. You may take up to {MAX_TURNS} turns to find the answer, so if your first search doesn't find the answer, you can try with different keywords.

        User's email address is {scenario.inbox_address}
        Today's date is {scenario.query_date}

        When you have found the answer, use the return_final_answer_tool to provide your final answer along with the source message IDs.
        """
    )

    # Store final answer in trajectory
    final_answer = None

    # Define tools inside the rollout function to access local variables
    @tool
    def search_inbox_tool(keywords: list[str]) -> list[dict]:
        """Search the inbox for emails matching the given keywords and return
        a list of dictionaries so the LLM can easily consume them."""
        results = search_emails(
            inbox=scenario.inbox_address,
            keywords=keywords,
            sent_before=scenario.query_date,
        )
        return [asdict(result) for result in results]

    @tool
    def read_email_tool(message_id: str) -> dict | None:
        """Read a specific email by message ID."""
        email = read_email(message_id)
        if email:
            return email.model_dump()
        return None

    @tool
    def return_final_answer_tool(answer: str, reference_message_ids: list[str]) -> dict:
        """Return the final answer and the message IDs of the emails that were used to generate the answer."""
        nonlocal final_answer
        final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids)
        return final_answer.model_dump()

    # Create LangGraph tools
    tools = [search_inbox_tool, read_email_tool, return_final_answer_tool]

    chat_model = init_chat_model(model.name, temperature=1.0)

    # Create the LangGraph ReAct agent
    react_agent = create_react_agent(chat_model, tools)

    try:
        # Run the agent
        config = {
            "configurable": {"thread_id": str(uuid.uuid4())},
            "recursion_limit": MAX_TURNS,
        }

        await react_agent.ainvoke(
            {
                "messages": [
                    SystemMessage(content=system_prompt),
                    HumanMessage(content=scenario.question),
                ]
            },
            config=config,
        )

        # Check if we got a final answer
        if final_answer:
            traj.final_answer = final_answer
            # Score the trajectory
            correctness_judge_response = await judge_correctness(
                scenario, traj.final_answer.answer
            )
            traj.metrics["correct"] = float(correctness_judge_response.accept)

    except Exception as e:
        print(f"Error running LangGraph agent: {e}")
        # Add error information to trajectory
        traj.messages_and_choices.append(
            {"role": "assistant", "content": f"Error: {str(e)}"}
        )

    return traj


print("LangGraph rollout function defined!")

LangGraph rollout function defined!


<a name="ruler"></a>

### How RULER works

**RULER** 는 다음 2개의 핵심 인사이트를 이용합니다:

1. **상대적 채점은 절대적 채점보다 쉽습니다**
    : LLM이 여러 솔루션을 개별적으로 채점하는 것보다 서로 상대적으로 순위를 매기는 것이 더 쉽습니다.
    
2. **GRPO는 상대 점수만 필요합니다**
    : GRPO는 각 그룹 내에서 점수를 정규화하므로 절대 값이 아닌 상대적 순위만 중요합니다.
    

#### process:

1. 주어진 시나리오에 대해 N개의 궤적을 생성합니다.
2. 모든 N개의 궤적을 RULER에 전달합니다.
3. RULER는 공통 접두사(예: 동일한 시스템 메시지)의 중복을 제거합니다.
4. LLM 심사위원은 목표 달성 여부에 따라 각 궤적을 0~1점으로 평가합니다.
5. 이 점수는 GRPO 훈련에서 보상으로 직접 사용됩니다.


To learn more about **RULER**, check out the [RULER docs](https://art.openpipe.ai/fundamentals/ruler).










In [6]:
#@title Sample RULER evaluation

import art
from art.rewards import ruler_score_group

# Test RULER with a simple example
base_messages = [
    {"role": "system", "content": "You count numbers using numeric symbols."},
    {"role": "user", "content": "Count to 10."},
]

good_trajectory = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {"role": "assistant", "content": "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"},
    ],
    reward=0,
)

mediocre_trajectory = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {
            "role": "assistant",
            "content": "one, two, three, four, five, six, seven, eight, nine, ten",
        },
    ],
    reward=0,
)

bad_trajectory = art.Trajectory(
    messages_and_choices=[
        *base_messages,
        {"role": "assistant", "content": "a, b, c, d, e, f, g, h, i, j"},
    ],
    reward=0,
)

sample_group = art.TrajectoryGroup(
    trajectories=[
        good_trajectory,
        mediocre_trajectory,
        bad_trajectory,
    ]
)

judged_group = await ruler_score_group(sample_group, "openai/o4-mini", debug=True)
assert judged_group is not None

# Display rankings
sorted_trajectories = sorted(
    judged_group.trajectories, key=lambda t: t.reward, reverse=True
)
for rank, traj in enumerate(sorted_trajectories, 1):
    messages = traj.messages()
    print(f"\nRank {rank}: Score {traj.reward:.3f}")
    print(f"  Response: {messages[-1]['content'][:50]}...")


Rank 1: Score 1.000
  Response: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10...

Rank 2: Score 0.500
  Response: one, two, three, four, five, six, seven, eight, ni...

Rank 3: Score 0.000
  Response: a, b, c, d, e, f, g, h, i, j...


<a name="Loop"></a>

### Training Loop with LangGraph

The training loop is where the magic happens. For each of the steps defined below, the rollout function will be called multiple times in parallel using LangGraph's ReAct agent. Each scenario will produce a trajectory, which will be used to update the model.

The `gather` step will wait for all of the trajectories to be generated, then it will use RULER to assign relative scores to each trajectory.

Our notebook will then delete all but the most recent checkpoint and train the model on the scored trajectories.

In [None]:
# Training configuration
from art.utils import iterate_dataset
from art.langgraph import wrap_rollout

training_config = {
    "groups_per_step": 2,
    "num_epochs": 20,
    "rollouts_per_group": 4,
    "learning_rate": 1e-5,
    "max_steps": 20,
}

# Use iterate_dataset with real training scenarios (similar to train.py)
training_iterator = iterate_dataset(
    training_scenarios,  # Use real scenarios from Hugging Face
    groups_per_step=training_config["groups_per_step"],
    num_epochs=training_config["num_epochs"],
    initial_step=await model.get_step(),
)

for batch in training_iterator:
    print(
        f"Training step {batch.step}, epoch {batch.epoch}, epoch step {batch.epoch_step}"
    )
    print(f"Batch contains {len(batch.items)} scenarios")

    # Create trajectory groups for this batch (similar to train.py)
    groups = []
    for scenario in batch.items:
        groups.append(
            art.TrajectoryGroup(
                (
                    wrap_rollout(model, rollout)(
                        model, EmailScenario(step=batch.step, scenario=scenario)
                    )
                    for _ in range(training_config["rollouts_per_group"])
                )
            )
        )
    print(groups[0])
    # Gather all trajectory groups
    finished_groups = await art.gather_trajectory_groups(
        groups,
        pbar_desc="gather",
        max_exceptions=training_config["rollouts_per_group"] * len(batch.items),
    )

    judged_groups = []
    for group in finished_groups:
        # Use RULER to assign relative scores to each trajectory
        judged_group = await ruler_score_group(group, "openai/o4-mini", debug=True)
        judged_groups.append(judged_group)

    await model.delete_checkpoints()
    await model.train(
        judged_groups,
        config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
        # Lowering the logprob_calculation_chunk_size is a memory saving measure
        # to allow longer sequences (up to 8192 tokens) to be processed on a T4.
        _config={"logprob_calculation_chunk_size": 8},
    )

    print(f"Completed training step {batch.step}")

    # Stop after max_steps for demo purposes (adjust as needed)
    if batch.step >= training_config["max_steps"]:
        break

### Using the Model

LangGraph를 이용해서 Agent가 emails를 검색하고 질문에 대답할 수 있도록 훈련시켰습니다! 

이제 training loop 밖으로 당신의 모델을 사용할 시간입니다. 


방금 훈련시킨 모델의 데모를 위해 아래 코드를 확인하세요.



In [9]:
#@title Loading/inference code

# Test the trained model using the rollout function
# This avoids memory issues and uses the same inference path as training

print("Testing the trained LangGraph model with a real scenario...\n")


# Use a scenario from our training set
test_scenario = training_scenarios[1]

print(f"Test scenario ID: {test_scenario.id}")
print(f"Question: {test_scenario.question}")
print(f"Expected answer: {test_scenario.answer}")
print(f"Reference message IDs: {test_scenario.message_ids}")
print(f"Inbox: {test_scenario.inbox_address}")
print(f"Query date: {test_scenario.query_date}")
print("-" * 50)

# Run the rollout function with the trained model
test_email_scenario = EmailScenario.model_validate(
    {"step": 0, "scenario": test_scenario.model_dump()}
)
result_trajectory = await wrap_rollout(model, rollout)(model, test_email_scenario)

print("LangGraph Agent's trajectory:")
print("-" * 20)

# Display the conversation
messages = result_trajectory.messages()
for i, msg in enumerate(messages):
    role = msg.get("role", "unknown")
    content = msg.get("content", "")
    tool_calls = msg.get("tool_calls", [])

    if role == "system":
        print(
            f"[SYSTEM]: {content[:100]}..."
            if len(content) > 100
            else f"[SYSTEM]: {content}"
        )
    elif role == "user":
        print(f"[USER]: {content}")
    elif role == "assistant":
        if tool_calls:
            print(f"[ASSISTANT]: {tool_calls}")
        if content:
            print(f"[ASSISTANT]: {content}")
    elif role == "tool":
        tool_name = msg.get("name", "unknown_tool")
        print(
            f"[TOOL - {tool_name}]: {content[:200]}..."
            if len(content) > 200
            else f"[TOOL - {tool_name}]: {content}"
        )

    print()

print("-" * 50)
if result_trajectory.final_answer:
    print(f"Agent's Final Answer: {result_trajectory.final_answer.answer}")
    print(f"Source IDs Used: {result_trajectory.final_answer.source_ids}")
else:
    print("No final answer provided by the agent")

print(f"\nExpected Answer: {test_scenario.answer}")
print(f"Expected Source IDs: {test_scenario.message_ids}")

print("\n🎉 LangGraph email search agent testing completed!")
print(
    "The agent used LangGraph's ReAct pattern with the same inference path as during training."
)

Testing the trained LangGraph model with a real scenario...

Test scenario ID: 3751
Question: What is my schedule for the week of August 28th?
Expected answer: Monday--in office; Tuesday--in office; Wednesday--in office; Thursday--vacation; Friday--vacation.
Reference message IDs: ['<28971276.1075842953941.JavaMail.evans@thyme>']
Inbox: jeff.dasovich@enron.com
Query date: 2000-09-20
--------------------------------------------------


/tmp/ipykernel_288833/1452329291.py:133: LangGraphDeprecatedSinceV10: create_react_agent has been moved to langchain.agents. Please update your import to 'from langchain.agents import create_agent'. Deprecated in LangGraph V1.0 to be removed in V2.0.
  react_agent = create_react_agent(chat_model, tools)


Loading existing database from ./enron_emails.db
Loading existing database from ./enron_emails.db
Loading existing database from ./enron_emails.db
Loading existing database from ./enron_emails.db
Loading existing database from ./enron_emails.db
LangGraph Agent's trajectory:
--------------------
[SYSTEM]: 
You are an email search agent. You are given a user query and a list of tools you can use to search...

[USER]: What is my schedule for the week of August 28th?

[ASSISTANT]: [{'id': 'chatcmpl-tool-46d34658dbb743978da665a239ef61fe', 'type': 'function', 'function': {'name': 'search_inbox_tool', 'arguments': '{"keywords": ["schedule", "August 28th"]}'}}]

[TOOL - unknown_tool]: 

[ASSISTANT]: [{'id': 'chatcmpl-tool-d0ad457ef20a45388d5cc2324e7cc073', 'type': 'function', 'function': {'name': 'search_inbox_tool', 'arguments': '{"keywords": ["week", "August 28th", " agenda", "plan"]}'}}]

[TOOL - unknown_tool]: 

[ASSISTANT]: [{'id': 'chatcmpl-tool-36cfeb1bbe404d56aefc5da06d8dcead', 'type':

<div class="align-center">
<a href="https://github.com/openpipe/art"><img src="https://github.com/openpipe/art/raw/main/assets/ART_pill.png" height="50"></a>
<a href="https://discord.gg/zbBHRUpwf4"><img src="https://github.com/openpipe/art/raw/main/assets/Discord.png" height="50"></a>
<a href="https://art.openpipe.ai"><img src="https://github.com/openpipe/art/raw/main/assets/Documentation_pill.png" height="50"></a>

Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [Github](https://github.com/openpipe/art).

</div>