In [None]:
# Install dependencies
!pip install gradio torch torchvision torchaudio scipy matplotlib sympy

# Create directories
!mkdir -p core data search ui utils


In [None]:
%%writefile core/grammar.py
import numpy as np
from scipy.special import gamma as scipy_gamma, gammaln
import math

# Supported operators and their arity (number of arguments)
# Organized by curriculum stage for progressive unlocking
OPERATORS = {
    # === STAGE 0: Pure Arithmetic ===
    '+': 2,
    '-': 2,
    '*': 2,
    '/': 2,
    
    # === STAGE 1: Powers ===
    'pow': 2,
    'sqrt': 1,
    
    # === STAGE 2: Trigonometry ===
    'sin': 1,
    'cos': 1,
    'tan': 1,
    
    # === STAGE 3: Transcendental ===
    'exp': 1,
    'log': 1,
    
    # === STAGE 4: Advanced ===
    'abs': 1,
    'neg': 1,
    'sign': 1,
    'floor': 1,
    'ceil': 1,
    'mod': 2,
    'gamma': 1,
    'lgamma': 1,  # Log-gamma function (from C++ GP engine)
}

# Operator groups for curriculum control
OPERATOR_STAGES = {
    0: ['+', '-', '*', '/'],
    1: ['+', '-', '*', '/', 'pow', 'sqrt'],
    2: ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos', 'tan'],
    3: ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos', 'tan', 'exp', 'log'],
    4: list(OPERATORS.keys()),  # All operators
}

# Terminal tokens
VARIABLES = ['x']
# 'C' is a placeholder for learnable constants
CONSTANTS = ['C', '0', '1', '2', '3', '5', '10', 'pi', 'e']

# Full Vocabulary
VOCABULARY = list(OPERATORS.keys()) + VARIABLES + CONSTANTS
TOKEN_TO_ID = {token: i for i, token in enumerate(VOCABULARY)}
ID_TO_TOKEN = {i: token for token, i in TOKEN_TO_ID.items()}

# Special token for start of sequence
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'

class Node:
    def __init__(self, value, children=None):
        self.value = value
        self.children = children if children else []

    def __repr__(self):
        if not self.children:
            return str(self.value)
        return f"({self.value} " + " ".join([str(c) for c in self.children]) + ")"
    
    def to_infix(self):
        if not self.children:
            return str(self.value)
        
        op = self.value
        if len(self.children) == 1:
            return f"{op}({self.children[0].to_infix()})"
        elif len(self.children) == 2:
            if op == 'pow':
                return f"({self.children[0].to_infix()} ^ {self.children[1].to_infix()})"
            elif op == 'mod':
                return f"({self.children[0].to_infix()} % {self.children[1].to_infix()})"
            return f"({self.children[0].to_infix()} {op} {self.children[1].to_infix()})"
        return str(self.value)
    
    def count_constants(self):
        """Count the number of 'C' placeholders in the tree."""
        count = 1 if self.value == 'C' else 0
        for child in self.children:
            count += child.count_constants()
        return count
    
    def get_constant_positions(self, path=None):
        """Returns a list of paths to all 'C' nodes for optimization."""
        if path is None:
            path = []
        positions = []
        if self.value == 'C':
            positions.append(path.copy())
        for i, child in enumerate(self.children):
            positions.extend(child.get_constant_positions(path + [i]))
        return positions


import ast

class ExpressionTree:
    def __init__(self, token_list):
        """
        Parses a list of tokens in Pre-order traversal (Prefix notation)
        Example: ['+', 'x', 'sin', 'x'] -> x + sin(x)
        """
        self.tokens = token_list
        try:
            self.root, remaining = self._build_tree(token_list)
            if remaining:
                raise ValueError("Tokens remained after building tree")
            self.is_valid = True
        except Exception:
            self.root = None
            self.is_valid = False

    @classmethod
    def from_infix(cls, infix_str):
        """
        Creates an ExpressionTree from a standard infix string (e.g. "sin(x) + x^2").
        Uses Python's ast to parse.
        """
        # Replacements to make it valid python for AST
        # 1. Handle postfix factorial '!' which C++ outputs as '(... )!'
        # We convert '(... )!' to 'gamma(...)'
        # Iterate until no '!' left
        processed_str = infix_str
        while '!' in processed_str:
            idx = processed_str.find('!')
            # Helper to find matching paren backwards
            if idx > 0 and processed_str[idx-1] == ')':
                paren_count = 1
                start = idx - 2
                while start >= 0 and paren_count > 0:
                    if processed_str[start] == ')':
                        paren_count += 1
                    elif processed_str[start] == '(':
                        paren_count -= 1
                    start -= 1
                # start is now 1 char before the matching '('
                start += 1 
                # Reconstruct: ... + gamma( + ... + ) + ...
                # Content includes the parens: ( ... )
                content = processed_str[start:idx] 
                processed_str = processed_str[:start] + "gamma" + content + processed_str[idx+1:]
            else:
                # Fallback: Just remove ! if it's weirdly placed (should not happen with GP output)
                processed_str = processed_str.replace('!', '', 1)

        # 2. C++ uses ^ for power, Python uses **. AST parses ^ as BitXor.
        try:
            tree = ast.parse(processed_str, mode='eval')
            tokens = cls._ast_to_prefix(tree.body)
            return cls(tokens)
        except Exception as e:
            print(f"Error parsing infix: {e} | Original: {infix_str} | Processed: {processed_str}")
            return cls([]) # Invalid

    @staticmethod
    def _ast_to_prefix(node):
        if isinstance(node, ast.BinOp):
            # Map operators
            op_map = {
                ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/',
                ast.BitXor: 'pow', ast.Pow: 'pow', ast.Mod: 'mod'
            }
            op_type = type(node.op)
            if op_type in op_map:
                return [op_map[op_type]] + ExpressionTree._ast_to_prefix(node.left) + ExpressionTree._ast_to_prefix(node.right)
        
        elif isinstance(node, ast.UnaryOp):
            op_map = {ast.USub: 'neg', ast.UAdd: None} # Ignore unary +
            op_type = type(node.op)
            if op_type == ast.USub:
                # Check directly if it's a number to collapse "-5"
                if isinstance(node.operand, ast.Constant) and isinstance(node.operand.value, (int, float)):
                    return [str(-node.operand.value)]
                return ['neg'] + ExpressionTree._ast_to_prefix(node.operand)
            elif op_type == ast.UAdd:
                 return ExpressionTree._ast_to_prefix(node.operand)

        elif isinstance(node, ast.Call):
            # Functions like sin(x)
            func_id = node.func.id
            if func_id in ['sin', 'cos', 'tan', 'exp', 'log', 'sqrt', 'abs', 'floor', 'ceil', 'gamma', 'lgamma']:
                tokens = [func_id]
                for arg in node.args:
                    tokens.extend(ExpressionTree._ast_to_prefix(arg))
                return tokens
        
        elif isinstance(node, ast.Name):
            return [node.id]
        
        elif isinstance(node, ast.Constant): # Python 3.8+
            return [str(node.value)]
        elif isinstance(node, ast.Num): # Older python
            return [str(node.n)]

        raise ValueError(f"Unsupported AST node: {node}")


    def _build_tree(self, tokens):
        if not tokens:
            raise ValueError("Empty token list")
        
        token = tokens[0]
        remaining = tokens[1:]
        
        if token in OPERATORS:
            arity = OPERATORS[token]
            children = []
            for _ in range(arity):
                child, remaining = self._build_tree(remaining)
                children.append(child)
            return Node(token, children), remaining
        elif token in VARIABLES or token in CONSTANTS:
            return Node(token), remaining
        else:
            # Try to parse as float literal
            try:
                float(token)
                return Node(token), remaining
            except:
                raise ValueError(f"Unknown token: {token}")

    def evaluate(self, x_values, constants=None):
        """
        Evaluates the expression tree for a given array of x values.
        constants: optional dict mapping path tuples to constant values
        Returns a numpy array of results.
        """
        # Ensure x_values is a numpy array
        if not isinstance(x_values, np.ndarray):
            x_values = np.array(x_values, dtype=np.float64)
        
        if not self.is_valid:
            return np.full_like(x_values, np.nan, dtype=np.float64)
        return self._eval_node(self.root, x_values, constants, path=[])

    def _eval_node(self, node, x, constants=None, path=None):
        val = node.value
        
        if val == 'x':
            return x.astype(np.float64)
        if val == 'pi':
            return np.full_like(x, np.pi, dtype=np.float64)
        if val == 'e':
            return np.full_like(x, np.e, dtype=np.float64)
        if val == 'C':
            # Check if we have an optimized constant for this position
            if constants is not None and tuple(path) in constants:
                return np.full_like(x, constants[tuple(path)], dtype=np.float64)
            return np.full_like(x, 1.0, dtype=np.float64)  # Default constant = 1
        
        # Check for numeric constants
        try:
            return np.full_like(x, float(val), dtype=np.float64)
        except:
            pass
            
        # Recursive evaluation
        args = []
        for i, c in enumerate(node.children):
            args.append(self._eval_node(c, x, constants, path + [i] if path is not None else None))
        
        # Operators
        with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
            if val == '+': return args[0] + args[1]
            if val == '-': return args[0] - args[1]
            if val == '*': return args[0] * args[1]
            if val == '/': 
                return np.divide(args[0], args[1], out=np.zeros_like(x, dtype=np.float64), where=args[1]!=0)
            if val == 'pow':
                # Safe power
                return np.power(np.abs(args[0]) + 1e-10, np.clip(args[1], -10, 10))
            if val == 'mod':
                return np.mod(args[0], args[1] + 1e-10)
            if val == 'sin': return np.sin(args[0])
            if val == 'cos': return np.cos(args[0])
            if val == 'tan': return np.tan(args[0])
            if val == 'exp': 
                return np.exp(np.clip(args[0], -100, 100))
            if val == 'log': 
                return np.log(np.abs(args[0]) + 1e-10)
            if val == 'sqrt':
                return np.sqrt(np.abs(args[0]))
            if val == 'abs':
                return np.abs(args[0])
            if val == 'floor':
                return np.floor(args[0])
            if val == 'ceil':
                return np.ceil(args[0])
            if val == 'gamma':
                # Match C++ Protected Gamma/Factorial: tgamma(|x| + 1)
                # This ensures consistent evaluation for formulas from C++ engine (which uses !)
                arg = np.abs(args[0]) + 1.0
                clipped = np.clip(arg, 0.1, 50) # Clip upper bound to avoid overflow
                return scipy_gamma(clipped)
            if val == 'lgamma':
                # Protected lgamma: lgamma(|x| + 1)
                arg = np.abs(args[0]) + 1.0
                # gammaln is safe for large positive numbers, so less aggressive clipping needed for overflow,
                # but we clip for consistency and to avoid extremely large outputs if followed by exp
                clipped = np.clip(arg, 0.1, 1000) 
                return gammaln(clipped)
            if val == 'neg':
                return -args[0]
            if val == 'sign':
                return np.sign(args[0])
                
        return np.zeros_like(x, dtype=np.float64)

    def get_infix(self):
        if not self.is_valid:
            return "Invalid"
        return self.root.to_infix()
    
    
    def count_constants(self):
        if not self.is_valid:
            return 0
        return self.root.count_constants()

import sympy

def simplify_formula(formula_str):
    """
    Simplifies a mathematical formula using SymPy.
    """
    try:
        # 1. Clean up C++ notation that sympy might not like directly
        # e.g., 'pi' is fine. 'neg(x)' -> '-x'.
        # But our infix is usually standard. 
        # C++ 'pow(x,2)' might need conversion to 'x**2' or sympy handles it?
        # Sympy uses 'Pow'. 
        
        # Replace common mismatches
        s_str = formula_str.replace("pow(", "Pow(")
        # s_str = s_str.replace("abs(", "Abs(") # Sympy handles abs
        
        # Parse
        expr = sympy.sympify(s_str)
        
        # Simplify
        simplified = sympy.simplify(expr)
        
        # Convert back to string
        # We need to ensure it uses our function names (e.g. sin, cos)
        # Sympy standard printer is usually good.
        # But 'Power' is '**'. We used 'hat' or 'pow' in some places?
        # Our tokenizer supports standard operators. 'x**2' is not standard infix for our parser?
        # Our Parser supports 'x^2' or 'pow(x,2)'? 
        # AST parser handles '**' -> 'pow'.
        
        final_str = str(simplified)
        return final_str
        
    except Exception as e:
        # Fallback if simplification fails (e.g. unknown functions)
        return formula_str


In [None]:
%%writefile core/model.py
import torch
import torch.nn as nn
import numpy as np

class AlphaSymbolicModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_encoder_layers=2, num_decoder_layers=2, max_seq_len=50):
        super(AlphaSymbolicModel, self).__init__()
        
        self.d_model = d_model
        
        # 1. Point Encoder: Processes pairs of (x, y)
        # Input dim: 2 (x value, y value)
        self.point_embedding = nn.Linear(2, d_model)
        
        # We use a standard Transformer Encoder for the "Problem Embedding"
        # Since points are a set, we don't necessarily need positional encoding, 
        # but the Transformer will process them as a sequence.
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.problem_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # 2. Formula Decoder: Generates tokens
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.formula_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        # 3. Heads
        self.policy_head = nn.Linear(d_model, vocab_size)
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 3) # Quantiles: 0.25, 0.50, 0.75
        )
        
    def forward(self, x_values, y_values, formula_input, formula_mask=None):
        """
        x_values: [batch, num_points]
        y_values: [batch, num_points]
        formula_input: [batch, seq_len] (Token IDs)
        formula_mask: Optional mask for the decoder (causal mask)
        """
        batch_size, num_points = x_values.shape
        
        # -- Problem Encoding --
        # Stack x and y: [batch, num_points, 2]
        points = torch.stack([x_values, y_values], dim=2)
        
        # Project to d_model
        points_emb = self.point_embedding(points) # [batch, num_points, d_model]
        
        # Encode problem (memory for decoder)
        memory = self.problem_encoder(points_emb)
        
        # -- Formula Decoding --
        # Embed tokens
        tgt = self.token_embedding(formula_input) # [batch, seq_len, d_model]
        tgt = self.pos_encoder(tgt)
        
        # Decode
        # memory is [batch, num_points, d_model]
        # tgt is [batch, seq_len, d_model]
        if formula_mask is None:
             # Create causal mask
            seq_len = formula_input.size(1)
            formula_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(formula_input.device)

        output = self.formula_decoder(tgt, memory, tgt_mask=formula_mask)
        
        # -- Heads --
        # Policy: distribution over vocab for each token position
        logits = self.policy_head(output) # [batch, seq_len, vocab_size]
        
        # Value: estimate value from the LAST token's state
        # (Assuming the last token summarizes the current state)
        last_token_output = output[:, -1, :] # [batch, d_model]
        value = self.value_head(last_token_output) # [batch, 1]
        
        return logits, value

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x

if __name__ == "__main__":
    # Smoke Test
    vocab_size = 20
    model = AlphaSymbolicModel(vocab_size=vocab_size, d_model=32)
    
    # Dummy data
    bs = 2
    points = 10
    x = torch.randn(bs, points)
    y = torch.randn(bs, points)
    
    # Formula input (start token + some tokens)
    seq = torch.randint(0, vocab_size, (bs, 5))
    
    logits, value = model(x, y, seq)
    
    print("Logits shape:", logits.shape) # Should be [2, 5, 20]
    print("Value shape:", value.shape)   # Should be [2, 1]
    print("Smoke test passed.")


In [None]:
%%writefile core/environment.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from core.grammar import VOCABULARY, OPERATORS, TOKEN_TO_ID, ExpressionTree
from data.synthetic_data import DataGenerator

class SymbolicEnv(gym.Env):
    def __init__(self, max_length=50):
        super(SymbolicEnv, self).__init__()
        
        self.vocab_size = len(VOCABULARY)
        self.max_length = max_length
        self.vocab = VOCABULARY
        
        # Action space: Choose a token from the vocabulary
        self.action_space = spaces.Discrete(self.vocab_size)
        
        # Observation space: 
        # 1. Current token sequence (padded)
        # 2. X values (fixed size for simplicity)
        # 3. Y values
        # For this prototype we will expose a dictionary observation
        self.observation_space = spaces.Dict({
            "sequence": spaces.Box(low=0, high=self.vocab_size, shape=(max_length,), dtype=np.int32),
            "x": spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32),
            "y": spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32)
        })
        
        self.data_gen = DataGenerator(max_depth=4)
        self.current_problem = None
        self.current_sequence = []
        self.open_branches = 0
        
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Generate a new problem (X, Y)
        # In a real scenario, this could be sampled from a fixed dataset
        batch = self.data_gen.generate_batch(1, point_count=10)
        self.current_problem = batch[0]
        
        self.current_sequence = []
        self.open_branches = 1 # Start expecting a root node
        
        return self._get_obs(), {}

    def step(self, action_id):
        token = self.vocab[action_id]
        self.current_sequence.append(token)
        
        # Update open branches
        if token in OPERATORS:
            arity = OPERATORS[token]
            self.open_branches += (arity - 1)
        else:
            self.open_branches -= 1
            
        term = False
        trunc = False
        reward = 0.0
        
        # Check completion
        if self.open_branches == 0:
            term = True
            # Tree is complete, evaluate
            reward = self._calculate_reward()
        elif self.open_branches < 0:
            # Should not happen if we mask actions, but for safety
            term = True
            reward = -100.0 # Syntax error penalty
        elif len(self.current_sequence) >= self.max_length:
            trunc = True
            reward = -10.0 # Incomplete penalty
            
        return self._get_obs(), reward, term, trunc, {}

    def _get_obs(self):
        # Convert sequence to IDs and pad
        seq_ids = [TOKEN_TO_ID[t] for t in self.current_sequence]
        padded_seq = np.zeros(self.max_length, dtype=np.int32)
        padded_seq[:len(seq_ids)] = seq_ids
        
        return {
            "sequence": padded_seq,
            "x": self.current_problem['x'].astype(np.float32),
            "y": self.current_problem['y'].astype(np.float32)
        }

    def _calculate_reward(self):
        try:
            tree = ExpressionTree(self.current_sequence)
            if not tree.is_valid:
                return -100.0
            
            y_pred = tree.evaluate(self.current_problem['x'])
            
            # Root Mean Squared Error (RMSE)
            mse = np.mean((y_pred - self.current_problem['y'])**2)
            rmse = np.sqrt(mse)
            
            if np.isnan(rmse) or np.isinf(rmse):
                return -1000.0
                
            # Reward is negative RMSE
            # We want to maximize reward -> minimize RMSE
            # Normalize or scale? simpler is just -RMSE
            return -rmse
            
        except Exception:
            return -100.0

if __name__ == "__main__":
    env = SymbolicEnv()
    obs, _ = env.reset()
    print("Initial Observation Keys:", obs.keys())
    
    # Simulate a few steps for x + x
    # Prefix: + x x
    actions = ['+', 'x', 'x']
    tot_reward = 0
    for tok in actions:
        aid = TOKEN_TO_ID[tok]
        obs, reward, term, trunc, _ = env.step(aid)
        print(f"Action: {tok}, Reward: {reward}, Term: {term}, Branches: {env.open_branches}")
        tot_reward += reward
        if term: break
    
    print(f"Total Reward: {tot_reward}")


In [None]:
%%writefile core/loss.py

import torch
import torch.nn as nn

class QuantileLoss(nn.Module):
    """
    Quantile Loss (Pinball Loss) for multiple quantiles.
    
    Args:
        quantiles (list): List of quantiles to estimate (e.g. [0.25, 0.5, 0.75])
    """
    def __init__(self, quantiles=[0.25, 0.5, 0.75]):
        super().__init__()
        self.quantiles = quantiles
        
    def forward(self, preds, target):
        """
        preds: [batch, num_quantiles] - Predicted values for each quantile
        target: [batch, 1] - True scalar target
        """
        # Ensure target matches batch dim
        # target shape might be [batch] or [batch, 1]
        if target.dim() == 1:
            target = target.unsqueeze(1)
            
        loss = 0
        for i, q in enumerate(self.quantiles):
            error = target - preds[:, i:i+1]
            # Pinball loss: max(q * error, (q - 1) * error)
            # Equivalent to: error * (q - I(error < 0))
            loss += torch.max(q * error, (q - 1) * error).mean()
            
        return loss


