# Custom Reward Module Development Guide

This notebook demonstrates how to create custom reward modules by extending the base classes in RM-Gallery.

In [None]:
# Import base classes and dependencies
import sys
sys.path.append('..')

from pydantic import Field
from rm_gallery.core.reward.base import BasePointWiseReward
from rm_gallery.core.reward.schema import RewardDimensionWithScore
from rm_gallery.core.data.schema import DataSample
from rm_gallery.core.reward.schema import RewardResult



[32m2025-06-16 19:42:21.741[0m | [1mINFO    [0m | [36mrm_gallery.core.utils.logger[0m:[36minit_logger[0m:[36m16[0m - [1mstart![0m


## Choose Base Class

Select appropriate base class based on evaluation strategy:
- [BasePointWiseReward](../rm_gallery/core/reward/base.py#L269-L357): For independent response evaluation
- [BaseListWiseReward](../rm_gallery/core/reward/base.py#L360-L419): For comparative evaluation of multiple responses
- [BaseStepWiseReward](../rm_gallery/core/reward/base.py#L168-L266): For multi-step reasoning evaluation
- [BaseLLMReward](../rm_gallery/core/reward/base.py#L168-L266): For llm reasoning evaluation

### Example 1: Custom Point-wise Reward

In [None]:
# Example: Custom Point-wise Reward
class CustomSafetyReward(BasePointWiseReward):
    """Custom reward module for safety evaluation."""
    name: str = 'safety'
    threshold: float = Field(default=0.5, description="safety score threshold")


    def _evaluate(self, sample: DataSample, **kwargs) -> RewardResult:
        """
        Evaluate response safety using custom logic.
        
        Args:
            sample: Data sample containing response to evaluate
            **kwargs: Additional parameters
            
        Returns:
            Safety score with explanation
        """
        # Example: Simple keyword-based safety check
        answer = sample.output[0].answer.content.lower()
        unsafe_keywords = ['violence', 'hate', 'illegal']
        
        score = 1.0  # Default safe
        reasons = []
        
        for keyword in unsafe_keywords:
            if keyword in answer:
                score = 0.0
                reasons.append(f'Contains unsafe keyword: {keyword}')
                break
                
        return RewardResult(
            name=self.name,
            details=[
                RewardDimensionWithScore(
                    name=self.name,
                    score=score,
                    reason='; '.join(reasons) if reasons else 'No safety issues found'
                )
            ]
        )
    
    

### Example 1: Usage

In [None]:
# Create test sample
from rm_gallery.core.data.schema import DataSample, DataOutput, Step
from rm_gallery.core.model.message import ChatMessage

test_sample = DataSample(
    unique_id="test_001",
    input=[ChatMessage(role="user", content="How do I make a cake?")],
    output=[DataOutput(answer=Step(content="Mix flour, eggs, and sugar, then bake at 350°F for 30 minutes."))]
)

# Initialize and use custom reward
safety_checker = CustomSafetyReward(threshold=0.7)

# Single sample evaluation
result = safety_checker.evaluate(test_sample)
print(f"Safety score: {result.output[0].answer.reward.details[0].score}")
print(f"Reason: {result.output[0].answer.reward.details[0].reason}")

Safety score: 1.0
Reason: No safety issues found


### Example 2: Custom Point-wise LLM Reward

In [None]:
from typing import Type
from pydantic import Field
from rm_gallery.core.model.message import format_messages
from rm_gallery.core.reward.base import BaseLLMReward
from rm_gallery.core.reward.schema import RewardDimensionWithScore, RewardResult
from rm_gallery.core.reward.template import BasePromptTemplate

class FactualityPromptTemplate(BasePromptTemplate):
    """Prompt template for factuality assessment"""
    score: float = Field(default=..., description="Return only the numerical factuality score")
    
    @classmethod
    def format(cls, question: str, answer: str, **kwargs) -> str:
        return f"""
Question: {question}
Response: {answer}

Score according to these criteria:
1. Fully accurate and verifiable: 1.0
2. Partially correct with minor errors: 0.5
3. Completely incorrect/misleading: 0.0

# Output:
{cls.schema()}
    """

class FactualityReward(BaseLLMReward, BasePointWiseReward):
    """LLM-based factuality assessment reward module"""
    
    name: str = "factuality"
    threshold: float = Field(default=0.7, description="Factuality score threshold")
    template: Type[BasePromptTemplate] = FactualityPromptTemplate

    def _before_evaluate(self, sample: DataSample, **kwargs) -> dict:
        """
        Prepare prompt parameters
        
        Args:
            sample: Data sample containing question and response
            
        Returns:
            dict: Dictionary containing 'question' and 'answer' fields
        """
        question = format_messages(sample.input)
        answer = sample.output[0].answer.content
        return {"question": question, "answer": answer}

    def _after_evaluate(self, response: FactualityPromptTemplate, **kwargs) -> RewardResult:
        """
        Parse LLM response into reward value
        
        Args:
            response: Raw response string from LLM
            
        Returns:
            RewardResult: Object containing factuality score
        """
        score = response.score
        return RewardResult(
            name=self.name,
            details=[
                RewardDimensionWithScore(
                    name=self.name,
                    score=score,
                    reason=f"LLM factuality score: {score}"
                )
            ],
            extra_data={"raw_response": response}
        )

### Example 2: Usage

In [None]:
# Initialize LLM client
from rm_gallery.core.model.openai_llm import OpenaiLLM

llm = OpenaiLLM(model="qwen3-8b", enable_thinking=True)

# Create reward module instance
factuality_checker = FactualityReward(llm=llm)

# Create test sample
from rm_gallery.core.data.schema import DataSample, DataOutput, ChatMessage

test_sample = DataSample(
    unique_id="test_001",
    input=[ChatMessage(role="user", content="What is the capital of France?")],
    output=[DataOutput(answer=Step(content="The capital of France is Paris."))]
)

# Execute evaluation
result = factuality_checker.evaluate(test_sample)
print(f"Factuality score: {result.output[0].answer.reward.details[0].score}")
print(f"Reason: {result.output[0].answer.reward.details[0].reason}")

[32m2025-06-16 19:42:22.348[0m | [1mINFO    [0m | [36mrm_gallery.core.reward.base[0m:[36m_evaluate[0m:[36m540[0m - [1mprompt: 
Question: <user>What is the capital of France?</user>
Response: The capital of France is Paris.

Score according to these criteria:
1. Fully accurate and verifiable: 1.0
2. Partially correct with minor errors: 0.5
3. Completely incorrect/misleading: 0.0

# Output:
Note: Ensure all outputs are placed within the tags like <tag> </tag> as required!!!
<think>
your reasoning trace
</think>
<score>
Return only the numerical factuality score
</score>

    [0m
[32m2025-06-16 19:42:24.337[0m | [1mINFO    [0m | [36mrm_gallery.core.reward.base[0m:[36m_evaluate[0m:[36m544[0m - [1mresponse: reason="Okay, let's see. The user asked for the capital of France. The response given was Paris. Well, I know that Paris is indeed the capital of France. Let me double-check to make sure there's no mistake. Yes, France's capital is Paris. There's no conflicting inf

Factuality score: 1.0
Reason: LLM factuality score: 1.0
