# SCARL Framework Implementation

In [27]:
"""
SCARL: Self-Corrective Agentic Reinforcement Learning Framework
Implementation based on my research paper proposing a novel RL framework
for self-corrective NLP agents.
"""

import numpy as np
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from enum import Enum
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque


class ActionType(Enum):
    """Defines the types of actions available to the agent."""
    PRIMARY = "primary"
    CORRECTIVE = "corrective"


class CorrectiveAction(Enum):
    """Specific corrective actions available to the agent."""
    REPLAN = "replan"
    SEARCH = "search"
    CRITIQUE = "critique"
    ASK = "ask"


@dataclass
class State:
    """
    Augmented state space for SCARL.
    S_t = {S_context, S_memory, S_internal}
    """
    context: Any  # Current context (e.g., conversation history, task state)
    memory: List[Any]  # History of past actions and states
    internal: Dict[str, Any]  # Internal reflection state
    
    def __init__(self, context: Any, memory: Optional[List] = None, internal: Optional[Dict] = None):
        self.context = context
        self.memory = memory if memory is not None else []
        self.internal = internal if internal is not None else {
            'r_meta': 0.0,
            'correction_history': [],
            'confidence': 1.0
        }


@dataclass
class Action:
    """Represents an action with its type and content."""
    action_type: ActionType
    content: Any
    corrective_type: Optional[CorrectiveAction] = None


class MetaRewardGenerator(nn.Module):
    """
    Generates the intrinsic meta-reward (R_meta) signal.
    This evaluates the quality/confidence of the agent's current trajectory.
    """
    def __init__(self, state_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Output between 0 and 1
        )
    
    def forward(self, state_embedding: torch.Tensor) -> torch.Tensor:
        """
        Returns R_meta score between 0 and 1.
        Higher values indicate higher confidence/quality.
        """
        return self.network(state_embedding)


class PrimaryPolicy(nn.Module):
    """
    Primary policy (π_P) trained to maximize R_ext.
    Generates task-level actions.
    """
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state_embedding: torch.Tensor) -> torch.Tensor:
        """Returns action probability distribution."""
        return self.network(state_embedding)
    
    def select_action(self, state_embedding: torch.Tensor) -> int:
        """Sample action from policy."""
        probs = self.forward(state_embedding)
        action = torch.multinomial(probs, 1)
        return action.item()


class CorrectivePolicy(nn.Module):
    """
    Corrective policy (π_C) trained to maximize R_ext following low R_meta.
    Selects from corrective action set: REPLAN, SEARCH, CRITIQUE, ASK.
    """
    def __init__(self, state_dim: int, hidden_dim: int = 256):
        super().__init__()
        num_corrective_actions = len(CorrectiveAction)
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_corrective_actions),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, state_embedding: torch.Tensor) -> torch.Tensor:
        """Returns corrective action probability distribution."""
        return self.network(state_embedding)
    
    def select_action(self, state_embedding: torch.Tensor) -> CorrectiveAction:
        """Sample corrective action from policy."""
        probs = self.forward(state_embedding)
        action_idx = torch.multinomial(probs, 1).item()
        return list(CorrectiveAction)[action_idx]