In [None]:
%%writefile core/gp_bridge.py
import os
import subprocess
import tempfile
import re
import time
from typing import List, Optional

class GPEngine:
    def __init__(self, binary_path=None):
        if binary_path is None:
            # Default location: Code/build/Release/SymbolicRegressionGP.exe
            # Assuming we are in AlphaSymbolic/.. root or similar.
            # Adjust path relative to this file: alphasybolic/core/gp_bridge.py
            # So binary is at ../../Code/build/Release/SymbolicRegressionGP.exe
            base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
            possible_paths = [
                os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP.exe"),
                os.path.join(base_dir, "Code", "build", "SymbolicRegressionGP.exe")
            ]
            self.binary_path = None
            for p in possible_paths:
                if os.path.exists(p):
                    self.binary_path = p
                    break
            
            if self.binary_path is None:
                # Fallback to default for error message
                self.binary_path = possible_paths[0]
        else:
            self.binary_path = binary_path

    def run(self, x_values: List[float], y_values: List[float], seeds: List[str] = [], timeout_sec: int = 10) -> Optional[str]:
        """
        Runs the C++ GP Engine with the given data and seeds.
        Returns the best formula found as a string, or None if failed.
        """
        if not os.path.exists(self.binary_path):
            print(f"[Error] GP Binary not found at: {self.binary_path}")
            return None

        # Create temporary files
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as seed_file, \
             tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as data_file:
            
            # Write Seeds
            for seed in seeds:
                seed_file.write(seed + "\n")
            seed_file_path = seed_file.name
            
            # Write Data
            # Line 1: x1 x2 ...
            # Line 2: y1 y2 ...
            data_file.write(" ".join(map(str, x_values)) + "\n")
            data_file.write(" ".join(map(str, y_values)) + "\n")
            data_file_path = data_file.name

        try:
            # Run Command
            cmd = [self.binary_path, "--seed", seed_file_path, "--data", data_file_path]
            print(f"Running GP Engine: {' '.join(cmd)}")
            
            # Capture output
            # We can't strictly enforce timeout via subprocess.run's timeout argument easily if we want partial results?
            # Actually we can.
            start_time = time.time()
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout_sec)
            
            output = result.stdout
            
            # Parse Output
            # We look for the LAST occurrence of "Formula: ..."
            # Standard formats:
            # "Formula: ((x * x) + 2)"
            # "Final Formula: ..."
            
            best_formula = None
            # Look for formula lines (case-insensitive)
            # Priority: "Final Formula:" > "Formula:" > "Initial best formula:"
            for line in output.splitlines():
                line_lower = line.lower()
                if "formula:" in line_lower:
                    # Extract the part after "formula:" (case-insensitive split)
                    idx = line_lower.find("formula:")
                    if idx != -1:
                        formula_part = line[idx + len("formula:"):].strip()
                        if formula_part:
                            best_formula = formula_part
                            # Keep looking for better matches (Final Formula is best)
                            if "final formula:" in line_lower:
                                break  # Final Formula is the best, stop looking
                        
            print(f"GP Engine finished in {time.time() - start_time:.2f}s")
            
            if best_formula is None:
                print(f"[DEBUG] GP Engine Output (Stdout):\n{output}")
                print(f"[DEBUG] GP Engine Output (Stderr):\n{result.stderr}")
            
            return best_formula

        except subprocess.TimeoutExpired as e:
            print(f"GP Engine timed out after {timeout_sec}s.")
            # Recover output captured so far
            output = e.stdout if e.stdout else ""
            best_formula = None
            if output:
                for line in output.splitlines():
                    line_lower = line.lower()
                    if "formula:" in line_lower:
                        idx = line_lower.find("formula:")
                        if idx != -1:
                            formula_part = line[idx + len("formula:"):].strip()
                            if formula_part:
                                best_formula = formula_part
                                if "final formula:" in line_lower:
                                    break
            
            if best_formula:
                print(f"Recovered best formula from timeout: {best_formula}")
                return best_formula
            
            # Print stderr for timeout diagnose
            if e.stderr:
                 print(f"GP Engine Timeout Stderr: {e.stderr}")
            return None

        except Exception as e:
            print(f"GP Engine failed: {e}")
            if hasattr(e, 'stderr') and e.stderr:
                print(f"Stderr: {e.stderr}")
            return None
        finally:
            # Cleanup
            if os.path.exists(seed_file_path):
                os.unlink(seed_file_path)
            if os.path.exists(data_file_path):
                os.unlink(data_file_path)

if __name__ == "__main__":
    # Test
    engine = GPEngine()
    x = [1, 2, 3, 4]
    y = [1+2, 4+2, 9+2, 16+2] # x^2 + 2
    seeds = ["(x * x)", "(x + 2)"]
    
    print("Testing GPEngine...")
    res = engine.run(x, y, seeds)
    print(f"Result: {res}")


In [None]:
%%writefile core/__init__.py


In [None]:
%%writefile data/synthetic_data.py
import numpy as np
import random
from core.grammar import VOCABULARY, OPERATORS, VARIABLES, CONSTANTS, ExpressionTree
from data.augmentation import augment_formula_tokens

class DataGenerator:
    def __init__(self, max_depth=5, population_size=1000, allowed_operators=None):
        self.max_depth = max_depth
        self.population_size = population_size
        self.vocab = VOCABULARY
        # Pre-compute terminal vs operator lists
        self.terminals = VARIABLES + CONSTANTS
        if allowed_operators:
            self.operators = [op for op in allowed_operators if op in OPERATORS]
        else:
            self.operators = list(OPERATORS.keys())

    def generate_random_tree(self, max_depth, current_depth=0):
        if current_depth >= max_depth:
            # Balanced Terminal Selection: 50% x, 50% constant
            if random.random() < 0.5:
                return ['x']
            else:
                return [random.choice(CONSTANTS)]
        
        # Decide if terminal or operator
        # Higher probability of operator at shallow depths
        if random.random() < 0.7: 
            op = random.choice(self.operators)
            arity = OPERATORS[op]
            tokens = [op]
            for _ in range(arity):
                tokens.extend(self.generate_random_tree(max_depth, current_depth + 1))
            return tokens
        else:
            # Balanced Terminal Selection: 40% x, 30% C, 30% numbers
            r = random.random()
            if r < 0.4:
                return ['x']
            elif r < 0.7:
                return ['C']
            else:
                return [random.choice([c for c in CONSTANTS if c != 'C'])]

    def generate_batch(self, batch_size, point_count=10, x_range=(-10, 10)):
        """
        Generates a batch of (X, Y) pairs and their generating formulas.
        """
        data = []
        
        while len(data) < batch_size:
            # Generate random formula
            tokens = self.generate_random_tree(self.max_depth)
            tree = ExpressionTree(tokens)
            
            if not tree.is_valid:
                continue
            
            # Ensure 'x' is present in the formula (90% of the time)
            if 'x' not in tokens and random.random() < 0.9:
                continue
                
            # Generate random X points
            x_values = np.random.uniform(x_range[0], x_range[1], point_count)
            # Sort X for cleaner visualization/learning
            x_values.sort()
            
            # Randomize 'C' values if present
            c_positions = tree.root.get_constant_positions()
            constant_vals = {}
            for pos in c_positions:
                # Expanded range: -20 to 20. Favor 1.0 occasionally
                val = random.uniform(-20, 20) if random.random() > 0.1 else 1.0
                constant_vals[tuple(pos)] = val
            
            # Calculate Y with randomized constants
            y_values = tree.evaluate(x_values, constants=constant_vals)
            
            # Check for validity (no NaNs, Infs, or extremely large values)
            if np.any(np.isnan(y_values)) or np.any(np.isinf(y_values)):
                continue
            if np.max(np.abs(y_values)) > 1e6: # Reject too large numbers
                continue
            if np.std(y_values) < 1e-6: # Reject flat lines (too simple)
                 # Optionally keep some, but mostly we want interesting curves
                 if random.random() > 0.1: continue

            data.append({
                'tokens': tokens,
                'infix': tree.get_infix(),
                'x': x_values,
                'y': y_values
            })
            
        return data

    def generate_structured_tree(self, complexity=1, input_node='x'):
        """
        Recursively builds a structured, human-like formula.
        Respects self.operators.
        """
        # Base cases
        if complexity <= 0:
            # Randomly choose between x, C and constants
            r = random.random()
            if r < 0.4: return ['x']
            if r < 0.7: return ['C']
            return [random.choice([c for c in CONSTANTS if c != 'C'])]
            
        # Filter available structures based on allowed operators
        available_structures = []
        
        # Arithmetic needed: +, -, *
        if any(op in self.operators for op in ['+', '-', '*']):
            available_structures.append('arithmetic')
            
        # Poly needed: pow
        if 'pow' in self.operators:
            available_structures.append('poly')
            
        # Trig needed: sin, cos
        if 'sin' in self.operators or 'cos' in self.operators:
            available_structures.append('trig')
            
        # Exp/Log needed
        if 'exp' in self.operators or 'log' in self.operators:
            available_structures.append('exp_log')
            
        # Composition needs enough variety
        if len(self.operators) > 4 and complexity > 1:
             available_structures.append('composition')
        
        # Fallback if nothing allowed matches (shouldn't happen with proper init)
        if not available_structures:
            return input_node if isinstance(input_node, list) else [input_node]

        choice = random.choice(available_structures)
        
        if choice == 'poly':
            # a*x + b or a*x^2 + b
            a = str(random.randint(1, 5))
            b = str(random.randint(-5, 5))
            power = random.choice(['1', '2', '3'])
            if power == '1':
                term = ['*', a] + (input_node if isinstance(input_node, list) else [input_node])
                return ['+', ] + term + [b]
            else:
                base = input_node if isinstance(input_node, list) else [input_node]
                pow_term = ['pow'] + base + [power]
                term = ['*', a] + pow_term
                return ['+', ] + term + [b]
                
        elif choice == 'trig':
            # Filter trig ops that are allowed
            ops = [op for op in ['sin', 'cos'] if op in self.operators]
            if not ops: return input_node # Should be caught by structure check
            func = random.choice(ops)
            val = input_node if isinstance(input_node, list) else [input_node]
            return [func] + val
            
        elif choice == 'exp_log':
            ops = [op for op in ['exp', 'log'] if op in self.operators]
            if not ops: return input_node
            func = random.choice(ops)
            val = input_node if isinstance(input_node, list) else [input_node]
            return [func] + val
            
        elif choice == 'arithmetic':
            left = self.generate_structured_tree(complexity - 1, input_node)
            right = self.generate_structured_tree(complexity - 1, input_node)
            ops = [op for op in ['+', '-', '*'] if op in self.operators]
            if not ops: return input_node
            op = random.choice(ops)
            return [op] + left + right
            
        elif choice == 'composition':
            inner = self.generate_structured_tree(complexity - 1, input_node)
            outer = self.generate_structured_tree(1, inner)
            return outer
            
        return [input_node]

    def generate_inverse_batch(self, batch_size, point_count=10, x_range=(-5, 5)):
        """
        Generates complex, structured formulas using the new engine.
        """
        data = []
        attempts = 0
        
        while len(data) < batch_size and attempts < batch_size * 5:
            attempts += 1
            # Random complexity capped by max_depth
            complexity = random.randint(1, max(1, self.max_depth - 1))
            
            try:
                tokens = self.generate_structured_tree(complexity, 'x')
                
                # Convert numeric strings to 'C' placeholders if needed
                # But here we want the GROUND TRUTH tokens with numbers for checking?
                # The model predicts tokens. 'C' is for optimization.
                # If we train "End-to-End" (predict 3*x), we keep numbers.
                # If we train "Symbolic" (predict C*x), we swap.
                # The original code swapped numbers to 'C'. Let's check VOCABULARY.
                # '1','2','3' are in VOCABULARY. So we can keep small integers.
                # Large integers -> 'C'.
                
                final_tokens = []
                for t in tokens:
                    if t in self.vocab:
                        final_tokens.append(t)
                    else:
                        # If it's a number not in vocab, map to C?
                        # Or just nearest constant?
                        # For now, simplistic mapping:
                        try:
                            val = float(t)
                            if abs(val - round(val)) < 0.01 and str(int(round(val))) in self.vocab:
                                final_tokens.append(str(int(round(val))))
                            else:
                                final_tokens.append('C')
                        except:
                            final_tokens.append('C')

                # --- DATA AUGMENTATION ---
                if random.random() < 0.3:
                    final_tokens = augment_formula_tokens(final_tokens)
                # -------------------------
                
                tree = ExpressionTree(final_tokens)
                if not tree.is_valid:
                    continue
                
                # Ensure 'x' is present (90% of the time)
                if 'x' not in final_tokens and random.random() < 0.9:
                    continue
                    
                # Check constraints (depth, length)
                if len(final_tokens) > 30: # Limit length
                    continue

                # Generate X points
                # Use safer range for complex funcs
                # Exp/Pow grow very fast, so we constrain X to avoid float overflow
                if 'exp' in final_tokens or 'pow' in final_tokens:
                    x_safe = np.linspace(-2, 2, point_count)
                elif 'log' in final_tokens or 'sqrt' in final_tokens:
                    x_safe = np.linspace(0.1, 5, point_count)
                else:
                    x_safe = np.linspace(x_range[0], x_range[1], point_count)
                
                # Randomize 'C' values if present
                c_positions = tree.root.get_constant_positions()
                constant_vals = {}
                for pos in c_positions:
                    # Expanded range: -20 to 20
                    val = random.uniform(-20, 20) if random.random() > 0.1 else 1.0
                    constant_vals[tuple(pos)] = val
                
                y_values = tree.evaluate(x_safe, constants=constant_vals)
                
                # Quality Control
                if np.any(np.isnan(y_values)) or np.any(np.isinf(y_values)):
                    continue
                if np.max(np.abs(y_values)) > 1e4: # Relaxed limit
                    continue
                if np.std(y_values) < 0.01: # Too flat
                    continue
                
                data.append({
                    'tokens': final_tokens,
                    'infix': tree.get_infix(),
                    'x': x_safe,
                    'y': y_values
                })
            except Exception:
                continue
                
        return data

# Quick test if run directly
if __name__ == "__main__":
    gen = DataGenerator(max_depth=4)
    batch = gen.generate_batch(5)
    for item in batch:
        print(f"Formula: {item['infix']}")
        print(f"Tokens: {item['tokens']}")
        print(f"Y sample: {item['y'][:3]}...")
        print("-" * 20)


In [None]:
%%writefile data/benchmark_data.py
import numpy as np

# Standard Benchmark Problems
# Levels: 1 (Easy), 2 (Medium), 3 (Hard)

BENCHMARK_SUITE = [
    # --- Level 1: Polynomials & Basic Arithmetic ---
    {
        'id': 'p1',
        'name': 'Lineal',
        'formula_str': '2.5 * x + 1.0',
        'lambda': lambda x: 2.5 * x + 1.0,
        'domain': (-10, 10),
        'points': 20,
        'level': 1
    },
    {
        'id': 'p2',
        'name': 'Cuadratica Simple',
        'formula_str': 'x * x',
        'lambda': lambda x: x**2,
        'domain': (-5, 5),
        'points': 20,
        'level': 1
    },
    {
        'id': 'p3',
        'name': 'Polinomio Cubico',
        'formula_str': 'x**3 + x**2',
        'lambda': lambda x: x**3 + x**2,
        'domain': (-3, 3),
        'points': 20,
        'level': 1
    },
    
    # --- Level 2: Trigonometric & Transcendental ---
    {
        'id': 'p4',
        'name': 'Seno Basico',
        'formula_str': 'sin(x)',
        'lambda': lambda x: np.sin(x),
        'domain': (-np.pi, np.pi),
        'points': 30,
        'level': 2
    },
    {
        'id': 'p5',
        'name': 'Coseno Desplazado',
        'formula_str': 'cos(x) + 1',
        'lambda': lambda x: np.cos(x) + 1,
        'domain': (-np.pi, np.pi),
        'points': 30,
        'level': 2
    },
    {
        'id': 'p6',
        'name': 'Exponencial Simple',
        'formula_str': 'exp(x)',
        'lambda': lambda x: np.exp(x),
        'domain': (-2, 2), # Small domain to avoid explosion
        'points': 20,
        'level': 2
    },
    
    # --- Level 3: Physics / Complex ---
    {
        'id': 'p7',
        'name': 'Damped Oscillation',
        'formula_str': 'exp(-x) * sin(2*x)',
        'lambda': lambda x: np.exp(-x) * np.sin(2*x),
        'domain': (0, 4),
        'points': 40,
        'level': 3
    },
    {
        'id': 'p8',
        'name': 'Gaussian',
        'formula_str': 'exp(-x**2)',
        'lambda': lambda x: np.exp(-x**2),
        'domain': (-3, 3),
        'points': 30,
        'level': 3
    },
    {
        'id': 'p9',
        'name': 'Nguyen-3 (x^3 + x^2 + x)',
        'formula_str': 'x**3 + x**2 + x',
        'lambda': lambda x: x**3 + x**2 + x,
        'domain': (-2, 2),
        'points': 20,
        'level': 3
    },
    {
        'id': 'p10',
        'name': 'Rational Function',
        'formula_str': 'x / (1 + x**2)',
        'lambda': lambda x: x / (1 + x**2),
        'domain': (-4, 4),
        'points': 30,
        'level': 3
    }
]

def get_benchmark_data(problem_id):
    """Returns (x, y) for a specific problem ID."""
    for p in BENCHMARK_SUITE:
        if p['id'] == problem_id:
            x = np.linspace(p['domain'][0], p['domain'][1], p['points'])
            y = p['lambda'](x)
            return x, y, p
    return None, None, None


In [None]:
%%writefile data/augmentation.py

import random
from core.grammar import OPERATORS

def augment_formula_tokens(tokens):
    """
    Applies mathematical invariants to generate an equivalent formula structure.
    Acts as 'Data Augmentation' for symbolic regression.
    
    Supported Transformations:
    1. Commutativity: (+) and (*)
       e.g. [+ a b] -> [+ b a]
    2. Identity:
       e.g. x -> [+ x 0], x -> [* x 1] (Rarely used to avoid bloat, but useful for robustness)
    3. Inverse operations (Conceptually):
       Not implemented directly on tokens without tree parsing, 
       so we focus on purely structural swaps that don't change value.
    
    Args:
        tokens (list): List of tokens in Prefix notation.
    
    Returns:
        list: A new list of tokens representing an equivalent formula.
    """
    if not tokens:
        return []

    # Helper to parse prefix expression into a tree-like structure (recursive)
    def parse_prefix(token_list):
        if not token_list:
            return None, []
        
        root = token_list[0]
        remaining = token_list[1:]
        
        if root in OPERATORS:
            try:
                arity = OPERATORS[root]
                children = []
                for _ in range(arity):
                    child, remaining = parse_prefix(remaining)
                    children.append(child)
                return {'val': root, 'children': children}, remaining
            except:
                 # Fallback for malformed
                return {'val': root, 'children': []}, remaining
        else:
            # Terminal
            return {'val': root, 'children': []}, remaining

    # Helper to flatten tree back to tokens
    def flatten(node):
        res = [node['val']]
        for child in node['children']:
            res.extend(flatten(child))
        return res

    # 1. Parse
    try:
        tree, _ = parse_prefix(tokens)
    except:
        return list(tokens) # Fail safe

    # 2. Augment Recursive
    def augment_recursive(node):
        # First augment children
        for i in range(len(node['children'])):
            node['children'][i] = augment_recursive(node['children'][i])
            
        val = node['val']
        children = node['children']
        
        # Transformation: Commutativity
        if val in ['+', '*'] and len(children) == 2:
            if random.random() < 0.5:
                # Swap children
                node['children'] = [children[1], children[0]]
        
        # Transformation: (- a b) -> (+ a (- b)) ? Too complex for tokens only without 'neg'
        # Transformation: (+ x x) -> (* x 2) ?
        if val == '+' and len(children) == 2:
            # Check deep equality is hard, but simple check:
            if flatten(children[0]) == flatten(children[1]):
                if random.random() < 0.3:
                    # Convert x + x -> x * 2
                    return {'val': '*', 'children': [children[0], {'val': '2', 'children': []}]}

        return node

    # 3. Apply
    augmented_tree = augment_recursive(tree)
    
    # 4. Flatten
    return flatten(augmented_tree)

