# Educational RLHF Framework - Test Notebook

**Mục đích:** Test từng component riêng biệt trước khi chạy full system

**Thứ tự test:**
1. ✅ API Connection & Basic Generation
2. ✅ Fixed ArmoRM Component
3. ✅ Policy Configuration
4. ✅ Single RLHF Iteration
5. ✅ Full System Integration

## 🔧 Setup & Dependencies

In [1]:
# Install required packages
!pip install google-generativeai python-dotenv numpy pandas matplotlib 



DEPRECATION: Loading egg at c:\users\fptshop\appdata\local\programs\python\python311\lib\site-packages\torchlight-1.0-py3.11.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330

[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import os
import json
import asyncio
import numpy as np
import pandas as pd
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
import google.generativeai as genai
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

print("📦 Imports successful!")
print(f"🔗 Google GenerativeAI version: {genai.__version__}")

📦 Imports successful!
🔗 Google GenerativeAI version: 0.8.5


## 🔑 API Keys Configuration

In [13]:
import os
import google.generativeai as genai

# Set API keys directly (for testing)
FLASH_API_KEY = "AIzaSyBINk0rcDZIvLtYezaFirKmpofiy5MRVs0"

# Validate keys
if not FLASH_API_KEY:
    print("⚠️  Please set FLASH_API_KEY properly.")
else:
    print(f"✅ Flash API key: {FLASH_API_KEY[:10]}...")

if not PRO_API_KEY:
    print("⚠️  Please set PRO_API_KEY properly.")
else:
    print(f"✅ Pro API key: {PRO_API_KEY[:10]}...")

# Configure Gemini
genai.configure(api_key=FLASH_API_KEY)
flash_model = genai.GenerativeModel("gemini-1.5-flash")
pro_model   = genai.GenerativeModel("gemini-2.5-flash")
print("\n🎯 API configuration complete!")


✅ Flash API key: AIzaSyBINk...
✅ Pro API key: AIzaSyC45g...

🎯 API configuration complete!


## 🧪 Test 1: Basic API Connection

In [14]:
from google.generativeai.types import HarmCategory, HarmBlockThreshold

safety_settings = [
    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]


async def test_basic_generation():
    """Test if models can generate basic responses"""
    
    test_prompt = "What is machine learning? Give a brief answer."
    
    try:
        # Test Flash model
        print("Testing Gemini Flash...")
        flash_response = await flash_model.generate_content_async(
            test_prompt,
            generation_config=genai.GenerationConfig(
                temperature=0.7,
                max_output_tokens=2048,
                candidate_count=1
            ),
            safety_settings={
                "HARASSMENT": "block_none",
                "HATE": "block_none"
            }
        )
        
        if flash_response.candidates:
            candidate = flash_response.candidates[0]
            if candidate.content and candidate.content.parts:
                flash_text = candidate.content.parts[0].text
                print(f"✅ Flash response: {flash_text[:100]}...")
            else:
                print(f"❌ Flash response missing parts or content. Full candidate: {candidate}")
        else:
            print("❌ Flash response has no candidates.")
            print(f"🔍 Full Flash response: {flash_response}")
            return False
            
    except Exception as e:
        print(f"❌ Flash error: {e}")
        return False
    
    try:
        # Test Pro model (for ArmoRM)
        print("\nTesting Gemini Pro...")
        pro_response = await pro_model.generate_content_async(
            test_prompt,
            generation_config=genai.GenerationConfig(
                temperature=0.1,
                max_output_tokens=2048
            ),
            safety_settings=safety_settings
        )
                
        if pro_response.candidates:
            candidate = pro_response.candidates[0]
            if candidate.content and candidate.content.parts:
                pro_text = candidate.content.parts[0].text
                print(f"✅ Pro response: {pro_text[:100]}...")
            else:
                print(f"❌ Pro response missing parts or content. Full candidate: {candidate}")
        else:
            print("❌ Pro response has no candidates.")
            print(f"🔍 Full Pro response: {pro_response}")
            return False
            
    except Exception as e:
        print(f"❌ Pro error: {e}")
        return False
    
    return True
api_test_result = await test_basic_generation()
print(f"\n🎯 API Test Result: {'PASSED' if api_test_result else 'FAILED'}")


Testing Gemini Flash...
✅ Flash response: Machine learning is a type of artificial intelligence that allows software applications to become mo...

Testing Gemini Pro...
✅ Pro response: Machine learning is a subset of artificial intelligence that enables computer systems to learn from ...

🎯 API Test Result: PASSED


## 🧪 Test 2: Fixed ArmoRM Implementation

In [18]:
from dataclasses import dataclass, field
from typing import Dict, Any
import asyncio
import json
import numpy as np
import google.generativeai as genai
from google.generativeai import GenerationConfig

@dataclass
class RewardComponents:
    helpfulness: float = 0.0
    correctness: float = 0.0
    clarity: float = 0.0

    def to_dict(self) -> Dict[str, float]:
        return {
            "helpfulness": self.helpfulness,
            "correctness": self.correctness,
            "clarity": self.clarity
        }

@dataclass
class ArmoRMResult:
    prompt: str
    response: str
    reward_components: RewardComponents
    composite_score: float
    reasoning: str = ""
    metadata: Dict[str, Any] = field(default_factory=dict)

class TestArmoRM:
    def __init__(self, pro_model):
        self.pro_model = pro_model
        self.weights = {
            "helpfulness": 0.4,
            "correctness": 0.4,
            "clarity": 0.2
        }
        self.safety_settings = [
            {"category": "HARM_CATEGORY_HARASSMENT",       "threshold": "BLOCK_NONE"},
            {"category": "HARM_CATEGORY_HATE_SPEECH",      "threshold": "BLOCK_NONE"},
            {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold": "BLOCK_NONE"},
            {"category": "HARM_CATEGORY_DANGEROUS_CONTENT","threshold": "BLOCK_NONE"},
        ]

    def _create_eval_prompt(self, prompt: str, response: str) -> str:
        return (
            f"Evaluate this response (0-10) for helpfulness, correctness, clarity.\n\n"
            f"QUESTION: {prompt}\n\n"
            f"RESPONSE: {response}\n\n"
            "Return ONLY JSON: {\"helpfulness\":7.5, \"correctness\":8.0, \"clarity\":7.0}"
        )

    def _clean_json(self, text: str) -> str:
        s, e = text.find('{'), text.rfind('}') + 1
        return text[s:e] if s != -1 and e > s else '{"helpfulness":5,"correctness":5,"clarity":5}'

    def _validate_scores(self, scores: Dict) -> Dict[str, float]:
        out = {}
        for k in self.weights:
            try:
                v = float(scores.get(k, 5.0))
                out[k] = max(0.0, min(10.0, v))
            except (ValueError, TypeError):
                out[k] = 5.0
        return out

    def _compute_composite(self, comp: RewardComponents) -> float:
        d = comp.to_dict()
        return round(sum(self.weights[k] * d[k] for k in d), 2)

    def _create_fallback_result(self, prompt: str, response: str) -> ArmoRMResult:
        length_score = min(8.0, 4.0 + len(response.split()) / 50)
        ex = 1.0 if "example" in response.lower() else 0.0
        code = 1.0 if "```" in response else 0.0
        base = 5.0 + 0.5 * ex + 0.5 * code
        comps = RewardComponents(
            helpfulness=min(10.0, base + length_score / 10),
            correctness=base,
            clarity=min(10.0, base + 0.3)
        )
        return ArmoRMResult(
            prompt=prompt,
            response=response,
            reward_components=comps,
            composite_score=self._compute_composite(comps),
            reasoning="Fallback heuristic evaluation",
            metadata={"method": "fallback"}
        )

    async def evaluate_response(self, prompt: str, response: str) -> ArmoRMResult:
        eval_prompt = self._create_eval_prompt(prompt, response)

        for attempt in range(3):
            try:
                result = await self.pro_model.generate_content_async(
                    eval_prompt,
                    generation_config=GenerationConfig(
                        temperature=0.1,
                        max_output_tokens=2048,
                        candidate_count=1
                    ),
                    safety_settings=self.safety_settings
                )

                cands = getattr(result, "candidates", []) or []
                if not cands:
                    print(f"Attempt {attempt+1}: no candidates, retry")
                    await asyncio.sleep(1)
                    continue

                content = cands[0].content

                # ✅ Lấy text từ .text hoặc .parts[0].text
                raw = getattr(content, "text", None)
                if not raw and hasattr(content, "parts") and content.parts:
                    raw = content.parts[0].text

                if not raw:
                    print(f"Attempt {attempt+1}: no usable content, retry")
                    await asyncio.sleep(1)
                    continue

                raw = raw.strip()
                print(f"[Armo][Attempt {attempt+1}] Raw JSON: {raw[:80]}...")

                cleaned = self._clean_json(raw)
                scores  = json.loads(cleaned)
                valid   = self._validate_scores(scores)

                comps     = RewardComponents(**valid)
                composite = self._compute_composite(comps)

                return ArmoRMResult(
                    prompt=prompt,
                    response=response,
                    reward_components=comps,
                    composite_score=composite,
                    reasoning=f"api ok #{attempt+1}",
                    metadata={"attempt": attempt+1, "method": "api"}
                )

            except Exception as e:
                print(f"Attempt {attempt+1}: error - {e}")
                await asyncio.sleep(1)

        return self._create_fallback_result(prompt, response)

In [19]:
# Test ArmoRM
async def test_armo_rm():
    """Test ArmoRM evaluation"""
    
    print("🧪 Testing ArmoRM Evaluation...")
    
    # Initialize ArmoRM
    armo = TestArmoRM(pro_model)
    
    # Test cases
    test_cases = [
        {
            "prompt": "What is LangChain?",
            "response": "LangChain is a framework for building applications with large language models. It provides tools for chaining operations, managing prompts, and integrating with various data sources. For example, you can use it to create chatbots or document analysis systems."
        },
        {
            "prompt": "How to create a simple chatbot?", 
            "response": "Use LangChain components to build a chatbot. Import the necessary modules and configure your model."
        }
    ]
    
    results = []
    
    for i, test_case in enumerate(test_cases):
        print(f"\n--- Test Case {i+1} ---")
        print(f"Prompt: {test_case['prompt']}")
        print(f"Response: {test_case['response'][:80]}...")
        
        # Evaluate
        result = await armo.evaluate_response(
            test_case['prompt'],
            test_case['response']
        )
        
        results.append(result)
        
        # Show results
        print(f"Composite Score: {result.composite_score:.1f}")
        print(f"Components: {result.reward_components.to_dict()}")
        print(f"Method: {result.metadata.get('method', 'unknown')}")
        
        # Small delay
        await asyncio.sleep(2)
    
    # Summary
    avg_score = np.mean([r.composite_score for r in results])
    print(f"\n🎯 ArmoRM Test Results:")
    print(f"   Average Score: {avg_score:.1f}")
    print(f"   Tests Completed: {len(results)}")
    
    return results

# Run ArmoRM test
if api_test_result:
    armo_results = await test_armo_rm()
    armo_test_passed = len(armo_results) > 0 and all(r.composite_score > 0 for r in armo_results)
    print(f"\n🎯 ArmoRM Test: {'PASSED' if armo_test_passed else 'FAILED'}")
else:
    print("⏭️ Skipping ArmoRM test (API test failed)")
    armo_test_passed = False

🧪 Testing ArmoRM Evaluation...

--- Test Case 1 ---
Prompt: What is LangChain?
Response: LangChain is a framework for building applications with large language models. I...
[Armo][Attempt 1] Raw JSON: ```json
{"helpfulness":7.5, "correctness":8.0, "clarity":7.0}
```...
Composite Score: 7.6
Components: {'helpfulness': 7.5, 'correctness': 8.0, 'clarity': 7.0}
Method: api

--- Test Case 2 ---
Prompt: How to create a simple chatbot?
Response: Use LangChain components to build a chatbot. Import the necessary modules and co...
[Armo][Attempt 1] Raw JSON: ```json
{
  "helpfulness": 6.0,
  "correctness": 9.0,
  "clarity": 8.0
}
```...
Composite Score: 7.6
Components: {'helpfulness': 6.0, 'correctness': 9.0, 'clarity': 8.0}
Method: api

🎯 ArmoRM Test Results:
   Average Score: 7.6
   Tests Completed: 2

🎯 ArmoRM Test: PASSED


## 🧪 Test 3: Policy Configuration

In [20]:
# Policy Configuration Test
@dataclass
class PolicyConfig:
    """Fixed policy configuration"""
    model_name: str = "gemini-1.5-flash"
    temperature: float = 0.7
    top_p: float = 0.9
    max_output_tokens: int = 2048
    policy_type: str = "main"
    system_prompt: str = ""
    
    def to_generation_config(self):
        """Convert to GenerationConfig without thinking_budget"""
        return genai.GenerationConfig(
            temperature=self.temperature,
            top_p=self.top_p,
            max_output_tokens=self.max_output_tokens,
            candidate_count=1
        )

async def test_policy_generation():
    """Test policy-based content generation"""
    
    print("🧪 Testing Policy Generation...")
    
    # Create test policies
    policies = {
        "conservative": PolicyConfig(
            temperature=0.3,
            top_p=0.8,
            max_output_tokens=2048,
            system_prompt="You are a helpful and precise AI tutor. Give accurate, concise answers.",
            policy_type="conservative"
        ),
        "creative": PolicyConfig(
            temperature=0.9,
            top_p=0.95,
            max_output_tokens=2048,
            system_prompt="You are an engaging AI tutor. Use examples and creative explanations.",
            policy_type="creative"
        )
    }
    
    test_prompt = "Explain what machine learning is"
    results = {}
    
    for policy_name, policy in policies.items():
        print(f"\n--- Testing {policy_name} policy ---")
        
        try:
            # Create full prompt
            full_prompt = f"{policy.system_prompt}\n\nUser: {test_prompt}\n\nAssistant:"
            
            # Generate response
            response = await flash_model.generate_content_async(
                full_prompt,
                generation_config=policy.to_generation_config(),
                safety_settings={
                    "HARASSMENT": "block_none",
                    "HATE": "block_none"
                }
            )
            
            # Check response
            if response.candidates and response.candidates[0].content:
                text = response.candidates[0].content.parts[0].text.strip()
                results[policy_name] = {
                    "success": True,
                    "response": text,
                    "length": len(text.split()),
                    "temperature": policy.temperature
                }
                print(f"✅ Success! Response length: {len(text.split())} words")
                print(f"Preview: {text[:100]}...")
            else:
                results[policy_name] = {
                    "success": False,
                    "error": "Empty response",
                    "finish_reason": response.candidates[0].finish_reason if response.candidates else "No candidates"
                }
                print(f"❌ Empty response. Finish reason: {results[policy_name]['finish_reason']}")
                
        except Exception as e:
            results[policy_name] = {
                "success": False,
                "error": str(e)
            }
            print(f"❌ Error: {e}")
        
        # Delay between policies
        await asyncio.sleep(1)
    
    # Summary
    successful_policies = sum(1 for r in results.values() if r.get("success"))
    print(f"\n🎯 Policy Test Results:")
    print(f"   Successful Policies: {successful_policies}/{len(policies)}")
    
    return results

# Run policy test
if api_test_result:
    policy_results = await test_policy_generation()
    policy_test_passed = sum(1 for r in policy_results.values() if r.get("success")) > 0
    print(f"\n🎯 Policy Test: {'PASSED' if policy_test_passed else 'FAILED'}")
else:
    print("⏭️ Skipping policy test (API test failed)")
    policy_test_passed = False

🧪 Testing Policy Generation...

--- Testing conservative policy ---
✅ Success! Response length: 44 words
Preview: Machine learning is a type of artificial intelligence that allows software applications to become mo...

--- Testing creative policy ---
✅ Success! Response length: 441 words
Preview: Hey there!  Ready to dive into the amazing world of machine learning?  Forget robots taking over the...

🎯 Policy Test Results:
   Successful Policies: 2/2

🎯 Policy Test: PASSED


## 🧪 Test 4: Single RLHF Iteration

In [21]:
# Single RLHF Iteration Test
async def test_single_rlhf_iteration():
    """Test a single RLHF data collection iteration"""
    
    print("🧪 Testing Single RLHF Iteration...")
    
    if not (api_test_result and armo_test_passed and policy_test_passed):
        print("⏭️ Skipping RLHF test (prerequisites failed)")
        return False
    
    # Initialize components
    armo = TestArmoRM(pro_model)
    
    # Create two policies for comparison
    policy1 = PolicyConfig(
        temperature=0.5,
        system_prompt="You are a precise AI tutor for LangChain.",
        policy_type="main"
    )
    
    policy2 = PolicyConfig(
        temperature=0.8,
        system_prompt="You are a creative AI tutor for LangChain. Use examples.",
        policy_type="enhancer"
    )
    
    # Test prompts
    test_prompts = [
        "What is LangChain?",
        "How to create a chatbot with LangChain?",
        "Explain LangChain chains"
    ]
    
    iteration_data = []
    
    async def sample_from_policy(policy, prompt):
        """Sample response from policy with error handling"""
        full_prompt = f"{policy.system_prompt}\n\nUser: {prompt}\n\nAssistant:"
        
        try:
            response = await flash_model.generate_content_async(
                full_prompt,
                generation_config=policy.to_generation_config(),
                safety_settings={"HARASSMENT": "block_none", "HATE": "block_none"}
            )
            
            if response.candidates and response.candidates[0].content:
                return response.candidates[0].content.parts[0].text.strip()
            else:
                return f"EMPTY_RESPONSE (finish_reason: {response.candidates[0].finish_reason if response.candidates else 'no_candidates'})"
                
        except Exception as e:
            return f"ERROR: {str(e)}"
    
    # Collect data for 3 tuples
    for i, prompt in enumerate(test_prompts):
        print(f"\n--- Collecting tuple {i+1}: {prompt} ---")
        
        # Generate responses from both policies
        print("Sampling from policy 1...")
        response1 = await sample_from_policy(policy1, prompt)
        
        await asyncio.sleep(1)  # Rate limiting
        
        print("Sampling from policy 2...")
        response2 = await sample_from_policy(policy2, prompt)
        
        await asyncio.sleep(1)  # Rate limiting
        
        # Show responses
        print(f"Response 1: {response1[:60]}...")
        print(f"Response 2: {response2[:60]}...")
        
        # Evaluate responses with ArmoRM
        print("Evaluating responses...")
        
        eval1 = await armo.evaluate_response(prompt, response1)
        await asyncio.sleep(2)  # Rate limiting
        
        eval2 = await armo.evaluate_response(prompt, response2)
        await asyncio.sleep(2)  # Rate limiting
        
        # Determine preference
        if eval1.composite_score > eval2.composite_score:
            preference = 1
            chosen = response1
            rejected = response2
            strength = eval1.composite_score - eval2.composite_score
        else:
            preference = 2
            chosen = response2
            rejected = response1
            strength = eval2.composite_score - eval1.composite_score
        
        # Store data
        tuple_data = {
            'prompt': prompt,
            'response_1': response1,
            'response_2': response2,
            'preference': preference,
            'preference_strength': strength,
            'scores': {
                'response_1': eval1.composite_score,
                'response_2': eval2.composite_score
            },
            'evaluations': {
                'response_1': eval1,
                'response_2': eval2
            }
        }
        
        iteration_data.append(tuple_data)
        
        print(f"Preference: Response {preference} (strength: {strength:.2f})")
        print(f"Scores: R1={eval1.composite_score:.1f}, R2={eval2.composite_score:.1f}")
    
    # Summary
    print(f"\n🎯 Single RLHF Iteration Results:")
    print(f"   Tuples Collected: {len(iteration_data)}")
    
    # Analyze preferences
    avg_strength = np.mean([d['preference_strength'] for d in iteration_data])
    policy1_wins = sum(1 for d in iteration_data if d['preference'] == 1)
    policy2_wins = sum(1 for d in iteration_data if d['preference'] == 2)
    
    print(f"   Average Preference Strength: {avg_strength:.2f}")
    print(f"   Policy 1 Wins: {policy1_wins}")
    print(f"   Policy 2 Wins: {policy2_wins}")
    
    # Check for errors
    errors = sum(1 for d in iteration_data 
                if 'ERROR' in d['response_1'] or 'ERROR' in d['response_2'] or 'EMPTY' in d['response_1'] or 'EMPTY' in d['response_2'])
    
    print(f"   Responses with Errors: {errors}")
    
    success = len(iteration_data) == 3 and errors == 0
    return success, iteration_data

# Run single iteration test
rlhf_success, rlhf_data = await test_single_rlhf_iteration()
print(f"\n🎯 RLHF Iteration Test: {'PASSED' if rlhf_success else 'FAILED'}")

🧪 Testing Single RLHF Iteration...

--- Collecting tuple 1: What is LangChain? ---
Sampling from policy 1...
Sampling from policy 2...
Response 1: LangChain is a framework for developing applications powered...
Response 2: User: What is LangChain?

Assistant:

Hey there! LangChain i...
Evaluating responses...
[Armo][Attempt 1] Raw JSON: ```json
{"helpfulness":7.5, "correctness":8.0, "clarity":7.0}
```...
Attempt 1: no usable content, retry
[Armo][Attempt 2] Raw JSON: ```json
{"helpfulness":7.5, "correctness":8.0, "clarity":7.0}
```...
Preference: Response 2 (strength: 0.00)
Scores: R1=7.6, R2=7.6

--- Collecting tuple 2: How to create a chatbot with LangChain? ---
Sampling from policy 1...
Sampling from policy 2...
Response 1: Creating a chatbot with LangChain involves several steps, an...
Response 2: Let's build a chatbot with LangChain!  We'll explore a few a...
Evaluating responses...
Attempt 1: no usable content, retry
[Armo][Attempt 2] Raw JSON: ```json
{"helpfulness":7.5, "correc

## 🧪 Test 5: Full System Integration (Mini)

In [22]:
# Mini Full System Test
class MiniRLHFSystem:
    """Minimal RLHF system for testing"""
    
    def __init__(self, flash_model, pro_model):
        self.flash_model = flash_model
        self.pro_model = pro_model
        self.armo = TestArmoRM(pro_model)
        self.historical_data = []
    
    def create_reference_policy(self):
        """Create reference policy"""
        return PolicyConfig(
            temperature=0.7,
            system_prompt="You are a helpful AI tutor for LangChain.",
            policy_type="reference"
        )
    
    def create_policy_pair(self, iteration: int):
        """Create main and enhancer policies"""
        
        # Main policy (exploitation)
        main_policy = PolicyConfig(
            temperature=max(0.5, 0.7 - 0.1 * iteration),
            system_prompt="You are an expert AI tutor for LangChain. Provide clear, accurate explanations.",
            policy_type="main"
        )
        
        # Enhancer policy (exploration)
        enhancer_policy = PolicyConfig(
            temperature=min(1.0, main_policy.temperature + 0.2),
            system_prompt="You are a creative AI tutor for LangChain. Use examples and engaging explanations.",
            policy_type="enhancer"
        )
        
        return main_policy, enhancer_policy
    
    async def sample_from_policy(self, policy, prompt):
        """Sample response from policy"""
        full_prompt = f"{policy.system_prompt}\n\nUser: {prompt}\n\nAssistant:"
        
        try:
            response = await self.flash_model.generate_content_async(
                full_prompt,
                generation_config=policy.to_generation_config(),
                safety_settings={"HARASSMENT": "block_none", "HATE": "block_none"}
            )
            
            if response.candidates and response.candidates[0].content:
                return response.candidates[0].content.parts[0].text.strip()
            else:
                return "EMPTY_RESPONSE"
                
        except Exception as e:
            return f"ERROR: {str(e)}"
    
    async def collect_preferences(self, policy_pair, prompts, num_samples):
        """Collect preference data"""
        
        main_policy, enhancer_policy = policy_pair
        data = []
        
        for i in range(num_samples):
            # Sample prompt
            prompt = prompts[i % len(prompts)]
            
            print(f"Collecting sample {i+1}/{num_samples}: {prompt[:30]}...")
            
            # Generate responses
            response1 = await self.sample_from_policy(main_policy, prompt)
            await asyncio.sleep(1)
            
            response2 = await self.sample_from_policy(enhancer_policy, prompt)
            await asyncio.sleep(1)
            
            # Evaluate
            eval1 = await self.armo.evaluate_response(prompt, response1)
            await asyncio.sleep(2)
            
            eval2 = await self.armo.evaluate_response(prompt, response2)
            await asyncio.sleep(2)
            
            # Determine preference
            if eval1.composite_score > eval2.composite_score:
                preference = 1
                strength = eval1.composite_score - eval2.composite_score
            else:
                preference = 2
                strength = eval2.composite_score - eval1.composite_score
            
            data.append({
                'prompt': prompt,
                'response_1': response1,
                'response_2': response2,
                'preference': preference,
                'strength': strength,
                'scores': [eval1.composite_score, eval2.composite_score]
            })
            
            print(f"  Preference: {preference} (strength: {strength:.2f})")
        
        return data
    
    async def run_mini_rlhf(self, prompts, iterations=2, samples_per_iteration=3):
        """Run mini RLHF training"""
        
        print(f"🚀 Starting Mini RLHF Training")
        print(f"   Iterations: {iterations}")
        print(f"   Samples per iteration: {samples_per_iteration}")
        print(f"   Prompts: {len(prompts)}")
        
        all_iterations = []
        
        for t in range(iterations):
            print(f"\n=== Iteration {t+1}/{iterations} ===")
            
            # Create policy pair
            policy_pair = self.create_policy_pair(t)
            
            # Collect data
            iteration_data = await self.collect_preferences(
                policy_pair, prompts, samples_per_iteration
            )
            
            # Update historical data
            self.historical_data.extend(iteration_data)
            
            # Analyze iteration
            avg_score = np.mean([np.mean(d['scores']) for d in iteration_data])
            avg_strength = np.mean([d['strength'] for d in iteration_data])
            
            iteration_summary = {
                'iteration': t,
                'data': iteration_data,
                'avg_score': avg_score,
                'avg_strength': avg_strength,
                'total_samples': len(iteration_data)
            }
            
            all_iterations.append(iteration_summary)
            
            print(f"Iteration {t+1} complete:")
            print(f"  Average Score: {avg_score:.2f}")
            print(f"  Average Strength: {avg_strength:.2f}")
            print(f"  Total Historical Data: {len(self.historical_data)}")
        
        return all_iterations

async def test_full_system():
    """Test full system integration"""
    
    print("🧪 Testing Full System Integration...")
    
    if not rlhf_success:
        print("⏭️ Skipping full system test (RLHF iteration failed)")
        return False
    
    # Initialize mini system
    mini_rlhf = MiniRLHFSystem(flash_model, pro_model)
    
    # Test prompts
    test_prompts = [
        "What is LangChain and why is it useful?",
        "How do you create a simple chatbot with LangChain?",
        "Explain LangChain chains with an example",
        "What are LangChain agents?"
    ]
    
    try:
        # Run mini RLHF
        results = await mini_rlhf.run_mini_rlhf(
            prompts=test_prompts,
            iterations=2,
            samples_per_iteration=2  # Very small for testing
        )
        
        # Analyze results
        total_samples = sum(len(r['data']) for r in results)
        final_avg_score = results[-1]['avg_score'] if results else 0
        
        print(f"\n🎯 Full System Test Results:")
        print(f"   Iterations Completed: {len(results)}")
        print(f"   Total Samples Collected: {total_samples}")
        print(f"   Final Average Score: {final_avg_score:.2f}")
        
        # Check for success
        success = (
            len(results) == 2 and  # Completed both iterations
            total_samples >= 4 and  # Collected expected samples
            final_avg_score > 0  # Got valid scores
        )
        
        return success, results
        
    except Exception as e:
        print(f"❌ Full system test error: {e}")
        return False, None

# Run full system test
full_system_success, full_results = await test_full_system()
print(f"\n🎯 Full System Test: {'PASSED' if full_system_success else 'FAILED'}")

🧪 Testing Full System Integration...
🚀 Starting Mini RLHF Training
   Iterations: 2
   Samples per iteration: 2
   Prompts: 4

=== Iteration 1/2 ===
Collecting sample 1/2: What is LangChain and why is i...
[Armo][Attempt 1] Raw JSON: ```json
{"helpfulness":9.0, "correctness":9.5, "clarity":9.0}
```...
[Armo][Attempt 1] Raw JSON: ```json
{"helpfulness":8.5, "correctness":9.0, "clarity":8.5}
```...
  Preference: 1 (strength: 0.50)
Collecting sample 2/2: How do you create a simple cha...
Attempt 1: no usable content, retry
[Armo][Attempt 2] Raw JSON: ```json...
Attempt 1: no usable content, retry
Attempt 2: no usable content, retry
Attempt 3: no usable content, retry
  Preference: 2 (strength: 1.38)
Iteration 1 complete:
  Average Score: 7.32
  Average Strength: 0.94
  Total Historical Data: 2

=== Iteration 2/2 ===
Collecting sample 1/2: What is LangChain and why is i...
Attempt 1: no usable content, retry
[Armo][Attempt 2] Raw JSON: ```json
{
  "helpfulness": 9.5,
  "correctness": 10.0,

## 📊 Overall Test Summary

In [None]:
# Test Summary
test_results = {
    "API Connection": api_test_result,
    "ArmoRM Evaluation": armo_test_passed,
    "Policy Generation": policy_test_passed,
    "RLHF Iteration": rlhf_success,
    "Full System": full_system_success
}

print("\n" + "="*50)
print("📊 EDUCATIONAL RLHF FRAMEWORK - TEST SUMMARY")
print("="*50)

for test_name, result in test_results.items():
    status = "✅ PASSED" if result else "❌ FAILED"
    print(f"{test_name:.<30} {status}")

passed_tests = sum(test_results.values())
total_tests = len(test_results)

print(f"\nOverall: {passed_tests}/{total_tests} tests passed ({passed_tests/total_tests*100:.1f}%)")

if passed_tests == total_tests:
    print("\n🎉 ALL TESTS PASSED! System is ready for production.")
elif passed_tests >= 3:
    print("\n⚠️  PARTIAL SUCCESS. Core components working, some issues remain.")
else:
    print("\n❌ MAJOR ISSUES. Please check API keys and network connectivity.")

# Recommendations
print("\n📝 RECOMMENDATIONS:")

if not api_test_result:
    print("• Check API keys and network connectivity")
    print("• Verify API quotas and billing")

if not armo_test_passed:
    print("• ArmoRM may need simpler prompts")
    print("• Consider increasing retry delays")

if not policy_test_passed:
    print("• Check safety settings")
    print("• Review system prompts for policy violations")

if not rlhf_success:
    print("• Reduce batch sizes further")
    print("• Increase delays between API calls")

if not full_system_success:
    print("• Consider running with even smaller parameters")
    print("• Check rate limiting and quotas")

if passed_tests == total_tests:
    print("\n🚀 NEXT STEPS:")
    print("• Scale up batch sizes gradually")
    print("• Add more sophisticated prompts")
    print("• Implement production monitoring")
    print("• Deploy full Educational RLHF system")

print("\n" + "="*50)

## 🔍 Diagnostic Information

In [None]:
# Diagnostic Information
print("🔍 DIAGNOSTIC INFORMATION")
print("="*30)

print(f"Environment:")
print(f"  Python version: {sys.version.split()[0]}")
print(f"  Google GenerativeAI: {genai.__version__}")
print(f"  NumPy: {np.__version__}")

print(f"\nAPI Configuration:")
print(f"  Flash API Key: {'Set' if FLASH_API_KEY != 'your_flash_api_key_here' else 'Not Set'}")
print(f"  Pro API Key: {'Set' if PRO_API_KEY != 'your_pro_api_key_here' else 'Not Set'}")

if 'armo_results' in globals() and armo_results:
    print(f"\nArmoRM Results:")
    for i, result in enumerate(armo_results):
        print(f"  Test {i+1}: Score {result.composite_score:.1f} (Method: {result.metadata.get('method', 'unknown')})")

if 'policy_results' in globals() and policy_results:
    print(f"\nPolicy Results:")
    for name, result in policy_results.items():
        if result.get('success'):
            print(f"  {name}: ✅ {result.get('length', 0)} words")
        else:
            print(f"  {name}: ❌ {result.get('error', 'Unknown error')}")

if 'rlhf_data' in globals() and rlhf_data:
    print(f"\nRLHF Iteration Data:")
    for i, data in enumerate(rlhf_data):
        print(f"  Tuple {i+1}: Preference {data['preference']} (Strength: {data['preference_strength']:.2f})")

if 'full_results' in globals() and full_results:
    print(f"\nFull System Results:")
    for i, iteration in enumerate(full_results):
        print(f"  Iteration {i+1}: {iteration['total_samples']} samples, Avg Score: {iteration['avg_score']:.2f}")

print("\n" + "="*30)

## 🎯 Ready for Production?

**If all tests passed**, your Educational RLHF Framework is ready for production deployment!

**Next steps:**
1. Save this notebook as a reference
2. Use the fixed components in your main application
3. Gradually scale up batch sizes and iterations
4. Implement production monitoring and logging
5. Deploy with proper error handling and retry mechanisms

**If some tests failed**, review the diagnostic information above and:
1. Check API keys and quotas
2. Verify network connectivity
3. Adjust safety settings if needed
4. Reduce batch sizes and increase delays
5. Simplify prompts if necessary