class StateEncoder(nn.Module):
    """
    Encodes the augmented state into a fixed-dimensional embedding.
    Handles context, memory, and internal state components.
    """
    def __init__(self, context_dim: int, embedding_dim: int = 256):
        super().__init__()
        self.context_encoder = nn.Linear(context_dim, embedding_dim // 2)
        self.internal_encoder = nn.Linear(10, embedding_dim // 4)  # Fixed internal features
        self.memory_encoder = nn.LSTM(context_dim, embedding_dim // 4, batch_first=True)
        self.combiner = nn.Linear(embedding_dim, embedding_dim)
    
    def forward(self, state: State) -> torch.Tensor:
        """Encode state into embedding."""
        # Encode context (simplified - assumes tensor input)
        context_emb = torch.relu(self.context_encoder(state.context))
        
        # Encode internal state
        internal_features = torch.tensor([
            state.internal.get('r_meta', 0.0),
            state.internal.get('confidence', 1.0),
            len(state.internal.get('correction_history', [])),
            # Additional internal features can be added
        ] + [0.0] * 7, dtype=torch.float32)
        internal_emb = torch.relu(self.internal_encoder(internal_features))
        
        # Encode memory (simplified)
        if state.memory:
            memory_tensor = torch.stack(state.memory[-10:])  # Last 10 states
            _, (memory_emb, _) = self.memory_encoder(memory_tensor.unsqueeze(0))
            memory_emb = memory_emb.squeeze(0)
        else:
            memory_emb = torch.zeros(context_emb.shape[0] // 2)
        
        # Combine all embeddings
        combined = torch.cat([context_emb, internal_emb, memory_emb])
        return torch.relu(self.combiner(combined))


class SCARLAgent:
    """
    Main SCARL Agent implementing the self-corrective framework.
    
    Key Features:
    - Dual policy system (primary + corrective)
    - Meta-reward based policy switching
    - Online introspection and error correction
    """
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        lambda_introspection: float = 0.3,
        r_meta_threshold: float = 0.5,
        gamma: float = 0.99,
        learning_rate: float = 1e-4
    ):
        """
        Initialize SCARL agent.
        
        Args:
            state_dim: Dimension of state encoding
            action_dim: Number of primary actions
            lambda_introspection: Weight for meta-reward (λ in R_total = R_ext + λR_meta)
            r_meta_threshold: Threshold for triggering corrective policy
            gamma: Discount factor
            learning_rate: Learning rate for all networks
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.lambda_introspection = lambda_introspection
        self.r_meta_threshold = r_meta_threshold
        self.gamma = gamma
        
        # Initialize networks
        self.state_encoder = StateEncoder(state_dim)
        self.meta_reward_gen = MetaRewardGenerator(state_dim)
        self.primary_policy = PrimaryPolicy(state_dim, action_dim)
        self.corrective_policy = CorrectivePolicy(state_dim)
        
        # Value networks for training (Actor-Critic)
        self.value_network = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
        # Optimizers
        self.optimizer_primary = optim.Adam(self.primary_policy.parameters(), lr=learning_rate)
        self.optimizer_corrective = optim.Adam(self.corrective_policy.parameters(), lr=learning_rate)
        self.optimizer_meta = optim.Adam(self.meta_reward_gen.parameters(), lr=learning_rate)
        self.optimizer_value = optim.Adam(self.value_network.parameters(), lr=learning_rate)
        
        # Experience replay buffer
        self.experience_buffer = deque(maxlen=10000)
        
    def select_action(self, state: State) -> Tuple[Action, float]:
        """
        Main action selection with policy switching logic.
        
        Returns:
            (selected_action, r_meta_score)
        """
        # Encode state
        state_emb = self.state_encoder(state)
        
        # Generate meta-reward
        with torch.no_grad():
            r_meta = self.meta_reward_gen(state_emb).item()
        
        # Update internal state
        state.internal['r_meta'] = r_meta
        
        # Policy switching logic
        if r_meta < self.r_meta_threshold:
            # Trigger corrective policy
            corrective_action = self.corrective_policy.select_action(state_emb)
            state.internal['correction_history'].append({
                'timestep': len(state.memory),
                'r_meta': r_meta,
                'action': corrective_action
            })
            return Action(
                action_type=ActionType.CORRECTIVE,
                content=corrective_action.value,
                corrective_type=corrective_action
            ), r_meta
        else:
            # Use primary policy
            primary_action_idx = self.primary_policy.select_action(state_emb)
            return Action(
                action_type=ActionType.PRIMARY,
                content=primary_action_idx,
                corrective_type=None
            ), r_meta
    
    def compute_total_reward(self, r_ext: float, r_meta: float) -> float:
        """
        Compute composite reward: R_total = R_ext + λR_meta
        """
        return r_ext + self.lambda_introspection * r_meta
    
    def store_experience(
        self,
        state: State,
        action: Action,
        r_ext: float,
        r_meta: float,
        next_state: State,
        done: bool
    ):
        """Store experience tuple for training."""
        self.experience_buffer.append({
            'state': state,
            'action': action,
            'r_ext': r_ext,
            'r_meta': r_meta,
            'next_state': next_state,
            'done': done
        })
    
    def train_step(self, batch_size: int = 32) -> Dict[str, float]:
        """
        Perform one training step using experience replay.
        Uses PPO-style updates for both policies.
        
        Returns:
            Dictionary of loss values
        """
        if len(self.experience_buffer) < batch_size:
            return {}
        
        # Sample batch
        indices = np.random.choice(len(self.experience_buffer), batch_size, replace=False)
        batch = [self.experience_buffer[i] for i in indices]
        
        losses = {}
        
        # Prepare batch data
        states = [exp['state'] for exp in batch]
        actions = [exp['action'] for exp in batch]
        r_exts = torch.tensor([exp['r_ext'] for exp in batch], dtype=torch.float32)
        r_metas = torch.tensor([exp['r_meta'] for exp in batch], dtype=torch.float32)
        next_states = [exp['next_state'] for exp in batch]
        dones = torch.tensor([exp['done'] for exp in batch], dtype=torch.float32)
        
        # Encode states
        state_embs = torch.stack([self.state_encoder(s) for s in states])
        next_state_embs = torch.stack([self.state_encoder(s) for s in next_states])
        
        # Compute values
        values = self.value_network(state_embs).squeeze()
        next_values = self.value_network(next_state_embs).squeeze()
        
        # Compute total rewards
        r_totals = r_exts + self.lambda_introspection * r_metas
        
        # Compute advantages (TD error)
        targets = r_totals + self.gamma * next_values * (1 - dones)
        advantages = targets - values
        
        # Update value network
        value_loss = nn.MSELoss()(values, targets.detach())
        self.optimizer_value.zero_grad()
        value_loss.backward()
        self.optimizer_value.step()
        losses['value_loss'] = value_loss.item()
        
        # Update primary policy (only for primary actions)
        primary_mask = torch.tensor([a.action_type == ActionType.PRIMARY for a in actions])
        if primary_mask.sum() > 0:
            primary_state_embs = state_embs[primary_mask]
            primary_actions = torch.tensor([
                a.content for a in actions if a.action_type == ActionType.PRIMARY
            ], dtype=torch.long)
            primary_advantages = advantages[primary_mask].detach()
            
            action_probs = self.primary_policy(primary_state_embs)
            log_probs = torch.log(action_probs.gather(1, primary_actions.unsqueeze(1)) + 1e-8)
            policy_loss = -(log_probs.squeeze() * primary_advantages).mean()
            
            self.optimizer_primary.zero_grad()
            policy_loss.backward()
            self.optimizer_primary.step()
            losses['primary_policy_loss'] = policy_loss.item()
        
        # Update corrective policy (only for corrective actions)
        corrective_mask = torch.tensor([a.action_type == ActionType.CORRECTIVE for a in actions])
        if corrective_mask.sum() > 0:
            corrective_state_embs = state_embs[corrective_mask]
            corrective_actions = torch.tensor([
                list(CorrectiveAction).index(a.corrective_type)
                for a in actions if a.action_type == ActionType.CORRECTIVE
            ], dtype=torch.long)
            corrective_advantages = advantages[corrective_mask].detach()
            
            action_probs = self.corrective_policy(corrective_state_embs)
            log_probs = torch.log(action_probs.gather(1, corrective_actions.unsqueeze(1)) + 1e-8)
            corrective_loss = -(log_probs.squeeze() * corrective_advantages).mean()
            
            self.optimizer_corrective.zero_grad()
            corrective_loss.backward()
            self.optimizer_corrective.step()
            losses['corrective_policy_loss'] = corrective_loss.item()
        
        # Update meta-reward generator to predict R_ext
        # (Training R_meta to be predictive of future R_ext)
        predicted_meta = self.meta_reward_gen(state_embs).squeeze()
        meta_loss = nn.MSELoss()(predicted_meta, r_exts.detach())
        
        self.optimizer_meta.zero_grad()
        meta_loss.backward()
        self.optimizer_meta.step()
        losses['meta_reward_loss'] = meta_loss.item()
        
        return losses
    
    def set_introspection_weight(self, lambda_val: float):
        """Adjust the introspection weight (λ) to control caution vs. speed."""
        self.lambda_introspection = lambda_val
    
    def set_meta_threshold(self, threshold: float):
        """Adjust the R_meta threshold for triggering corrections."""
        self.r_meta_threshold = threshold


class SCARLTrainer:
    """
    Training coordinator for SCARL agent.
    Handles episode management and training loops.
    """
    def __init__(self, agent: SCARLAgent, environment: Any):
        self.agent = agent
        self.environment = environment
        self.episode_rewards = []
        self.episode_corrections = []
    
    def train_episode(self, max_steps: int = 100) -> Dict[str, float]:
        """
        Run one training episode.
        
        Returns:
            Episode statistics
        """
        state = self.environment.reset()
        total_r_ext = 0.0
        total_r_meta = 0.0
        num_corrections = 0
        
        for step in range(max_steps):
            # Select action
            action, r_meta = self.agent.select_action(state)
            
            # Execute action in environment
            next_state, r_ext, done, info = self.environment.step(action)
            
            # Track corrections
            if action.action_type == ActionType.CORRECTIVE:
                num_corrections += 1
            
            # Store experience
            self.agent.store_experience(state, action, r_ext, r_meta, next_state, done)
            
            # Train agent
            if step % 4 == 0:  # Train every 4 steps
                self.agent.train_step(batch_size=32)
            
            total_r_ext += r_ext
            total_r_meta += r_meta
            state = next_state
            
            if done:
                break
        
        stats = {
            'total_r_ext': total_r_ext,
            'avg_r_meta': total_r_meta / (step + 1),
            'num_corrections': num_corrections,
            'episode_length': step + 1
        }
        
        self.episode_rewards.append(total_r_ext)
        self.episode_corrections.append(num_corrections)
        
        return stats


# Demonstration
if __name__ == "__main__":
    print("SCARL Framework Implementation")
    print("=" * 50)
    
    # Initialize agent
    agent = SCARLAgent(
        state_dim=256,
        action_dim=10,
        lambda_introspection=0.3,
        r_meta_threshold=0.5
    )
    
    print(f"Agent initialized with:")
    print(f"  - Lambda (introspection weight): {agent.lambda_introspection}")
    print(f"  - R_meta threshold: {agent.r_meta_threshold}")
    print(f"  - State dimension: {agent.state_dim}")
    print(f"  - Action dimension: {agent.action_dim}")
    print()
    
    # Create dummy state
    dummy_context = torch.randn(256)
    dummy_state = State(context=dummy_context)
    
    print("Testing action selection:")
    action, r_meta = agent.select_action(dummy_state)
    print(f"  - Selected action type: {action.action_type.value}")
    print(f"  - R_meta score: {r_meta:.3f}")
    if action.action_type == ActionType.CORRECTIVE:
        print(f"  - Corrective action: {action.corrective_type.value}")
    
    print("\nSCARL Framework ready for training!")

SCARL Framework Implementation
Agent initialized with:
  - Lambda (introspection weight): 0.3
  - R_meta threshold: 0.5
  - State dimension: 256
  - Action dimension: 10

Testing action selection:
  - Selected action type: primary
  - R_meta score: 0.522

SCARL Framework ready for training!


In [29]:
"""
SCARL Integration with NLP Environments
Demonstrates how to integrate the SCARL framework with real NLP tasks
using language models and text-based environments.
"""

import torch
import torch.nn as nn
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import numpy as np
from transformers import AutoTokenizer, AutoModel
import re


class NLPStateEncoder(nn.Module):
    """
    Encodes NLP-specific state information using a pre-trained language model.
    Handles text context, conversation history, and internal agent state.
    """
    def __init__(self, model_name: str = "bert-base-uncased", embedding_dim: int = 256):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.language_model = AutoModel.from_pretrained(model_name)
        self.embedding_dim = embedding_dim
        
        # Freeze LM parameters for efficiency (optional)
        for param in self.language_model.parameters():
            param.requires_grad = False
        
        lm_hidden_size = self.language_model.config.hidden_size
        
        # Project LM outputs to embedding dimension
        self.context_projection = nn.Linear(lm_hidden_size, embedding_dim // 2)
        self.history_lstm = nn.LSTM(lm_hidden_size, embedding_dim // 4, batch_first=True)
        self.internal_encoder = nn.Linear(10, embedding_dim // 4)
        self.combiner = nn.Linear(embedding_dim, embedding_dim)
    
    def encode_text(self, text: str) -> torch.Tensor:
        """Encode text using language model. Returns shape [hidden_size]"""
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self.language_model(**inputs)
        # Return [CLS] token embedding, squeeze to 1D
        return outputs.last_hidden_state[0, 0, :]  # Shape: [hidden_size]
    
    def forward(self, state: 'NLPState') -> torch.Tensor:
        """Encode NLP state into fixed-dimensional embedding. Returns shape [embedding_dim]"""
        # Encode current context - shape: [hidden_size]
        context_emb = self.encode_text(state.context_text)
        # Project to embedding_dim // 2 - shape: [embedding_dim // 2]
        context_emb = torch.relu(self.context_projection(context_emb))
        
        # Encode conversation history
        if state.history and len(state.history) > 0:
            history_texts = state.history[-5:]  # Last 5 turns
            # Stack embeddings - shape: [num_turns, hidden_size]
            history_embs = torch.stack([self.encode_text(turn) for turn in history_texts])
            # LSTM expects [batch, seq, features] - add batch dim
            # Shape: [1, num_turns, hidden_size]
            _, (history_emb, _) = self.history_lstm(history_embs.unsqueeze(0))
            # Shape: [1, embedding_dim // 4] -> squeeze -> [embedding_dim // 4]
            history_emb = history_emb.squeeze(0).squeeze(0)
        else:
            # Create zero embedding - shape: [embedding_dim // 4]
            history_emb = torch.zeros(self.embedding_dim // 4)
        
        # Encode internal state - shape: [10]
        internal_features = torch.tensor([
            state.internal.get('r_meta', 0.0),
            state.internal.get('confidence', 1.0),
            len(state.internal.get('correction_history', [])),
            float(state.internal.get('last_correction_successful', False)),
            state.internal.get('turns_since_correction', 0),
        ] + [0.0] * 5, dtype=torch.float32)
        # Project to embedding_dim // 4 - shape: [embedding_dim // 4]
        internal_emb = torch.relu(self.internal_encoder(internal_features))
        
        # Combine all embeddings - all should be 1D tensors now
        # Shapes: [embedding_dim // 2] + [embedding_dim // 4] + [embedding_dim // 4] = [embedding_dim]
        combined = torch.cat([context_emb, internal_emb, history_emb])
        
        # Final projection - input: [embedding_dim], output: [embedding_dim]
        return torch.relu(self.combiner(combined))


@dataclass
class NLPState:
    """
    NLP-specific state representation.
    Extends the base State class with text-specific fields.
    """
    context_text: str  # Current text context
    history: List[str]  # Conversation/interaction history
    task_goal: str  # Task description
    generated_text: str  # Text generated so far
    internal: Dict[str, Any]
    
    def __init__(
        self,
        context_text: str,
        task_goal: str,
        history: Optional[List[str]] = None,
        generated_text: str = "",
        internal: Optional[Dict] = None
    ):
        self.context_text = context_text
        self.task_goal = task_goal
        self.history = history if history is not None else []
        self.generated_text = generated_text
        self.internal = internal if internal is not None else {
            'r_meta': 0.0,
            'correction_history': [],
            'confidence': 1.0,
            'last_correction_successful': False,
            'turns_since_correction': 0
        }


class NLPEnvironment:
    """
    Base class for NLP environments.
    Defines the interface for text-based RL tasks.
    """
    def __init__(self, task_description: str):
        self.task_description = task_description
        self.current_state = None
        self.episode_history = []
    
    def reset(self) -> NLPState:
        """Reset environment and return initial state."""
        raise NotImplementedError
    
    def step(self, action: Any) -> Tuple[NLPState, float, bool, Dict]:
        """
        Execute action and return (next_state, reward, done, info).
        """
        raise NotImplementedError
    
    def render(self):
        """Display current state (optional)."""
        pass


class DialogueEnvironment(NLPEnvironment):
    """
    Environment for dialogue/conversation tasks.
    Goal: Generate helpful, accurate, and safe responses.
    """
    def __init__(
        self,
        user_query: str,
        knowledge_base: Optional[Dict[str, str]] = None,
        max_turns: int = 10
    ):
        super().__init__("Generate helpful dialogue response")
        self.user_query = user_query
        self.knowledge_base = knowledge_base or {}
        self.max_turns = max_turns
        self.turn_count = 0
        self.conversation_history = []
    
    def reset(self) -> NLPState:
        """Initialize dialogue session."""
        self.turn_count = 0
        self.conversation_history = [f"User: {self.user_query}"]
        
        return NLPState(
            context_text=self.user_query,
            task_goal="Generate accurate and helpful response",
            history=self.conversation_history.copy(),
            generated_text=""
        )
    
    def step(self, action: Any) -> Tuple[NLPState, float, bool, Dict]:
        """Process agent's action (text generation or correction)."""
        
        self.turn_count += 1
        info = {}
        r_ext = 0.0
        
        if isinstance(action, Action):
            if action.action_type == ActionType.PRIMARY:
                # Primary action: generate text
                generated_text = action.content
                self.conversation_history.append(f"Agent: {generated_text}")
                
                # Compute reward based on response quality
                r_ext = self._evaluate_response(generated_text)
                info['response'] = generated_text
                
            elif action.action_type == ActionType.CORRECTIVE:
                # Corrective action: modify approach
                if action.corrective_type == CorrectiveAction.SEARCH:
                    # Agent searches knowledge base
                    search_results = self._search_knowledge_base(self.user_query)
                    info['search_results'] = search_results
                    r_ext = 0.1  # Small reward for seeking information
                    
                elif action.corrective_type == CorrectiveAction.REPLAN:
                    # Agent replans approach
                    info['action'] = 'replanning'
                    r_ext = 0.0  # Neutral reward
                    
                elif action.corrective_type == CorrectiveAction.CRITIQUE:
                    # Agent critiques own response
                    critique = action.content
                    info['critique'] = critique
                    r_ext = 0.05  # Small reward for self-reflection
                    
                elif action.corrective_type == CorrectiveAction.ASK:
                    # Agent asks clarifying question
                    clarification = action.content
                    info['clarification_request'] = clarification
                    r_ext = 0.15  # Reward for seeking clarity
        
        # Check if episode is done
        done = self.turn_count >= self.max_turns
        
        # Create next state
        next_state = NLPState(
            context_text=self.conversation_history[-1] if self.conversation_history else "",
            task_goal=self.task_description,
            history=self.conversation_history.copy(),
            generated_text=info.get('response', '')
        )
        
        return next_state, r_ext, done, info
    
    def _evaluate_response(self, response: str) -> float:
        """
        Evaluate response quality (R_ext).
        In practice, this would use human feedback or a reward model.
        """
        reward = 0.0
        
        # Simple heuristics (replace with actual reward model)
        if len(response) > 10:  # Non-trivial response
            reward += 0.3
        
        if len(response) < 500:  # Not too verbose
            reward += 0.2
        
        # Check if response addresses the query
        query_words = set(self.user_query.lower().split())
        response_words = set(response.lower().split())
        overlap = len(query_words & response_words) / max(len(query_words), 1)
        reward += overlap * 0.3
        
        # Penalize unsafe content (simple check)
        unsafe_patterns = ['hack', 'steal', 'illegal']
        if any(pattern in response.lower() for pattern in unsafe_patterns):
            reward -= 0.5
        
        return np.clip(reward, -1.0, 1.0)
    
    def _search_knowledge_base(self, query: str) -> str:
        """Search knowledge base for relevant information."""
        # Simple keyword matching
        for key, value in self.knowledge_base.items():
            if key.lower() in query.lower():
                return value
        return "No relevant information found."


class SummarizationEnvironment(NLPEnvironment):
    """
    Environment for text summarization tasks.
    Goal: Generate accurate, concise summaries.
    """
    def __init__(self, source_text: str, max_summary_length: int = 150):
        super().__init__("Generate concise and accurate summary")
        self.source_text = source_text
        self.max_summary_length = max_summary_length
        self.current_summary = ""
    
    def reset(self) -> NLPState:
        """Initialize summarization task."""
        self.current_summary = ""
        
        return NLPState(
            context_text=self.source_text,
            task_goal="Generate accurate summary",
            history=[],
            generated_text=""
        )
    
    def step(self, action: Any) -> Tuple[NLPState, float, bool, Dict]:
        """Process summarization action."""
        
        info = {}
        done = False
        r_ext = 0.0
        
        if isinstance(action, Action):
            if action.action_type == ActionType.PRIMARY:
                # Generate or extend summary
                summary_text = action.content
                self.current_summary = summary_text
                
                # Evaluate summary quality
                r_ext = self._evaluate_summary(summary_text)
                info['summary'] = summary_text
                done = True  # Single-step task
                
            elif action.action_type == ActionType.CORRECTIVE:
                if action.corrective_type == CorrectiveAction.REPLAN:
                    # Start over with new approach
                    self.current_summary = ""
                    r_ext = 0.0
                    
                elif action.corrective_type == CorrectiveAction.CRITIQUE:
                    # Self-critique current summary
                    critique = self._generate_critique()
                    info['critique'] = critique
                    r_ext = 0.1
        
        next_state = NLPState(
            context_text=self.source_text,
            task_goal=self.task_description,
            history=[self.current_summary] if self.current_summary else [],
            generated_text=self.current_summary
        )
        
        return next_state, r_ext, done, info
    
    def _evaluate_summary(self, summary: str) -> float:
        """
        Evaluate summary quality using ROUGE-like heuristics.
        In practice, use actual ROUGE or learned reward model.
        """
        reward = 0.0
        
        # Length penalty
        if len(summary) <= self.max_summary_length:
            reward += 0.3
        else:
            reward -= 0.2
        
        # Coverage: check if key concepts are included
        source_words = set(self.source_text.lower().split())
        summary_words = set(summary.lower().split())
        
        # Simple n-gram overlap (simplified ROUGE)
        overlap = len(source_words & summary_words) / max(len(source_words), 1)
        reward += overlap * 0.7
        
        return np.clip(reward, -1.0, 1.0)
    
    def _generate_critique(self) -> str:
        """Generate self-critique of current summary."""
        if not self.current_summary:
            return "Summary is empty. Need to generate content."
        if len(self.current_summary) > self.max_summary_length:
            return "Summary is too long. Need to condense."
        return "Summary structure looks reasonable. Check for key details."


class NLPMetaRewardGenerator(nn.Module):
    """
    NLP-specific meta-reward generator.
    Evaluates text quality, coherence, and safety.
    """
    def __init__(self, state_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, state_embedding: torch.Tensor) -> torch.Tensor:
        """
        Returns R_meta score for text generation quality.
        Expects input shape: [state_dim]
        Returns shape: [1] (scalar)
        """
        # Ensure input is 1D, add batch dimension
        if state_embedding.dim() == 1:
            state_embedding = state_embedding.unsqueeze(0)  # [1, state_dim]
        
        output = self.network(state_embedding)  # [1, 1]
        return output.squeeze()  # scalar


class TextGeneratorWrapper:
    """
    Wrapper for text generation models (GPT-2, T5, etc.)
    Lazy loads the model only when first used.
    """
    def __init__(self, model_name: str = "gpt2"):
        self.model_name = model_name
        self._model = None
        self._tokenizer = None
    
    def load(self):
        """Lazy load the model and tokenizer."""
        if self._model is None:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            print(f"Loading {self.model_name}...")
            self._model = AutoModelForCausalLM.from_pretrained(self.model_name)
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self._tokenizer.pad_token = self._tokenizer.eos_token
            print("Model loaded successfully!")
    
    def generate(self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.7) -> str:
        """Generate text from prompt."""
        self.load()
        
        inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        outputs = self._model.generate(
            inputs.input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            pad_token_id=self._tokenizer.eos_token_id,
            attention_mask=inputs.attention_mask
        )
        
        generated = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the prompt from the output
        if generated.startswith(prompt):
            generated = generated[len(prompt):].strip()
        
        return generated


class SCARLNLPAgent:
    """
    SCARL agent specifically configured for NLP tasks.
    Integrates language model with SCARL's self-correction mechanism.
    """
    def __init__(
        self,
        model_name: str = "bert-base-uncased",
        embedding_dim: int = 256,
        lambda_introspection: float = 0.3,
        r_meta_threshold: float = 0.5,
        use_real_generator: bool = False,
        generator_model: str = "gpt2"
    ):
        
        self.embedding_dim = embedding_dim
        self.lambda_introspection = lambda_introspection
        self.r_meta_threshold = r_meta_threshold
        self.use_real_generator = use_real_generator
        
        # NLP-specific components
        self.state_encoder = NLPStateEncoder(model_name, embedding_dim)
        self.meta_reward_gen = NLPMetaRewardGenerator(embedding_dim)
        
        # Policies (using base SCARL policies)
        self.primary_policy = PrimaryPolicy(embedding_dim, action_dim=256)
        self.corrective_policy = CorrectivePolicy(embedding_dim)
        
        # Text generation model (optional)
        if use_real_generator:
            self.text_generator = TextGeneratorWrapper(generator_model)
        else:
            self.text_generator = None
    
    def select_action(self, state: NLPState) -> Tuple[Any, float]:
        """Select action based on NLP state."""
        
        # Encode state - returns 1D tensor [embedding_dim]
        state_emb = self.state_encoder(state)
        
        # Generate meta-reward - input is 1D, output is scalar
        with torch.no_grad():
            r_meta_tensor = self.meta_reward_gen(state_emb)
            # Ensure it's a scalar
            if r_meta_tensor.dim() > 0:
                r_meta = r_meta_tensor.item()
            else:
                r_meta = float(r_meta_tensor)
        
        # Update internal state
        state.internal['r_meta'] = r_meta
        state.internal['turns_since_correction'] += 1
        
        # Policy switching
        if r_meta < self.r_meta_threshold:
            # Low confidence - trigger correction
            corrective_action = self.corrective_policy.select_action(state_emb.unsqueeze(0))
            state.internal['correction_history'].append({
                'turn': len(state.history),
                'r_meta': r_meta,
                'action': corrective_action.value
            })
            state.internal['turns_since_correction'] = 0
            
            return Action(
                action_type=ActionType.CORRECTIVE,
                content=self._execute_corrective_action(corrective_action, state),
                corrective_type=corrective_action
            ), r_meta
        else:
            # High confidence - proceed with primary action
            generated_text = self._generate_text(state)
            return Action(
                action_type=ActionType.PRIMARY,
                content=generated_text,
                corrective_type=None
            ), r_meta
    
    def _generate_text(self, state: NLPState) -> str:
        """
        Generate text using language model.
        Uses real LLM if configured, otherwise returns placeholder.
        """
        if self.use_real_generator and self.text_generator:
            prompt = f"Context: {state.context_text}\n\nProvide a helpful response:"
            return self.text_generator.generate(prompt, max_new_tokens=100)
        else:
            # Placeholder for demo
            return f"Generated response based on: {state.context_text[:50]}..."
    
    def _execute_corrective_action(
        self,
        action: 'CorrectiveAction',
        state: NLPState
    ) -> str:
        """Execute specific corrective action."""        
        if action == CorrectiveAction.SEARCH:
            return f"Searching for information about: {state.task_goal}"
        elif action == CorrectiveAction.REPLAN:
            return "Replanning approach to task"
        elif action == CorrectiveAction.CRITIQUE:
            return f"Critiquing current approach: {state.generated_text[:50]}..."
        elif action == CorrectiveAction.ASK:
            return "Could you provide more details about what you're looking for?"
        return ""


# Example integration demonstration
if __name__ == "__main__":
    print("SCARL NLP Integration Example")
    print("=" * 60)
    
    # Setup 1: Dialogue Environment
    print("\n1. DIALOGUE TASK EXAMPLE")
    print("-" * 60)
    
    knowledge_base = {
        "python": "Python is a high-level programming language known for simplicity.",
        "machine learning": "ML is a subset of AI focused on learning from data."
    }
    
    dialogue_env = DialogueEnvironment(
        user_query="Can you explain what Python is used for?",
        knowledge_base=knowledge_base,
        max_turns=5
    )
    
    agent = SCARLNLPAgent(
        lambda_introspection=0.3,
        r_meta_threshold=0.5,
        use_real_generator=True
    )
    
    state = dialogue_env.reset()
    print(f"User Query: {state.context_text}")
    print(f"Task Goal: {state.task_goal}")
    
    # Simulate one interaction
    action, r_meta = agent.select_action(state)
    next_state, reward, done, info = dialogue_env.step(action)
    
    print(f"\nAgent Action Type: {action.action_type.value}")
    print(f"R_meta Score: {r_meta:.3f}")
    print(f"External Reward: {reward:.3f}")
    if action.corrective_type:
        print(f"Corrective Action: {action.corrective_type.value}")

    
    # Setup 2: Summarization Environment
    print("\n\n2. SUMMARIZATION TASK EXAMPLE")
    print("-" * 60)
    
    source_text = """
    Artificial intelligence has revolutionized many industries. Machine learning,
    a subset of AI, enables computers to learn from data without explicit programming.
    Deep learning, using neural networks, has achieved remarkable results in image
    recognition, natural language processing, and game playing.
    """
    
    summ_env = SummarizationEnvironment(
        source_text=source_text,
        max_summary_length=100
    )
    
    state = summ_env.reset()
    print(f"Source Text Length: {len(source_text)} characters")
    print(f"Task: {state.task_goal}")
    
    action, r_meta = agent.select_action(state)
    print(f"\nR_meta Score: {r_meta:.3f}")
    print(f"Action Type: {action.action_type.value}")
    
    print("\n" + "=" * 60)
    print("Integration complete! Ready for training.")

SCARL NLP Integration Example

1. DIALOGUE TASK EXAMPLE
------------------------------------------------------------
User Query: Can you explain what Python is used for?
Task Goal: Generate accurate and helpful response
Loading gpt2...
Model loaded successfully!

Agent Action Type: primary
R_meta Score: 0.505
External Reward: 0.575

------------------------------------------------------------
To use real GPT-2 generation, create agent with:
  agent = SCARLNLPAgent(use_real_generator=True)
This will load GPT-2 on first text generation call.


2. SUMMARIZATION TASK EXAMPLE
------------------------------------------------------------
Source Text Length: 321 characters
Task: Generate accurate summary

R_meta Score: 0.506
Action Type: primary

Integration complete! Ready for training.

Next steps:
1. Replace text generation placeholder with actual LLM (GPT, T5, etc.)
2. Implement proper reward models for task evaluation
3. Train meta-reward generator on diverse NLP tasks
4. Collect self-cor