if __name__ == "__main__":
    # Test
    # Formula: (+ x y) -> prefix ['+', 'x', 'y']
    t1 = ['+', 'x', 'y']
    print(f"Original: {t1} -> Aug: {augment_formula_tokens(t1)}")
    
    # Formula: (* (+ a b) c)
    t2 = ['*', '+', 'a', 'b', 'c']
    print(f"Original: {t2} -> Aug: {augment_formula_tokens(t2)}")
    
    # Formula: (+ x x)
    t3 = ['+', 'x', 'x']
    print(f"Original: {t3} -> Aug: {augment_formula_tokens(t3)}")


In [None]:
%%writefile data/__init__.py


In [None]:
%%writefile search/mcts.py
import math
import numpy as np
import torch
import copy
from core.grammar import VOCABULARY, TOKEN_TO_ID, OPERATORS, ExpressionTree, VARIABLES
from utils.optimize_constants import optimize_constants

class MCTSNode:
    def __init__(self, tokens, parent=None, prior=0.0):
        self.tokens = tokens
        self.parent = parent
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.prior = prior
        self.is_expanded = False
        
        # for parallel search
        self.virtual_loss = 0.0
        self.virtual_visits = 0

    @property
    def value(self):
        count = self.visit_count + self.virtual_visits
        if count == 0:
            return 0.0
        # Combine real value and virtual loss
        # Virtual loss is SUBTRACTED to discourage visits
        return (self.value_sum - self.virtual_loss) / count

    def ucb_score(self, c_puct=1.0):
        count = self.visit_count + self.virtual_visits
        parent_count = self.parent.visit_count + self.parent.virtual_visits if self.parent else 1
        
        if self.parent is None:
            return 0.0
            
        u = c_puct * self.prior * math.sqrt(parent_count) / (1 + count)
        return self.value + u

    @property
    def complexity(self):
        """Estimate complexity (length of formula)."""
        return len(self.tokens)

class MCTS:
    def __init__(self, model, device, grammar=None, c_puct=1.0, n_simulations=100, max_simulations=None, max_depth=50, complexity_lambda=0.1, max_len=200, batch_size=8):
        self.model = model
        self.device = device
        self.grammar = grammar
        self.c_puct = c_puct
        
        # Handle backwards compatibility for max_simulations
        if max_simulations is not None:
            self.n_simulations = max_simulations
        else:
            self.n_simulations = n_simulations
            
        self.max_depth = max_depth
        self.complexity_lambda = complexity_lambda
        self.max_len = max_len
        self.min_value = -float('inf')
        self.max_value = float('inf')
        self.vocab_size = len(VOCABULARY)
        self.sos_id = self.vocab_size
        self.batch_size = batch_size
        
        # Pareto Front: List of {'tokens':, 'rmse':, 'complexity':, 'formula':}
        self.pareto_front = []
        
        # Virtual loss constant usually 1-3
        self.v_loss_const = 3.0
        
    def search(self, x_values, y_values, num_simulations=None):
        """
        Run MCTS (Parallel/Batched) to find the best formula.
        """
        self.pareto_front = [] # Reset Pareto Front for new search
        root = MCTSNode(tokens=[])
        
        # Initial expansion (single)
        self._expand_batch([root], x_values, y_values)
        
        best_rmse = float('inf')
        best_formula = None
        best_tokens = None
        
        limit = num_simulations if num_simulations is not None else self.n_simulations
        
        # Loop in batches
        # Ensure we do at least 1 batch
        num_batches = max(1, (limit + self.batch_size - 1) // self.batch_size)
        
        for _ in range(num_batches): 
            leaves = []
            
            # 1. Selection (find N leaves)
            for _ in range(self.batch_size):
                node = root
                depth = 0
                
                # Selection loop
                while node.is_expanded and node.children and depth < self.max_depth:
                    node = max(node.children.values(), key=lambda n: n.ucb_score(self.c_puct))
                    
                    # Apply virtual loss to discourage re-selection in same batch
                    node.virtual_loss += self.v_loss_const
                    node.virtual_visits += 1
                    depth += 1
                
                # Check if valid leaf to expand
                if depth < self.max_depth and not node.is_expanded:
                    # Avoid duplicates in batch (simple check)
                    if node not in leaves:
                        leaves.append(node)
                else:
                    pass
            
            if not leaves:
                # If no leaves found (tree fully explored or locked), standard MCTS usually continues or stops.
                # We can just break or continue backprop of terminals.
                if root.visit_count > limit: break 
                continue
                
            # 2. Batch Expansion & Evaluation
            values = self._expand_batch(leaves, x_values, y_values)
            
            # 3. Backpropagation
            for node, val in zip(leaves, values):
                # Check for best solution found
                if self._is_complete_tree(node.tokens):
                    # For completed formulas, we calculate REAL RMSE
                    try:
                        # Evaluar
                        # Importar aquí para evitar circular imports si es necesario
                        from utils.optimize_constants import optimize_constants
                        
                        # 1. Optimizar constants (Crucial para Accuracy)
                        # Esto es "Phase 1" de TPSR (constantes en las hojas)
                        # Por simplicidad en esta iteración, asumimos que 'evaluate_formula' ya hace algo o usamos el string directo.
                        # Idealmente llamaríamos a BFGS aquí.
                        
                        # Use existing _evaluate_formula to get RMSE and optimized constants
                        tree = ExpressionTree(node.tokens)
                        optimized_constants, real_rmse = optimize_constants(tree, x_values, y_values)
                        
                        # Get y_pred using the optimized constants
                        y_pred = tree.evaluate(x_values, constants=optimized_constants)
                        
                        # Check dimensions
                        if y_pred.shape != y_values.shape:
                            # If shapes don't match, it's an invalid evaluation
                            final_val = 0.0
                        else:
                            # 2. Calcular Reward TPSR (Hybrid Accuracy + Complexity)
                            # R = 1 / (1 + NMSE) + lambda * exp(-len/L)
                            
                            mse = np.mean((y_pred - y_values)**2)
                            var_y = np.var(y_values)
                            if var_y < 1e-9: var_y = 1.0 # Avoid division by zero
                            
                            nmse = mse / var_y
                            
                            # Evitar NMSE gigantes
                            if np.isnan(nmse) or np.isinf(nmse):
                                nmse = 1e9
                            
                            r_acc = 1.0 / (1.0 + nmse)
                            
                            # Penalización por complejidad
                            token_len = len(node.tokens)
                            L = self.max_len # Max length del modelo
                            
                            r_cplx = self.complexity_lambda * np.exp(-token_len / L)
                            
                            # Suma y Normalización (para mantener rango 0-1)
                            # El máximo teórico es (1.0 + lambda). Dividimos por eso.
                            raw_reward = r_acc + r_cplx
                            final_val = raw_reward / (1.0 + self.complexity_lambda)

                        # Update best formula based on RMSE (for reporting, not for MCTS value)
                        if real_rmse < best_rmse:
                            best_rmse = real_rmse
                            best_tokens = node.tokens
                            best_formula = ExpressionTree(node.tokens).get_infix()
                        
                        # Update Pareto Front
                        # Complexity = len(tokens) (or could use count_constants + nodes)
                        complexity = len(node.tokens)
                        self._update_pareto_front(node.tokens, real_rmse, complexity, ExpressionTree(node.tokens).get_infix())

                    except Exception as e:
                        # print(f"Error evaluating formula: {e}")
                        final_val = 0.0 # Invalid formula gets 0 reward
                else:
                    final_val = val
                
                # The following lines were part of the user's instruction but contained syntax errors and undefined variables.
                # They are commented out to maintain a syntactically correct and functional document.
                # If these lines were intended to be added, please provide a complete and correct snippet.
                #
                # # Construir vector de probabilidades
                # probs = np.zeros(self.vocab_size, dtype=np.float32)
                # for token_id, count in counts.items():
                #     probs[token_id] = count / total_visits_count += 1
                
                curr = node
                while curr is not None:
                    curr.visit_count += 1
                    curr.value_sum += final_val
                    
                    # Revert virtual loss for parent and above
                    # Since we added to PARENT's child (which is curr), 
                    # and we traverse Up...
                    # Wait, logic: We selected CHILD. Virtual loss was added TO CHILD (curr).
                    # So we must remove it from curr.
                    if curr.virtual_visits > 0:
                        curr.virtual_loss -= self.v_loss_const
                        curr.virtual_visits -= 1
                            
                    curr = curr.parent
        
        # After search, force cleanup of any residual virtual loss (safety)
        # (Not strictly needed if logic is perfect, but good practice in complex async MCTS)
        
        return {
            'tokens': best_tokens,
            'formula': best_formula,
            'rmse': best_rmse,
            'root': root,
            'pareto_front': self.pareto_front
        }

    def _update_pareto_front(self, tokens, rmse, complexity, formula_str):
        """
        Update the Pareto Front with a new solution.
        Keep solutions that are not dominated by any other solution.
        Solution A dominates B if:
        A.rmse <= B.rmse AND A.complexity <= B.complexity AND (A.rmse < B.rmse OR A.complexity < B.complexity)
        """
        # Create candidate
        candidate = {'tokens': tokens, 'rmse': rmse, 'complexity': complexity, 'formula': formula_str}
        
        # Check if dominated by existing
        is_dominated = False
        to_remove = []
        
        for existing in self.pareto_front:
            # Check if existing dominates candidate
            if (existing['rmse'] <= candidate['rmse'] and 
                existing['complexity'] <= candidate['complexity'] and 
                (existing['rmse'] < candidate['rmse'] or existing['complexity'] < candidate['complexity'])):
                is_dominated = True
                break
                
            # Check if candidate dominates existing
            if (candidate['rmse'] <= existing['rmse'] and 
                candidate['complexity'] <= existing['complexity'] and 
                (candidate['rmse'] < existing['rmse'] or candidate['complexity'] < existing['complexity'])):
                to_remove.append(existing)
        
        if not is_dominated:
            # Remove dominated existing solutions
            for item in to_remove:
                self.pareto_front.remove(item)
            
            # Add candidate
            self.pareto_front.append(candidate)
            # Sort by RMSE for easier viewing
            self.pareto_front.sort(key=lambda x: x['rmse'])

    def _expand_batch(self, nodes, x_values, y_values):
        """
        Batched expansion. Returns list of values.
        """
        if not nodes:
            return []
            
        # Prepare inputs
        x_tensor = torch.tensor(x_values, dtype=torch.float32).unsqueeze(0).to(self.device)
        y_tensor = torch.tensor(y_values, dtype=torch.float32).unsqueeze(0).to(self.device)
        
        # Repeat X/Y for batch
        batch_size = len(nodes)
        x_batch = x_tensor.repeat(batch_size, 1, 1).squeeze(1) # [batch, points]
        y_batch = y_tensor.repeat(batch_size, 1, 1).squeeze(1) # [batch, points]
        
        # Prepare sequences
        # Find max len
        max_len = 0
        seqs = []
        for n in nodes:
            s = [self.sos_id] + [TOKEN_TO_ID[t] for t in n.tokens]
            seqs.append(s)
            max_len = max(max_len, len(s))
            
        # Pad and stack
        input_tensor = torch.full((batch_size, max_len), self.sos_id, dtype=torch.long).to(self.device)
        for i, s in enumerate(seqs):
            input_tensor[i, :len(s)] = torch.tensor(s, dtype=torch.long)
            
        # Inference
        with torch.no_grad():
            logits, value_preds = self.model(x_batch, y_batch, input_tensor)
            
        # Process results
        values = []
        
        # To CPU numpy for probability processing
        probs_batch = torch.softmax(logits[:, -1, :self.vocab_size], dim=1).cpu().numpy()
        value_preds = value_preds.cpu().numpy() # [batch, 3]
        
        for i, node in enumerate(nodes):
            # 1. Store Value (Median for now)
            # value_preds is [batch, 3] -> (Pessimistic, Median, Optimistic)
            # We use Median (index 1) for standard UCB.
            val_pred = value_preds[i, 1] 
            val = float(np.clip(val_pred, 0.0, 1.0))
            values.append(val)
            
            # 2. Expand children
            node_probs = probs_batch[i]
            valid_next = self._get_valid_next_tokens(node.tokens)
            
            for idx in valid_next:
                token = VOCABULARY[idx]
                prior = node_probs[idx]
                child = MCTSNode(tokens=node.tokens + [token], parent=node, prior=prior)
                node.children[token] = child
            
            node.is_expanded = True
            
        return values

    def _get_valid_next_tokens(self, tokens):
        """Simple grammar check."""
        open_slots = 1
        for t in tokens:
            if t in OPERATORS:
                open_slots += OPERATORS[t] - 1
            else:
                open_slots -= 1
        
        if open_slots <= 0:
            return []
        return list(range(self.vocab_size))

    def _is_complete_tree(self, tokens):
        if not tokens: return False
        try:
            tree = ExpressionTree(tokens)
            # Basic validation
            if len(tokens) > self.max_depth * 2: return False
            return tree.is_valid
        except:
            return False

    def _evaluate_formula(self, tokens, x, y):
        try:
            tree = ExpressionTree(tokens)
            _, rmse = optimize_constants(tree, x, y)
            return rmse
        except:
            return 1e9

    def get_training_examples(self, root):
        """
        Extrae ejemplos de entrenamiento del árbol generado.
        Retorna: lista de (state_tokens, policy_probs, value_target)
        """
        examples = []
        queue = [root]
        
        while queue:
            node = queue.pop(0)
            if node.visit_count < 1: 
                continue
            
            # Policy Target (Pi)
            # Distribución de visitas de los hijos
            counts = {}
            total_visits = 0
            has_children = False
            
            for token_id, child in node.children.items():
                # child key is token STRING or ID?
                # In _expand_batch: node.children[token] = child.
                # token = VOCABULARY[idx] (String).
                # So keys are strings.
                # But we need ID for probabilities array index.
                if token_id in TOKEN_TO_ID:
                    tid = TOKEN_TO_ID[token_id]
                    counts[tid] = child.visit_count
                    total_visits += child.visit_count
                    queue.append(child)
                    has_children = True
            
            if not has_children or total_visits == 0:
                continue
                
            # Construir vector de probabilidades
            probs = np.zeros(self.vocab_size, dtype=np.float32)
            for tid, count in counts.items():
                probs[tid] = count / total_visits
            
            # Value Target (V)
            # Usamos el Q-value (valor esperado) del nodo como target para el Value Head.
            # Q = value_sum / visit_count
            v = node.value_sum / node.visit_count
            
            # State: node.tokens (lista de ids?)
            # node.tokens is list of strings (from VOCABULARY).
            # self_play.py expects tokens as strings in ReplayBuffer.add.
            examples.append((node.tokens, probs, v))
            
        return examples


In [None]:
%%writefile search/beam_search.py
"""
Beam Search for AlphaSymbolic.
Explores multiple formula candidates in parallel, keeping top-K at each step.
"""
import torch
import numpy as np
from core.grammar import VOCABULARY, OPERATORS, TOKEN_TO_ID, ExpressionTree, OPERATOR_STAGES
from utils.optimize_constants import optimize_constants

class BeamSearch:
    def __init__(self, model, device, beam_width=10, max_length=30, curriculum_stage=None):
        self.model = model
        self.device = device
        self.beam_width = beam_width
        self.max_length = max_length
        self.vocab_size = len(VOCABULARY)
        self.sos_id = self.vocab_size  # SOS token ID
        
        # Build token mask based on curriculum stage
        self.token_mask = None
        if curriculum_stage is not None:
            allowed_ops = OPERATOR_STAGES.get(curriculum_stage, list(OPERATORS.keys()))
            allowed_tokens = set(['x', 'C', '0', '1', '2', '3', '5', '10', 'pi', 'e'])
            allowed_tokens.update(allowed_ops)
            
            # Create mask: 0 for allowed, -inf for disallowed
            mask = torch.full((self.vocab_size,), float('-inf'), device=device)
            for token in allowed_tokens:
                if token in TOKEN_TO_ID:
                    mask[TOKEN_TO_ID[token]] = 0.0
            self.token_mask = mask
        
    def search(self, x_values, y_values, return_partial=False):
        """
        Beam Search to find the best formula structure.
        """
        # Prepare data once
        x_tensor = torch.tensor(x_values, dtype=torch.float32).unsqueeze(0).to(self.device) # [1, points]
        y_tensor = torch.tensor(y_values, dtype=torch.float32).unsqueeze(0).to(self.device) # [1, points]
        
        # Each element in beams is just the sequence of tokens (list of strings)
        # We track scores and open branches in parallel lists or a list of dicts
        beams = [{'seq': [], 'log_prob': 0.0, 'open': 1}]
        
        completed = []
        
        for step in range(self.max_length):
            if not beams:
                break
                
            # Filter valid beams just in case
            active_beams = [b for b in beams if b['open'] > 0]
            if not active_beams:
                break
                
            # Prepare batch for model
            # Batch size = number of active beams
            batch_size = len(active_beams)
            
            # Expand X and Y to match batch size [batch, points]
            x_batch = x_tensor.expand(batch_size, -1)
            y_batch = y_tensor.expand(batch_size, -1)
            
            # Prepare input sequences [batch, current_seq_len]
            # Must prepend SOS token
            seqs = [[self.sos_id] + [TOKEN_TO_ID[t] for t in b['seq']] for b in active_beams]
            input_tensor = torch.tensor(seqs, dtype=torch.long).to(self.device)
            
            # Single model call for all beams
            with torch.no_grad():
                logits, _ = self.model(x_batch, y_batch, input_tensor)
            
            # Logits shape: [batch, seq_len, vocab_size]
            # We want the last token's probabilities
            last_token_logits = logits[:, -1, :self.vocab_size]
            
            # Apply curriculum mask if set
            if self.token_mask is not None:
                last_token_logits = last_token_logits + self.token_mask
            
            log_probs = torch.log_softmax(last_token_logits, dim=-1) # [batch, vocab]
            
            # --- Repetition Penalty (Simple) ---
            # If the same token was generated recently, penalize it slightly.
            # This prevents 10 ////////// loops.
            penalty_factor = 2.0  # Reduce log_prob (which is negative) by absolute amount or multiplier?
            # Log probs are negative (e.g. -0.1). Making them MORE negative penalizes.
            # If we multiply by 1.2, -0.1 becomes -0.12 (lower probability).
            
            for i, beam in enumerate(active_beams):
                if beam['seq']:
                     # Get last token ID
                    last_token = beam['seq'][-1]
                    if last_token in TOKEN_TO_ID:
                        last_id = TOKEN_TO_ID[last_token]
                        # Penalize current step logits for this token
                        # If log_prob is close to 0 (high prob), e.g. -0.01 -> -0.012
                        # If log_prob is -10 (low prob), -> -12
                        # Check bounds to avoid NaN if -inf
                        if log_probs[i, last_id] > -1e9:
                             log_probs[i, last_id] *= 1.5 
            # -----------------------------------
            
            # We need to find the top-K candidates ACROSS current beams?
            # Standard beam search: expand all, then prune to K
            
            all_candidates = []
            
            # Get top-K for EACH beam to avoid explosion (e.g. top 2*width)
            k_per_beam = min(self.beam_width, self.vocab_size)
            beam_topk_scores, beam_topk_indices = torch.topk(log_probs, k_per_beam, dim=-1)
            
            # Move to CPU for processing logic
            beam_topk_scores = beam_topk_scores.cpu().numpy()
            beam_topk_indices = beam_topk_indices.cpu().numpy()
            
            for i, beam in enumerate(active_beams):
                for score, idx in zip(beam_topk_scores[i], beam_topk_indices[i]):
                    token = VOCABULARY[idx]
                    new_seq = beam['seq'] + [token]
                    
                    # Calculate new open branches
                    if token in OPERATORS:
                        new_open = beam['open'] + OPERATORS[token] - 1
                    else:
                        new_open = beam['open'] - 1
                    
                    if new_open < 0:
                        continue
                        
                    all_candidates.append({
                        'seq': new_seq,
                        'log_prob': beam['log_prob'] + score,
                        'open': new_open
                    })
            
            # Global prune: keep top beam_width
            all_candidates.sort(key=lambda x: x['log_prob'], reverse=True)
            beams = all_candidates[:self.beam_width]
            
            # Check for completions
            still_active = []
            for b in beams:
                if b['open'] == 0:
                    completed.append(b)
                else:
                    still_active.append(b)
            
            beams = still_active
            # If we filled up on completions, we might still want to explore? 
            # Usually we keep exploring until all beams are done or max length
            if len(completed) >= self.beam_width:
                 # Optional: early exit if we found enough good candidates
                 pass

        # Evaluate results
        scored_results = []
        for beam in completed:
            tree = ExpressionTree(beam['seq'])
            if tree.is_valid:
                constants, rmse = optimize_constants(tree, x_values, y_values)
                scored_results.append({
                    'tokens': beam['seq'],
                    'log_prob': beam['log_prob'],
                    'rmse': rmse,
                    'constants': constants,
                    'formula': tree.get_infix()
                })
        
        scored_results.sort(key=lambda x: x['rmse'])
        
        # If no results and return_partial is requested, return the best incomplete beam
        if not scored_results and return_partial and beams:
            # Take the beam with highest probability
            best_beam = beams[0] 
            # Construct a partial result
            # We can't optimize constants or get a valid infix easily, but we can show tokens
            scored_results.append({
                'tokens': best_beam['seq'],
                'log_prob': best_beam['log_prob'],
                'rmse': float('inf'),
                'constants': {},
                'formula': "Partial: " + " ".join(best_beam['seq']) + "..."
            })
            
        return scored_results


def beam_solve(target_x, target_y, model, device, beam_width=20, max_length=25):
    """
    Solve symbolic regression using beam search.
    """
    searcher = BeamSearch(model, device, beam_width=beam_width, max_length=max_length)
    results = searcher.search(target_x, target_y)
    
    if not results:
        return None
        
    return results  # Return all results for Pareto analysis


if __name__ == "__main__":
    from core.model import AlphaSymbolicModel
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    VOCAB_SIZE = len(VOCABULARY)
    
    model = AlphaSymbolicModel(vocab_size=VOCAB_SIZE + 1, d_model=64).to(DEVICE)
    try:
        model.load_state_dict(torch.load("alpha_symbolic_model.pth", map_location=DEVICE, weights_only=True))
    except:
        print("Model not found, using random weights")
    model.eval()
    
    # Test
    x_test = np.linspace(-5, 5, 20).astype(np.float64)
    y_test = 2 * x_test + 3
    
    print("Running Beam Search...")
    results = beam_solve(x_test, y_test, model, DEVICE, beam_width=10)
    
    print(f"\nFound {len(results)} valid formulas:")
    for i, r in enumerate(results[:5]):
        print(f"  {i+1}. {r['formula']} (RMSE: {r['rmse']:.4f})")


In [None]:
%%writefile search/hybrid_search.py
import time
import torch
import numpy as np
from typing import List, Dict, Any, Optional

from core.gp_bridge import GPEngine
from search.beam_search import BeamSearch, beam_solve

def hybrid_solve(
    x_values: np.ndarray,
    y_values: np.ndarray,
    model: torch.nn.Module,
    device: torch.device,
    beam_width: int = 50,
    gp_timeout: int = 10,
    gp_binary_path: Optional[str] = None
) -> Dict[str, Any]:
    """
    Solves Symbolic Regression using a Hybrid Neuro-Evolutionary approach.
    
    Phase 1: Neural Beam Search (The Brain)
             - Rapidly scans the search space.
             - Generates diverse, high-likelihood formula skeletons.
             
    Phase 2: Genetic Programming Refinement (The Muscle)
             - Takes the best skeletons from Phase 1.
             - Uses GPU-accelerated evolution to optimize constants and structure.
             - Runs for `gp_timeout` seconds.
             
    Returns:
        Best found formula result dict.
    """
    
    print(f"--- Starting Alpha-GP Hybrid Search ---")
    start_time = time.time()
    
    # 1. Neural Beam Search (Phase 1)
    print(f"[Phase 1] Neural Beam Search (Width={beam_width})...")
    # We use a larger beam width to ensure diversity for the GP
    # If the user requests beam_width=X, we might want to multiply it for the "seeds"
    # But let's stick to what is passed.
    
    neural_results = beam_solve(x_values, y_values, model, device, beam_width=beam_width)
    
    seeds = []
    if neural_results:
        print(f"[Phase 1] Found {len(neural_results)} candidates.")
        # Extract formulas tokens/string
        # neural_results is a list of dicts with 'formula' key (infix string)
        # GPEngine expects infix strings (e.g. "((x*x)+2)")
        
        # Filter for uniqueness and validity
        seen_formulas = set()
        for res in neural_results:
            f_str = res['formula']
            # Basic validation: must verify it's not a Partial result
            if f_str.startswith("Partial"): continue
            
            if f_str not in seen_formulas:
                seeds.append(f_str)
                seen_formulas.add(f_str)
        
        print(f"[Phase 1] Generated {len(seeds)} unique seeds for GP.")
        if len(seeds) > 0:
            print(f"Top Seed: {seeds[0]}")
    else:
        print("[Phase 1] No valid candidates found (Beam Search failed).")
        print("[Phase 1] Falling back to pure GP (Random Initialization).")
        seeds = []

    # 2. GP Refinement (Phase 2)
    print(f"[Phase 2] GPU Genetic Refinement (Timeout={gp_timeout}s)...")
    gp_engine = GPEngine(binary_path=gp_binary_path)
    
    # Run GP
    # We pass the seeds. GP engine handles the rest.
    # Ensure x_values and y_values are lists for gp_engine
    x_list = x_values.tolist() if hasattr(x_values, 'tolist') else list(x_values)
    y_list = y_values.tolist() if hasattr(y_values, 'tolist') else list(y_values)
    gp_result_str = gp_engine.run(x_list, y_list, seeds, timeout_sec=gp_timeout)
    
    total_time = time.time() - start_time
    
    if gp_result_str:
        print(f"--- Hybrid Search Completed in {total_time:.2f}s ---")
        print(f"Best Formula: {gp_result_str}")
        
        # Construct a result dict similar to Beam Search for consistency
        # Ideally we would evaluate it here to get RMSE, but GP output doesn't give us RMSE directly in a structured way (only stdout).
        # We can implement a quick evaluator if needed, or assume the user trusts the string.
        # For UI display, we probably want RMSE.
        
        return {
            'formula': gp_result_str,
            'rmse': 0.0, # Placeholder, will be evaluated by UI if needed or we can do it here
            'source': 'Alpha-GP Hybrid',
            'time': total_time
        }
    else:
        print(f"--- Hybrid Search Failed (GP did not return valid result) ---")
        return None

if __name__ == "__main__":
    # Test
    # Mock Model
    class MockModel(torch.nn.Module):
        def forward(self, x, y, seq):
            # Return random logits
            bs, seq_len = seq.shape
            vocab = 20
            return torch.randn(bs, seq_len, vocab), None

    print("Testing Hybrid Search...")
    x = np.linspace(-5, 5, 10)
    y = x**2
    try:
        res = hybrid_solve(x, y, MockModel(), torch.device("cpu"), beam_width=5)
        print(res)
    except Exception as e:
        print(f"Test failed: {e}")


In [None]:
%%writefile search/pareto.py
"""
Pareto Front Manager for AlphaSymbolic.
Maintains a set of non-dominated solutions (accuracy vs complexity).
"""
import numpy as np
from core.grammar import ExpressionTree

class ParetoSolution:
    def __init__(self, tokens, rmse, complexity, formula_str, constants=None):
        self.tokens = tokens
        self.rmse = rmse  # Lower is better
        self.complexity = complexity  # Lower is better (number of nodes)
        self.formula = formula_str
        self.constants = constants or {}
        
    def dominates(self, other):
        """Returns True if self dominates other (better in all objectives)."""
        # Self dominates other if:
        # - Self is at least as good in all objectives
        # - Self is strictly better in at least one objective
        at_least_as_good = (self.rmse <= other.rmse) and (self.complexity <= other.complexity)
        strictly_better = (self.rmse < other.rmse) or (self.complexity < other.complexity)
        return at_least_as_good and strictly_better
    
    def __repr__(self):
        return f"ParetoSolution(rmse={self.rmse:.4f}, complexity={self.complexity}, formula='{self.formula}')"


class ParetoFront:
    def __init__(self, max_size=50):
        self.solutions = []
        self.max_size = max_size
        
    def add(self, solution):
        """
        Attempts to add a solution to the Pareto front.
        Returns True if added, False if dominated.
        """
        # Check if new solution is dominated by any existing
        for existing in self.solutions:
            if existing.dominates(solution):
                return False  # New solution is dominated
        
        # Remove any solutions dominated by the new one
        self.solutions = [s for s in self.solutions if not solution.dominates(s)]
        
        # Add the new solution
        self.solutions.append(solution)
        
        # Enforce max size by removing worst solutions
        if len(self.solutions) > self.max_size:
            # Sort by a combined score and keep top max_size
            self.solutions.sort(key=lambda s: s.rmse + 0.01 * s.complexity)
            self.solutions = self.solutions[:self.max_size]
        
        return True
    
    def add_from_results(self, results_list):
        """
        Add multiple results from beam search or MCTS.
        results_list: list of dicts with 'tokens', 'rmse', 'constants', 'formula'
        """
        added = 0
        for r in results_list:
            tree = ExpressionTree(r['tokens'])
            complexity = len(r['tokens'])  # Simple complexity = token count
            
            sol = ParetoSolution(
                tokens=r['tokens'],
                rmse=r['rmse'],
                complexity=complexity,
                formula_str=r['formula'],
                constants=r.get('constants', {})
            )
            
            if self.add(sol):
                added += 1
        
        return added
    
    def get_best_by_rmse(self):
        """Returns the solution with lowest RMSE."""
        if not self.solutions:
            return None
        return min(self.solutions, key=lambda s: s.rmse)
    
    def get_simplest(self):
        """Returns the solution with lowest complexity."""
        if not self.solutions:
            return None
        return min(self.solutions, key=lambda s: s.complexity)
    
    def get_balanced(self, alpha=0.5):
        """
        Returns a balanced solution.
        alpha: weight for RMSE (1-alpha for complexity)
        """
        if not self.solutions:
            return None
        
        # Normalize scores
        rmse_vals = [s.rmse for s in self.solutions]
        comp_vals = [s.complexity for s in self.solutions]
        
        min_rmse, max_rmse = min(rmse_vals), max(rmse_vals) + 1e-10
        min_comp, max_comp = min(comp_vals), max(comp_vals) + 1e-10
        
        def score(s):
            norm_rmse = (s.rmse - min_rmse) / (max_rmse - min_rmse)
            norm_comp = (s.complexity - min_comp) / (max_comp - min_comp)
            return alpha * norm_rmse + (1 - alpha) * norm_comp
        
        return min(self.solutions, key=score)
    
    def summary(self):
        """Print a summary of the Pareto front."""
        print(f"\n=== Pareto Front ({len(self.solutions)} solutions) ===")
        for i, sol in enumerate(sorted(self.solutions, key=lambda s: s.rmse)[:10]):
            print(f"  {i+1}. RMSE={sol.rmse:.6f}, Nodes={sol.complexity}, Formula: {sol.formula}")


# Quick test
if __name__ == "__main__":
    front = ParetoFront()
    
    # Add some test solutions
    solutions = [
        ParetoSolution(['x'], 10.0, 1, "x"),
        ParetoSolution(['+', 'x', '1'], 5.0, 3, "(x + 1)"),
        ParetoSolution(['*', '2', 'x'], 3.0, 3, "(2 * x)"),
        ParetoSolution(['+', '*', '2', 'x', '3'], 0.5, 5, "((2 * x) + 3)"),
        ParetoSolution(['+', '*', '*', '2', 'x', 'x', '+', 'x', '1'], 0.1, 9, "complicated"),
    ]
    
    for sol in solutions:
        added = front.add(sol)
        print(f"Added {sol.formula}: {added}")
    
    front.summary()
    
    print(f"\nBest by RMSE: {front.get_best_by_rmse()}")
    print(f"Simplest: {front.get_simplest()}")
    print(f"Balanced: {front.get_balanced()}")


In [None]:
%%writefile search/__init__.py


In [None]:
%%writefile ui/app_core.py
"""
Core state and model management for AlphaSymbolic Gradio App.
"""
import torch
import os
from core.model import AlphaSymbolicModel
from core.grammar import VOCABULARY

from collections import deque
import time

# Global state
MODEL = None
DEVICE = None
TRAINING_STATUS = {"running": False, "epoch": 0, "loss": 0, "message": "Listo"}
STOP_TRAINING = False  # Flag to request training stop

def request_stop_training():
    """Request training to stop gracefully."""
    global STOP_TRAINING
    STOP_TRAINING = True
    return "⏹️ Deteniendo entrenamiento..."

def should_stop_training():
    """Check if training should stop."""
    return STOP_TRAINING

def reset_stop_flag():
    """Reset the stop flag (call at start of training)."""
    global STOP_TRAINING
    STOP_TRAINING = False

# Hall of Shame: Rolling buffer of recent failures
# Format: {'time': str, 'target': str, 'predicted': str, 'loss': float, 'stage': str}
TRAINING_ERRORS = deque(maxlen=20)

def add_training_error(target, predicted, loss, stage):
    """Add an error to the Hall of Shame."""
    TRAINING_ERRORS.append({
        'time': time.strftime("%H:%M:%S"),
        'target': target,
        'predicted': predicted,
        'loss': float(loss),
        'stage': stage
    })

def get_training_errors():
    """Get list of errors for the UI."""
    return list(TRAINING_ERRORS)

MODEL_PRESETS = {
    'lite': {'d_model': 128, 'nhead': 4, 'num_encoder_layers': 3, 'num_decoder_layers': 3},
    'pro': {'d_model': 256, 'nhead': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6}
}
CURRENT_PRESET = 'lite'

def get_device(force_cpu=False):
    """Get the best available device (CUDA > MPS > CPU)."""
    if force_cpu:
        return torch.device("cpu")
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def set_device(use_gpu=True):
    """Set the device (GPU or CPU)."""
    global DEVICE, MODEL
    new_device = get_device(force_cpu=not use_gpu)
    
    if MODEL is not None and DEVICE != new_device:
        MODEL = MODEL.to(new_device)
    
    DEVICE = new_device
    return get_device_info()

def get_device_info():
    """Get device info string."""
    global DEVICE
    if DEVICE is None:
        DEVICE = get_device()
    
    if DEVICE.type == "cuda":
        return f"CUDA ({torch.cuda.get_device_name(0)})"
    elif DEVICE.type == "mps":
        return "MPS (Apple Silicon)"
    else:
        return "CPU"

def load_model(force_reload=False, preset_name=None):
    """Load or reload the model."""
    global MODEL, DEVICE, CURRENT_PRESET
    
    if preset_name:
        CURRENT_PRESET = preset_name
    
    if DEVICE is None:
        DEVICE = get_device()
    
    VOCAB_SIZE = len(VOCABULARY)
    config = MODEL_PRESETS[CURRENT_PRESET]
    
    print(f"Loading Model [{CURRENT_PRESET.upper()}]...")
    MODEL = AlphaSymbolicModel(
        vocab_size=VOCAB_SIZE + 1, 
        d_model=config['d_model'], 
        nhead=config['nhead'],
        num_encoder_layers=config['num_encoder_layers'], 
        num_decoder_layers=config['num_decoder_layers']
    ).to(DEVICE)
    
    filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
    status = f"Nuevo modelo ({CURRENT_PRESET})" # Default status
    
    if os.path.exists(filename):
        try:
            state_dict = torch.load(filename, map_location=DEVICE, weights_only=True)
            
            # Check for NaNs
            has_nans = False
            for k, v in state_dict.items():
                if torch.isnan(v).any() or torch.isinf(v).any():
                    has_nans = True
                    break
            
            if has_nans:
                print(f"⚠️ Modelo corrupto detectado (NaNs) en {filename}. Eliminando y esperando reinicio.")
                try:
                    os.remove(filename)
                    print("✅ Archivo corrupto eliminado.")
                except OSError as e:
                    print(f"Error al eliminar archivo: {e}")
                status = "⚠️ Modelo corrupto eliminado y reiniciado"
            else:
                MODEL.load_state_dict(state_dict)
                MODEL.eval()
                status = f"Modelo cargado ({CURRENT_PRESET})"
                
        except RuntimeError as e:
            print(f"⚠️ Error de compatibilidad ({e}). Iniciando modelo fresco.")
            status = f"Nuevo modelo ({CURRENT_PRESET})"
        except Exception as e:
            print(f"Error cargando: {e}")
            status = "Sin modelo pre-entrenado"
    
    return status, get_device_info()

def get_model():
    """Get the current model, loading if needed."""
    global MODEL, DEVICE
    if MODEL is None:
        load_model()
    return MODEL, DEVICE

def save_model():
    """Save the current model."""
    global MODEL, CURRENT_PRESET
    if MODEL is not None:
        filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
        torch.save(MODEL.state_dict(), filename)


In [None]:
%%writefile ui/app_search.py
"""
Search/Solve functions for AlphaSymbolic Gradio App.
Supports both Beam Search and MCTS.
"""
import numpy as np
import matplotlib.pyplot as plt
import time
import gradio as gr

from core.grammar import ExpressionTree
from search.beam_search import BeamSearch
from search.mcts import MCTS
from search.hybrid_search import hybrid_solve
from utils.simplify import simplify_tree
from search.pareto import ParetoFront
from utils.detect_pattern import detect_pattern
from utils.optimize_constants import optimize_constants, substitute_constants
from ui.app_core import get_model


def parse_data(x_str, y_str):
    """Parse comma-separated input strings."""
    try:
        x = np.array([float(v.strip()) for v in x_str.split(',')], dtype=np.float64)
        y = np.array([float(v.strip()) for v in y_str.split(',')], dtype=np.float64)
        if len(x) != len(y):
            return None, None, "Error: X e Y deben tener igual longitud"
        return x, y, None
    except Exception as e:
        return None, None, f"Error: {str(e)}"


def create_fit_plot(x, y, y_pred, formula):
    """Create a plot showing data vs prediction."""
    fig, ax = plt.subplots(figsize=(8, 5), facecolor='#1a1a2e')
    ax.set_facecolor('#1a1a2e')
    
    ax.scatter(x, y, color='#00d4ff', s=100, label='Datos Reales', zorder=3, edgecolors='white', linewidth=1)
    
    sort_idx = np.argsort(x)
    ax.plot(x[sort_idx], y_pred[sort_idx], color='#ff6b6b', linewidth=3, label='Prediccion', zorder=2)
    
    ax.set_xlabel('X', color='white', fontsize=12)
    ax.set_ylabel('Y', color='white', fontsize=12)
    ax.set_title('Ajuste de la Formula', color='white', fontsize=14, fontweight='bold')
    ax.legend(facecolor='#16213e', edgecolor='#00d4ff', labelcolor='white')
    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.2, color='white')
    
    for spine in ax.spines.values():
        spine.set_color('#00d4ff')
    
    plt.tight_layout()
    return fig


def solve_formula(x_str, y_str, beam_width, search_method, progress=gr.Progress()):
    """Main solving function with search method selection."""
    x, y, error = parse_data(x_str, y_str)
    if error:
        return error, None, "", "", ""
    
    MODEL, DEVICE = get_model()
    
    progress(0.1, desc=f"Analizando patron... [{DEVICE.type.upper()}]")
    pattern = detect_pattern(x, y)
    
    progress(0.3, desc=f"Buscando formulas ({search_method})... [{DEVICE.type.upper()}]")
    start_time = time.time()
    
    results = []
    
    if search_method == "Alpha-GP Hybrid":
        # Using hybrid search
        progress(0.4, desc="Fase 1: Neural Beam Search...")
        # Note: Hybrid search handles its own phases printing, but we want UI updates.
        # We pass beam_width. gp_timeout is increased to 30s to allow convergence on complex problems.
        hybrid_res = hybrid_solve(x, y, MODEL, DEVICE, beam_width=int(beam_width), gp_timeout=30)
        
        if hybrid_res:
            progress(0.9, desc="Procesando resultados GP...")
            # Convert infix string back to tokens for consistency
            tree = ExpressionTree.from_infix(hybrid_res['formula'])
            if tree.is_valid:
                 # Evaluate RMSE roughly (GP result should be good, but let's confirm)
                 # Optimization is already done by GP, but we might want to fine-tune 
                 # or at least extract constants if they are numbers in the string.
                 # The string from GP has numbers like 2.345 embedded.
                 # optimize_constants expects a tree with 'C' placeholders if we want to re-optimize.
                 # But GP output is fully instantiated.
                 # So we just evaluate.
                 
                 y_pred_check = tree.evaluate(x)
                 rmse_check = np.sqrt(np.mean((y_pred_check - y)**2))
                 
                 results = [{
                     'tokens': tree.tokens,
                     'formula': tree.get_infix(),
                     'rmse': rmse_check,
                     'constants': {} # Constants are baked into the formula string
                 }]
    
    elif search_method == "Beam Search":
        searcher = BeamSearch(MODEL, DEVICE, beam_width=int(beam_width), max_length=25)
        results = searcher.search(x, y)
    else:  # MCTS
        mcts = MCTS(MODEL, DEVICE, max_simulations=int(beam_width) * 10)
        result = mcts.search(x, y)
        if result and result.get('tokens'):
            tokens = result['tokens']
            tree = ExpressionTree(tokens)
            if tree.is_valid:
                constants, rmse = optimize_constants(tree, x, y)
                results = [{
                    'tokens': tokens,
                    'formula': tree.get_infix(),
                    'rmse': rmse,
                    'constants': constants
                }]
    
    search_time = time.time() - start_time
    
    if not results:
        return "No se encontraron formulas validas", None, "", "", ""
    
    progress(0.7, desc="Optimizando constantes...")
    pareto = ParetoFront()
    pareto.add_from_results(results)
    best = pareto.get_best_by_rmse()
    
    if not best:
        return "Error en optimizacion", None, "", "", ""
    
    progress(0.9, desc="Simplificando...")
    tree = ExpressionTree(best.tokens)
    simplified = simplify_tree(tree)
    y_pred = tree.evaluate(x, constants=best.constants)
    
    # Substitute constants for display
    substituted_formula = simplified
    if best.constants:
        try:
            positions = tree.root.get_constant_positions()
            # We use the raw infix for substitution to ensure matching C positions
            raw_infix = tree.get_infix()
            substituted_formula = substitute_constants(raw_infix, best.constants, positions)
        except:
            substituted_formula = simplified
    
    fig = create_fit_plot(x, y, y_pred, simplified)
    
    # Format results
    result_html = f"""
    <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #00d4ff;">
        <h2 style="color: #00d4ff; margin: 0; font-size: 24px;">Formula Encontrada</h2>
        <div style="background: #0f0f23; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #ff6b6b;">
            <code style="color: #ff6b6b; font-size: 28px; font-weight: bold;">{substituted_formula}</code>
        </div>
        <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 10px;">
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">RMSE</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{best.rmse:.6f}</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Nodos</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{best.complexity}</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Tiempo</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{search_time:.2f}s</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Metodo</span><br>
                <span style="color: #4ade80; font-size: 16px; font-weight: bold;">{search_method}</span>
            </div>
        </div>
        <div style="margin-top: 15px; padding: 10px; background: #0f0f23; border-radius: 8px;">
            <span style="color: #888;">Patron:</span> 
            <span style="color: #ffd93d;">{pattern['type']}</span> 
            <span style="color: #666;">({pattern['confidence']:.0%})</span>
            <span style="color: #888; margin-left: 20px;">Device:</span>
            <span style="color: #4ade80;">{DEVICE.type.upper()}</span>
        </div>
    """
    
    # Add constants if any
    # Add constants if any
    if best.constants:
        # Sort and format cleanly
        sorted_items = sorted(best.constants.items(), key=lambda x: str(x[0]))
        clean_consts = []
        for i, (k, v) in enumerate(sorted_items):
            clean_consts.append(f"C_{i+1}: {v:.4f}")
        const_str = "  |  ".join(clean_consts)
        
        result_html += f"""
        <div style="margin-top: 10px; padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid #ffd93d;">
            <span style="color: #888;">Constantes:</span>
            <span style="color: #fff; font-family: monospace; margin-left: 10px;">{const_str}</span>
        </div>
        """
        
    result_html += "</div>"
    
    # Predictions table
    pred_html = '<table style="width: 100%; border-collapse: collapse; background: #1a1a2e; border-radius: 10px; overflow: hidden;">'
    pred_html += '<tr style="background: #16213e;"><th style="padding: 10px; color: #00d4ff;">X</th><th style="color: #00d4ff;">Pred</th><th style="color: #00d4ff;">Real</th><th style="color: #00d4ff;">Delta</th></tr>'
    for i in range(min(50, len(x))):
        delta = abs(y_pred[i] - y[i])
        color = "#4ade80" if delta < 0.1 else "#fbbf24" if delta < 1 else "#ef4444"
        pred_html += f'<tr style="border-bottom: 1px solid #333;"><td style="padding: 8px; color: white; text-align: center;">{x[i]:.2f}</td><td style="color: white; text-align: center;">{y_pred[i]:.4f}</td><td style="color: white; text-align: center;">{y[i]:.4f}</td><td style="color: {color}; text-align: center; font-weight: bold;">{delta:.4f}</td></tr>'
    pred_html += '</table>'
    
    # Alternatives
    alt_html = '<div style="background: #1a1a2e; padding: 15px; border-radius: 10px;">'
    alt_html += '<h4 style="color: #00d4ff; margin-top: 0;">Alternativas</h4>'
    for i, sol in enumerate(pareto.solutions[:4]):
        alt_html += f'<div style="padding: 5px 10px; margin: 5px 0; background: #0f0f23; border-radius: 5px; border-left: 3px solid {"#00d4ff" if i == 0 else "#666"};"><code style="color: {"#ff6b6b" if i == 0 else "#888"};">{sol.formula}</code> <span style="color: #666; font-size: 12px;">RMSE: {sol.rmse:.4f}</span></div>'
    alt_html += '</div>'
    
    return result_html, fig, pred_html, alt_html, simplified


def generate_example(tipo):
    """Generate example data."""
    if tipo == "lineal":
        x = np.linspace(1, 10, 10)
        y = 2 * x + 3
    elif tipo == "cuadratico":
        x = np.linspace(-5, 5, 11)
        y = x**2 + 1
    elif tipo == "trig":
        x = np.linspace(0, 6.28, 20)
        y = np.sin(x)
    elif tipo == "exp":
        x = np.linspace(0, 5, 15)
        y = 2 * np.exp(0.5 * x)
    else:
        x = np.linspace(1, 10, 10)
        y = 2 * x + 3
    
    return ", ".join([f"{v:.2f}" for v in x]), ", ".join([f"{v:.4f}" for v in y])


In [None]:
%%writefile ui/app_training.py
"""
Training functions for AlphaSymbolic Gradio App.
With proper data normalization.
"""
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from collections import deque
import random
import time

from core.grammar import VOCABULARY, TOKEN_TO_ID, OPERATORS, OPERATOR_STAGES
from data.synthetic_data import DataGenerator
from ui.app_core import get_model, save_model, TRAINING_STATUS, add_training_error, should_stop_training, reset_stop_flag
from core.loss import QuantileLoss
from search.hybrid_search import hybrid_solve
from core.grammar import ExpressionTree, simplify_formula


def get_allowed_token_mask(stage, vocab_size, device):
    """
    Creates a mask tensor for token logits.
    Allowed tokens = 1.0, Disallowed = 0.0 (for multiplication mask)
    Or returns indices of allowed tokens for -inf masking.
    """
    allowed_ops = OPERATOR_STAGES.get(stage, list(OPERATORS.keys()))
    
    # All terminals are always allowed
    allowed_tokens = set(['x', 'C', '0', '1', '2', '3', '5', '10', 'pi', 'e'])
    allowed_tokens.update(allowed_ops)
    
    # Build mask
    mask = torch.zeros(vocab_size + 1, device=device)  # +1 for SOS token
    for token in allowed_tokens:
        if token in TOKEN_TO_ID:
            mask[TOKEN_TO_ID[token]] = 1.0
    mask[vocab_size] = 1.0  # SOS always allowed
    
    return mask


def normalize_batch(x_list, y_list):
    """Normalize X and Y values to prevent numerical instability."""
    normalized_x = []
    normalized_y = []
    
    for x, y in zip(x_list, y_list):
        # Normalize X to [-1, 1]
        x_min, x_max = x.min(), x.max()
        if x_max - x_min > 1e-6:
            x_norm = 2 * (x - x_min) / (x_max - x_min) - 1
        else:
            x_norm = np.zeros_like(x)
        
        # Normalize Y to [-1, 1] 
        y_min, y_max = y.min(), y.max()
        if y_max - y_min > 1e-6:
            y_norm = 2 * (y - y_min) / (y_max - y_min) - 1
        else:
            y_norm = np.zeros_like(y)
        
        normalized_x.append(x_norm)
        normalized_y.append(y_norm)
    
    return normalized_x, normalized_y


def train_basic(epochs, batch_size, point_count=10, progress=gr.Progress()):
    """Basic training with synthetic data."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(epochs), eta_min=1e-6)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        data_gen = DataGenerator(max_depth=4)
        losses = []
        
        for epoch in range(int(epochs)):
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{int(epochs)} [{DEVICE.type.upper()}]")
            
            # Mix of inverse (known formulas) + random data (AlphaTensor-style)
            half_batch = int(batch_size) // 2
            batch_inverse = data_gen.generate_inverse_batch(half_batch, point_count=int(point_count))
            batch_random = data_gen.generate_batch(int(batch_size) - half_batch, point_count=int(point_count))
            batch = batch_inverse + batch_random
            if len(batch) < 2:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            
            # Normalize data
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for i, seq in enumerate(token_lists):
                decoder_input[i, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
            
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            # Forward
            optimizer.zero_grad()
            logits, _ = MODEL(x_tensor, y_tensor, decoder_input)
            loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            # Skip if loss is NaN
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            losses.append(loss.item())
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        if not losses:
            return "Error: No se pudo calcular loss (revisar datos)", None
        
        fig = create_loss_plot(losses, "Entrenamiento Basico")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #4ade80;">
            <h2 style="color: #4ade80; margin: 0;">Entrenamiento Completado</h2>
            <p style="color: white;">Epocas: {int(epochs)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #00d4ff;">Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_curriculum(epochs, batch_size, point_count=10, progress=gr.Progress()):
    """Curriculum Learning - starts simple, increases difficulty gradually."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)  # Lower LR
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        losses = []
        
        for epoch in range(int(epochs)):
            # Curriculum: slow progression
            # Stage 1 (0-50%): depth 2-3, 80% inverse data
            # Stage 2 (50-80%): depth 3-4, 50% inverse data  
            # Stage 3 (80-100%): depth 4-5, 20% inverse data
            progress_pct = epoch / epochs
            
            if progress_pct < 0.5:
                current_depth = 2 + int(progress_pct * 2)  # 2-3
                inverse_ratio = 0.8
            elif progress_pct < 0.8:
                current_depth = 3 + int((progress_pct - 0.5) * 3.3)  # 3-4
                inverse_ratio = 0.5
            else:
                current_depth = 4 + int((progress_pct - 0.8) * 5)  # 4-5
                inverse_ratio = 0.2
            
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{int(epochs)} (prof: {current_depth}, inv: {inverse_ratio:.0%}) [{DEVICE.type.upper()}]")
            
            data_gen = DataGenerator(max_depth=current_depth)
            
            # Mix inverse + random based on curriculum stage
            n_inverse = int(batch_size * inverse_ratio)
            n_random = int(batch_size) - n_inverse
            
            batch_inverse = data_gen.generate_inverse_batch(max(1, n_inverse), point_count=int(point_count)) if n_inverse > 0 else []
            batch_random = data_gen.generate_batch(max(1, n_random), point_count=int(point_count)) if n_random > 0 else []
            batch = batch_inverse + batch_random
            if len(batch) < 2:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for i, seq in enumerate(token_lists):
                decoder_input[i, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
            
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            optimizer.zero_grad()
            logits, value_pred = MODEL(x_tensor, y_tensor, decoder_input)
            
            # Policy Loss
            loss_policy = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            # Value Loss
            # For supervised learning, these are "perfect" solutions, so Value Target = 1.0
            value_targets = torch.ones_like(value_pred)
            loss_value = torch.nn.functional.mse_loss(value_pred, value_targets)
            
            # Combined Loss
            loss = loss_policy + 0.5 * loss_value
            
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            losses.append(loss.item())
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        if not losses:
            return "Error: No se pudo calcular loss", None
        
        fig = create_loss_plot(losses, "Curriculum Learning")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #00d4ff;">
            <h2 style="color: #00d4ff; margin: 0;">Curriculum Learning Completado</h2>
            <p style="color: white;">Epocas: {int(epochs)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #888;">Profundidad maxima: 6 | Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_self_play(iterations, problems_per_iter, point_count=10, progress=gr.Progress()):
    """AlphaZero Self-Play loop."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()  # Reset stop flag at start
    
    try:
        MODEL, DEVICE = get_model()
        
        from search.mcts import MCTS
        
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)
        # Scheduler: Reduce LR when plateauing to help convergence
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15, min_lr=1e-6)
        
        # Losses for AlphaZero
        # Policy: KLDiv (comparing distributions)
        # Value: Quantile Loss (3 Quantiles)
        kl_loss = torch.nn.KLDivLoss(reduction='batchmean')
        quantile_loss_fn = QuantileLoss()
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        replay_buffer = deque(maxlen=20000)
        
        # Adaptive Curriculum State
        current_depth = 2
        data_gen = DataGenerator(max_depth=current_depth)
        
        # MCTS for A100: Increase batch size and simulations significantly
        # Adjusted for RTX 3050/i5: Batch 64 is smoother (less CPU wait)
        searcher = MCTS(MODEL, DEVICE, max_simulations=500, complexity_lambda=0.1, batch_size=64)
        
        rmses = []
        losses = []
        best_avg_rmse = float('inf')
        
        start_time = time.time()
        
        for iteration in range(int(iterations)):
            # Check for stop request
            if should_stop_training():
                print("⏹️ Training stopped by user")
                break
            # ETA Calculation
            elapsed = time.time() - start_time
            if iteration > 0:
                avg_time_per_iter = elapsed / iteration
                remaining_iters = int(iterations) - iteration
                eta_seconds = remaining_iters * avg_time_per_iter
                
                # Format ETA
                if eta_seconds > 3600:
                    eta_str = f"{eta_seconds/3600:.1f}h"
                elif eta_seconds > 60:
                    eta_str = f"{eta_seconds/60:.0f}m"
                else:
                    eta_str = f"{eta_seconds:.0f}s"
            else:
                eta_str = "Calculando..."

            # Adaptive Curriculum Check
            # Stages: 0=Arithmetic, 1=Poly, 2=Trig, 3=Adv, 4=Complex
            CURRICULUM_LEVELS = [
                {'depth': 1, 'ops': ['+', '-', '*', '/']},
                {'depth': 2, 'ops': ['+', '-', '*', '/']},
                {'depth': 3, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt']},
                {'depth': 4, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos']},
                {'depth': 5, 'ops': None} # All
            ]
            
            # Initialize state if not present
            if 'curriculum_stage' not in locals():
                curriculum_stage = 0
            
            recent_rmse = np.mean(rmses[-20:]) if len(rmses) >= 20 else 1.0
            
            # Graduation condition: RMSE < 0.1 stable
            if len(rmses) > 20 and recent_rmse < 0.1 and curriculum_stage < len(CURRICULUM_LEVELS) - 1:
                curriculum_stage += 1
                stage_info = CURRICULUM_LEVELS[curriculum_stage]
                data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'])
                print(f"*** Curriculum Level Up! Stage {curriculum_stage} ({stage_info['depth']}, {stage_info['ops']}) ***")
                # Clear buffer to avoid training on old easy data? Maybe keep some for replay.
            
            # Ensure data_gen is initialized at start
            if iteration == 0:
                 stage_info = CURRICULUM_LEVELS[0]
                 data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'])

            stage_name = ["Arithmetic", "Polynomials", "Trigonometry", "Advanced", "Complex"][curriculum_stage]
            
            # Safe access to current_lr
            curr_lr_disp = optimizer.param_groups[0]['lr']
            msg = f"Iter {iteration+1}/{int(iterations)} [{stage_name}] RMSE:{recent_rmse:.3f} LR:{curr_lr_disp:.1e} | ETA: {eta_str}"
            progress((iteration + 1) / iterations, desc=msg)
            
            # Active Learning / Hard Mining Phase
            MODEL.eval()
            
            # Generate a large pool of candidates candidates to find the "hard" ones
            pool_size = problems_per_iter * 3  # Generate 3x more than we need
            candidates = data_gen.generate_inverse_batch(pool_size, point_count=int(point_count))
            
            if not candidates:
                continue
                
            # Quick forward pass to estimate difficulty (Loss)
            # We want to train on problems where the model currently FAILS (High Loss)
            hard_problems = []
            
            with torch.no_grad():
                # Process in chunks to avoid OOM
                chunk_size = 32
                for i in range(0, len(candidates), chunk_size):
                    chunk = candidates[i:i+chunk_size]
                    
                    x_list = [d['x'] for d in chunk]
                    y_list = [d['y'] for d in chunk]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in chunk]
                    max_len = max(len(s) for s in token_lists)
                    
                    # Prepare tensors
                    dec_in = torch.full((len(chunk), max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                    targets = torch.full((len(chunk), max_len + 1), -1, dtype=torch.long).to(DEVICE)
                    
                    for j, seq in enumerate(token_lists):
                        dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                        targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                        
                    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    
                    logits, _ = MODEL(x_tensor, y_tensor, dec_in)
                    
                    # Calculate loss per item
                    # CrossEntropy usually aggregates, so we use reduction='none'
                    loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
                    raw_losses = loss_f(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
                    
                    # Reshape back to [Batch, Seq] to sum/mean per sample
                    raw_losses = raw_losses.view(len(chunk), -1)
                    # Average loss per non-padded token
                    mask = (targets != -1)
                    sample_losses = (raw_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)
                    
                    for j, loss_val in enumerate(sample_losses):
                        # Store (Loss, Problem)
                        hard_problems.append((loss_val.item(), chunk[j]))
            
            # Sort by difficulty (Loss descending)
            hard_problems.sort(key=lambda x: x[0], reverse=True)
            
            # Stabilization: Mix Hardest (70%) + Random Examples (30%)
            # This prevents "Catastrophic Forgetting" of simpler patterns
            n_hard = int(problems_per_iter * 0.7)
            n_random = int(problems_per_iter) - n_hard
            
            # Top K hardest
            selected_hard = [p[1] for p in hard_problems[:n_hard]]
            
            # Random selection from the rest of the pool (to keep variety)
            remaining_pool = [p[1] for p in hard_problems[n_hard:]]
            selected_random = random.sample(remaining_pool, min(n_random, len(remaining_pool))) if remaining_pool else []
            
            selected_problems = selected_hard + selected_random
            
            avg_pool_loss = np.mean([p[0] for p in hard_problems])
            top_loss = np.mean([p[0] for p in hard_problems[:n_hard]]) if n_hard > 0 else 0
            
            print(f"Active Learning: Pool Loss {avg_pool_loss:.3f} -> Selected Mix (Hard:{top_loss:.3f})")

            # --- HALL OF SHAME CAPTURE ---
            # Capture what the model predicts for the top 3 hardest failures
            try:
                top_failures = hard_problems[:3]
                x_fail = [p[1]['x'].astype(np.float64) for p in top_failures]
                y_fail = [p[1]['y'].astype(np.float64) for p in top_failures]
                target_formulas = [p[1]['infix'] for p in top_failures]
                fail_losses = [p[0] for p in top_failures]
                
                # Simple Greedy Decode to see what it predicts
                from search.beam_search import BeamSearch
                # Use beam search with width 1 (Greedy) for speed, with curriculum mask
                bs = BeamSearch(MODEL, DEVICE, beam_width=1, max_length=20, curriculum_stage=curriculum_stage)
                
                for i in range(len(top_failures)):
                    try:
                        # Decode
                        # Enable return_partial to see what the model is thinking if it fails
                        res = bs.search(x_fail[i], y_fail[i], return_partial=True)
                        if not res:
                            pred_formula = "Search Empty (No Tokens)"
                        else:
                            pred_formula = res[0]['formula']
                            
                        # Detect Looping (e.g. "10 / / / / / /")
                        # Basic heuristic: check if last 10 chars contain > 80% same char or repeating pattern
                        if len(pred_formula) > 20:
                            # Check for repeating slashes or other single chars
                            if pred_formula.count('/') > 10 and pred_formula.endswith('/ .'): 
                                 pred_formula = pred_formula[:20] + " ... [Loop Detected]"
                            elif " / / / " in pred_formula:
                                 pred_formula = pred_formula.split(" / / / ")[0] + " ... [Loop Detected]"
                        
                        add_training_error(
                            target=target_formulas[i],
                            predicted=pred_formula,
                            loss=fail_losses[i],
                            stage=stage_name
                        )
                    except Exception as e:
                        print(f"HoS Inner Error: {e}")
                        add_training_error(
                            target=target_formulas[i],
                            predicted=f"CRASH: {str(e)[:20]}",
                            loss=fail_losses[i],
                            stage=stage_name
                        )
            except Exception as e:
                import traceback
                print(f"HoS Outer Error: {e}")
                traceback.print_exc()

            # --- MCTS SOLVE ---
            for prob in selected_problems:
                x_data = prob['x'].astype(np.float64)
                y_data = prob['y'].astype(np.float64)
                
                try:
                    # Use MCTS to find the solution (or improve upon it)
                    # For inverse problems, we KNOW the solution, but MCTS helps explore variations
                    # and generates the policy distribution we want to learn.
                    result = searcher.search(x_data, y_data)
                    
                    # 1. Store Training Examples
                    if 'root' in result:
                        examples = searcher.get_training_examples(result['root'])
                        for (tokens, policy, value) in examples:
                            replay_buffer.append({
                                'x': x_data, 'y': y_data,
                                'tokens': tokens,
                                'policy': policy,
                                'value': value
                            })
                    
                    # 2. Track Metrics
                    if result.get('tokens'):
                        rmses.append(result['rmse'])
                        
                except Exception as e:
                    print(f"Self-play error: {e}")
                    continue
            
            # Training phase
            # To saturate GPU: Increase batch size and number of updates
            if len(replay_buffer) >= 64:
                MODEL.train()
                
                # Dynamic training steps: Train more if we have more data
                # AlphaZero ratio usually high (e.g. 10 epochs on new data)
                # Here we sample from buffer.
                train_batch_size = 128
                if len(replay_buffer) < train_batch_size:
                    train_batch_size = 64
                
                # Steps: roughly cover 20% of buffer or at least 10 steps
                steps = max(10, min(50, len(replay_buffer) // train_batch_size))
                
                for _ in range(steps):
                    batch = random.sample(list(replay_buffer), min(train_batch_size, len(replay_buffer)))
                    
                    x_list = [exp['x'] for exp in batch]
                    y_list = [exp['y'] for exp in batch]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID[t] for t in exp['tokens']] for exp in batch]
                    policy_targets = [exp['policy'] for exp in batch]
                    value_targets_list = [exp['value'] for exp in batch]
                    
                    max_len = max(len(s) for s in token_lists)
                    decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
                    
                    # Policy targets (for KLDiv) and Value targets
                    policy_target_tensor = torch.tensor(np.array(policy_targets), dtype=torch.float32).to(DEVICE)
                    value_target_tensor = torch.tensor(np.array(value_targets_list), dtype=torch.float32).unsqueeze(1).to(DEVICE)
                    
                    for i, seq in enumerate(token_lists):
                        l = len(seq)
                        decoder_input[i, 1:l+1] = torch.tensor(seq, dtype=torch.long)
                    
                    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    decoder_input = decoder_input.to(DEVICE)
                    
                    optimizer.zero_grad()
                    logits, value_pred = MODEL(x_tensor, y_tensor, decoder_input)
                    
                    # Policy Loss (KL Divergence)
                    # Get logits for the last token position of each sequence
                    last_logits = []
                    for i, seq in enumerate(token_lists):
                        idx = len(seq) # Post-padding index? No, index in padded tensor.
                        # decoder_input: [SOS, T1, T2]
                        # logits: [PredSOS, PredT1, PredT2]
                        # We want prediction AFTER T2? No.
                        # MCTS Example: State=[T1, T2]. Policy=Dist for T3.
                        # Model Input: [SOS, T1, T2]. Output Last: Dist for T3.
                        # Index is len(seq).
                        last_logits.append(logits[i, idx, :VOCAB_SIZE])
                    
                    last_logits = torch.stack(last_logits)
                    log_probs = torch.nn.functional.log_softmax(last_logits, dim=1)
                    
                    loss_policy = kl_loss(log_probs, policy_target_tensor)
                    
                    # Value Loss (Quantile)
                    loss_value = quantile_loss_fn(value_pred, value_target_tensor)
                    
                    # Total Loss
                    loss = loss_policy + loss_value 
                    
                    if not (torch.isnan(loss) or torch.isinf(loss)):
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                        optimizer.step()
                        losses.append(loss.item())
            
            # Step Scheduler based on recent Loss
            if losses:
                current_loss = np.mean(losses[-10:])
                scheduler.step(current_loss)
            
            current_lr = optimizer.param_groups[0]['lr']
            
            # Periodic save
            if (iteration + 1) % 10 == 0:
                save_model()
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_selfplay_plot(losses, rmses)
        
        avg_rmse = np.mean(rmses[-50:]) if rmses else 0
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #ff6b6b;">
            <h2 style="color: #ff6b6b; margin: 0;">Self-Play Completado</h2>
            <p style="color: white;">Iteraciones: {int(iterations)} | Problemas: {len(rmses)}</p>
            <p style="color: #888;">RMSE Promedio: {avg_rmse:.4f} | Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def create_loss_plot(losses, title):
    """Create a loss plot with dark theme."""
    fig, ax = plt.subplots(figsize=(8, 4), facecolor='#1a1a2e')
    ax.set_facecolor('#1a1a2e')
    ax.plot(losses, color='#00d4ff', linewidth=2)
    ax.set_xlabel('Epoca', color='white')
    ax.set_ylabel('Loss', color='white')
    ax.set_title(title, color='white', fontweight='bold')
    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.2)
    for spine in ax.spines.values():
        spine.set_color('#00d4ff')
    plt.tight_layout()
    return fig


def create_selfplay_plot(losses, rmses):
    """Create dual plot for self-play results."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4), facecolor='#1a1a2e')
    
    ax1.set_facecolor('#1a1a2e')
    if losses:
        ax1.plot(losses, color='#00d4ff', linewidth=2)
    ax1.set_xlabel('Step', color='white')
    ax1.set_ylabel('Loss', color='white')
    ax1.set_title('Policy Loss', color='white', fontweight='bold')
    ax1.tick_params(colors='white')
    ax1.grid(True, alpha=0.2)
    
    ax2.set_facecolor('#1a1a2e')
    if rmses:
        ax2.plot(rmses, color='#ff6b6b', linewidth=1, alpha=0.5)
        if len(rmses) > 10:
            ma = np.convolve(rmses, np.ones(10)/10, mode='valid')
            ax2.plot(range(9, len(rmses)), ma, color='#ff6b6b', linewidth=2)
    ax2.set_xlabel('Problema', color='white')
    ax2.set_ylabel('RMSE', color='white')
    ax2.set_title('RMSE', color='white', fontweight='bold')
    ax2.tick_params(colors='white')
    ax2.grid(True, alpha=0.2)
    
    for ax in [ax1, ax2]:
        for spine in ax.spines.values():
            spine.set_color('#00d4ff')
    
    plt.tight_layout()
    return fig

def train_supervised(iterations, batch_size=128, point_count=10, progress=gr.Progress()):
    """
    Massive Supervised Pre-training (Warmup).
    Focus: Syntax, Basic Arithmetic, Overcoming "Collapse to Constant".
    Speed: High (No MCTS, just random generation + CrossEntropy).
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()  # Reset stop flag at start
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4, weight_decay=0.01)
        # Slower decay: T_max = iterations * 2 keeps LR higher for longer
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(iterations*2), eta_min=1e-6)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        # Start extremely simple (Depth 1: x+1, x*x, etc.)
        allowed_ops = OPERATOR_STAGES[0]
        data_gen = DataGenerator(max_depth=1, allowed_operators=allowed_ops) 
        allowed_mask = get_allowed_token_mask(0, VOCAB_SIZE, DEVICE) # Stage 0 mask
        losses = []
        
        start_time = time.time()
        
        for i in range(int(iterations)):
            # Check for stop request
            if should_stop_training():
                print("⏹️ Pre-training stopped by user")
                break
            # ETA
            elapsed = time.time() - start_time
            if i > 0:
                iter_per_sec = i / elapsed
                remaining = int(iterations) - i
                eta = remaining / iter_per_sec
                eta_str = f"{eta:.0f}s"
            else:
                eta_str = "..."
                
            current_lr = optimizer.param_groups[0]['lr']
            msg = f"Iter {i+1}/{int(iterations)} Loss:{np.mean(losses[-50:]) if losses else 0:.3f} LR:{current_lr:.1e} ETA:{eta_str}"
            progress((i + 1) / iterations, desc=msg)
            
            # Generate Random Batch (High Speed)
            batch = data_gen.generate_batch(int(batch_size), point_count=int(point_count))
            
            if not batch:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for j, seq in enumerate(token_lists):
                decoder_input[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            optimizer.zero_grad()
            logits, _ = MODEL(x_tensor, y_tensor, decoder_input)
            
            # Apply Stage 0 mask to bridge Pre-training with Curriculum
            # Use a more stable value (-1e4 instead of -1e9) to avoid overflow
            logits = logits + (1 - allowed_mask.view(1, 1, -1)) * -1e4
            
            loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            if not (torch.isnan(loss) or torch.isinf(loss)):
                loss.backward()
                torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                losses.append(loss.item())
                
            if (i+1) % 100 == 0:
                save_model()
                
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Pre-Entrenamiento Supervisado")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #ffd93d;">
            <h2 style="color: #ffd93d; margin: 0;">Escuela Primaria (Warmup) Completada</h2>
            <p style="color: white;">Iteraciones: {int(iterations)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #888;">El modelo ha aprendido sintaxis basica.</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_hybrid_feedback_loop(iterations, problems_per_iter=10, gp_timeout=10, progress=gr.Progress()):
    """
    Teacher-Student Distillation Loop.
    1. Find problems where model has high loss.
    2. Use Hybrid Search (GP) to solve them.
    3. Train model on GP solutions.
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()
    
    try:
        MODEL, DEVICE = get_model()
        
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        # Replay buffer for "Gold Standard" examples found by GP
        replay_buffer = deque(maxlen=5000)
        
        # Start with simple problems and grow
        data_gen = DataGenerator(max_depth=3)
        
        losses = []
        gp_successes = 0
        gp_attempts = 0
        
        start_time = time.time()
        
        for iteration in range(int(iterations)):
            if should_stop_training():
                print("⏹️ Feedback Loop stopped")
                break
                
            elapsed = time.time() - start_time
            # eta_str = f"{(int(iterations)-iteration) * (elapsed/(iteration+1) if iteration>0 else 0):.0f}s"
            iter_dur = elapsed/(iteration+1) if iteration > 0 else 0
            eta_seconds = (int(iterations)-iteration) * iter_dur
            eta_str = f"{eta_seconds:.0f}s"

            progress((iteration + 1) / iterations, 
                     desc=f"Iter {iteration+1}/{int(iterations)} | GP Success: {gp_successes}/{gp_attempts} | Loss: {np.mean(losses[-10:]) if losses else 0:.3f}")
            
            # --- PHASE 1: HARD MINING ---
            MODEL.eval()
            
            # Generate candidates
            pool_size = 50 
            candidates = data_gen.generate_inverse_batch(pool_size, point_count=10)
            
            hard_problems = []
            
            with torch.no_grad():
                # We want to find problems with HIGH LOSS (model failure)
                # Quick batch forward
                x_list = [d['x'] for d in candidates]
                y_list = [d['y'] for d in candidates]
                x_list, y_list = normalize_batch(x_list, y_list)
                
                token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in candidates]
                max_len = max(len(s) for s in token_lists)
                
                dec_in = torch.full((pool_size, max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                targets = torch.full((pool_size, max_len + 1), -1, dtype=torch.long).to(DEVICE)
                
                for j, seq in enumerate(token_lists):
                    dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                    targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                    
                x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                
                logits, value_pred = MODEL(x_tensor, y_tensor, dec_in)
                
                loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
                raw_losses = loss_f(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
                raw_losses = raw_losses.view(pool_size, -1)
                
                mask = (targets != -1)
                sample_losses = (raw_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)
                
                # Filter: Keep if loss > 1.0 (arbitrary threshold for "confused")
                for j, loss_val in enumerate(sample_losses):
                    if loss_val.item() > 0.5: # Lower threshold to catch more
                        hard_problems.append(candidates[j])
            
            # Take top K hardest
            # Limit GP calls per iter to avoid slowness
            problems_to_solve = hard_problems[:int(problems_per_iter)]
            
            if not problems_to_solve:
                continue

            # --- PHASE 2: TEACHER SOLVES (GP) ---
            print(f"Iter {iteration}: Asking Teacher to solve {len(problems_to_solve)} hard problems...")
            
            for prob in problems_to_solve:
                gp_attempts += 1
                try:
                    # Run Hybrid Search (Quick Mode)
                    # We pass the model so beam search can seed the GP
                    res = hybrid_solve(
                        prob['x'], 
                        prob['y'], 
                        MODEL, 
                        DEVICE, 
                        beam_width=10,     # Faster beam
                        gp_timeout=gp_timeout,
                        gp_binary_path=None 
                    )
                    
                    if res and res.get('formula') and res.get('rmse', 1e6) < 0.01:
                        # SUCCESS!
                        gp_successes += 1
                        
                        # Parse formula to tokens
                        try:
                            # 1. Parse string to tree
                            tree = ExpressionTree.from_infix(res['formula'])
                            # 2. Get tokens
                            tokens = tree.tokens
                            
                            replay_buffer.append({
                                'x': prob['x'],
                                'y': prob['y'],
                                'tokens': tokens,
                                'source': 'GP_Teacher'
                            })
                            
                        except Exception as e:
                            print(f"Failed to tokenize GP result: {e}")
                            
                except Exception as e:
                    print(f"GP Hybrid Error: {e}")
                    
            # --- PHASE 3: STUDENT TRAINS (NN) ---
            if len(replay_buffer) > 10:
                MODEL.train()
                # Train on batch from buffer
                batch_size_train = min(len(replay_buffer), 64)
                
                # Multiple steps to enforce learning
                steps = 5
                
                for _ in range(steps):
                    batch = random.sample(list(replay_buffer), batch_size_train)
                    
                    x_list = [d['x'] for d in batch]
                    y_list = [d['y'] for d in batch]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in batch]
                    max_len = max(len(s) for s in token_lists)
                    
                    dec_in = torch.full((batch_size_train, max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                    targets = torch.full((batch_size_train, max_len + 1), -1, dtype=torch.long).to(DEVICE)
                    
                    for j, seq in enumerate(token_lists):
                        dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                        targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                        
                    x_t = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_t = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    dec_in = dec_in.to(DEVICE)
                    targets = targets.to(DEVICE)
                    
                    optimizer.zero_grad()
                    logits, value_pred = MODEL(x_t, y_t, dec_in)
                    
                    # Policy Loss only (Standard Supervised)
                    # We trust the GP solution is "Correct" (Value=1.0)
                    loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1)(logits.view(-1, VOCAB_SIZE+1), targets.view(-1))
                    
                    # Value Loss
                    value_targets = torch.ones_like(value_pred) # GP solutions are always valid
                    loss_val = torch.nn.functional.mse_loss(value_pred, value_targets)
                    
                    loss = loss_ce + 0.1 * loss_val
                    
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                    optimizer.step()
                    
                    losses.append(loss.item())
                    
                scheduler.step(np.mean(losses[-10:]))
                
            if (iteration + 1) % 5 == 0:
                save_model()
                
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Feedback Loop Loss")
        
        result_html = f"""
        <div style="background: linear-gradient(135deg, #2c3e50 0%, #000000 100%); padding: 20px; border-radius: 15px; border: 2px solid #f1c40f;">
            <h2 style="color: #f1c40f; margin: 0;">Feedback Loop Completado</h2>
            <p style="color: white;">Iteraciones: {iterations} | GP Success: {gp_successes}/{gp_attempts}</p>
            <p style="color: #bbb;">Nuevos Ejemplos Generados: {len(replay_buffer)}</p>
        </div>
        """
        return result_html, fig

    except Exception as e:
        TRAINING_STATUS["running"] = False
        import traceback
        traceback.print_exc()
        return f"Error CRITICO: {str(e)}", None


In [None]:
%%writefile ui/app_benchmark.py
import gradio as gr
from utils.benchmark_comparison import run_comparison_benchmark
from ui.app_core import get_model, DEVICE

def get_benchmark_tab():
    with gr.Tab("🥇 Benchmark (IQ Test)"):
        gr.Markdown("### Evaluar Inteligencia del Modelo (Comparativa)")
        gr.Markdown("Ejecuta una batería de **10 problemas estándar** comparando diferentes métodos de búsqueda.")
        
        with gr.Row():
            methods_chk = gr.CheckboxGroup(
                choices=["beam", "mcts", "hybrid"], 
                value=["hybrid"], 
                label="Métodos a Evaluar",
                info="Selecciona uno o más métodos para comparar."
            )
            timeout_slider = gr.Slider(
                minimum=5, 
                maximum=60, 
                value=30, 
                step=5, 
                label="Timeout GP (s)", 
                info="Tiempo máximo para Beta-GP por problema."
            )
        
        run_btn = gr.Button("🚀 Iniciar Benchmark Comparativo", variant="primary")
        
        progress_bar = gr.HTML("")
        
        # Area de resultados
        summary_html = gr.HTML("Resultados aparecerán aquí...")
        
        results_df = gr.Dataframe(
            headers=["Problema", "Nivel", "Método", "Formula", "RMSE", "Tiempo", "Estado"],
            label="Resultados Detallados",
            interactive=False
        )
        
        def run_bench(selected_methods, gp_timeout, progress=gr.Progress()):
            model_obj, device_obj = get_model()
            if not model_obj:
                return "<div>⚠️ Error: Modelo no cargado. Ve a la pestaña 'Config' y carga un modelo.</div>", None, []
            
            if not selected_methods:
                return "<div>⚠️ Error: Selecciona al menos un método.</div>", None, []
                
            progress(0, desc="Iniciando Benchmark...")
            
            # Run comparison
            try:
                result_data = run_comparison_benchmark(
                    model_obj, 
                    device_obj, 
                    methods=selected_methods,
                    gp_timeout=gp_timeout,
                    beam_width=50,
                    progress_callback=lambda p, desc: progress(p, desc=desc)
                )
            except Exception as e:
                import traceback
                traceback.print_exc()
                return f"<div>❌ Error en Benchmark: {e}</div>", None, []
            
            results = result_data['results']
            summary_dict = result_data['summary']
            
            # Format dataframe
            rows = []
            for r in results:
                status_icon = "✅" if r['success'] else "❌"
                rmse_val = f"{r['rmse']:.5f}" if r['rmse'] < 1e6 else "> 10^6"
                rows.append([
                    r['problem_name'],
                    r['level'],
                    r['method'].upper(),
                    r['formula'],
                    rmse_val,
                    f"{r['time']:.2f}s",
                    status_icon
                ])
            
            # Generate HTML Summary
            html_content = "<div style='display: flex; gap: 20px; flex-wrap: wrap; justify-content: center;'>"
            
            # Determine winner if multiple methods
            winner_method = None
            if len(selected_methods) > 1:
                winner_method = max(summary_dict.items(), key=lambda x: (x[1]['solved'], -x[1]['avg_rmse']))[0]
            
            for method, stats in summary_dict.items():
                is_winner = (method == winner_method)
                border_color = "#4CAF50" if is_winner else ("#FF9800" if stats['score'] > 50 else "#F44336")
                bg_color = "#1e1e2f"
                if is_winner:
                    bg_color = "#1b3a24" # Dark green tint for winner
                    
                trophy = "🏆 GANADOR" if is_winner else ""
                
                html_content += f"""
                <div style="background: {bg_color}; padding: 15px; border-radius: 10px; border: 2px solid {border_color}; min-width: 200px; text-align: center;">
                    <h2 style="color: {border_color}; margin: 0 0 10px 0;">{method.upper()} {trophy}</h2>
                    <div style="font-size: 24px; font-weight: bold; margin-bottom: 5px;">{stats['solved']} / {stats['total']}</div>
                    <div style="color: #ccc; font-size: 14px;">Resueltos</div>
                    <hr style="border-color: #444; margin: 10px 0;">
                    <div style="font-size: 14px;">Nota: <b>{stats['score']:.1f}%</b></div>
                    <div style="font-size: 14px;">Tiempo Avg: <b>{stats['avg_time']:.2f}s</b></div>
                </div>
                """
            html_content += "</div>"
            
            return html_content, rows
            
        run_btn.click(run_bench, inputs=[methods_chk, timeout_slider], outputs=[summary_html, results_df])


In [None]:
%%writefile ui/__init__.py


In [None]:
%%writefile utils/optimize_constants.py
"""
Constant Optimization Module for AlphaSymbolic.
Uses scipy.optimize to find optimal values for 'C' placeholders.
"""
import numpy as np
from scipy.optimize import minimize
from core.grammar import ExpressionTree

def optimize_constants(tree, x_data, y_data, method='L-BFGS-B'):
    """
    Given an ExpressionTree with 'C' placeholders, find optimal constant values.
    
    Args:
        tree: ExpressionTree object
        x_data: numpy array of x values
        y_data: numpy array of target y values
        method: optimization method ('L-BFGS-B', 'SLSQP', 'Nelder-Mead')
        
    Returns:
        dict: mapping of path tuples to optimized constant values
        float: final RMSE
    """
    if not tree.is_valid:
        return {}, float('inf')
    
    # Get positions of all constants
    positions = tree.root.get_constant_positions()
    n_constants = len(positions)
    
    if n_constants == 0:
        # No constants to optimize, just evaluate
        y_pred = tree.evaluate(x_data)
        mse = np.mean((y_pred - y_data)**2)
        return {}, np.sqrt(mse)
    
    def objective(params):
        """Objective function: RMSE given constant values."""
        # Build constants dict
        constants = {tuple(pos): params[i] for i, pos in enumerate(positions)}
        
        # Evaluate
        y_pred = tree.evaluate(x_data, constants=constants)
        
        # Handle invalid predictions
        if np.any(np.isnan(y_pred)) or np.any(np.isinf(y_pred)):
            return 1e10
        
        if not np.all(np.isfinite(y_pred)):
            return 1e9
        
        # Clip huge values to prevent overflow in MSE
        y_pred = np.clip(y_pred, -1e9, 1e9)
        
        mse = np.mean((y_pred - y_data)**2)
        return mse
    
    # Initial guess: all 1s
    x0 = np.ones(n_constants)
    
    # Bounds: reasonable range for constants
    bounds = [(-1000, 1000)] * n_constants
    
    try:
        result = minimize(
            objective,
            x0,
            method=method,
            bounds=bounds if method in ['L-BFGS-B', 'SLSQP'] else None,
            options={'maxiter': 1000, 'disp': False}
        )
        
        # Build final constants dict
        optimized_constants = {tuple(pos): result.x[i] for i, pos in enumerate(positions)}
        final_rmse = np.sqrt(result.fun) if result.fun > 0 else 0.0
        
        return optimized_constants, final_rmse
        
    except Exception as e:
        return {}, float('inf')

def substitute_constants(infix_str, constants_dict, positions):
    """
    Replace 'C' in the infix string with optimized values.
    Simple approach: replace each C with optimized value.
    """
    # For proper substitution, we'd need to track positions properly
    # This is a simplified version that replaces all C with the first constant
    result = infix_str
    for i, pos in enumerate(positions):
        if tuple(pos) in constants_dict:
            val = constants_dict[tuple(pos)]
            # Format nicely
            if abs(val - round(val)) < 1e-6:
                val_str = str(int(round(val)))
            else:
                val_str = f"{val:.4f}"
            # Replace first occurrence of C
            result = result.replace('C', val_str, 1)
    return result


# Quick test
if __name__ == "__main__":
    # Test: C * x + C should be optimized to fit y = 2*x + 3
    x_test = np.array([1, 2, 3, 4, 5], dtype=np.float64)
    y_test = 2 * x_test + 3  # y = 2x + 3
    
    tokens = ['+', '*', 'C', 'x', 'C']  # C*x + C
    tree = ExpressionTree(tokens)
    
    print(f"Formula structure: {tree.get_infix()}")
    print(f"Target: y = 2x + 3")
    
    constants, rmse = optimize_constants(tree, x_test, y_test)
    print(f"Optimized constants: {constants}")
    print(f"Final RMSE: {rmse:.6f}")
    
    # Verify
    y_pred = tree.evaluate(x_test, constants=constants)
    print(f"Predictions: {y_pred}")
    print(f"Targets: {y_test}")


In [None]:
%%writefile utils/detect_pattern.py
"""
Target Pattern Detection for AlphaSymbolic.
Analyzes target Y values to detect patterns (polynomial, exponential, periodic, etc.)
and suggests initial search biases.
"""
import numpy as np
from scipy import stats
from scipy.fft import fft
from core.grammar import ExpressionTree

def detect_pattern(x_values, y_values):
    """
    Analyze (x, y) data to detect patterns.
    Returns a dict with pattern type probabilities and suggested operators.
    """
    x = np.array(x_values, dtype=np.float64)
    y = np.array(y_values, dtype=np.float64)
    
    results = {
        'type': 'unknown',
        'confidence': 0.0,
        'suggested_ops': [],
        'details': {}
    }
    
    if len(x) < 3:
        return results
    
    scores = {}
    
    # 1. Check for linear pattern (y = ax + b)
    if len(x) >= 2:
        slope, intercept, r_value, _, _ = stats.linregress(x, y)
        scores['linear'] = r_value ** 2
        results['details']['linear'] = {
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value ** 2
        }
    
    # 2. Check for quadratic pattern (y = ax^2 + bx + c)
    if len(x) >= 3:
        try:
            coeffs = np.polyfit(x, y, 2)
            y_pred = np.polyval(coeffs, x)
            ss_res = np.sum((y - y_pred) ** 2)
            ss_tot = np.sum((y - np.mean(y)) ** 2)
            r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
            scores['quadratic'] = r2
            results['details']['quadratic'] = {
                'coefficients': coeffs.tolist(),
                'r_squared': r2
            }
        except:
            pass
    
    # 3. Check for exponential pattern (y = a * e^(bx))
    if np.all(y > 0):  # Exponential only for positive y
        try:
            log_y = np.log(y)
            slope, intercept, r_value, _, _ = stats.linregress(x, log_y)
            scores['exponential'] = r_value ** 2
            results['details']['exponential'] = {
                'a': np.exp(intercept),
                'b': slope,
                'r_squared': r_value ** 2
            }
        except:
            pass
    
    # 4. Check for periodic/sinusoidal pattern
    if len(y) >= 4:
        try:
            # Simple FFT analysis
            y_centered = y - np.mean(y)
            fft_vals = np.abs(fft(y_centered))
            
            # Check if there's a dominant frequency
            if len(fft_vals) > 1:
                max_idx = np.argmax(fft_vals[1:len(fft_vals)//2]) + 1
                max_power = fft_vals[max_idx]
                total_power = np.sum(fft_vals[1:len(fft_vals)//2])
                
                if total_power > 0:
                    periodicity = max_power / total_power
                    scores['periodic'] = periodicity
                    results['details']['periodic'] = {
                        'dominant_freq_idx': int(max_idx),
                        'periodicity_score': periodicity
                    }
        except:
            pass
    
    # 5. Check for power law (y = a * x^b)
    if np.all(x > 0) and np.all(y > 0):
        try:
            log_x = np.log(x)
            log_y = np.log(y)
            slope, intercept, r_value, _, _ = stats.linregress(log_x, log_y)
            scores['power'] = r_value ** 2
            results['details']['power'] = {
                'a': np.exp(intercept),
                'b': slope,
                'r_squared': r_value ** 2
            }
        except:
            pass
    
    # 6. Check for factorial/gamma pattern (for integer-like x)
    if np.all(x > 0) and np.all(x == np.floor(x)):
        try:
            from scipy.special import gamma
            x_int = x.astype(int)
            y_gamma = gamma(x_int + 1)  # gamma(n+1) = n!
            
            # Simple linear fit between y and gamma
            if not np.any(np.isinf(y_gamma)):
                slope, intercept, r_value, _, _ = stats.linregress(y_gamma, y)
                scores['factorial'] = r_value ** 2
                results['details']['factorial'] = {
                    'r_squared': r_value ** 2
                }
        except:
            pass
    
    # Determine best pattern
    if scores:
        best_pattern = max(scores.items(), key=lambda x: x[1])
        results['type'] = best_pattern[0]
        results['confidence'] = best_pattern[1]
        
        # Suggest operators based on pattern
        op_suggestions = {
            'linear': ['+', '-', '*', 'x', 'C'],
            'quadratic': ['pow', '+', '*', 'x', 'C', '2'],
            'exponential': ['exp', '*', '+', 'x', 'C'],
            'periodic': ['sin', 'cos', '*', '+', 'x', 'C'],
            'power': ['pow', '*', 'x', 'C'],
            'factorial': ['gamma', '*', '+', 'x', 'C']
        }
        results['suggested_ops'] = op_suggestions.get(best_pattern[0], [])
    
    return results


def summarize_pattern(result):
    """Pretty-print pattern detection result."""
    print(f"\n=== Pattern Detection ===")
    print(f"Detected Type: {result['type']} (confidence: {result['confidence']:.2%})")
    print(f"Suggested Operators: {', '.join(result['suggested_ops'])}")
    
    if result['type'] in result['details']:
        print(f"Details: {result['details'][result['type']]}")


if __name__ == "__main__":
    # Test with different patterns
    
    # Linear: y = 2x + 3
    print("\n--- Test: Linear ---")
    x1 = np.linspace(0, 10, 20)
    y1 = 2 * x1 + 3 + np.random.normal(0, 0.1, 20)
    result1 = detect_pattern(x1, y1)
    summarize_pattern(result1)
    
    # Quadratic: y = x^2 + 1
    print("\n--- Test: Quadratic ---")
    x2 = np.linspace(-5, 5, 20)
    y2 = x2**2 + 1
    result2 = detect_pattern(x2, y2)
    summarize_pattern(result2)
    
    # Exponential: y = 2 * e^(0.5x)
    print("\n--- Test: Exponential ---")
    x3 = np.linspace(0, 5, 20)
    y3 = 2 * np.exp(0.5 * x3)
    result3 = detect_pattern(x3, y3)
    summarize_pattern(result3)
    
    # Periodic: y = sin(x)
    print("\n--- Test: Periodic ---")
    x4 = np.linspace(0, 4*np.pi, 50)
    y4 = np.sin(x4)
    result4 = detect_pattern(x4, y4)
    summarize_pattern(result4)


In [None]:
%%writefile utils/benchmark_runner.py
import torch
import numpy as np
import time
import traceback
from search.mcts import MCTS
from data.benchmark_data import BENCHMARK_SUITE, get_benchmark_data
from utils.optimize_constants import optimize_constants

def run_benchmark_suite(model, device, progress_callback=None):
    """
    Runs the full benchmark suite.
    Args:
        model: Loaded AlphaSymbolic model
        device: Torch device
        progress_callback: Function(float, string) to update UI
        
    Returns:
        results: List of result dicts
        summary: Dict with aggregated stats
    """
    results = []
    
    # Configure MCTS for benchmark (balanced speed/accuracy)
    # 500 simulations is decent for benchmarking
    mcts = MCTS(model, device, max_simulations=500, batch_size=32)
    
    total = len(BENCHMARK_SUITE)
    solved_count = 0
    
    for i, problem in enumerate(BENCHMARK_SUITE):
        if progress_callback:
            progress_callback(i / total, f"Testing: {problem['name']}...")
            
        x, y, _ = get_benchmark_data(problem['id'])
        
        start_time = time.time()
        
        # Run Search
        try:
            search_result = mcts.search(x, y)
             # Determine success
            # Success threshold: RMSE < 0.01 (or 1% relative error)
            rmse = search_result['rmse']
            is_solved = rmse < 0.05 # Looser threshold for general regression
            
            # Special check for exact integer symbolic match? No, RMSE is ground truth.
            
            elapsed = time.time() - start_time
            
            if is_solved:
                solved_count += 1
                status = "✅ SOLVED"
            else:
                status = "❌ FAILED"
                
            results.append({
                'id': problem['id'],
                'name': problem['name'],
                'level': problem['level'],
                'rmse': rmse,
                'time': elapsed,
                'status': status,
                'found_formula': search_result.get('formula', '???'),
                'is_solved': is_solved
            })
            
        except Exception as e:
            print(f"Error in benchmark {problem['name']}:")
            traceback.print_exc()
            results.append({
                'id': problem['id'],
                'name': problem['name'],
                'level': problem['level'],
                'rmse': 1e9,
                'time': 0,
                'status': "⚠️ ERROR",
                'found_formula': "Error",
                'is_solved': False
            })

    # Summary
    if progress_callback:
        progress_callback(1.0, "Done!")
        
    score = (solved_count / total) * 100
    summary = {
        'total': total,
        'solved': solved_count,
        'score': score,
        'avg_time': np.mean([r['time'] for r in results]) if results else 0
    }
    
    return results, summary


In [None]:
%%writefile utils/benchmark_comparison.py
"""
Comparative Benchmark: Beam Search vs MCTS vs Alpha-GP Hybrid
Runs all three search methods on the standard benchmark suite and compares performance.
"""
import torch
import numpy as np
import time
import traceback
from typing import List, Dict, Callable, Optional

from search.mcts import MCTS
from search.beam_search import BeamSearch
from search.hybrid_search import hybrid_solve
from data.benchmark_data import BENCHMARK_SUITE, get_benchmark_data
from core.grammar import ExpressionTree
from utils.optimize_constants import optimize_constants


def run_single_problem(
    x: np.ndarray, 
    y: np.ndarray, 
    method: str, 
    model, 
    device,
    timeout_sec: int = 30,
    beam_width: int = 50
) -> Dict:
    """
    Runs a single search method on a single problem.
    
    Returns:
        dict with keys: formula, rmse, time, success
    """
    start_time = time.time()
    
    try:
        if method == "beam":
            searcher = BeamSearch(model, device, beam_width=beam_width)
            # BeamSearch expects list-like input and returns a list of results sorted by RMSE
            results_list = searcher.search(x.tolist(), y.tolist())
            elapsed = time.time() - start_time
            if results_list and len(results_list) > 0:
                result = results_list[0]  # Best result (sorted by RMSE)
                return {
                    'formula': result.get('formula', 'N/A'),
                    'rmse': result.get('rmse', 1e9),
                    'time': elapsed,
                    'success': result.get('rmse', 1e9) < 0.05
                }
            else:
                return {'formula': 'No Result', 'rmse': 1e9, 'time': elapsed, 'success': False}
            
        elif method == "mcts":
            mcts = MCTS(model, device, max_simulations=500, batch_size=32)
            # MCTS expects list-like input 
            result = mcts.search(x.tolist(), y.tolist())
            elapsed = time.time() - start_time
            return {
                'formula': result.get('formula', 'N/A'),
                'rmse': result.get('rmse', 1e9),
                'time': elapsed,
                'success': result.get('rmse', 1e9) < 0.05
            }
            
        elif method == "hybrid":
            result = hybrid_solve(
                model=model,
                device=device,
                x_values=x.tolist(),
                y_values=y.tolist(),
                beam_width=beam_width,
                gp_timeout=timeout_sec
            )
            elapsed = time.time() - start_time
            
            if result['formula']:
                # Evaluate RMSE for hybrid result
                try:
                    tree = ExpressionTree.from_infix(result['formula'])
                    if tree.is_valid:
                        preds = tree.evaluate(x)
                        rmse = np.sqrt(np.mean((preds - y) ** 2))
                    else:
                        rmse = 1e9
                except:
                    rmse = 1e9
            else:
                rmse = 1e9
                
            return {
                'formula': result.get('formula', 'N/A') or 'Failed',
                'rmse': rmse,
                'time': elapsed,
                'success': rmse < 0.05
            }
        else:
            return {'formula': 'Unknown Method', 'rmse': 1e9, 'time': 0, 'success': False}
            
    except Exception as e:
        print(f"[ERROR] Method {method} failed: {e}")
        traceback.print_exc()
        return {'formula': 'Error', 'rmse': 1e9, 'time': time.time() - start_time, 'success': False}


def run_comparison_benchmark(
    model, 
    device, 
    methods: List[str] = ["beam", "mcts", "hybrid"],
    gp_timeout: int = 30,
    beam_width: int = 50,
    progress_callback: Optional[Callable] = None
) -> Dict:
    """
    Runs all methods on all benchmark problems.
    
    Returns:
        Dict with 'results' (per-problem-per-method) and 'summary' (aggregated stats)
    """
    results = []
    method_stats = {m: {'solved': 0, 'total_time': 0, 'total_rmse': 0} for m in methods}
    
    total_steps = len(BENCHMARK_SUITE) * len(methods)
    current_step = 0
    
    for problem in BENCHMARK_SUITE:
        x, y, _ = get_benchmark_data(problem['id'])
        
        for method in methods:
            current_step += 1
            
            if progress_callback:
                progress_callback(
                    current_step / total_steps, 
                    f"[{method.upper()}] {problem['name']}..."
                )
            
            result = run_single_problem(x, y, method, model, device, gp_timeout, beam_width)
            
            results.append({
                'problem_id': problem['id'],
                'problem_name': problem['name'],
                'level': problem['level'],
                'method': method,
                'formula': result['formula'],
                'rmse': result['rmse'],
                'time': result['time'],
                'success': result['success']
            })
            
            # Update stats
            method_stats[method]['total_time'] += result['time']
            method_stats[method]['total_rmse'] += result['rmse'] if result['rmse'] < 1e6 else 0
            if result['success']:
                method_stats[method]['solved'] += 1
    
    # Compute summary
    num_problems = len(BENCHMARK_SUITE)
    summary = {}
    for method in methods:
        stats = method_stats[method]
        summary[method] = {
            'solved': stats['solved'],
            'total': num_problems,
            'score': (stats['solved'] / num_problems) * 100,
            'avg_time': stats['total_time'] / num_problems,
            'avg_rmse': stats['total_rmse'] / num_problems
        }
    
    if progress_callback:
        progress_callback(1.0, "Benchmark Complete!")
    
    return {'results': results, 'summary': summary}


def format_comparison_table(results: List[Dict]) -> str:
    """
    Formats the results as a human-readable table.
    """
    # Group by problem
    problems = {}
    for r in results:
        pid = r['problem_id']
        if pid not in problems:
            problems[pid] = {'name': r['problem_name'], 'level': r['level'], 'methods': {}}
        problems[pid]['methods'][r['method']] = {
            'rmse': r['rmse'],
            'time': r['time'],
            'success': r['success'],
            'formula': r['formula']
        }
    
    output = []
    output.append("=" * 100)
    output.append(f"{'Problem':<25} | {'Method':<8} | {'RMSE':<12} | {'Time':<8} | {'Status':<10} | Formula")
    output.append("=" * 100)
    
    for pid, pdata in problems.items():
        name = pdata['name'][:24]
        for method, mdata in pdata['methods'].items():
            rmse_str = f"{mdata['rmse']:.6f}" if mdata['rmse'] < 1e6 else "FAILED"
            time_str = f"{mdata['time']:.2f}s"
            status = "[OK]" if mdata['success'] else "[FAIL]"
            formula = mdata['formula'][:40] if mdata['formula'] else "N/A"
            output.append(f"{name:<25} | {method:<8} | {rmse_str:<12} | {time_str:<8} | {status:<10} | {formula}")
        output.append("-" * 100)
    
    return "\n".join(output)


def print_summary(summary: Dict):
    """
    Prints a formatted summary comparison.
    """
    print("\n" + "=" * 60)
    print("BENCHMARK SUMMARY - Method Comparison")
    print("=" * 60)
    print(f"{'Method':<12} | {'Solved':<10} | {'Score':<10} | {'Avg Time':<10} | {'Avg RMSE':<12}")
    print("-" * 60)
    
    for method, stats in summary.items():
        solved_str = f"{stats['solved']}/{stats['total']}"
        score_str = f"{stats['score']:.1f}%"
        time_str = f"{stats['avg_time']:.2f}s"
        rmse_str = f"{stats['avg_rmse']:.6f}"
        print(f"{method.upper():<12} | {solved_str:<10} | {score_str:<10} | {time_str:<10} | {rmse_str:<12}")
    
    print("=" * 60)
    
    # Determine winner
    best_method = max(summary.items(), key=lambda x: (x[1]['solved'], -x[1]['avg_rmse']))
    print(f"\n*** WINNER: {best_method[0].upper()} with {best_method[1]['solved']}/{best_method[1]['total']} problems solved! ***")


if __name__ == "__main__":
    # Standalone test
    import sys
    sys.path.insert(0, '.')
    
    from ui.app_core import load_model, get_model
    
    print("Loading model...")
    load_model()
    model, device = get_model()
    
    if model is None:
        print("Error: No model loaded!")
        exit(1)
    
    print("Running comparison benchmark...")
    result = run_comparison_benchmark(
        model, 
        device, 
        methods=["beam", "mcts", "hybrid"],
        gp_timeout=30,
        beam_width=50
    )
    
    print(format_comparison_table(result['results']))
    print_summary(result['summary'])


In [None]:
%%writefile utils/simplify.py
"""
Algebraic Simplification Module for AlphaSymbolic.
Uses SymPy for symbolic math simplification.
"""
import sympy as sp
from core.grammar import Node, ExpressionTree, OPERATORS

# SymPy symbol for x
x_sym = sp.Symbol('x')

def tree_to_sympy(node):
    """Convert an ExpressionTree Node to a SymPy expression."""
    if node is None:
        return sp.Integer(0)
    
    val = node.value
    
    # Terminals
    if val == 'x':
        return x_sym
    if val == 'pi':
        return sp.pi
    if val == 'e':
        return sp.E
    if val == 'C':
        # Keep C as symbol for now
        return sp.Symbol('C')
    
    # Try numeric
    try:
        return sp.Float(float(val))
    except:
        pass
    
    # Operators
    args = [tree_to_sympy(c) for c in node.children]
    
    if val == '+': return args[0] + args[1]
    if val == '-': return args[0] - args[1]
    if val == '*': return args[0] * args[1]
    if val == '/': return args[0] / args[1]
    if val == 'pow': return sp.Pow(args[0], args[1])
    if val == 'mod': return sp.Mod(args[0], args[1])
    if val == 'sin': return sp.sin(args[0])
    if val == 'cos': return sp.cos(args[0])
    if val == 'tan': return sp.tan(args[0])
    if val == 'exp': return sp.exp(args[0])
    if val == 'log': return sp.log(args[0])
    if val == 'sqrt': return sp.sqrt(args[0])
    if val == 'abs': return sp.Abs(args[0])
    if val == 'floor': return sp.floor(args[0])
    if val == 'ceil': return sp.ceiling(args[0])
    if val == 'gamma': return sp.gamma(args[0])
    if val == 'lgamma': return sp.loggamma(args[0])  # SymPy's log-gamma
    if val == 'neg': return -args[0]
    
    return sp.Integer(0)

def sympy_to_infix(expr):
    """Convert SymPy expression back to a readable string."""
    return str(expr)

def simplify_tree(tree):
    """
    Takes an ExpressionTree and returns a simplified infix string.
    """
    if not tree.is_valid:
        return "Invalid"
    
    original_infix = tree.get_infix()
    
    try:
        sympy_expr = tree_to_sympy(tree.root)
        simplified = sp.simplify(sympy_expr)
        result_str = str(simplified)
        
        # Validate: reject results containing invalid SymPy artifacts
        # zoo = complex infinity, nan, oo = infinity
        invalid_terms = ['zoo', 'nan', 'I*']  # I* indicates complex numbers
        for term in invalid_terms:
            if term in result_str:
                return original_infix  # Fall back to original
        
        return result_str
    except Exception as e:
        # If simplification fails, return original
        return original_infix

def simplify_infix(infix_str):
    """
    Takes an infix string and returns a simplified version.
    """
    try:
        expr = sp.sympify(infix_str)
        simplified = sp.simplify(expr)
        return str(simplified)
    except:
        return infix_str

# Quick test
if __name__ == "__main__":
    from core.grammar import ExpressionTree
    
    # Test: x + 0 should simplify to x
    tokens = ['+', 'x', '0']
    tree = ExpressionTree(tokens)
    print(f"Original: {tree.get_infix()}")
    print(f"Simplified: {simplify_tree(tree)}")
    
    # Test: x * 1 should simplify to x
    tokens2 = ['*', 'x', '1']
    tree2 = ExpressionTree(tokens2)
    print(f"Original: {tree2.get_infix()}")
    print(f"Simplified: {simplify_tree(tree2)}")
    
    # Test: x - x should simplify to 0
    tokens3 = ['-', 'x', 'x']
    tree3 = ExpressionTree(tokens3)
    print(f"Original: {tree3.get_infix()}")
    print(f"Simplified: {simplify_tree(tree3)}")


In [None]:
%%writefile utils/__init__.py


In [None]:
%%writefile app.py
"""
AlphaSymbolic - Gradio Web Interface
With GPU/CPU toggle and search method selection.
"""
import gradio as gr
import torch

from ui.app_core import load_model, get_device, get_device_info, set_device, get_training_errors, request_stop_training
from ui.app_training import train_basic, train_curriculum, train_self_play, train_supervised, train_hybrid_feedback_loop
from ui.app_search import solve_formula, generate_example
from ui.app_benchmark import get_benchmark_tab


def toggle_device(use_gpu):
    """Toggle between GPU and CPU."""
    device_info = set_device(use_gpu)
    color = "#4ade80" if "CUDA" in device_info else "#fbbf24" if "MPS" in device_info else "#888"
    return f'<div style="padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid {color};"><span style="color: {color}; font-weight: bold;">{device_info}</span></div>'


def create_app():
    """Create the Gradio app."""
    
    with gr.Blocks(title="AlphaSymbolic") as demo:
        
        # Header
        device_info = get_device_info()
        device_color = "#4ade80" if "CUDA" in device_info else "#fbbf24" if "MPS" in device_info else "#888"
        
        gr.HTML(f"""
        <div style="text-align: center; padding: 20px; background: linear-gradient(90deg, #00d4ff22, transparent, #ff6b6b22); border-radius: 15px; margin-bottom: 20px;">
            <h1 style="color: #00d4ff; font-size: 42px; margin: 0;">AlphaSymbolic</h1>
            <p style="color: #888; font-size: 18px; margin: 5px 0;">Deep Reinforcement Learning para Regresion Simbolica</p>
        </div>
        """)
        
        # System Controls
        with gr.Row():
            with gr.Column(scale=1):
                model_selector = gr.Dropdown(choices=["lite", "pro"], value="lite", label="Arquitectura (Cerebro)", interactive=True)
            with gr.Column(scale=3):
                model_status = gr.Textbox(label="Estado del Modelo", value="Lite (Laptop Optimized) - Vocabulario Extendido", interactive=False)
        
        def on_model_change(preset):
            status, _ = load_model(preset_name=preset)
            return status

        model_selector.change(on_model_change, model_selector, model_status)
        
        with gr.Tabs():
            # TAB 1: Search
            with gr.Tab("Buscar Formula"):
                with gr.Row():
                    with gr.Column(scale=1):
                        gr.HTML('<h3 style="color: #00d4ff;">Datos de Entrada</h3>')
                        x_input = gr.Textbox(label="Valores X", placeholder="1, 2, 3, 4, 5...", lines=2)
                        y_input = gr.Textbox(label="Valores Y", placeholder="5, 7, 9, 11, 13...", lines=2)
                        
                        with gr.Row():
                            search_method = gr.Radio(
                                choices=["Beam Search", "MCTS", "Alpha-GP Hybrid"],
                                value="Alpha-GP Hybrid",
                                label="Metodo de Busqueda"
                            )
                        
                        beam_slider = gr.Slider(5, 500, value=50, step=5, label="Beam Width / Simulaciones")
                        
                        solve_btn = gr.Button("Buscar Formula", variant="primary", size="lg")
                        
                        with gr.Row():
                            gr.Button("Lineal", size="sm").click(lambda: generate_example("lineal"), outputs=[x_input, y_input])
                            gr.Button("Cuadratico", size="sm").click(lambda: generate_example("cuadratico"), outputs=[x_input, y_input])
                            gr.Button("Seno", size="sm").click(lambda: generate_example("trig"), outputs=[x_input, y_input])
                            gr.Button("Exponencial", size="sm").click(lambda: generate_example("exp"), outputs=[x_input, y_input])
                    
                    with gr.Column(scale=2):
                        result_html = gr.HTML(label="Resultado")
                        plot_output = gr.Plot(label="Visualizacion")
                
                with gr.Row():
                    pred_html = gr.HTML(label="Predicciones")
                    alt_html = gr.HTML(label="Alternativas")
                
                raw_formula = gr.Textbox(visible=False)
                
                solve_btn.click(solve_formula, [x_input, y_input, beam_slider, search_method], 
                               [result_html, plot_output, pred_html, alt_html, raw_formula])
            
            # TAB 2: Training
            with gr.Tab("Entrenar Modelo"):
                with gr.Row():
                    gr.HTML("""
                    <div style="background: #16213e; padding: 20px; border-radius: 10px; flex: 1;">
                        <h3 style="color: #ffd93d; margin: 0;">Centro de Entrenamiento</h3>
                    </div>
                    """)
                    with gr.Column():
                        use_gpu = gr.Checkbox(label="Usar GPU", value=torch.cuda.is_available())
                        device_display = gr.HTML(value=f'<div style="padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid {device_color};"><span style="color: {device_color}; font-weight: bold;">{device_info}</span></div>')
                        use_gpu.change(toggle_device, [use_gpu], [device_display])
                    with gr.Column():
                        delete_model_btn = gr.Button("🗑️ Borrar Modelo", variant="secondary", size="sm")
                        delete_status = gr.HTML()
                        
                        def delete_model_action():
                            import os
                            from ui.app_core import CURRENT_PRESET
                            filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
                            if os.path.exists(filename):
                                os.remove(filename)
                                return f'<div style="color: #4ade80; padding: 5px;">✅ Modelo [{CURRENT_PRESET}] eliminado. Reinicia la app para usar pesos nuevos.</div>'
                            return f'<div style="color: #888; padding: 5px;">No hay modelo [{CURRENT_PRESET}] guardado.</div>'
                        
                        delete_model_btn.click(delete_model_action, outputs=[delete_status])
                        
                        stop_train_btn = gr.Button("⏹️ Detener Entrenamiento", variant="stop", size="sm")
                        stop_status = gr.HTML()
                        stop_train_btn.click(request_stop_training, outputs=[stop_status])
                
                with gr.Tabs():
                    # Basic
                    with gr.Tab("Basico"):
                        gr.HTML('<p style="color: #888;">Entrenamiento rapido con datos sinteticos</p>')
                        with gr.Row():
                            with gr.Column():
                                epochs_basic = gr.Slider(10, 500, value=100, step=10, label="Epocas")
                                batch_basic = gr.Slider(16, 128, value=32, step=16, label="Batch Size")
                                points_basic = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_basic_btn = gr.Button("Entrenar Basico", variant="primary")
                            with gr.Column():
                                result_basic = gr.HTML()
                                plot_basic = gr.Plot()
                        train_basic_btn.click(train_basic, [epochs_basic, batch_basic, points_basic], [result_basic, plot_basic])
                    
                    # Curriculum
                    with gr.Tab("Curriculum"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
                            <p style="color: #00d4ff; margin: 0;"><strong>Curriculum Learning</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">Empieza con formulas simples y aumenta la dificultad.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                epochs_curriculum = gr.Slider(50, 2000, value=200, step=50, label="Epocas")
                                batch_curriculum = gr.Slider(16, 128, value=64, step=16, label="Batch Size")
                                points_curriculum = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_curriculum_btn = gr.Button("Entrenar Curriculum", variant="primary")
                            with gr.Column():
                                result_curriculum = gr.HTML()
                                plot_curriculum = gr.Plot()
                        train_curriculum_btn.click(train_curriculum, [epochs_curriculum, batch_curriculum, points_curriculum], [result_curriculum, plot_curriculum])
                    
                    # Self-Play
                    with gr.Tab("Self-Play"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px; border-left: 3px solid #ff6b6b;">
                            <p style="color: #ff6b6b; margin: 0;"><strong>AlphaZero Self-Play</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">El modelo resuelve problemas y aprende de sus exitos.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                iterations_sp = gr.Slider(10, 1000, value=100, step=10, label="Iteraciones")
                                problems_sp = gr.Slider(5, 200, value=10, step=5, label="Problemas/Iter")
                                points_sp = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_sp_btn = gr.Button("Iniciar Self-Play", variant="primary")
                            with gr.Column():
                                result_sp = gr.HTML()
                                plot_sp = gr.Plot()
                        train_sp_btn.click(train_self_play, [iterations_sp, problems_sp, points_sp], [result_sp, plot_sp])
                
                    # Feedback Loop (Teacher-Student)
                    with gr.Tab("Feedback Loop (Hybrid)"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px; border-left: 3px solid #f1c40f;">
                            <p style="color: #f1c40f; margin: 0;"><strong>Teacher-Student Feedback Loop</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">El modelo (Estudiante) intenta resolver problemas. Si falla, el Alpha-GP (Maestro) interviene y añade la solución al dataset.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                iterations_fb = gr.Slider(5, 500, value=20, step=5, label="Ciclos")
                                problems_fb = gr.Slider(5, 50, value=10, step=5, label="Problemas Difíciles / Ciclo")
                                timeout_fb = gr.Slider(5, 30, value=10, step=5, label="Timeout Maestro (s)")
                                train_fb_btn = gr.Button("Iniciar Feedback Loop", variant="primary")
                            with gr.Column():
                                result_fb = gr.HTML()
                                plot_fb = gr.Plot()
                        train_fb_btn.click(train_hybrid_feedback_loop, [iterations_fb, problems_fb, timeout_fb], [result_fb, plot_fb])
                
                # --- PRE-TRAINING (Warmup) ---
                with gr.Accordion("🎓 Escuela Primaria (Pre-Entrenamiento)", open=False):
                    gr.Markdown("Entrenamiento masivo supervisado de alta velocidad para aprender sintaxis basica. **Recomendado al inicio.**")
                    with gr.Row():
                        with gr.Column():
                            epochs_pre = gr.Slider(100, 10000, value=2000, step=100, label="Iteraciones Rápidas")
                            train_pre_btn = gr.Button("Iniciar Pre-Entrenamiento", variant="primary")
                        with gr.Column():
                            result_pre = gr.HTML()
                            plot_pre = gr.Plot()
                    train_pre_btn.click(train_supervised, [epochs_pre], [result_pre, plot_pre])

                # --- HALL OF SHAME (Error Analysis) ---
                with gr.Accordion("🕵️‍♂️ Hall of Shame (Analisis de Errores)", open=False):
                    gr.Markdown("Aquí se muestran los problemas donde el modelo falló drásticamente hoy.")
                    error_table = gr.DataFrame(
                        headers=["Time", "Target Formula", "Predicted", "Loss", "Stage"],
                        datatype=["str", "str", "str", "number", "str"],
                        interactive=False
                    )
                    refresh_errors_btn = gr.Button("🔄 Actualizar Errores", size="sm")
                    
                    def update_errors():
                        errors = get_training_errors()
                        # Reverse to show newest first
                        data = [[
                            e['time'], e['target'], e['predicted'], round(e['loss'], 2), e['stage']
                        ] for e in reversed(errors)]
                        return data
                    
                    refresh_errors_btn.click(update_errors, outputs=[error_table])
            
            # TAB 4: Benchmark
            get_benchmark_tab()

            # TAB 5: Info
            with gr.Tab("Informacion"):
                device_info_current = get_device_info()
                device_color_current = "#4ade80" if "CUDA" in device_info_current else "#fbbf24" if "MPS" in device_info_current else "#888"
                
                gr.HTML(f"""
                <div style="background: #1a1a2e; padding: 30px; border-radius: 15px;">
                    <h2 style="color: #00d4ff;">Que es AlphaSymbolic?</h2>
                    <p style="color: #ccc; line-height: 1.8;">
                        Sistema de <strong style="color: #ff6b6b;">regresion simbolica</strong> 
                        basado en <strong style="color: #00d4ff;">Deep Learning</strong> y 
                        <strong style="color: #ffd93d;">Monte Carlo Tree Search</strong>.
                    </p>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Dispositivo Actual</h3>
                    <p style="color: {device_color_current}; font-size: 20px;">{device_info_current}</p>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Metodos de Busqueda</h3>
                    <ul style="color: #ccc;">
                        <li><strong>Beam Search:</strong> Explora multiples candidatos en paralelo (rapido)</li>
                        <li><strong>MCTS:</strong> Monte Carlo Tree Search (mas preciso, lento)</li>
                        <li><strong>Alpha-GP Hybrid:</strong> Fusiona Neural Search con Algoritmo Genetico GPU (Extremo)</li>
                    </ul>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Operadores</h3>
                    <div style="display: flex; flex-wrap: wrap; gap: 10px; margin: 15px 0;">
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">+</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">-</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">*</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">/</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ff6b6b;">sin</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ff6b6b;">cos</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ffd93d;">exp</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ffd93d;">log</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #4ade80;">pow</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #4ade80;">sqrt</span>
                    </div>
                </div>
                """)
        
        gr.HTML("""
        <div style="text-align: center; padding: 20px; color: #666; margin-top: 30px;">
            <p>Powered by PyTorch - SymPy - Scipy - Gradio</p>
        </div>
        """)
    
    return demo



# --- Global Initialization for Hot Reloading ---
print("Iniciando AlphaSymbolic (Global Init)...")
# Load model once at module level so 'gradio app.py' works
status_init, device_info_init = load_model() 
print(f"   {status_init} | {device_info_init}")

# Create the app instance globally
demo = create_app()

if __name__ == "__main__":
    print("Abriendo navegador...")
    # Launch with auto-reload compatibility if run directly (though proper reload needs 'gradio app.py')
    demo.launch(share=True, inbrowser=True)


In [None]:
# Run the application
!python app.py
