In [None]:
# AlphaSymbolic - Unified Hybrid System
# -------------------------------------
# Instructions:
# 1. Runtime -> Change runtime type -> T4 GPU
# 2. Mount Google Drive to PERSIST models
# 3. Run All

try:
    from google.colab import drive
    drive.mount('/content/drive')
    os.makedirs('/content/drive/MyDrive/AlphaSymbolic_Models', exist_ok=True)
    print("✅ Google Drive mounted correctly")
except Exception as e:
    print("⚠️ Google Drive NOT mounted. Models will be LOST after session ends.")

!nvidia-smi

# Install dependencies
!pip install gradio torch torchvision torchaudio scipy matplotlib sympy

# Create Directory Structure
import os
os.makedirs('Code/src', exist_ok=True)
os.makedirs('Code/build', exist_ok=True)
os.makedirs('AlphaSymbolic', exist_ok=True)
os.makedirs('AlphaSymbolic/data/benchmarks', exist_ok=True)
directories = ['core', 'data', 'search', 'ui', 'utils', 'models', 'results', 'tools', 'logs', 'notebooks']
for d in directories:
    os.makedirs(os.path.join('AlphaSymbolic', d), exist_ok=True)


In [None]:
%%writefile Code/src/AdvancedFeatures.cpp
#include "AdvancedFeatures.h"
#include "GradientOptimizer.h"
#include "Globals.h"
#include "GeneticOperators.h"
#include "Fitness.h"
#include "ExpressionTree.h" // Necesario para tree_to_string en la simplificación
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iostream>
#include <unordered_map>
#include <vector>

//---------------------------------
// EvolutionParameters
//---------------------------------
EvolutionParameters EvolutionParameters::create_default() {
    // Usa constantes globales para los valores por defecto
    return {BASE_MUTATION_RATE, BASE_ELITE_PERCENTAGE, DEFAULT_TOURNAMENT_SIZE, DEFAULT_CROSSOVER_RATE};
}

void EvolutionParameters::mutate(int stagnation_counter) {
    auto& rng = get_rng();
    double aggression_factor = 1.0;
    // Ajuste del factor de agresión basado en el estancamiento
    if (stagnation_counter > STAGNATION_LIMIT_ISLAND / 2) {
        // Aumenta la agresión si hay estancamiento significativo
        aggression_factor = 1.0 + (static_cast<double>(stagnation_counter - STAGNATION_LIMIT_ISLAND / 2) / (STAGNATION_LIMIT_ISLAND / 2.0)) * 0.5; // Escala de 1.0 a 1.5
        aggression_factor = std::min(aggression_factor, 2.0); // Limitar la agresión máxima
    } else if (stagnation_counter < STAGNATION_LIMIT_ISLAND / 4 && stagnation_counter > 0) {
        // Reduce la agresión si no hay mucho estancamiento, pero no es 0
        aggression_factor = 1.0 - (static_cast<double>(STAGNATION_LIMIT_ISLAND / 4 - stagnation_counter) / (STAGNATION_LIMIT_ISLAND / 4.0)) * 0.5; // Escala de 0.5 a 1.0
        aggression_factor = std::max(aggression_factor, 0.5); // Limitar la agresión mínima
    } else if (stagnation_counter == 0) {
        // Muy poco estancamiento, cambios muy pequeños
        aggression_factor = 0.2; // Cambios muy conservadores
    }

    std::uniform_real_distribution<double> base_rate_change(-0.05, 0.05);
    std::uniform_int_distribution<int> base_tourney_change(-2, 2);

    double rate_change_val = base_rate_change(rng) * aggression_factor;
    int tourney_change_val = static_cast<int>(std::round(base_tourney_change(rng) * aggression_factor));
    
    // Asegurar que haya algún cambio si la agresión es alta y el cambio base es 0
    if (aggression_factor > 1.0 && tourney_change_val == 0 && base_tourney_change(rng) != 0) {
         tourney_change_val = (base_tourney_change(rng) > 0) ? 1 : -1;
    }

    // Definir límites dinámicos para los parámetros
    double min_mutation = 0.05;
    double max_mutation_base = 0.5;
    double max_mutation = min_mutation + (max_mutation_base - min_mutation) * (1.0 + aggression_factor / 2.0);

    double min_elite = 0.02;
    double max_elite_base = 0.25;
    double max_elite = min_elite + (max_elite_base - min_elite) * (1.0 + aggression_factor / 2.0);

    int min_tournament = 3;
    int max_tournament_base = 30;
    int max_tournament = min_tournament + static_cast<int>((max_tournament_base - min_tournament) * (1.0 + aggression_factor / 2.0));

    // Aplicar los cambios y asegurar que estén dentro de los límites
    mutation_rate = std::clamp(mutation_rate + rate_change_val, min_mutation, max_mutation);
    elite_percentage = std::clamp(elite_percentage + rate_change_val, min_elite, max_elite);
    tournament_size = std::clamp(tournament_size + tourney_change_val, min_tournament, max_tournament);
    crossover_rate = std::clamp(crossover_rate + rate_change_val, 0.5, 0.95);
}

//---------------------------------
// PatternMemory
//---------------------------------
void PatternMemory::record_success(const NodePtr& tree, double fitness) {
    std::string pattern = extract_pattern(tree);
    if (pattern.empty() || pattern.length() > 50 || pattern.length() < 3 || pattern == "N") return;
    auto it = patterns.find(pattern);
    if (it == patterns.end()) {
        patterns[pattern] = {pattern, fitness, 1, (fitness < INF ? 1.0 : 0.0)};
    } else {
        auto& p = it->second;
        p.uses++;
        double improvement = (fitness < p.best_fitness && p.best_fitness < INF) ? 1.0 : 0.0;
        p.success_rate = ((p.success_rate * (p.uses - 1)) + improvement) / p.uses;
        p.best_fitness = std::min(p.best_fitness, fitness);
    }
}

NodePtr PatternMemory::suggest_pattern_based_tree(int max_depth) {
    if (patterns.empty()) return nullptr;
    std::vector<std::pair<std::string, double>> candidates;
    for (const auto& [pattern_str, info] : patterns) {
        if (info.uses >= PATTERN_MEM_MIN_USES && (info.success_rate > 0.1 || info.best_fitness < PATTERN_RECORD_FITNESS_THRESHOLD)) {
             double weight = info.success_rate + (1.0 / (1.0 + info.best_fitness));
             candidates.emplace_back(pattern_str, weight);
        }
    }
    if (candidates.empty()) return nullptr;
    std::vector<double> weights;
    std::transform(candidates.begin(), candidates.end(), std::back_inserter(weights), [](const auto& p){ return p.second; });
    std::discrete_distribution<> dist(weights.begin(), weights.end());
    auto& rng = get_rng();
    int selected_idx = dist(rng);
    return parse_pattern(candidates[selected_idx].first, max_depth);
}

std::string PatternMemory::extract_pattern(const NodePtr& node) {
    if (!node) return "N";
    switch (node->type) {
        case NodeType::Constant: return "#";
        case NodeType::Variable: return "x";
        case NodeType::Operator:
            return "(" + extract_pattern(node->left) + node->op + extract_pattern(node->right) + ")";
        default: return "?";
    }
}

NodePtr PatternMemory::parse_pattern(const std::string& pattern, int max_depth) {
    // Placeholder implementation
    if (pattern == "#") {
        auto node = std::make_shared<Node>(NodeType::Constant);
        if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cd(CONSTANT_INT_MIN_VALUE, CONSTANT_INT_MAX_VALUE); node->value = static_cast<double>(cd(get_rng())); }
        else { std::uniform_real_distribution<double> cd(CONSTANT_MIN_VALUE, CONSTANT_MAX_VALUE); node->value = cd(get_rng()); }
        if(std::fabs(node->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) node->value = 0.0;
        return node;
    }
    if (pattern == "x") return std::make_shared<Node>(NodeType::Variable);
    if (pattern == "N") return nullptr;
    if (pattern.length() > 3 && pattern.front() == '(' && pattern.back() == ')') {
          return generate_random_tree(max_depth); // Fallback
     }
    return generate_random_tree(max_depth); // Fallback
}

//---------------------------------
// Pareto Optimizer
//---------------------------------
ParetoSolution::ParetoSolution(NodePtr t, double acc, double complexity_val) : tree(std::move(t)), accuracy(acc), complexity(complexity_val), dominated(false) {}

bool ParetoSolution::dominates(const ParetoSolution& other) const {
    bool better_in_one = (accuracy < other.accuracy) || (complexity < other.complexity);
    bool not_worse_in_any = (accuracy <= other.accuracy) && (complexity <= other.complexity);
    return better_in_one && not_worse_in_any;
}

void ParetoOptimizer::update(const std::vector<Individual>& population, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values) {
    std::vector<ParetoSolution> candidates = pareto_front;
    for (const auto& ind : population) {
        if (ind.tree && ind.fitness_valid && ind.fitness < INF) {
            candidates.emplace_back(ind.tree, ind.fitness, static_cast<double>(tree_size(ind.tree)));
        }
    }
    for (auto& sol1 : candidates) {
        sol1.dominated = false;
        for (const auto& sol2 : candidates) {
            if (&sol1 == &sol2) continue;
            if (sol2.dominates(sol1)) { sol1.dominated = true; break; }
        }
    }
    pareto_front.clear();
    std::copy_if(candidates.begin(), candidates.end(), std::back_inserter(pareto_front),
                 [](const auto& sol) { return !sol.dominated; });
    if (pareto_front.size() > PARETO_MAX_FRONT_SIZE) {
        std::sort(pareto_front.begin(), pareto_front.end(), [](const auto& a, const auto& b){ return a.accuracy < b.accuracy; });
        pareto_front.resize(PARETO_MAX_FRONT_SIZE);
    }
}

std::vector<NodePtr> ParetoOptimizer::get_pareto_solutions() {
    std::vector<NodePtr> result;
    result.reserve(pareto_front.size());
    std::transform(pareto_front.begin(), pareto_front.end(), std::back_inserter(result),
                   [](const auto& sol) { return sol.tree; });
    return result;
}

//---------------------------------
// Domain Constraints
//---------------------------------
bool DomainConstraints::is_valid_recursive(const NodePtr& node) {
     if (!node) return true;
     if (node->type == NodeType::Operator) {
         if (node->op == '/' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return false;
         if (node->op == '^') { // Solo chequear 0^negativo/0
              if (node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE &&
                  node->right && node->right->type == NodeType::Constant && node->right->value <= SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                      return false;
              }
         }
         if ((node->op == '*' || node->op == '/') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return false;
         if ((node->op == '+' || node->op == '-') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return false;
         if (node->op == '*' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return false;
         if (node->op == '+' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return false;
         if (!is_valid_recursive(node->left) || !is_valid_recursive(node->right)) return false;
     }
     return true;
 }

bool DomainConstraints::is_valid(const NodePtr& tree) {
    return is_valid_recursive(tree);
}

NodePtr DomainConstraints::simplify_recursive(NodePtr node) {
    if (!node || node->type != NodeType::Operator) return node;
    node->left = simplify_recursive(node->left);
    node->right = simplify_recursive(node->right);

    // Manejo de hijos nulos
    bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g');

    // Constant Folding (First priority)
    bool left_is_const = (node->left && node->left->type == NodeType::Constant);
    bool right_is_const = (node->right && node->right->type == NodeType::Constant);
    
    // Fold if binary op with 2 constants OR unary op with 1 constant
    if ((left_is_const && right_is_const) || (is_unary && left_is_const)) {
        try {
            double result = evaluate_tree(node, std::vector<double>{0.0}); 
            if (!std::isnan(result) && !std::isinf(result)) {
                auto cn = std::make_shared<Node>(NodeType::Constant);
                if (FORCE_INTEGER_CONSTANTS) cn->value = std::round(result); else cn->value = result;
                if (std::fabs(cn->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) cn->value = 0.0; return cn;
            }
        } catch (const std::exception&) {}
    }

    if (node->left && !node->right) {
        if (is_unary) return node; // Correct state for unary ops (Constant folding didn't trigger, so var inside)
        return node->left; // Simplify "A op null" -> A (for binary ops? dangerous but existing logic)
    }
    if (!node->left && node->right) return node->right;
    if (!node->left && !node->right) { auto cn = std::make_shared<Node>(NodeType::Constant); cn->value = 1.0; return cn; }

    // Identity Simplifications & Fixes
     if ((node->op == '+' || node->op == '-') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return node->left;
     if (node->op == '+' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return node->right;
     if ((node->op == '*' || node->op == '/') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return node->left;
     if (node->op == '*' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return node->right;
     if (node->op == '*' && ((node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) || (node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE))) { auto z = std::make_shared<Node>(NodeType::Constant); z->value = 0.0; return z; }
     if (node->op == '^' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return node->left; // A^1 -> A
     if (node->op == '^' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) { auto o = std::make_shared<Node>(NodeType::Constant); o->value = 1.0; return o; } // A^0 -> 1
    // Fix div by zero (constante)
    if (node->op == '/' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) node->right->value = 1.0;

    // --- NUEVAS REGLAS DE SIMPLIFICACIÓN ---
    // X / X = 1 (si X no es cero)
    if (node->op == '/' && node->left && node->right) {
        if (tree_to_string(node->left) == tree_to_string(node->right)) {
            // Verificar que el divisor no sea cero para evitar 0/0
            if (node->right->type != NodeType::Constant || std::fabs(node->right->value) >= SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                auto one = std::make_shared<Node>(NodeType::Constant);
                one->value = 1.0;
                return one;
            }
        }
    }

    // X - X = 0
    if (node->op == '-' && node->left && node->right) {
        if (tree_to_string(node->left) == tree_to_string(node->right)) {
            auto zero = std::make_shared<Node>(NodeType::Constant);
            zero->value = 0.0;
            return zero;
        }
    }
    // Ya no se hace clamp de exponente constante aquí, se quitó la restricción

    return node;
}

NodePtr DomainConstraints::fix_or_simplify(NodePtr tree) {
    if (!tree) return nullptr;
    NodePtr cloned_tree = clone_tree(tree);
    NodePtr simplified_tree = simplify_recursive(cloned_tree);
    return simplified_tree;
}

//---------------------------------
// Local Improvement
//---------------------------------
//---------------------------------
// Local Improvement
//---------------------------------
void optimize_constants(NodePtr& tree, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values, double* d_targets, double* d_x_values) {
    if (!tree) return;
    
    // 1. Collect constant nodes
    std::vector<Node*> constants;
    std::vector<Node*> stack;
    stack.push_back(tree.get());
    while(!stack.empty()){
        Node* n = stack.back(); stack.pop_back();
        if(!n) continue;
        if(n->type == NodeType::Constant) constants.push_back(n);
        else if(n->type == NodeType::Operator){
            stack.push_back(n->right.get());
            stack.push_back(n->left.get());
        }
    }
    
    if (constants.empty()) return;

    // 2. Hill Climbing (Numeric Optimization)
    int max_iter = 20; // Fast local search
    auto& rng = get_rng();
    
    // Evaluate initial fitness
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    double current_fitness = evaluate_fitness(tree, targets, x_values, d_targets, d_x_values);
#else
    double current_fitness = evaluate_fitness(tree, targets, x_values);
#endif

    std::normal_distribution<double> perturbation(0.0, 0.5); // Perturb standard deviation 0.5

    for(int i=0; i<max_iter; ++i) {
        // Select a random constant
        int idx = std::uniform_int_distribution<int>(0, constants.size()-1)(rng);
        double old_val = constants[idx]->value;
        
        // Perturb
        double delta = perturbation(rng);
        constants[idx]->value += delta;
        if (FORCE_INTEGER_CONSTANTS) constants[idx]->value = std::round(constants[idx]->value);

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
        double new_fitness = evaluate_fitness(tree, targets, x_values, d_targets, d_x_values);
#else
        double new_fitness = evaluate_fitness(tree, targets, x_values);
#endif

        if (new_fitness < current_fitness) {
            current_fitness = new_fitness; // Accept
            // Adapt perturbation? Maybe reduce sigma?
        } else {
            constants[idx]->value = old_val; // Revert
        }
        
        if (current_fitness < EXACT_SOLUTION_THRESHOLD) break;
    }
}

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree, double current_fitness, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values, int attempts, double* d_targets, double* d_x_values) {
    // 1. First, try to optimize constants of the CURRENT tree using GRADIENT DESCENT
    NodePtr optimized_tree = clone_tree(tree);
    // Use Gradient Optimization (Adam) - much more precise than Hill Climbing
    optimize_constants_gradient(optimized_tree, targets, x_values, 0.05, 30);
    
    #ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
        double optimized_fitness = evaluate_fitness(optimized_tree, targets, x_values, d_targets, d_x_values);
    #else
        double optimized_fitness = evaluate_fitness(optimized_tree, targets, x_values);
    #endif
    
    NodePtr best_neighbor = (optimized_fitness < current_fitness) ? optimized_tree : tree;
    double best_neighbor_fitness = (optimized_fitness < current_fitness) ? optimized_fitness : current_fitness;

    if (best_neighbor_fitness >= INF) return {best_neighbor, best_neighbor_fitness};

    // 2. Structural Search (as before)
    for (int i = 0; i < attempts; ++i) {
        NodePtr neighbor = mutate_tree(best_neighbor, 1.0, 2); // Mutate the BEST so far
        neighbor = DomainConstraints::fix_or_simplify(neighbor);
        if (!neighbor) continue;
        
        // Also optimize constants of structural neighbor?
        // Maybe too expensive. Let's do a quick random constant tweak.
        // optimize_constants(neighbor, targets, x_values, d_targets, d_x_values); 
        
        double neighbor_fitness = evaluate_fitness(neighbor, targets, x_values, d_targets, d_x_values);
        if (neighbor_fitness < best_neighbor_fitness) {
            best_neighbor = neighbor;
            best_neighbor_fitness = neighbor_fitness;
        }
    }
    return {best_neighbor, best_neighbor_fitness};
}
#else
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree, double current_fitness, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values, int attempts) {
    // 1. First, try to optimize constants of the CURRENT tree using GRADIENT DESCENT
    NodePtr optimized_tree = clone_tree(tree);
    optimize_constants_gradient(optimized_tree, targets, x_values, 0.05, 30);
    double optimized_fitness = evaluate_fitness(optimized_tree, targets, x_values);
    
    NodePtr best_neighbor = (optimized_fitness < current_fitness) ? optimized_tree : tree;
    double best_neighbor_fitness = (optimized_fitness < current_fitness) ? optimized_fitness : current_fitness;

    if (best_neighbor_fitness >= INF) return {best_neighbor, best_neighbor_fitness};

    for (int i = 0; i < attempts; ++i) {
        NodePtr neighbor = mutate_tree(best_neighbor, 1.0, 2);
        neighbor = DomainConstraints::fix_or_simplify(neighbor);
        if (!neighbor) continue;
        double neighbor_fitness = evaluate_fitness(neighbor, targets, x_values);
        if (neighbor_fitness < best_neighbor_fitness) {
            best_neighbor = neighbor;
            best_neighbor_fitness = neighbor_fitness;
        }
    }
    return {best_neighbor, best_neighbor_fitness};
}
#endif

//---------------------------------
// Target Pattern Detection
//---------------------------------
std::pair<std::string, double> detect_target_pattern(const std::vector<double>& targets) {
    if (targets.size() < 3) return {"none", 0.0};
    bool is_arithmetic = true; double diff = targets[1] - targets[0];
    for (size_t i = 2; i < targets.size(); ++i) if (std::fabs((targets[i] - targets[i-1]) - diff) > 1e-6) { is_arithmetic = false; break; }
    if (is_arithmetic) return {"arithmetic", diff};
    bool is_geometric = true;
    if (std::fabs(targets[0]) < 1e-9) {
        bool all_zero = true; for(double t : targets) if (std::fabs(t) > 1e-9) { all_zero = false; break; }
        if(all_zero) return {"constant_zero", 0.0}; else is_geometric = false;
    }
    if (is_geometric && std::fabs(targets[0]) >= 1e-9) {
        double ratio = targets[1] / targets[0];
        for (size_t i = 2; i < targets.size(); ++i) {
             if (std::fabs(targets[i-1]) < 1e-9) { if (std::fabs(targets[i]) > 1e-9) { is_geometric = false; break; } }
             else { if (std::fabs((targets[i] / targets[i-1]) - ratio) > 1e-6) { is_geometric = false; break; } }
        }
        if (is_geometric) return {"geometric", ratio};
    }
    return {"none", 0.0};
}

//---------------------------------
// Generate Pattern Based Tree
//---------------------------------
NodePtr generate_pattern_based_tree(const std::string& pattern_type, double pattern_value) {
    if (X_VALUES.empty() || RAW_TARGETS.empty()) return nullptr;
    double a = RAW_TARGETS[0]; 
    double x0 = (!X_VALUES[0].empty()) ? X_VALUES[0][0] : 0.0;
    if (pattern_type == "arithmetic") {
        double d = pattern_value; auto root = std::make_shared<Node>(NodeType::Operator); root->op = '+';
        auto cp = std::make_shared<Node>(NodeType::Constant); double cv = a - d * x0; if (FORCE_INTEGER_CONSTANTS) cv = std::round(cv); cp->value = (std::fabs(cv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : cv; // Use RAW_TARGETS to avoid "TARGETS" not found

        auto vp = std::make_shared<Node>(NodeType::Operator); vp->op = '*';
        auto dc = std::make_shared<Node>(NodeType::Constant); double dv = d; if (FORCE_INTEGER_CONSTANTS) dv = std::round(dv); dc->value = (std::fabs(dv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : dv;
        auto xv = std::make_shared<Node>(NodeType::Variable); vp->left = dc; vp->right = xv;
        root->left = cp; root->right = vp; return DomainConstraints::fix_or_simplify(root);
    } else if (pattern_type == "geometric") {
        double r = pattern_value; if (std::fabs(r) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return nullptr;
        auto root = std::make_shared<Node>(NodeType::Operator); root->op = '*';
        auto cp = std::make_shared<Node>(NodeType::Constant); double rpx0 = std::pow(r, x0); if (std::fabs(rpx0) < 1e-100) return nullptr;
        double cv = a / rpx0; if (FORCE_INTEGER_CONSTANTS) cv = std::round(cv); cp->value = (std::fabs(cv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : cv;
        auto vp = std::make_shared<Node>(NodeType::Operator); vp->op = '^';
        auto rc = std::make_shared<Node>(NodeType::Constant); double rv = r; if (FORCE_INTEGER_CONSTANTS) rv = std::round(rv); rc->value = (std::fabs(rv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : rv;
        auto xv = std::make_shared<Node>(NodeType::Variable); vp->left = rc; vp->right = xv;
        root->left = cp; root->right = vp; return DomainConstraints::fix_or_simplify(root);
    } else if (pattern_type == "constant_zero") {
         auto node = std::make_shared<Node>(NodeType::Constant); node->value = 0.0; return node;
     }
    return nullptr; // No pattern tree generated
}


In [None]:
%%writefile Code/src/AdvancedFeatures.h
#ifndef ADVANCEDFEATURES_H
#define ADVANCEDFEATURES_H

#include "ExpressionTree.h"
#include "Globals.h" // Incluir Globals.h para INF
#include <vector>
#include <string>
#include <map>
#include <set>
#include <utility> // Para std::pair
#include <unordered_map>

// Meta-evolución: Parámetros que pueden adaptarse durante la ejecución.
struct EvolutionParameters {
    double mutation_rate;    // Tasa de mutación actual
    double elite_percentage; // Porcentaje de élite actual
    int tournament_size;     // Tamaño del torneo actual
    double crossover_rate;   // Tasa de cruce actual

    // Crea un conjunto de parámetros con valores por defecto (iniciales).
    static EvolutionParameters create_default();

    // Adapta (muta) los parámetros ligeramente.
    // AHORA RECIBE el contador de estancamiento para ajustar la intensidad.
    void mutate(int stagnation_counter);
};

// Memoria de patrones: Almacena sub-estructuras exitosas (Reinforcement Learning).
class PatternMemory {
    struct PatternInfo {
        std::string pattern_str; // Representación del patrón
        double best_fitness = INF; // Mejor fitness visto para este patrón
        int uses = 0;             // Número de veces usado/visto
        double success_rate = 0.0; // Tasa de éxito estimada
    };
    std::unordered_map<std::string, PatternInfo> patterns; // Mapa para almacenar patrones
    int min_uses_for_suggestion = 3; // Mínimo de usos para considerar sugerir un patrón

public:
    // Registra el éxito de un árbol (y su patrón) basado en su fitness.
    void record_success(const NodePtr& tree, double fitness);
    // Sugiere un árbol basado en los patrones exitosos almacenados.
    NodePtr suggest_pattern_based_tree(int max_depth);

private:
    // Extrae la representación estructural (string) de un árbol.
    std::string extract_pattern(const NodePtr& tree);
    // Intenta construir un árbol a partir de un patrón (string) - función simplificada.
    NodePtr parse_pattern(const std::string& pattern, int max_depth);
};


// Optimización Pareto: Mantiene un frente de soluciones no dominadas (compromiso precisión/complejidad).
struct ParetoSolution {
    NodePtr tree = nullptr;   // Árbol de la solución
    double accuracy = INF;    // Objetivo 1: Precisión (fitness)
    double complexity = INF;  // Objetivo 2: Complejidad (tamaño)
    bool dominated = false;   // Bandera: ¿está dominada por otra solución?

    // Constructor por defecto (necesario si se usa en contenedores)
    ParetoSolution() = default;
    // Constructor principal
    ParetoSolution(NodePtr t, double acc, double complexity_val);

    // Comprueba si esta solución domina a otra.
    bool dominates(const ParetoSolution& other) const;
};

class ParetoOptimizer {
    std::vector<ParetoSolution> pareto_front; // Almacena las soluciones del frente
    size_t max_front_size = 50; // Límite opcional para el tamaño del frente

public:
    // Actualiza el frente de Pareto con individuos de la población actual.
    void update(const std::vector<struct Individual>& population, // Usa Individual struct
                const std::vector<double>& targets,
                const std::vector<std::vector<double>>& x_values);

    // Obtiene los árboles (NodePtr) de las soluciones en el frente actual.
    std::vector<NodePtr> get_pareto_solutions();

    // Obtiene una referencia constante al frente de Pareto completo.
    const std::vector<ParetoSolution>& get_pareto_front() const { return pareto_front; }
};


// Restricciones de Dominio: Verifica y corrige/simplifica árboles problemáticos.
class DomainConstraints {
public:
    // Comprueba si un árbol cumple reglas básicas de validez estática.
    static bool is_valid(const NodePtr& tree);

    // Intenta simplificar/corregir un árbol (devuelve una copia modificada).
    static NodePtr fix_or_simplify(NodePtr tree);

private:
     // Ayudante recursivo para la simplificación.
    static NodePtr simplify_recursive(NodePtr node);
    // Ayudante recursivo para la validación estática.
    static bool is_valid_recursive(const NodePtr& node);
};

// Búsqueda Local: Intenta mejorar una solución dada explorando vecinos cercanos.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree,
                                                  double current_fitness,
                                                  const std::vector<double>& targets,
                                                  const std::vector<std::vector<double>>& x_values,
                                                  int attempts,
                                                  double* d_targets, double* d_x_values);
#else
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree,
                                                  double current_fitness,
                                                  const std::vector<double>& targets,
                                                  const std::vector<std::vector<double>>& x_values,
                                                  int attempts = 10);
#endif


// Detección de Patrones en los Datos Objetivo.
std::pair<std::string, double> detect_target_pattern(const std::vector<double>& targets);
NodePtr generate_pattern_based_tree(const std::string& pattern_type, double pattern_value);


#endif // ADVANCEDFEATURES_H


In [None]:
%%writefile Code/src/ExpressionTree.cpp
#include "ExpressionTree.h"
#include "Globals.h"
#include <cmath>
#include <limits>
#include <stdexcept>
#include <vector>
#include <iostream>
#include <iomanip>
#include <string>
#include <sstream>
#include <stack>
#include <unordered_map>
#include <algorithm> // Para std::remove_if
#include <cctype>    // Para isdigit, isspace
#include <thread>    // Para thread_local RNG

// --- Función auxiliar para formatear constantes ---
// --- Función auxiliar para formatear constantes ---
std::string format_constant(double val) {
    // Si es un entero o muy cercano a un entero, formatarlo como tal.
    if (std::fabs(val - std::round(val)) < SIMPLIFY_NEAR_ZERO_TOLERANCE) {
        return std::to_string(static_cast<long long>(std::round(val)));
    } else {
        std::ostringstream oss;
        // Usar notación científica para valores muy grandes o muy pequeños,
        // o notación fija para el resto, con precisión adecuada.
        // Esto evita cadenas muy largas o pérdida de información.
        if (std::fabs(val) >= 1e6 || std::fabs(val) <= 1e-6) { // Umbrales ajustables
            oss << std::scientific << std::setprecision(8) << val;
        } else {
            oss << std::fixed << std::setprecision(8) << val;
        }
        
        std::string s = oss.str();
        // Eliminar ceros finales y el punto decimal si no hay parte fraccionaria
        // Esto puede ser delicado con std::scientific, así que hay que ser cuidadosos.
        // Para std::fixed:
        if (s.find('.') != std::string::npos) {
            s.erase(s.find_last_not_of('0') + 1, std::string::npos);
            if (!s.empty() && s.back() == '.') s.pop_back();
        }
        return s.empty() ? "0" : s;
    }
}

// --- evaluate_tree ---
double evaluate_tree(const NodePtr& node, const std::vector<double>& vars) {
    if (!node) return std::nan("");
    switch (node->type) {
        case NodeType::Constant: return node->value;
        case NodeType::Variable: 
            if (node->var_index >= 0 && node->var_index < vars.size()) {
                return vars[node->var_index];
            }
            return std::nan(""); // Index out of bounds
        case NodeType::Operator: {
            // Determine arity
            bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g' || node->op == 't' || node->op == 'q' || node->op == 'a' || node->op == 'n' || node->op == 'u' || node->op == 'S' || node->op == 'C' || node->op == 'T');

            double leftVal = evaluate_tree(node->left, vars);
            double rightVal = 0.0;
            if (!is_unary) {
                rightVal = evaluate_tree(node->right, vars);
            }

            if (std::isnan(leftVal)) return std::nan("");
            if (!is_unary && std::isnan(rightVal)) return std::nan("");

            double result = std::nan("");
            try {
                switch (node->op) {
                    case '+': result = leftVal + rightVal; break;
                    case '-': result = leftVal - rightVal; break;
                    case '*': result = leftVal * rightVal; break;
                    case '/':
                        if (std::fabs(rightVal) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return INF;
                        result = leftVal / rightVal;
                        break;
                    case '^':
                        if (leftVal == 0.0 && rightVal == 0.0) result = 1.0;
                        else if (leftVal == 0.0 && rightVal < 0.0) return INF;
                        else if (leftVal < 0.0 && std::fabs(rightVal - std::round(rightVal)) > SIMPLIFY_NEAR_ZERO_TOLERANCE) return INF;
                        else result = std::pow(leftVal, rightVal);
                        break;
                    case '%':
                        if (std::fabs(rightVal) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return INF;
                        result = std::fmod(leftVal, rightVal);
                        break;
                    case 's': result = std::sin(leftVal); break;
                    case 'c': result = std::cos(leftVal); break;
                    case 't': result = std::tan(leftVal); break;
                    case 'q': 
                        // Protected Sqrt: sqrt(|x|)
                        result = std::sqrt(std::abs(leftVal)); 
                        break;
                    case 'a': result = std::abs(leftVal); break;
                    case 'n': result = (leftVal > 0) ? 1.0 : ((leftVal < 0) ? -1.0 : 0.0); break;
                    case 'l': 
                        // Protected Log: log(|x|)
                        if (std::abs(leftVal) <= 1e-9) return INF; 
                        result = std::log(std::abs(leftVal)); 
                        break;
                    case 'e': 
                        if (leftVal > 700.0) return INF; // Overflow check
                        result = std::exp(leftVal); 
                        break;
                    case '!': 
                        // Protected Factorial/Gamma: tgamma(|x|+1)
                        if (std::abs(leftVal) > 170.0) return INF; 
                        result = std::tgamma(std::abs(leftVal) + 1.0); 
                        break;
                    case '_': result = std::floor(leftVal); break;
                    case 'u': result = std::ceil(leftVal); break; // 'u' for ceil (up)
                    case 'g':
                        result = std::lgamma(std::abs(leftVal) + 1.0); 
                        break;
                    case 'S': 
                        // Protected Asin: asin(clip(x, -1, 1))
                        result = std::asin(std::max(-1.0, std::min(1.0, leftVal))); 
                        break;
                    case 'C': 
                        // Protected Acos: acos(clip(x, -1, 1))
                        result = std::acos(std::max(-1.0, std::min(1.0, leftVal))); 
                        break;
                    case 'T': result = std::atan(leftVal); break;
                    default: return std::nan("");
                }
            } catch (const std::exception& e) { return INF; }
            if (std::isinf(result)) return INF;
            if (std::isnan(result)) return std::nan("");
            return result;
        }
        default: return std::nan("");
    }
}

// Convenience overload for single variable case
double evaluate_tree(const NodePtr& node, double val) {
    return evaluate_tree(node, std::vector<double>{val});
}

// --- tree_to_string ---
std::string tree_to_string(const NodePtr& node) {
     if (!node) return "NULL";
     switch (node->type) {
        case NodeType::Constant: return format_constant(node->value);
        case NodeType::Variable: return "x" + std::to_string(node->var_index);
        case NodeType::Operator: {
            NodePtr left_node = node->left;
            std::string left_str = tree_to_string(left_node);
            
            // Check arity
            bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g' || node->op == 't' || node->op == 'q' || node->op == 'a' || node->op == 'n' || node->op == 'u' || node->op == 'S' || node->op == 'C' || node->op == 'T');

            if (is_unary) {
                switch(node->op) {
                    case 's': return "sin(" + left_str + ")";
                    case 'c': return "cos(" + left_str + ")";
                    case 't': return "tan(" + left_str + ")";
                    case 'q': return "sqrt(" + left_str + ")";
                    case 'a': return "abs(" + left_str + ")";
                    case 'n': return "sign(" + left_str + ")";
                    case 'l': return "log(" + left_str + ")";
                    case 'e': return "exp(" + left_str + ")";
                    case '!': return "(" + left_str + ")!"; // Postfix for factorial
                    case '_': return "floor(" + left_str + ")";
                    case 'u': return "ceil(" + left_str + ")";
                    case 'g': return "lgamma(" + left_str + ")";
                    case 'S': return "asin(" + left_str + ")";
                    case 'C': return "acos(" + left_str + ")";
                    case 'T': return "atan(" + left_str + ")";
                    default: return "op(" + left_str + ")";
                }
            }

            NodePtr right_node = node->right;
            std::string right_str = tree_to_string(right_node);
            char current_op = node->op;
            bool right_is_neg_const = (right_node && right_node->type == NodeType::Constant && right_node->value < 0.0);
            if (right_is_neg_const) {
                double abs_right_val = std::fabs(right_node->value);
                std::string abs_right_str = format_constant(abs_right_val);
                if (node->op == '+') { current_op = '-'; right_str = abs_right_str; }
                else if (node->op == '-') { current_op = '+'; right_str = abs_right_str; }
            }
            // Simplificar impresión de (0-A) a (-A)
            if (left_node && left_node->type == NodeType::Constant && left_node->value == 0.0 && current_op == '-') {
                 return "(-" + right_str + ")";
            }
            return "(" + left_str + current_op + right_str + ")";
        }
        default: return "?";
    }
}

// --- tree_size ---
int tree_size(const NodePtr& node) {
    if (!node) return 0;
    if (node->type == NodeType::Constant || node->type == NodeType::Variable) return 1;
    if (node->type == NodeType::Operator) {
        return 1 + tree_size(node->left) + tree_size(node->right);
    }
    return 0;
}

// --- clone_tree ---
NodePtr clone_tree(const NodePtr& node) {
    if (!node) return nullptr;
    auto newNode = std::make_shared<Node>(node->type);
    newNode->value = node->value;
    newNode->var_index = node->var_index;
    newNode->op = node->op;
    newNode->left = clone_tree(node->left);
    newNode->right = clone_tree(node->right);
    return newNode;
}

// --- collect_node_ptrs ---
void collect_node_ptrs(NodePtr& node, std::vector<NodePtr*>& vec) {
    if (!node) return;
    vec.push_back(&node);
    if (node->type == NodeType::Operator) {
        collect_node_ptrs(node->left, vec);
        collect_node_ptrs(node->right, vec);
    }
}

// --- get_rng ---
// === OPTIMIZACIÓN: RNG thread-local para evitar contención en OpenMP ===
std::mt19937& get_rng() {
    thread_local std::mt19937 local_rng(
        std::random_device{}() ^ 
        static_cast<unsigned>(std::hash<std::thread::id>{}(std::this_thread::get_id()))
    );
    return local_rng;
}

// --- get_tree_depth ---
int get_tree_depth(const NodePtr& node) {
    if (!node) return 0;
    if (node->type != NodeType::Operator) return 1;
    return 1 + std::max(get_tree_depth(node->left), get_tree_depth(node->right));
}

// --- trim_tree ---
void trim_tree(NodePtr& node, int max_depth) {
    if (!node) return;
    if (max_depth <= 1) {
        // Force terminal if we reached depth limit
        if (node->type == NodeType::Operator) {
            // Replace with minimal terminal (Variable 'x' or Constant 1.0)
            // Using 'x' is generally safer for retaining some logic, but 1.0 is neutral for *
            // Let's pick a random terminal to avoid bias? 
            // For now, let's just make it a variable 'x' as it's often more useful than a constant 0 or 1.
             node->type = NodeType::Variable;
             node->op = 0;
             node->left = nullptr;
             node->right = nullptr;
             // value ignored for variable
        }
        return;
    }
    
    if (node->type == NodeType::Operator) {
        trim_tree(node->left, max_depth - 1);
        trim_tree(node->right, max_depth - 1);
    }
}


// ============================================================
// --- Parser de Fórmulas desde String (v4 - Parser Corregido) ---
// ============================================================

// Helper para obtener precedencia de operadores
int get_precedence(char op) {
    switch (op) {
        case '+': case '-': return 1;
        case '*': case '/': case '%': return 2;
        case '^': return 3;
        default: return 0;
    }
}

// Helper para aplicar un operador binario
NodePtr apply_binary_operation(NodePtr right, NodePtr left, char op) {
    if (!left || !right) {
        throw std::runtime_error("Error al aplicar operación binaria '" + std::string(1, op) + "': operandos insuficientes.");
    }
    auto node = std::make_shared<Node>(NodeType::Operator);
    node->op = op;
    node->left = left;
    node->right = right;
    return node;
}

// Función principal para parsear la fórmula
NodePtr parse_formula_string(const std::string& formula_raw) {
    std::string formula = formula_raw;
    formula.erase(std::remove_if(formula.begin(), formula.end(), ::isspace), formula.end());
    if (formula.empty()) throw std::runtime_error("La fórmula está vacía.");

    std::stack<NodePtr> operand_stack;
    std::stack<char> operator_stack;

    // Función interna para procesar operadores según precedencia y asociatividad
    auto process_operators_by_precedence = [&](int current_precedence, char current_op_char = 0) {
        // La asociatividad derecha para '^' significa que se procesa si el operador en la pila
        // tiene MAYOR precedencia, no MAYOR O IGUAL.
        bool is_right_associative = (current_op_char == '^');

        while (!operator_stack.empty() && operator_stack.top() != '(') {
            char top_op = operator_stack.top();
            int top_precedence = get_precedence(top_op);

            if (is_right_associative ? (top_precedence > current_precedence) : (top_precedence >= current_precedence)) {
                operator_stack.pop(); // Sacar operador de la pila
                if (operand_stack.size() < 2) throw std::runtime_error("Operandos insuficientes para operador '" + std::string(1, top_op) + "'.");
                NodePtr right = operand_stack.top(); operand_stack.pop();
                NodePtr left = operand_stack.top(); operand_stack.pop();
                operand_stack.push(apply_binary_operation(right, left, top_op));
            } else {
                break; // Parar si la precedencia es menor o si es asociativo a la derecha y es igual
            }
        }
    };

    bool last_token_was_operand = false;

    for (int i = 0; i < formula.length(); /* Incremento manual */ ) {
        char token = formula[i];

        // --- A. Parsear Números ---
        bool starts_number = isdigit(token) || (token == '.' && i + 1 < formula.length() && isdigit(formula[i+1]));
        if (starts_number) {
             if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
                 last_token_was_operand = false;
             }
            std::string num_str;
            if (token == '.') num_str += '0';
            num_str += token;
            i++;
            while (i < formula.length() && (isdigit(formula[i]) || (formula[i] == '.' && num_str.find('.') == std::string::npos))) {
                num_str += formula[i];
                i++;
            }
            try {
                double value = std::stod(num_str);
                auto node = std::make_shared<Node>(NodeType::Constant); node->value = value;
                operand_stack.push(node);
                last_token_was_operand = true;
            } catch (const std::invalid_argument& e) {
                throw std::runtime_error("Número inválido (formato): '" + num_str + "' - " + e.what());
            } catch (const std::out_of_range& e) {
                throw std::runtime_error("Número inválido (rango): '" + num_str + "' - " + e.what());
            }
            continue;
        }

        // --- B. Parsear Funciones Unarias y Constantes ---
        std::unordered_map<std::string, char> func_map = {
            {"sin", 's'}, {"cos", 'c'}, {"tan", 't'}, 
            {"asin", 'S'}, {"acos", 'C'}, {"atan", 'T'},
            {"log", 'l'}, {"exp", 'e'}, {"sqrt", 'q'},
            {"floor", '_'}, {"ceil", 'u'}, {"abs", 'a'}, {"sign", 'n'},
            {"gamma", '!'}, {"lgamma", 'g'}, {"g", 'g'}
        };

        // Special handling for Constants (pi, e, C)
        if (token == 'C') {
             if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
             }
             auto node = std::make_shared<Node>(NodeType::Constant); 
             // Default constant value (will be optimized later)
             node->value = 1.0; 
             // Mark it specifically as an optimizable constant in a way that clone/optimize respects?
             // Actually, for C++ GP, usually constants are just numbers. 
             // But if we want to preserve 'C' semantics:
             // Let's treat 'C' as a special Variable? No, Variable is x.
             // Let's just treat it as 1.0 for now, or use a special Op 'C'?
             // The system typically optimizes *numeric* constants attached to nodes.
             // If we parse 'C', we should probably parse it as a random constant?
             // Or better, a Constant node with a placeholder value.
             // Re-reading ExpressionTree.h might help, but let's stick to 1.0 for now.
             operand_stack.push(node);
             last_token_was_operand = true;
             i++;
             continue;
        }
        if (i + 1 < formula.length() && formula.substr(i, 2) == "pi") {
             if (last_token_was_operand) {
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
             }
             auto node = std::make_shared<Node>(NodeType::Constant); node->value = 3.14159265359;
             operand_stack.push(node);
             last_token_was_operand = true;
             i += 2;
             continue;
        }
        if (token == 'e' && (i+1 >= formula.length() || formula[i+1] != 'x')) { // Check it's not 'exp'
             // Handle 'e' constant
             if (last_token_was_operand) {
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
             }
             auto node = std::make_shared<Node>(NodeType::Constant); node->value = 2.71828182846;
             operand_stack.push(node);
             last_token_was_operand = true;
             i++;
             continue;
        }
        
        // Try to match function names (check longer names first)
        bool matched_func = false;
        for (const auto& [func_name, func_op] : func_map) {
            if (i + func_name.length() <= formula.length() && 
                formula.substr(i, func_name.length()) == func_name &&
                (i + func_name.length() >= formula.length() || formula[i + func_name.length()] == '(')) {
                
                // Check if this is actually a function call (followed by '(')
                size_t after_name = i + func_name.length();
                if (after_name < formula.length() && formula[after_name] == '(') {
                    if (last_token_was_operand) { // Implicit multiplication
                        process_operators_by_precedence(get_precedence('*'));
                        operator_stack.push('*');
                        last_token_was_operand = false;
                    }
                    
                    // Find the matching closing parenthesis
                    int paren_count = 1;
                    size_t arg_start = after_name + 1;
                    size_t j = arg_start;
                    while (j < formula.length() && paren_count > 0) {
                        if (formula[j] == '(') paren_count++;
                        else if (formula[j] == ')') paren_count--;
                        j++;
                    }
                    if (paren_count != 0) {
                        throw std::runtime_error("Paréntesis sin cerrar en función '" + func_name + "'.");
                    }
                    size_t arg_end = j - 1; // Position of closing ')'
                    
                    // Extract and recursively parse the argument
                    std::string arg_str = formula.substr(arg_start, arg_end - arg_start);
                    NodePtr arg_tree = parse_formula_string(arg_str);
                    
                    // Create unary operator node
                    auto func_node = std::make_shared<Node>(NodeType::Operator);
                    func_node->op = func_op;
                    func_node->left = arg_tree;
                    func_node->right = nullptr;
                    
                    operand_stack.push(func_node);
                    last_token_was_operand = true;
                    i = j; // Skip past the closing ')'
                    matched_func = true;
                    break;
                }
            }
        }
        if (matched_func) continue;

        // --- C. Parsear Variable 'x' ---
        if (token == 'x') {
            if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
                 last_token_was_operand = false;
            }
            auto node = std::make_shared<Node>(NodeType::Variable);
            
            // Check for digits after 'x' (for x0, x1, x10...)
            i++;
            if (i < formula.length() && isdigit(formula[i])) {
                 std::string idx_str;
                 while(i < formula.length() && isdigit(formula[i])) {
                     idx_str += formula[i];
                     i++;
                 }
                 try {
                     node->var_index = std::stoi(idx_str);
                 } catch (...) {
                     node->var_index = 0; // Fallback
                 }
            } else {
                 node->var_index = 0; // Default x -> x0
            }

            operand_stack.push(node);
            last_token_was_operand = true;
            continue;
        }

        // --- D. Parsear Paréntesis de Apertura '(' ---
        if (token == '(') {
            if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
                 last_token_was_operand = false;
            }
            operator_stack.push('(');
            last_token_was_operand = false;
            i++;
            continue;
        }

        // --- E. Parsear Paréntesis de Cierre ')' ---
        if (token == ')') {
             if (!last_token_was_operand) {
                  if (!operator_stack.empty() && operator_stack.top() == '(') throw std::runtime_error("Paréntesis vacíos '()' encontrados.");
                  else throw std::runtime_error("Se esperaba un operando antes de ')'.");
             }
            while (!operator_stack.empty() && operator_stack.top() != '(') {
                process_operators_by_precedence(0);
            }
            if (operator_stack.empty()) throw std::runtime_error("Paréntesis ')' sin correspondiente '('.");
            operator_stack.pop(); // Sacar '('
            last_token_was_operand = true;
            i++;
            continue;
        }

        // --- F. Parsear Operadores (+ - * / ^ %) ---
        if (std::string("+-*/^%").find(token) != std::string::npos) {
            // Manejar '-' unario vs binario
            if (token == '-' && !last_token_was_operand) {
                // Es un '-' unario. Insertar un 0 como operando izquierdo implícito.
                // Esto permite tratar el '-' como un operador binario normal.
                auto zero_node = std::make_shared<Node>(NodeType::Constant); zero_node->value = 0.0;
                operand_stack.push(zero_node);
                // No cambiar last_token_was_operand a true, ya que el 0 implícito
                // es solo para el operador unario y no un operando "real" previo.
                // Si hubiera una multiplicación implícita (ej. "2-x"), ya se habría manejado.
            }
            // Ignorar '+' unario (no afecta el valor, no necesita un 0 implícito)
            else if (token == '+' && !last_token_was_operand) {
                // No hacer nada, simplemente avanzar al siguiente token
                i++;
                continue;
            }
            
            // Operador binario normal
            if (!last_token_was_operand && (token == '*' || token == '/' || token == '^' || token == '%')) {
                throw std::runtime_error("Operador binario '" + std::string(1, token) + "' inesperado. Se esperaba operando.");
            }

            // Procesar operadores en la pila con mayor o igual precedencia (o solo mayor para asociativos a derecha)
            process_operators_by_precedence(get_precedence(token), token);
            operator_stack.push(token);
            last_token_was_operand = false; // Después de un operador, se espera un operando
            i++;
            continue;
        }

        // --- G. Token Desconocido ---
        throw std::runtime_error("Token desconocido en la fórmula: '" + std::string(1, token) + "'");

    } // Fin del bucle for

    // --- H. Procesamiento Final después del bucle ---
    while (!operator_stack.empty()) {
        if (operator_stack.top() == '(') throw std::runtime_error("Paréntesis '(' sin cerrar al final.");
        // Procesar todos los operadores restantes en la pila
        process_operators_by_precedence(0); // 0 como precedencia mínima para forzar el procesamiento
    }

    // Verificación final de la pila de operandos
    if (operand_stack.size() != 1) {
         if (operand_stack.empty() && formula.length() > 0) throw std::runtime_error("Error: No se generó ningún resultado del parseo. Fórmula inválida?");
         else if (operand_stack.size() > 1) throw std::runtime_error("Error en la estructura final (operandos restantes: " + std::to_string(operand_stack.size()) + "). Verifique operadores.");
         else throw std::runtime_error("Error desconocido al finalizar el parseo.");
    }

    return operand_stack.top();
}


In [None]:
%%writefile Code/src/ExpressionTree.h
#ifndef EXPRESSIONTREE_H
#define EXPRESSIONTREE_H

#include <memory>
#include <string>
#include <vector>
#include <stdexcept> // Para std::runtime_error

// Forward declaration
struct Node;

// Use shared_ptr for automatic memory management
using NodePtr = std::shared_ptr<Node>;

enum class NodeType { Constant, Variable, Operator };

struct Node {
    NodeType type;
    double value = 0.0;             // If type == Constant
    int var_index = 0;              // If type == Variable: index of the variable (0 for x0, 1 for x1...)
    char op = 0;                    // If type == Operator: '+', '-', '*', '/', '^'
    NodePtr left = nullptr;         // Children (for Operators)
    NodePtr right = nullptr;

    // Constructor for convenience
    Node(NodeType t = NodeType::Constant) : type(t) {}
};

// Core Tree Functions
// MODIFIED: Takes a vector of variables instead of a single double
double evaluate_tree(const NodePtr& node, const std::vector<double>& vars);

// Convenience overload for single variable case
double evaluate_tree(const NodePtr& node, double val);
std::string tree_to_string(const NodePtr& node);
int tree_size(const NodePtr& node);
NodePtr clone_tree(const NodePtr& node);
int get_tree_depth(const NodePtr& node);
void trim_tree(NodePtr& node, int max_depth);

// Helper for mutation/crossover
void collect_node_ptrs(NodePtr& node, std::vector<NodePtr*>& vec);

// --- NUEVO: Función para parsear una fórmula desde string ---
// Parsea una fórmula en notación infija simple (con paréntesis).
// Lanza std::runtime_error si hay error de sintaxis.
NodePtr parse_formula_string(const std::string& formula);


#endif // EXPRESSIONTREE_H


In [None]:
%%writefile Code/src/Fitness.cpp
#include "Fitness.h"
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
#include "FitnessGPU.cuh" // Include for GPU fitness evaluation
#endif
#include "Globals.h" // Necesario para constantes globales e INF
#include "ExpressionTree.h" // Necesario para tree_to_string
#include <cmath>
#include <limits>
#include <vector>
#include <numeric>
#include <iostream> // Para std::cerr en caso de error futuro
#include <iomanip>  // Para std::fixed/scientific si se necesita en errores

// Calculates the raw fitness using global parameters.
// This function will now dispatch to GPU if USE_GPU_ACCELERATION is enabled.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<std::vector<double>>& x_values,
                             double* d_targets, double* d_x_values) {
    // If GPU pointers are null (FORCE_CPU_MODE), use CPU evaluation
    if (d_targets == nullptr || d_x_values == nullptr) {
        // CPU fallback implementation
        if (x_values.size() != targets.size() || x_values.empty()) return INF;

        double sum_sq_error = 0.0;
        double total_weight = 0.0;
        bool all_precise = true;
        size_t num_points = x_values.size();
        bool calculation_failed = false;

        for (size_t i = 0; i < num_points; ++i) {
            double predicted_val = evaluate_tree(tree, x_values[i]);

            if (std::isnan(predicted_val) || std::isinf(predicted_val)) {
                calculation_failed = true;
                break;
            }

            double target_val = targets[i];
            double diff = predicted_val - target_val;
            double abs_diff = std::fabs(diff);

            if (abs_diff >= FITNESS_PRECISION_THRESHOLD) all_precise = false;

            double weight = 1.0;
            if (USE_WEIGHTED_FITNESS) {
                weight = std::exp(static_cast<double>(i) * WEIGHTED_FITNESS_EXPONENT);
            }
            total_weight += weight;

            double sq_error = diff * diff;
            sum_sq_error += sq_error * weight;
        }

        if (calculation_failed) return INF;

        // Normalize weighted error
        double raw_error;
        if (USE_WEIGHTED_FITNESS && total_weight > 0.0) {
            sum_sq_error = sum_sq_error / total_weight * num_points;
        }

        if (USE_RMSE_FITNESS && num_points > 0) {
            double mse = sum_sq_error / static_cast<double>(num_points);
            raw_error = std::sqrt(mse);
        } else {
            raw_error = sum_sq_error;
        }

        if (std::isnan(raw_error) || std::isinf(raw_error) || raw_error < 0) {
            return INF;
        }

        if (all_precise) {
            raw_error *= FITNESS_PRECISION_BONUS;
        }

        return raw_error;
    }
    
    // Use GPU evaluation
    return evaluate_fitness_gpu(tree, targets, x_values, d_targets, d_x_values);
}
#else
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<std::vector<double>>& x_values) {
    if (x_values.size() != targets.size() || x_values.empty()) return INF;

    double error_sum_pow13 = 0.0; // Solo si USE_RMSE_FITNESS = false
    double sum_sq_error = 0.0;
    double total_weight = 0.0; // Para normalizar el fitness ponderado
    bool all_precise = true;
    size_t num_points = x_values.size();
    bool calculation_failed = false; // Flag para detectar INF/NaN

    for (size_t i = 0; i < num_points; ++i) {
        double predicted_val = evaluate_tree(tree, x_values[i]);

        // Comprobar si la evaluación falló (INF o NaN)
        if (std::isnan(predicted_val) || std::isinf(predicted_val)) {
            calculation_failed = true;
            break; // Salir del bucle si la evaluación falla para un punto
        }

        double target_val = targets[i];
        double diff = predicted_val - target_val;
        double abs_diff = std::fabs(diff);

        if (abs_diff >= FITNESS_PRECISION_THRESHOLD) all_precise = false;

        // --- PESO PARA FITNESS PONDERADO ---
        // Hace que los últimos puntos (N altos) valgan muchísimo más.
        // Esto destruye a los polinomios porque fallan al final.
        double weight = 1.0;
        if (USE_WEIGHTED_FITNESS) {
            // Peso exponencial: más agresivo para penalizar errores en N altos
            weight = std::exp(static_cast<double>(i) * WEIGHTED_FITNESS_EXPONENT);
        }
        total_weight += weight;

        // Acumular error para ambas métricas (si aplica)
        if (!USE_RMSE_FITNESS) {
             error_sum_pow13 += std::pow(abs_diff, FITNESS_ORIGINAL_POWER) * weight;
        }

        // Calcular y acumular error cuadrático PONDERADO
        double sq_diff = diff * diff;
        sum_sq_error += sq_diff * weight;

        // Control de desbordamiento/Infinito en la suma
        if (std::isinf(sum_sq_error) || (error_sum_pow13 >= INF / 10.0 && !USE_RMSE_FITNESS)) {
            calculation_failed = true;
            break;
        }
    } // Fin bucle for puntos

    // Si la evaluación o suma falló en algún punto, devolver INF
    if (calculation_failed) {
        return INF;
    }

    // Seleccionar métrica de error crudo
    double raw_error;
    if (USE_RMSE_FITNESS) {
        if (num_points == 0 || total_weight == 0.0) return INF;
        // MSE ponderado: normalizar por suma de pesos, no por num_points
        double mse = sum_sq_error / total_weight;
        if (std::isinf(mse) || std::isnan(mse) || mse < 0) {
             raw_error = INF;
        } else {
             raw_error = std::sqrt(mse); // Calcular RMSE ponderado
        }
    } else {
        raw_error = error_sum_pow13;
    }

    // Comprobar si el error crudo es inválido
    if (std::isnan(raw_error) || std::isinf(raw_error) || raw_error < 0) {
         return INF;
    }

    // Aplicar bonus de precisión si todos los puntos estaban dentro del umbral
    if (all_precise) {
         raw_error *= FITNESS_PRECISION_BONUS;
    }

    return raw_error; // Devolver el error crudo (sin penalización por complejidad aún)
}
#endif // USE_GPU_ACCELERATION_DEFINED_BY_CMAKE

// Calcula el fitness final usando parámetros globales.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<std::vector<double>>& x_values,
                        double* d_targets, double* d_x_values) {
    double raw_fitness = calculate_raw_fitness(tree, targets, x_values, d_targets, d_x_values);
#else
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<std::vector<double>>& x_values) {
    double raw_fitness = calculate_raw_fitness(tree, targets, x_values);
#endif

    if (raw_fitness >= INF / 10.0) {
         return INF; // Si el error crudo es infinito, el fitness final es infinito
    }

    // Penalización por complejidad
    double complexity = static_cast<double>(tree_size(tree));
    double penalty = complexity * COMPLEXITY_PENALTY_FACTOR; // Usa constante global

    // Aplicar penalización multiplicativa
    double final_fitness = raw_fitness * (1.0 + penalty);

    // Comprobaciones finales
    if (std::isnan(final_fitness) || std::isinf(final_fitness) || final_fitness < 0) {
         return INF;
    }

    return final_fitness;
}


In [None]:
%%writefile Code/src/Fitness.h
#ifndef FITNESS_H
#define FITNESS_H

#include "ExpressionTree.h"
#include <vector>

// Calculates raw fitness based on target matching
// Lower is better. Returns INF if evaluation results in NaN/Inf.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<std::vector<double>>& x_values,
                             double* d_targets, double* d_x_values);
#else
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<std::vector<double>>& x_values);
#endif

// Calculates final fitness including complexity penalty
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<std::vector<double>>& x_values,
                        double* d_targets, double* d_x_values);
#else
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<std::vector<double>>& x_values);
#endif

#endif // FITNESS_H


In [None]:
%%writefile Code/src/FitnessGPU.cu
#include "FitnessGPU.cuh"
#include "Globals.h"
#include <cuda_runtime.h>
#include <math.h>
#include <vector>
#include <iostream>

// Helper function to linearize the tree into a post-order array
void linearize_tree(const NodePtr& node, std::vector<LinearGpuNode>& linear_tree) {
    if (!node) {
        return;
    }
    linearize_tree(node->left, linear_tree);
    linearize_tree(node->right, linear_tree);
    // Include var_index in the struct initialization
    linear_tree.push_back({node->type, node->value, node->var_index, node->op});
}

#if USE_GPU_ACCELERATION_DEFINED_BY_CMAKE

// Constant for large finite value
#define GPU_MAX_DOUBLE 1e308

// --- WEIGHTED FITNESS: Constantes para CUDA ---
// Estas deben coincidir con los valores en Globals.h
// CUDA device code no puede acceder a const C++, así que usamos #define
#define GPU_USE_WEIGHTED_FITNESS true
#define GPU_WEIGHTED_FITNESS_EXPONENT 0.25

// Single Tree Evaluation Kernel (Updated for Multivariable)
__global__ void calculate_raw_fitness_kernel(const LinearGpuNode* d_linear_tree,
                                             int tree_size,
                                             const double* d_targets,
                                             const double* d_x_values, // Flattened [num_points * num_vars]
                                             size_t num_points,
                                             int num_vars,
                                             double* d_raw_fitness_results) {
    // Shared memory optimization: Load tree into shared memory
    extern __shared__ LinearGpuNode s_linear_tree[];

    // Cooperative load
    for (int i = threadIdx.x; i < tree_size; i += blockDim.x) {
        s_linear_tree[i] = d_linear_tree[i];
    }
    __syncthreads();

    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < num_points) {
        double stack[64]; // Max tree depth
        int stack_top = -1;

        for (int i = 0; i < tree_size; ++i) {
            LinearGpuNode node = s_linear_tree[i]; // Access from shared memory
            if (node.type == NodeType::Constant) {
                stack[++stack_top] = node.value;
            } else if (node.type == NodeType::Variable) {
                // Access correct variable for this sample
                int var_idx = node.var_index;
                if (var_idx >= num_vars) var_idx = 0; // Safety fallback
                stack[++stack_top] = d_x_values[idx * num_vars + var_idx];
            } else if (node.type == NodeType::Operator) {
                bool is_unary = (node.op == 's' || node.op == 'c' || node.op == 'l' || node.op == 'e' || node.op == '!' || node.op == '_' || node.op == 'g');
                double result = 0.0;
                
                if (is_unary) {
                     if (stack_top < 0) {
                         result = GPU_MAX_DOUBLE;
                     } else {
                         double val = stack[stack_top--];
                         switch (node.op) {
                            case 's': result = sin(val); break;
                            case 'c': result = cos(val); break;
                            case 'l': result = (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val); break;
                            case 'e': result = (val > 700.0) ? GPU_MAX_DOUBLE : exp(val); break;
                            case '!': result = (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0); break;
                            case '_': result = floor(val); break;
                            case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); break;
                            default: result = NAN; break;
                         }
                     }
                     stack[++stack_top] = result;
                } else {
                    if (stack_top < 1) { 
                        result = GPU_MAX_DOUBLE;
                        stack[++stack_top] = result; // Push error
                    } else {
                        double right = stack[stack_top--];
                        double left = stack[stack_top--];
                        switch (node.op) {
                            case '+': result = left + right; break;
                            case '-': result = left - right; break;
                            case '*': result = left * right; break;
                            case '/':
                                if (fabs(right) < 1e-9) { // Avoid division by zero
                                    result = GPU_MAX_DOUBLE; 
                                } else {
                                    result = left / right;
                                }
                                break;
                            case '^': result = pow(left, right); break;
                            case '%':
                                if (fabs(right) < 1e-9) result = GPU_MAX_DOUBLE;
                                else result = fmod(left, right);
                                break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    }
                }
            }
        }

        double predicted_val = (stack_top == 0) ? stack[0] : NAN;

        if (isnan(predicted_val) || isinf(predicted_val)) {
            d_raw_fitness_results[idx] = GPU_MAX_DOUBLE; 
        } else {
            double diff = predicted_val - d_targets[idx];
            double sq_error = diff * diff;
            // --- WEIGHTED FITNESS: Apply exponential weight ---
            if (GPU_USE_WEIGHTED_FITNESS) {
                double weight = exp((double)idx * GPU_WEIGHTED_FITNESS_EXPONENT);
                sq_error *= weight;
            }
            d_raw_fitness_results[idx] = sq_error;
        }
    }
}

// CUDA kernel for parallel reduction (summation)
__global__ void reduce_sum_kernel(double* d_data, int N) {
    extern __shared__ double sdata[]; // Shared memory for reduction

    unsigned int tid = threadIdx.x;
    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

    sdata[tid] = (i < N) ? d_data[i] : 0.0; // Load data into shared memory

    __syncthreads(); // Synchronize threads in block

    // Perform reduction in shared memory
    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) { // Write result back to global memory (first element of block)
        d_data[blockIdx.x] = sdata[0];
    }
}


// --- New Batch Kernel (Updated for Multivariable) ---
__global__ void evaluate_population_kernel(const LinearGpuNode* d_all_nodes,
                                           const int* d_offsets,
                                           const int* d_sizes,
                                           int pop_size,
                                           const double* d_targets,
                                           const double* d_x_values,
                                           int num_points,
                                           int num_vars,
                                           double* d_results) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < pop_size) {
        int offset = d_offsets[idx];
        int size = d_sizes[idx];
        double sum_sq_error = 0.0;
        double total_weight = 0.0; // Para normalizar fitness ponderado
        bool valid = true;

        for (int p = 0; p < num_points; ++p) {
            double stack[64]; 
            int stack_top = -1;

            // Simple interpreter
            for (int i = 0; i < size; ++i) {
                LinearGpuNode node = d_all_nodes[offset + i];
                if (node.type == NodeType::Constant) {
                    stack[++stack_top] = node.value;
                } else if (node.type == NodeType::Variable) {
                    int var_idx = node.var_index;
                    if (var_idx >= num_vars) var_idx = 0;
                    stack[++stack_top] = d_x_values[p * num_vars + var_idx];
                } else if (node.type == NodeType::Operator) {
                    bool is_unary = (node.op == 's' || node.op == 'c' || node.op == 'l' || node.op == 'e' || node.op == '!' || node.op == '_' || node.op == 'g');
                    
                    if (is_unary) {
                        if (stack_top < 0) { valid = false; break; }
                        double val = stack[stack_top--];
                        double result = 0.0;
                         switch (node.op) {
                            case 's': result = sin(val); break;
                            case 'c': result = cos(val); break;
                            case 'l': result = (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val); break;
                            case 'e': result = (val > 700.0) ? GPU_MAX_DOUBLE : exp(val); break;
                            case '!': result = (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0); break;
                            case '_': result = floor(val); break;
                            case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); break;
                             default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    } else {
                        // Safety check index
                        if (stack_top < 1) { valid = false; break; }

                        double right = stack[stack_top--];
                        double left = stack[stack_top--];
                        double result;
                        switch (node.op) {
                            case '+': result = left + right; break;
                            case '-': result = left - right; break;
                            case '*': result = left * right; break;
                            case '/':
                                if (fabs(right) < 1e-9) { 
                                    result = GPU_MAX_DOUBLE; 
                                } else {
                                    result = left / right;
                                }
                                break;
                            case '^': result = pow(left, right); break;
                            case '%':
                                if (fabs(right) < 1e-9) result = GPU_MAX_DOUBLE;
                                else result = fmod(left, right);
                                break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    }
                }
            }

            if (!valid || stack_top != 0) {
                sum_sq_error = GPU_MAX_DOUBLE;
                break;
            }

            double predicted_val = stack[0];
            if (isnan(predicted_val) || isinf(predicted_val)) {
                sum_sq_error = GPU_MAX_DOUBLE;
                break;
            }

            double diff = predicted_val - d_targets[p];
            double sq_error = diff * diff;
            
            // --- WEIGHTED FITNESS: Peso exponencial ---
            double weight = 1.0;
            if (GPU_USE_WEIGHTED_FITNESS) {
                weight = exp((double)p * GPU_WEIGHTED_FITNESS_EXPONENT);
            }
            total_weight += weight;
            sum_sq_error += sq_error * weight;
        }

        // Normalizar por suma de pesos para obtener MSE ponderado
        if (GPU_USE_WEIGHTED_FITNESS && total_weight > 0.0) {
            sum_sq_error = sum_sq_error / total_weight * num_points; // Escalar de vuelta
        }
        d_results[idx] = sum_sq_error;
    }
}


// Host-side wrapper function to launch the CUDA kernel
double evaluate_fitness_gpu(NodePtr tree,
                            const std::vector<double>& targets,
                            const std::vector<std::vector<double>>& x_values,
                            double* d_targets, double* d_x_values) {
    if (x_values.size() != targets.size() || x_values.empty()) return INF;

    // Linearize the tree
    std::vector<LinearGpuNode> h_linear_tree;
    linearize_tree(tree, h_linear_tree);
    int tree_size = h_linear_tree.size();

    if (tree_size == 0) {
        return INF;
    }

    size_t num_points = x_values.size();
    int num_vars = (num_points > 0) ? x_values[0].size() : 0;
    
    LinearGpuNode* d_linear_tree;
    double* d_raw_fitness_results; // This will hold individual errors and then the final sum

    cudaMalloc((void**)&d_linear_tree, tree_size * sizeof(LinearGpuNode));
    cudaMalloc((void**)&d_raw_fitness_results, num_points * sizeof(double));

    cudaMemcpy(d_linear_tree, h_linear_tree.data(), tree_size * sizeof(LinearGpuNode), cudaMemcpyHostToDevice);

    int threadsPerBlock = 256;
    int blocksPerGrid = (num_points + threadsPerBlock - 1) / threadsPerBlock;

    // Launch kernel to calculate individual squared errors
    size_t shared_mem_size = tree_size * sizeof(LinearGpuNode);
    calculate_raw_fitness_kernel<<<blocksPerGrid, threadsPerBlock, shared_mem_size>>>(
        d_linear_tree, tree_size, d_targets, d_x_values, num_points, num_vars, d_raw_fitness_results
    );
    cudaDeviceSynchronize(); // Ensure kernel completes before reduction

    // --- Perform reduction on the GPU ---
    int current_size = num_points;
    while (current_size > 1) {
        int next_blocks_per_grid = (current_size + threadsPerBlock - 1) / threadsPerBlock;
        // Use shared memory for reduction, size is threadsPerBlock * sizeof(double)
        reduce_sum_kernel<<<next_blocks_per_grid, threadsPerBlock, threadsPerBlock * sizeof(double)>>>(
            d_raw_fitness_results, current_size
        );
        cudaDeviceSynchronize(); // Ensure reduction step completes
        current_size = next_blocks_per_grid; // The result is in the first `next_blocks_per_grid` elements
    }

    double sum_sq_error_gpu = 0.0;
    cudaMemcpy(&sum_sq_error_gpu, d_raw_fitness_results, sizeof(double), cudaMemcpyDeviceToHost);

    cudaFree(d_linear_tree);
    cudaFree(d_raw_fitness_results);

    // Check for invalid results (propagated from kernel)
    if (isinf(sum_sq_error_gpu) || isnan(sum_sq_error_gpu)) {
        return INF;
    }

    double raw_fitness;
    if (USE_RMSE_FITNESS) {
        if (num_points == 0) return INF;
        double mse = sum_sq_error_gpu / num_points;
        raw_fitness = sqrt(mse);
    } else {
        raw_fitness = sum_sq_error_gpu;
    }

    double complexity = static_cast<double>(::tree_size(tree));
    double penalty = complexity * COMPLEXITY_PENALTY_FACTOR;
    double final_fitness = raw_fitness * (1.0 + penalty);

    if (isnan(final_fitness) || isinf(final_fitness) || final_fitness < 0) {
        return INF;
    }

    return final_fitness;
}

void evaluate_population_gpu(const std::vector<LinearGpuNode>& all_nodes,
                             const std::vector<int>& tree_offsets,
                             const std::vector<int>& tree_sizes,
                             const std::vector<double>& targets,
                             const std::vector<std::vector<double>>& x_values,
                             std::vector<double>& results,
                             double* d_targets, double* d_x_values,
                             void*& d_nodes_ptr, size_t& d_nodes_cap,
                             void*& d_offsets_ptr, void*& d_sizes_ptr, void*& d_results_ptr, size_t& d_pop_cap) {
    
    int pop_size = tree_offsets.size();
    if (pop_size == 0) return;

    size_t total_nodes = all_nodes.size();
    int num_points = x_values.size();
    int num_vars = (num_points > 0) ? x_values[0].size() : 0;

    // Buffer Management for Nodes
    if (total_nodes > d_nodes_cap) {
        if (d_nodes_ptr) cudaFree(d_nodes_ptr);
        size_t new_cap = total_nodes * 1.5; // Growth factor
        cudaMalloc(&d_nodes_ptr, new_cap * sizeof(LinearGpuNode));
        d_nodes_cap = new_cap;
    }

    // Buffer Management for Population Arrays
    if (pop_size > d_pop_cap) {
        if (d_offsets_ptr) cudaFree(d_offsets_ptr);
        if (d_sizes_ptr) cudaFree(d_sizes_ptr);
        if (d_results_ptr) cudaFree(d_results_ptr);
        
        size_t new_cap = pop_size * 1.5;
        cudaMalloc(&d_offsets_ptr, new_cap * sizeof(int));
        cudaMalloc(&d_sizes_ptr, new_cap * sizeof(int));
        cudaMalloc(&d_results_ptr, new_cap * sizeof(double));
        d_pop_cap = new_cap;
    }

    LinearGpuNode* d_all_nodes = (LinearGpuNode*)d_nodes_ptr;
    int* d_offsets = (int*)d_offsets_ptr;
    int* d_sizes = (int*)d_sizes_ptr;
    double* d_results = (double*)d_results_ptr;

    cudaMemcpy(d_all_nodes, all_nodes.data(), total_nodes * sizeof(LinearGpuNode), cudaMemcpyHostToDevice);
    cudaMemcpy(d_offsets, tree_offsets.data(), pop_size * sizeof(int), cudaMemcpyHostToDevice);
    cudaMemcpy(d_sizes, tree_sizes.data(), pop_size * sizeof(int), cudaMemcpyHostToDevice);

    int threadsPerBlock = 256;
    int blocksPerGrid = (pop_size + threadsPerBlock - 1) / threadsPerBlock;

    evaluate_population_kernel<<<blocksPerGrid, threadsPerBlock>>>(
        d_all_nodes, d_offsets, d_sizes, pop_size, d_targets, d_x_values, num_points, num_vars, d_results
    );

    // Synchronize and copy back
    cudaDeviceSynchronize();
    
    cudaMemcpy(results.data(), d_results, pop_size * sizeof(double), cudaMemcpyDeviceToHost);
}

// ============================================================
// GLOBAL BATCH EVALUATION - Maximum GPU Utilization
// ============================================================

void init_global_gpu_buffers(GlobalGpuBuffers& buffers) {
    // Create CUDA stream for async operations
    cudaStream_t stream;
    cudaStreamCreate(&stream);
    buffers.cuda_stream = (void*)stream;
    
    // Pre-allocate initial buffers (will grow as needed)
    // Initial capacity for 50,000 trees with ~30 nodes each
    buffers.d_nodes_capacity = 1500000;
    buffers.d_pop_capacity = 60000;
    
    cudaMalloc(&buffers.d_nodes, buffers.d_nodes_capacity * sizeof(LinearGpuNode));
    cudaMalloc(&buffers.d_offsets, buffers.d_pop_capacity * sizeof(int));
    cudaMalloc(&buffers.d_sizes, buffers.d_pop_capacity * sizeof(int));
    cudaMalloc(&buffers.d_results, buffers.d_pop_capacity * sizeof(double));
}

void cleanup_global_gpu_buffers(GlobalGpuBuffers& buffers) {
    if (buffers.cuda_stream) {
        cudaStreamDestroy((cudaStream_t)buffers.cuda_stream);
        buffers.cuda_stream = nullptr;
    }
    if (buffers.d_nodes) { cudaFree(buffers.d_nodes); buffers.d_nodes = nullptr; }
    if (buffers.d_offsets) { cudaFree(buffers.d_offsets); buffers.d_offsets = nullptr; }
    if (buffers.d_sizes) { cudaFree(buffers.d_sizes); buffers.d_sizes = nullptr; }
    if (buffers.d_results) { cudaFree(buffers.d_results); buffers.d_results = nullptr; }
    buffers.d_nodes_capacity = 0;
    buffers.d_pop_capacity = 0;
}

// Optimized kernel: Process one tree per thread
// REMOVED SHARED MEMORY FOR DATA to support arbitrary dataset sizes and multivariable
__global__ void evaluate_all_populations_kernel(
    const LinearGpuNode* __restrict__ d_all_nodes,
    const int* __restrict__ d_offsets,
    const int* __restrict__ d_sizes,
    int total_trees,
    const double* __restrict__ d_targets,
    const double* __restrict__ d_x_values,
    int num_points,
    int num_vars,
    double* __restrict__ d_results,
    double complexity_penalty_factor,
    bool use_rmse) 
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < total_trees) {
        int offset = d_offsets[idx];
        int size = d_sizes[idx];
        double sum_sq_error = 0.0;
        double total_weight = 0.0;
        bool valid = true;

        // Evaluate tree on all data points
        for (int p = 0; p < num_points && valid; ++p) {
            double stack[64]; 
            int stack_top = -1;

            // Interpret the linearized tree
            for (int i = 0; i < size && valid; ++i) {
                LinearGpuNode node = d_all_nodes[offset + i];
                
                if (node.type == NodeType::Constant) {
                    stack[++stack_top] = node.value;
                } else if (node.type == NodeType::Variable) {
                    // Access global memory directly
                    int var_idx = node.var_index;
                    if (var_idx >= num_vars) var_idx = 0;
                    stack[++stack_top] = d_x_values[p * num_vars + var_idx];
                } else if (node.type == NodeType::Operator) {
                    bool is_unary = (node.op == 's' || node.op == 'c' || node.op == 'l' || 
                                    node.op == 'e' || node.op == '!' || node.op == '_' || node.op == 'g');
                    
                    if (is_unary) {
                        if (stack_top < 0) { valid = false; break; }
                        double val = stack[stack_top--];
                        double result = 0.0;
                        switch (node.op) {
                            case 's': result = sin(val); break;
                            case 'c': result = cos(val); break;
                            case 'l': result = (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val); break;
                            case 'e': result = (val > 700.0) ? GPU_MAX_DOUBLE : exp(val); break;
                            case '!': result = (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0); break;
                            case '_': result = floor(val); break;
                            case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    } else {
                        if (stack_top < 1) { valid = false; break; }
                        double right = stack[stack_top--];
                        double left = stack[stack_top--];
                        double result;
                        switch (node.op) {
                            case '+': result = left + right; break;
                            case '-': result = left - right; break;
                            case '*': result = left * right; break;
                            case '/': result = (fabs(right) < 1e-9) ? GPU_MAX_DOUBLE : left / right; break;
                            case '^': result = pow(left, right); break;
                            case '%': result = (fabs(right) < 1e-9) ? GPU_MAX_DOUBLE : fmod(left, right); break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    }
                }
            }

            if (!valid || stack_top != 0) {
                sum_sq_error = GPU_MAX_DOUBLE;
                valid = false;
                break;
            }

            double predicted_val = stack[0];
            if (isnan(predicted_val) || isinf(predicted_val)) {
                sum_sq_error = GPU_MAX_DOUBLE;
                valid = false;
                break;
            }

            double diff = predicted_val - d_targets[p];
            double sq_error = diff * diff;
            
            // Weighted fitness
            double weight = 1.0;
            if (GPU_USE_WEIGHTED_FITNESS) {
                weight = exp((double)p * GPU_WEIGHTED_FITNESS_EXPONENT);
            }
            total_weight += weight;
            sum_sq_error += sq_error * weight;
        }

        // Calculate final fitness with complexity penalty ON GPU
        double raw_fitness = GPU_MAX_DOUBLE;
        if (valid && sum_sq_error < 1e300) {
            if (GPU_USE_WEIGHTED_FITNESS && total_weight > 0.0) {
                sum_sq_error = sum_sq_error / total_weight * num_points;
            }
            
            if (use_rmse && num_points > 0) {
                double mse = sum_sq_error / num_points;
                raw_fitness = sqrt(mse);
            } else {
                raw_fitness = sum_sq_error;
            }
            
            // Apply complexity penalty (size is same as tree size in linearized form)
            double complexity = (double)size;
            double penalty = complexity * complexity_penalty_factor;
            raw_fitness = raw_fitness * (1.0 + penalty);
        }
        
        d_results[idx] = raw_fitness;
    }
}

void evaluate_all_populations_gpu(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    const std::vector<int>& tree_complexities,
    int total_trees,
    const std::vector<double>& targets,
    const std::vector<std::vector<double>>& x_values,
    std::vector<double>& results,
    double* d_targets, double* d_x_values,
    GlobalGpuBuffers& buffers)
{
    if (total_trees == 0) return;
    
    cudaStream_t stream = (cudaStream_t)buffers.cuda_stream;
    size_t total_nodes = all_nodes.size();
    int num_points = x_values.size();
    int num_vars = (num_points > 0) ? x_values[0].size() : 0;

    // Dynamic buffer resizing with growth factor
    if (total_nodes > buffers.d_nodes_capacity) {
        if (buffers.d_nodes) cudaFree(buffers.d_nodes);
        size_t new_cap = total_nodes * 1.5;
        cudaMalloc(&buffers.d_nodes, new_cap * sizeof(LinearGpuNode));
        buffers.d_nodes_capacity = new_cap;
    }

    if ((size_t)total_trees > buffers.d_pop_capacity) {
        if (buffers.d_offsets) cudaFree(buffers.d_offsets);
        if (buffers.d_sizes) cudaFree(buffers.d_sizes);
        if (buffers.d_results) cudaFree(buffers.d_results);
        
        size_t new_cap = total_trees * 1.5;
        cudaMalloc(&buffers.d_offsets, new_cap * sizeof(int));
        cudaMalloc(&buffers.d_sizes, new_cap * sizeof(int));
        cudaMalloc(&buffers.d_results, new_cap * sizeof(double));
        buffers.d_pop_capacity = new_cap;
    }

    LinearGpuNode* d_all_nodes = (LinearGpuNode*)buffers.d_nodes;
    int* d_offsets = (int*)buffers.d_offsets;
    int* d_sizes = (int*)buffers.d_sizes;
    double* d_results = (double*)buffers.d_results;

    // Async memory transfers using CUDA stream
    cudaMemcpyAsync(d_all_nodes, all_nodes.data(), total_nodes * sizeof(LinearGpuNode), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_offsets, tree_offsets.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_sizes, tree_sizes.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);

    // Optimized kernel launch configuration for RTX 3050
    // RTX 3050 has 20 SMs, each can handle 2048 threads max
    // For 50k trees, we want maximum occupancy
    int threadsPerBlock = 256;
    int blocksPerGrid = (total_trees + threadsPerBlock - 1) / threadsPerBlock;

    // Launch kernel on stream
    evaluate_all_populations_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        d_all_nodes, d_offsets, d_sizes, total_trees,
        d_targets, d_x_values, num_points, num_vars, d_results,
        COMPLEXITY_PENALTY_FACTOR, USE_RMSE_FITNESS
    );

    // Async copy results back
    cudaMemcpyAsync(results.data(), d_results, total_trees * sizeof(double), 
                    cudaMemcpyDeviceToHost, stream);
    
    // Synchronize stream
    cudaStreamSynchronize(stream);
}

// ============================================================
// DOUBLE-BUFFERED GPU IMPLEMENTATION
// ============================================================

void init_double_buffered_gpu(DoubleBufferedGpu& db) {
    // Create two streams for overlapped execution
    cudaStreamCreate((cudaStream_t*)&db.streams[0]);
    cudaStreamCreate((cudaStream_t*)&db.streams[1]);
    
    // Pre-allocate buffers for both ping and pong
    size_t initial_nodes_cap = 1500000;  // 50k trees * 30 nodes avg
    size_t initial_pop_cap = 60000;      // Slightly more than 50k
    
    for (int i = 0; i < 2; ++i) {
        db.d_nodes_capacity[i] = initial_nodes_cap;
        db.d_pop_capacity[i] = initial_pop_cap;
        
        cudaMalloc(&db.d_nodes[i], initial_nodes_cap * sizeof(LinearGpuNode));
        cudaMalloc(&db.d_offsets[i], initial_pop_cap * sizeof(int));
        cudaMalloc(&db.d_sizes[i], initial_pop_cap * sizeof(int));
        cudaMalloc(&db.d_results[i], initial_pop_cap * sizeof(double));
    }
    
    // Allocate pinned host memory for faster H2D/D2H transfers
    db.h_pinned_capacity = initial_pop_cap;
    cudaMallocHost(&db.h_pinned_results, initial_pop_cap * sizeof(double));
    
    db.current_buffer = 0;
}

void cleanup_double_buffered_gpu(DoubleBufferedGpu& db) {
    for (int i = 0; i < 2; ++i) {
        if (db.streams[i]) {
            cudaStreamDestroy((cudaStream_t)db.streams[i]);
            db.streams[i] = nullptr;
        }
        if (db.d_nodes[i]) { cudaFree(db.d_nodes[i]); db.d_nodes[i] = nullptr; }
        if (db.d_offsets[i]) { cudaFree(db.d_offsets[i]); db.d_offsets[i] = nullptr; }
        if (db.d_sizes[i]) { cudaFree(db.d_sizes[i]); db.d_sizes[i] = nullptr; }
        if (db.d_results[i]) { cudaFree(db.d_results[i]); db.d_results[i] = nullptr; }
    }
    
    if (db.h_pinned_results) {
        cudaFreeHost(db.h_pinned_results);
        db.h_pinned_results = nullptr;
    }
}

void launch_evaluation_async(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    int total_trees,
    double* d_targets, double* d_x_values,
    int num_points,
    int num_vars,
    DoubleBufferedGpu& db)
{
    if (total_trees == 0) return;
    
    int buf = db.current_buffer;
    cudaStream_t stream = (cudaStream_t)db.streams[buf];
    size_t total_nodes = all_nodes.size();
    
    // Ensure buffers are large enough
    if (total_nodes > db.d_nodes_capacity[buf]) {
        cudaFree(db.d_nodes[buf]);
        size_t new_cap = total_nodes * 1.5;
        cudaMalloc(&db.d_nodes[buf], new_cap * sizeof(LinearGpuNode));
        db.d_nodes_capacity[buf] = new_cap;
    }
    
    if ((size_t)total_trees > db.d_pop_capacity[buf]) {
        cudaFree(db.d_offsets[buf]);
        cudaFree(db.d_sizes[buf]);
        cudaFree(db.d_results[buf]);
        
        size_t new_cap = total_trees * 1.5;
        cudaMalloc(&db.d_offsets[buf], new_cap * sizeof(int));
        cudaMalloc(&db.d_sizes[buf], new_cap * sizeof(int));
        cudaMalloc(&db.d_results[buf], new_cap * sizeof(double));
        db.d_pop_capacity[buf] = new_cap;
    }
    
    // Ensure pinned results buffer is large enough
    if ((size_t)total_trees > db.h_pinned_capacity) {
        cudaFreeHost(db.h_pinned_results);
        db.h_pinned_capacity = total_trees * 1.5;
        cudaMallocHost(&db.h_pinned_results, db.h_pinned_capacity * sizeof(double));
    }
    
    LinearGpuNode* d_nodes = (LinearGpuNode*)db.d_nodes[buf];
    int* d_offsets = (int*)db.d_offsets[buf];
    int* d_sizes = (int*)db.d_sizes[buf];
    double* d_results = (double*)db.d_results[buf];
    
    // Async transfers
    cudaMemcpyAsync(d_nodes, all_nodes.data(), total_nodes * sizeof(LinearGpuNode), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_offsets, tree_offsets.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_sizes, tree_sizes.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);
    
    // Launch kernel
    int threadsPerBlock = 256;
    int blocksPerGrid = (total_trees + threadsPerBlock - 1) / threadsPerBlock;
    
    evaluate_all_populations_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        d_nodes, d_offsets, d_sizes, total_trees,
        d_targets, d_x_values, num_points, num_vars, d_results,
        COMPLEXITY_PENALTY_FACTOR, USE_RMSE_FITNESS
    );
    
    // Async copy results to pinned memory
    cudaMemcpyAsync(db.h_pinned_results, d_results, total_trees * sizeof(double), 
                    cudaMemcpyDeviceToHost, stream);
    
    // DO NOT SYNC HERE - let CPU do other work
}

void retrieve_results_sync(
    std::vector<double>& results,
    int total_trees,
    DoubleBufferedGpu& db)
{
    int buf = db.current_buffer;
    cudaStream_t stream = (cudaStream_t)db.streams[buf];
    
    // Wait for this stream to complete
    cudaStreamSynchronize(stream);
    
    // Copy from pinned memory to results vector (this is very fast - memory to memory)
    results.resize(total_trees);
    memcpy(results.data(), db.h_pinned_results, total_trees * sizeof(double));
    
    // Switch to other buffer for next generation
    db.current_buffer = 1 - buf;
}

#endif // USE_GPU_ACCELERATION_DEFINED_BY_CMAKE


In [None]:
%%writefile Code/src/FitnessGPU.cuh
#ifndef FITNESS_GPU_CUH
#define FITNESS_GPU_CUH

#include <vector>
#include <memory> // For NodePtr in the host-side wrapper
#include "ExpressionTree.h" // For NodeType enum and original Node structure (host-side)
#include "Globals.h" // For INF, USE_RMSE_FITNESS, COMPLEXITY_PENALTY_FACTOR etc.

// Forward declaration for host-side NodePtr
struct Node;
using NodePtr = std::shared_ptr<Node>;

// A simplified node structure for the linearized tree on the GPU
struct LinearGpuNode {
    NodeType type;
    double value;
    int var_index;
    char op;
};

// Helper function to linearize the tree into a post-order array
void linearize_tree(const NodePtr& node, std::vector<LinearGpuNode>& linear_tree);

// Host-side wrapper for launching CUDA kernel
#if USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double evaluate_fitness_gpu(NodePtr tree,
                            const std::vector<double>& targets,
                            const std::vector<std::vector<double>>& x_values,
                            double* d_targets, double* d_x_values);

// Batch evaluation function with persistent buffers
void evaluate_population_gpu(const std::vector<LinearGpuNode>& all_nodes,
                             const std::vector<int>& tree_offsets,
                             const std::vector<int>& tree_sizes,
                             const std::vector<double>& targets,
                             const std::vector<std::vector<double>>& x_values,
                             std::vector<double>& results,
                             double* d_targets, double* d_x_values,
                             void*& d_nodes_ptr, size_t& d_nodes_cap,
                             void*& d_offsets_ptr, void*& d_sizes_ptr, void*& d_results_ptr, size_t& d_pop_cap);

// ============================================================
// GLOBAL BATCH EVALUATION - Evaluates ALL islands in ONE kernel call
// ============================================================
// Persistent GPU buffers for global batch (managed by GeneticAlgorithm)
struct GlobalGpuBuffers {
    void* d_nodes = nullptr;
    void* d_offsets = nullptr;
    void* d_sizes = nullptr;
    void* d_results = nullptr;
    size_t d_nodes_capacity = 0;
    size_t d_pop_capacity = 0;
    void* cuda_stream = nullptr; // cudaStream_t
};

// ============================================================
// DOUBLE-BUFFERED GPU EVALUATION - Maximum overlap of CPU/GPU work
// ============================================================
struct DoubleBufferedGpu {
    // Two sets of device buffers for ping-pong operation
    void* d_nodes[2] = {nullptr, nullptr};
    void* d_offsets[2] = {nullptr, nullptr};
    void* d_sizes[2] = {nullptr, nullptr};
    void* d_results[2] = {nullptr, nullptr};
    size_t d_nodes_capacity[2] = {0, 0};
    size_t d_pop_capacity[2] = {0, 0};
    
    // Two streams for overlapped execution
    void* streams[2] = {nullptr, nullptr};
    
    // Current buffer index (0 or 1)
    int current_buffer = 0;
    
    // Host-side pinned memory for faster transfers
    void* h_pinned_results = nullptr;
    size_t h_pinned_capacity = 0;
};

// Initialize double-buffered GPU resources
void init_double_buffered_gpu(DoubleBufferedGpu& db);

// Cleanup double-buffered GPU resources
void cleanup_double_buffered_gpu(DoubleBufferedGpu& db);

// Async launch - starts GPU work without waiting (CPU can do other work)
void launch_evaluation_async(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    int total_trees,
    double* d_targets, double* d_x_values,
    int num_points,
    int num_vars,
    DoubleBufferedGpu& db);

// Wait for GPU work to complete and retrieve results
void retrieve_results_sync(
    std::vector<double>& results,
    int total_trees,
    DoubleBufferedGpu& db);

// Initialize global GPU buffers and CUDA stream
void init_global_gpu_buffers(GlobalGpuBuffers& buffers);

// Cleanup global GPU buffers
void cleanup_global_gpu_buffers(GlobalGpuBuffers& buffers);

// Evaluate ALL trees from ALL islands in a single GPU batch call (maximum GPU utilization)
void evaluate_all_populations_gpu(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    const std::vector<int>& tree_complexities, // For complexity penalty
    int total_trees,
    const std::vector<double>& targets,
    const std::vector<std::vector<double>>& x_values,
    std::vector<double>& results,
    double* d_targets, double* d_x_values,
    GlobalGpuBuffers& buffers);
#endif

#endif // FITNESS_GPU_CUH


In [None]:
%%writefile Code/src/GeneticAlgorithm.cpp
#include "GeneticAlgorithm.h"
#include "Globals.h"
#include "Fitness.h"
#include "AdvancedFeatures.h" // Incluir este para DomainConstraints::
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
#include "FitnessGPU.cuh"     // Para funciones de GPU
#include <cuda_runtime.h>     // Para CUDA runtime
#endif
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
#include <omp.h>
#include <iomanip>
#include <iterator>
#include <chrono>
#include <unordered_set>

// --- Constructor (Modificado para que evaluate_population procese todo) ---
// --- Constructor (Modificado para que evaluate_population procese todo) ---
GeneticAlgorithm::GeneticAlgorithm(const std::vector<double>& targets_ref,
                                     const std::vector<std::vector<double>>& x_values_ref,
                                     int total_pop,
                                     int gens,
                                     const std::vector<std::string>& seeds,
                                     int n_islands)
    : targets(targets_ref),
      x_values(x_values_ref),
      total_population_size(total_pop),
      generations(gens),
      num_islands(n_islands),
      overall_best_fitness(INF),
      last_overall_best_fitness(INF),
      generation_last_improvement(0)
{
    // Validar y ajustar número de islas y población por isla
    if (this->num_islands <= 0) this->num_islands = 1;
    pop_per_island = this->total_population_size / this->num_islands;
    if (pop_per_island < MIN_POP_PER_ISLAND) {
        pop_per_island = MIN_POP_PER_ISLAND;
        this->num_islands = this->total_population_size / pop_per_island;
        if (this->num_islands == 0) this->num_islands = 1;
        std::cerr << "Warning: Adjusted number of islands to " << this->num_islands
                  << " for minimum population size per island (" << pop_per_island <<")." << std::endl;
    }
    this->total_population_size = this->num_islands * pop_per_island;
    std::cout << "Info: Running with " << this->num_islands << " islands, "
              << pop_per_island << " individuals per island." << std::endl;

    // Crear las islas
    islands.reserve(this->num_islands);
    for (int i = 0; i < this->num_islands; ++i) {
        try {
            islands.push_back(std::make_unique<Island>(i, pop_per_island));
        }
        catch (const std::exception& e) { std::cerr << "[ERROR] Creating Island " << i << ": " << e.what() << std::endl; throw; }
        catch (...) { std::cerr << "[ERROR] Unknown exception creating island " << i << std::endl; throw; }
    }

    // --- INJECT SEEDS ---
    if (!seeds.empty()) {
        std::cout << "Info: Injecting " << seeds.size() << " seed formulas into population..." << std::endl;
        int seed_idx = 0;
        
        // Distribute seeds cyclically across islands to promote diversity
        for (int i = 0; i < this->num_islands && seed_idx < seeds.size(); ++i) {
            for(size_t j = 0; j < islands[i]->population.size(); ++j) {
                if (seed_idx >= seeds.size()) break;

                try {
                    NodePtr parsed_tree = parse_formula_string(seeds[seed_idx]);
                    if (parsed_tree) {
                        islands[i]->population[j].tree = std::move(parsed_tree);
                    }
                    seed_idx++; 
                } catch (const std::exception& e) {
                    std::cerr << "[Warning] Failed to parse seed formula: " << seeds[seed_idx] << " | Error: " << e.what() << std::endl;
                    seed_idx++; // Skip this seed
                }
            }
        }
    }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    bool gpu_init_failed = false;
    if (!FORCE_CPU_MODE) {
        // Asignar memoria en la GPU y copiar datos
        size_t targets_size = targets.size() * sizeof(double);
        // Multivariable: flatten X values [NUM_SAMPLES * NUM_FEATURES]
        size_t n_samples = x_values.size();
        size_t n_features = (n_samples > 0) ? x_values[0].size() : 0;
        size_t x_values_size = n_samples * n_features * sizeof(double);
        
        // Linearize
        std::vector<double> flattened_x;
        flattened_x.reserve(n_samples * n_features);
        for(const auto& row : x_values) {
            flattened_x.insert(flattened_x.end(), row.begin(), row.end());
        }

        cudaError_t err_t = cudaMalloc(&d_targets, targets_size);
        cudaError_t err_x = cudaMalloc(&d_x_values, x_values_size);

        if (err_t != cudaSuccess || err_x != cudaSuccess) {
            std::cerr << "[WARNING] CUDA memory allocation failed: "
                      << cudaGetErrorString(err_t) << " | " << cudaGetErrorString(err_x) << std::endl;
            std::cerr << "[INFO] Falling back to CPU mode." << std::endl;
            gpu_init_failed = true;
            // Clean up any partial allocation
            if (d_targets) { cudaFree(d_targets); d_targets = nullptr; }
            if (d_x_values) { cudaFree(d_x_values); d_x_values = nullptr; }
        } else {
            cudaMemcpy(d_targets, targets.data(), targets_size, cudaMemcpyHostToDevice);
            cudaMemcpy(d_x_values, flattened_x.data(), x_values_size, cudaMemcpyHostToDevice);
            
            // Initialize global GPU buffers for batch evaluation of ALL islands
            init_global_gpu_buffers(global_gpu_buffers);
            
            // Initialize double-buffered GPU for async pipelining
            init_double_buffered_gpu(double_buffer_gpu);
            
            std::cout << "GPU buffers initialized for global batch evaluation (max " 
                      << total_population_size << " trees in single kernel call)" << std::endl;
            std::cout << "Double-buffered GPU enabled for async CPU/GPU overlap" << std::endl;
        }
    }
    
    if (FORCE_CPU_MODE || gpu_init_failed) {
        std::cout << "Using CPU for all evaluations" << std::endl;
    }
#endif

     // Evaluación inicial de TODA la población (incluyendo la inyectada)
     // La función evaluate_population ahora simplificará y evaluará a todos.
     std::cout << "Evaluating initial population (simplifying all)..." << std::endl;
     evaluate_all_islands(); // Use new global batch evaluation

     // Actualizar el mejor global inicial (en serie)
     overall_best_fitness = INF;
     overall_best_tree = nullptr;
     int initial_best_island = -1;
     int initial_best_idx = -1;

     for (int i = 0; i < islands.size(); ++i) {
        for(int j=0; j < islands[i]->population.size(); ++j) {
            const auto& ind = islands[i]->population[j];
            if (ind.tree && ind.fitness_valid && ind.fitness < overall_best_fitness) {
                overall_best_fitness = ind.fitness;
                initial_best_island = i;
                initial_best_idx = j;
            }
        }
     }
     if(initial_best_island != -1 && initial_best_idx != -1) {
         overall_best_tree = clone_tree(islands[initial_best_island]->population[initial_best_idx].tree);
     }

     last_overall_best_fitness = overall_best_fitness;
     generation_last_improvement = 0;
     std::cout << "Initial best fitness: " << std::scientific << overall_best_fitness << std::fixed << std::endl;
     if (overall_best_tree) {
          std::cout << "Initial best formula size: " << tree_size(overall_best_tree) << std::endl;
          std::cout << "Initial best formula: " << tree_to_string(overall_best_tree) << std::endl;
          // Nota para saber si el mejor inicial fue la fórmula inyectada (ahora simplificada)
          if (USE_INITIAL_FORMULA && initial_best_island != -1 && initial_best_idx == 0) {
               std::cout << "   (Note: Initial best is the (simplified) injected formula from Island " << initial_best_island << ")" << std::endl;
          } else if (initial_best_island != -1) {
               std::cout << "   (Note: Initial best found in Island " << initial_best_island << ", Index " << initial_best_idx << ")" << std::endl;
          }
      } else { std::cout << "No valid initial solution found (all fitness INF?)." << std::endl; }
     std::cout << "----------------------------------------" << std::endl;
}

GeneticAlgorithm::~GeneticAlgorithm() {
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    if (!FORCE_CPU_MODE) {
        // Cleanup double-buffered GPU
        cleanup_double_buffered_gpu(double_buffer_gpu);
        
        // Cleanup global GPU buffers
        cleanup_global_gpu_buffers(global_gpu_buffers);
        
        if (d_targets) {
            cudaFree(d_targets);
            d_targets = nullptr;
        }
        if (d_x_values) {
            cudaFree(d_x_values);
            d_x_values = nullptr;
        }
    }
#endif
    // El destructor de std::unique_ptr en 'islands' se encarga de liberar la memoria de las islas.
    // 'overall_best_tree' es un NodePtr. Si es un smart pointer (como std::unique_ptr<Node>),
    // su memoria se liberará automáticamente. Si es un puntero crudo, necesitaría una función delete_tree.
    // Asumiendo que NodePtr es un smart pointer o que la liberación se maneja en otro lugar,
    // o que un árbol nulo al final no causa fugas si no fue asignado con 'new'.
    // Si NodePtr es un puntero crudo y se asigna con 'new' en clone_tree, entonces
    // delete_tree(overall_best_tree) sería necesario aquí.
    // Por ahora, se deja vacío, asumiendo manejo automático o externo.
}

void GeneticAlgorithm::evaluate_population(Island& island) {
    int pop_size = island.population.size();
    if (pop_size == 0) return;

    // 1. Simplify trees (CPU Parallel)
    // We do this first so we only send simplified trees to GPU
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < pop_size; ++i) {
        Individual& ind = island.population[i];
        if (ind.tree) {
            ind.tree = DomainConstraints::fix_or_simplify(ind.tree);
        }
    }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    // 2. Prepare for Batch GPU Evaluation
    std::vector<LinearGpuNode> all_nodes;
    std::vector<int> tree_offsets;
    std::vector<int> tree_sizes;
    
    // Reserve memory to avoid reallocations (Optimization)
    // Assuming average tree size is around 20-30 nodes. 
    // This dramatically reduces CPU overhead during linearization.
    all_nodes.reserve(pop_size * 30); 
    tree_offsets.reserve(pop_size);
    tree_sizes.reserve(pop_size);
    
    // We need map back to original index because some trees might be null
    std::vector<int> valid_indices; 
    valid_indices.reserve(pop_size);

    for (int i = 0; i < pop_size; ++i) {
        if (island.population[i].tree) {
            int start_offset = all_nodes.size();
            linearize_tree(island.population[i].tree, all_nodes);
            int size = all_nodes.size() - start_offset;
            
            if (size > 0) {
                tree_offsets.push_back(start_offset);
                tree_sizes.push_back(size);
                valid_indices.push_back(i);
            } else {
                 island.population[i].fitness = INF;
                 island.population[i].fitness_valid = true;
            }
        } else {
             island.population[i].fitness = INF;
             island.population[i].fitness_valid = true;
        }
    }

    if (valid_indices.empty()) return;

    // 3. call GPU Batch (d_targets and d_x_values already exist)
    std::vector<double> raw_results(valid_indices.size());
    evaluate_population_gpu(all_nodes, tree_offsets, tree_sizes, targets, x_values, raw_results, d_targets, d_x_values,
                            island.d_nodes, island.d_nodes_capacity,
                            island.d_offsets, island.d_sizes, island.d_results, island.d_pop_capacity);

    // 4. Process results
    for (size_t k = 0; k < valid_indices.size(); ++k) {
        int idx = valid_indices[k];
        double sum_sq_error = raw_results[k];
        double raw_fitness = INF;

        // Check for validity
        if (!std::isnan(sum_sq_error) && !std::isinf(sum_sq_error) && sum_sq_error < 1e300) { // 1e300 as safety threshold
             if (USE_RMSE_FITNESS) {
                 if (x_values.size() > 0) {
                     double mse = sum_sq_error / x_values.size();
                     raw_fitness = sqrt(mse);
                 }
             } else {
                 raw_fitness = sum_sq_error;
             }
        }

        if (raw_fitness >= INF/2) {
             island.population[idx].fitness = INF;
        } else {
             // Complexity Penalty
             double complexity = static_cast<double>(tree_sizes[k]); // We already have the linear size
             double penalty = complexity * COMPLEXITY_PENALTY_FACTOR;
             island.population[idx].fitness = raw_fitness * (1.0 + penalty);
        }
        island.population[idx].fitness_valid = true;
    }

#else
    // CPU Fallback (Parallel)
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < pop_size; ++i) {
        Individual& ind = island.population[i];
        if (ind.tree) {
             ind.fitness = evaluate_fitness(ind.tree, targets, x_values);
             ind.fitness_valid = true;
        } else {
             ind.fitness = INF;
             ind.fitness_valid = true;
        }
    }
#endif
}


// ============================================================
// GLOBAL BATCH EVALUATION - Evaluates ALL islands in ONE GPU kernel call
// ============================================================
void GeneticAlgorithm::evaluate_all_islands() {
    int total_trees = 0;
    for (const auto& island : islands) {
        total_trees += island->population.size();
    }
    if (total_trees == 0) return;

    // Step 1: Simplify ALL trees in parallel (CPU)
    // Note: collapse(2) not supported by MSVC OpenMP 2.0, using nested parallel for
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
        for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
            Individual& ind = islands[i]->population[j];
            if (ind.tree) {
                ind.tree = DomainConstraints::fix_or_simplify(ind.tree);
            }
        }
    }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    // Runtime check: if FORCE_CPU_MODE is true or GPU init failed (d_targets == nullptr), use CPU
    if (!FORCE_CPU_MODE && d_targets != nullptr) {
    // Step 2: Linearize ALL trees from ALL islands into single buffer
    // OPTIMIZATION: Parallel linearization using OpenMP
    
    // First pass: count valid trees and compute per-tree sizes in parallel
    std::vector<int> tree_sizes_temp(total_trees, 0);
    std::vector<std::pair<int, int>> index_mapping(total_trees); // (island, individual)
    std::vector<bool> tree_valid(total_trees, false);
    
    int tree_idx = 0;
    for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
        for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
            index_mapping[tree_idx] = {i, j};
            tree_idx++;
        }
    }
    
    // Parallel linearization into per-thread buffers
    int num_threads = omp_get_max_threads();
    std::vector<std::vector<LinearGpuNode>> thread_nodes(num_threads);
    std::vector<std::vector<int>> thread_offsets(num_threads);
    std::vector<std::vector<int>> thread_sizes(num_threads);
    std::vector<std::vector<std::pair<int, int>>> thread_mappings(num_threads);
    
    // Pre-allocate per-thread buffers
    int trees_per_thread = (total_trees + num_threads - 1) / num_threads;
    for (int t = 0; t < num_threads; ++t) {
        thread_nodes[t].reserve(trees_per_thread * 30);
        thread_offsets[t].reserve(trees_per_thread);
        thread_sizes[t].reserve(trees_per_thread);
        thread_mappings[t].reserve(trees_per_thread);
    }
    
    #pragma omp parallel
    {
        int tid = omp_get_thread_num();
        auto& local_nodes = thread_nodes[tid];
        auto& local_offsets = thread_offsets[tid];
        auto& local_sizes = thread_sizes[tid];
        auto& local_mappings = thread_mappings[tid];
        
        #pragma omp for schedule(static)
        for (int t = 0; t < total_trees; ++t) {
            int i = index_mapping[t].first;
            int j = index_mapping[t].second;
            Individual& ind = islands[i]->population[j];
            
            if (ind.tree) {
                int start_offset = local_nodes.size();
                linearize_tree(ind.tree, local_nodes);
                int size = local_nodes.size() - start_offset;
                
                if (size > 0) {
                    local_offsets.push_back(start_offset);
                    local_sizes.push_back(size);
                    local_mappings.push_back({i, j});
                } else {
                    ind.fitness = INF;
                    ind.fitness_valid = true;
                }
            } else {
                ind.fitness = INF;
                ind.fitness_valid = true;
            }
        }
    }
    
    // Merge thread-local buffers into global buffers
    std::vector<LinearGpuNode> all_nodes;
    std::vector<int> tree_offsets;
    std::vector<int> tree_sizes;
    std::vector<std::pair<int, int>> result_mapping;
    
    size_t total_node_count = 0;
    size_t total_valid_trees = 0;
    for (int t = 0; t < num_threads; ++t) {
        total_node_count += thread_nodes[t].size();
        total_valid_trees += thread_mappings[t].size();
    }
    
    all_nodes.reserve(total_node_count);
    tree_offsets.reserve(total_valid_trees);
    tree_sizes.reserve(total_valid_trees);
    result_mapping.reserve(total_valid_trees);
    
    for (int t = 0; t < num_threads; ++t) {
        int offset_adjustment = all_nodes.size();
        
        // Copy nodes
        all_nodes.insert(all_nodes.end(), thread_nodes[t].begin(), thread_nodes[t].end());
        
        // Adjust offsets and copy
        for (size_t k = 0; k < thread_offsets[t].size(); ++k) {
            tree_offsets.push_back(thread_offsets[t][k] + offset_adjustment);
            tree_sizes.push_back(thread_sizes[t][k]);
            result_mapping.push_back(thread_mappings[t][k]);
        }
    }
    
    std::vector<int> tree_complexities = tree_sizes; // Same as sizes for now

    if (result_mapping.empty()) return;

    int valid_trees = result_mapping.size();
    int num_points = x_values.size();
    int num_vars = (num_points > 0) ? x_values[0].size() : 0;
    
    // Step 3: Launch GPU evaluation ASYNC (no blocking!)
    // GPU will work while CPU continues with other tasks
    launch_evaluation_async(
        all_nodes, tree_offsets, tree_sizes,
        valid_trees, d_targets, d_x_values, num_points,
        num_vars,
        double_buffer_gpu
    );
    
    // Step 4: Wait for GPU results (this is where we sync)
    std::vector<double> results;
    retrieve_results_sync(results, valid_trees, double_buffer_gpu);

    // Step 5: Distribute results back to islands
    for (size_t k = 0; k < static_cast<size_t>(valid_trees); ++k) {
        int island_idx = result_mapping[k].first;
        int ind_idx = result_mapping[k].second;
        double fitness = results[k];
        
        // Validate result
        if (std::isnan(fitness) || std::isinf(fitness) || fitness >= 1e300) {
            fitness = INF;
        }
        
        islands[island_idx]->population[ind_idx].fitness = fitness;
        islands[island_idx]->population[ind_idx].fitness_valid = true;
    }
    
    } else {
        // FORCE_CPU_MODE is true OR GPU init failed: Use CPU
        #pragma omp parallel for schedule(dynamic)
        for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
            for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
                Individual& ind = islands[i]->population[j];
                if (ind.tree) {
                    // Pass GPU pointers (even if null) to match signature
                    ind.fitness = evaluate_fitness(ind.tree, targets, x_values, d_targets, d_x_values);
                    ind.fitness_valid = true;
                } else {
                    ind.fitness = INF;
                    ind.fitness_valid = true;
                }
            }
        }
    }

#else
    // CPU Fallback: CUDA not available, use parallel CPU evaluation
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
        for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
            Individual& ind = islands[i]->population[j];
            if (ind.tree) {
                ind.fitness = evaluate_fitness(ind.tree, targets, x_values);
                ind.fitness_valid = true;
            } else {
                ind.fitness = INF;
                ind.fitness_valid = true;
            }
        }
    }
#endif
}


// --- evolve_island ---
// (Sin cambios)
void GeneticAlgorithm::evolve_island(Island& island, int current_generation) {
    int current_pop_size = island.population.size(); if (current_pop_size == 0) return;
    auto best_it = std::min_element(island.population.begin(), island.population.end(),
        [](const Individual& a, const Individual& b) {
            if (!a.tree || !a.fitness_valid) return false;
            if (!b.tree || !b.fitness_valid) return true;
            return a.fitness < b.fitness;
        });
    double current_best_fitness = INF;
    int best_idx = -1;
    if (best_it != island.population.end() && best_it->tree && best_it->fitness_valid) {
        best_idx = std::distance(island.population.begin(), best_it);
        current_best_fitness = best_it->fitness;
    }
    island.fitness_history.push_back(current_best_fitness);
    if (best_idx != -1 && current_best_fitness < INF) {
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
         auto local_search_result = try_local_improvement(island.population[best_idx].tree, island.population[best_idx].fitness, targets, x_values, LOCAL_SEARCH_ATTEMPTS, d_targets, d_x_values);
#else
         auto local_search_result = try_local_improvement(island.population[best_idx].tree, island.population[best_idx].fitness, targets, x_values, LOCAL_SEARCH_ATTEMPTS);
#endif
         if (local_search_result.first && local_search_result.second < island.population[best_idx].fitness) {
             island.population[best_idx].tree = local_search_result.first;
             island.population[best_idx].fitness = local_search_result.second;
             island.population[best_idx].fitness_valid = true;
             current_best_fitness = local_search_result.second;
         }
    }
    if (current_best_fitness < island.best_fitness - FITNESS_EQUALITY_TOLERANCE) {
        island.best_fitness = current_best_fitness;
        island.stagnation_counter = 0;
    } else if (current_best_fitness < INF) {
        island.stagnation_counter++;
    }
    island.pareto_optimizer.update(island.population, targets, x_values);
    for(const auto& ind : island.population) {
        if(ind.tree && ind.fitness_valid && ind.fitness < PATTERN_RECORD_FITNESS_THRESHOLD) {
            island.pattern_memory.record_success(ind.tree, ind.fitness);
        }
    }
    std::vector<Individual> next_generation;
    next_generation.reserve(current_pop_size);
    int elite_count = std::max(1, static_cast<int>(current_pop_size * island.params.elite_percentage));
    if (elite_count > 0 && elite_count <= current_pop_size) {
        std::partial_sort(island.population.begin(), island.population.begin() + elite_count, island.population.end());
        int added_elites = 0;
        for (int i = 0; i < elite_count && i < island.population.size(); ++i) {
             if (island.population[i].tree && island.population[i].fitness_valid) {
                 next_generation.emplace_back(clone_tree(island.population[i].tree));
                 next_generation.back().fitness = island.population[i].fitness;
                 next_generation.back().fitness_valid = true;
                 added_elites++;
             }
        }
        elite_count = added_elites;
    } else { elite_count = 0; }
    int random_injection_count = 0;
    if (island.stagnation_counter > STAGNATION_LIMIT_ISLAND / 2) {
        random_injection_count = static_cast<int>(current_pop_size * STAGNATION_RANDOM_INJECT_PERCENT);
        for(int i = 0; i < random_injection_count && next_generation.size() < current_pop_size; ++i) {
             NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
             if (random_tree) next_generation.emplace_back(std::move(random_tree));
        }
    }
    int pattern_injection_count = 0;
    
    // --- ISLAND CATACLYSM ---
    // If enabled, triggers a hard reset if stagnation persists.
    if (USE_ISLAND_CATACLYSM && island.stagnation_counter >= STAGNATION_LIMIT_ISLAND) {
        // Keep only top 1 elite (already in next_generation[0] if elite_count > 0)
        // Or if we need to enforce better elitism during cataclysm:
        
        int survivors = 1; // Only the absolute best one survives
        // Resize to survivors
        if (next_generation.size() > survivors) next_generation.resize(survivors);
        
        // Fill the rest with completely random trees
        int to_fill = current_pop_size - next_generation.size();
        for(int i=0; i<to_fill; ++i) {
             NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
             if (random_tree) next_generation.emplace_back(std::move(random_tree));
        }
        
        island.stagnation_counter = 0; // Reset counter
        // Optional: Pattern injection could also happen here, but random is better for total diversity.
    }
    // Only do standard injections if we didn't just nuke everything
    else {
        if (random_injection_count == 0 && current_generation % PATTERN_INJECT_INTERVAL == 0) {
            pattern_injection_count = static_cast<int>(current_pop_size * PATTERN_INJECT_PERCENT);
            for (int i = 0; i < pattern_injection_count && next_generation.size() < current_pop_size; ++i) {
                NodePtr pt = island.pattern_memory.suggest_pattern_based_tree(MAX_TREE_DEPTH_INITIAL);
                if (pt) { next_generation.emplace_back(std::move(pt)); }
                else {
                     NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
                     if (random_tree) next_generation.emplace_back(std::move(random_tree));
                }
            }
        }
    }
    auto& rng = get_rng();
    std::uniform_real_distribution<double> prob_dist(0.0, 1.0);
    // >>> Parallel Parent Selection Loop with Uniqueness Check <<<
    
    // 1. Initialize uniqueness set with survivors (elites/injected)
    std::unordered_set<std::string> unique_signatures;
    if (PREVENT_DUPLICATES) {
        for (const auto& ind : next_generation) {
            if (ind.tree) {
                unique_signatures.insert(tree_to_string(ind.tree));
            }
        }
    }

    // 2. Fill the rest of the population
    int fail_safe_counter = 0;
    while (next_generation.size() < current_pop_size) {
        int needed = current_pop_size - next_generation.size();
        
        // Generate candidates in parallel
        std::vector<Individual> candidates(needed);
        
        #pragma omp parallel for schedule(dynamic)
        for (int i = 0; i < needed; ++i) {
            // Thread-local RNG
            auto& rng = get_rng(); 
            
            Individual offspring;
            // Use distribution defined outside or create new one? 
            // Better create local to avoid shared state issues if not const
            std::uniform_real_distribution<double> local_prob_dist(0.0, 1.0);

            if (local_prob_dist(rng) < island.params.crossover_rate) {
                Individual p1, p2;
                if (USE_LEXICASE_SELECTION) {
                    p1 = lexicase_selection(island.population, targets, x_values);
                    p2 = lexicase_selection(island.population, targets, x_values);
                } else {
                    p1 = tournament_selection(island.population, island.params.tournament_size);
                    p2 = tournament_selection(island.population, island.params.tournament_size);
                }
                // Use Semantic Crossover for better diversity
                offspring = semantic_crossover(p1, p2, x_values);
            } else {
                Individual p1;
                if (USE_LEXICASE_SELECTION) {
                    p1 = lexicase_selection(island.population, targets, x_values);
                } else {
                    p1 = tournament_selection(island.population, island.params.tournament_size);
                }
                if (p1.tree) p1.tree = clone_tree(p1.tree); 
                mutate(p1, island.params.mutation_rate);
                offspring = std::move(p1);
            }
            
            candidates[i] = std::move(offspring);
        }
        
        // Filter and add unique candidates (Serial)
        int added_this_round = 0;
        for (auto& cand : candidates) {
            if (next_generation.size() >= current_pop_size) break;
            
            bool is_valid_to_add = true;
            if (PREVENT_DUPLICATES && cand.tree) {
                std::string sig = tree_to_string(cand.tree);
                if (unique_signatures.find(sig) != unique_signatures.end()) {
                    is_valid_to_add = false; 
                } else {
                    unique_signatures.insert(sig);
                }
            }
            
            if (is_valid_to_add) {
                next_generation.emplace_back(std::move(cand));
                added_this_round++;
            }
        }
        
        // Deadlock prevention
        if (added_this_round == 0) {
            fail_safe_counter++;
            if (fail_safe_counter > DUPLICATE_RETRIES) {
                // Fill remaining with random trees
                int remaining = current_pop_size - next_generation.size();
                for (int k = 0; k < remaining; ++k) {
                    NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
                    if (random_tree) next_generation.emplace_back(std::move(random_tree));
                }
                break; // Exit loop
            }
        } else {
            fail_safe_counter = 0; // Reset if we made progress
        }
    }
     if (next_generation.size() > current_pop_size) next_generation.resize(current_pop_size);
    island.population = std::move(next_generation);
    if (current_generation > 0 && current_generation % PARAM_MUTATE_INTERVAL == 0) island.params.mutate(island.stagnation_counter);
}

// --- migrate ---
// (Sin cambios)
void GeneticAlgorithm::migrate() {
    if (num_islands <= 1) return;
    int current_pop_per_island = islands.empty() ? 0 : islands[0]->population.size();
    if (current_pop_per_island == 0) return;
    int num_migrants = std::min(MIGRATION_SIZE, current_pop_per_island / 5);
    if (num_migrants <= 0) return;
    std::vector<std::vector<Individual>> outgoing_migrants(num_islands);
    #pragma omp parallel for
    for (int i = 0; i < num_islands; ++i) {
        Island& src = *islands[i];
        if (src.population.size() < num_migrants) continue;
        std::partial_sort(src.population.begin(), src.population.begin() + num_migrants, src.population.end());
        outgoing_migrants[i].reserve(num_migrants);
        int migrants_selected = 0;
        for (int j = 0; j < src.population.size() && migrants_selected < num_migrants; ++j) {
             if (src.population[j].tree && src.population[j].fitness_valid) {
                 Individual migrant_copy;
                 migrant_copy.tree = clone_tree(src.population[j].tree);
                 migrant_copy.fitness = src.population[j].fitness;
                 migrant_copy.fitness_valid = true;
                 outgoing_migrants[i].push_back(std::move(migrant_copy));
                 migrants_selected++;
             }
        }
    }
    for (int dest_idx = 0; dest_idx < num_islands; ++dest_idx) {
        int src_idx = (dest_idx + num_islands - 1) % num_islands;
        Island& dest = *islands[dest_idx];
        const auto& migrants_to_receive = outgoing_migrants[src_idx];
        if (migrants_to_receive.empty() || dest.population.empty()) continue;
        int replace_count = std::min((int)migrants_to_receive.size(), (int)dest.population.size());
        if (replace_count <= 0) continue;
        std::partial_sort(dest.population.begin(), dest.population.end() - replace_count, dest.population.end());
        int migrant_idx = 0;
        for (int i = 0; i < replace_count; ++i) {
            int replace_idx = dest.population.size() - 1 - i;
            if (migrant_idx < migrants_to_receive.size()) {
                 dest.population[replace_idx] = std::move(migrants_to_receive[migrant_idx++]);
                 dest.population[replace_idx].fitness_valid = false; // Marcar para reevaluar
            }
        }
    }
}


// --- run ---
// (Sin cambios)
NodePtr GeneticAlgorithm::run() {
    std::cout << "Starting Genetic Algorithm..." << std::endl;
    auto start_time = std::chrono::high_resolution_clock::now();

    for (int gen = 0; gen < generations; ++gen) {
        // [DEBUG] Trace execution
        if (gen == 0) { std::cout << "[DEBUG] Entering main loop, gen=0" << std::endl; std::cout.flush(); }
        
        // 1. Evaluate ALL islands in ONE GPU kernel call (maximum GPU utilization)
        evaluate_all_islands();

        // 2. Evolve Islands (Parallel Island Loop)
        // Genetic operators (crossover, mutation) are CPU-bound and independent per island.
        #pragma omp parallel for
        for (int i = 0; i < islands.size(); ++i) {
             evolve_island(*islands[i], gen);
        }

        double current_gen_best_fitness = INF;
        int best_island_idx = -1;
        int best_ind_idx = -1;
        for (int i = 0; i < islands.size(); ++i) {
             for (int j = 0; j < islands[i]->population.size(); ++j) {
                 const auto& ind = islands[i]->population[j];
                 if (ind.tree && ind.fitness_valid && ind.fitness < current_gen_best_fitness) {
                     current_gen_best_fitness = ind.fitness;
                     best_island_idx = i; best_ind_idx = j;
                 }
             }
        }

        if (best_island_idx != -1 && current_gen_best_fitness < overall_best_fitness) {
             if (current_gen_best_fitness < overall_best_fitness) {
                  overall_best_fitness = current_gen_best_fitness;
                  overall_best_tree = clone_tree(islands[best_island_idx]->population[best_ind_idx].tree);
                  std::cout << "\n========================================" << std::endl;
                  std::cout << "New Global Best Found (Gen " << gen + 1 << ", Island " << best_island_idx << ")" << std::endl;
                  std::cout << "Fitness: " << std::fixed << std::setprecision(8) << overall_best_fitness << std::endl;
                  std::cout << "Size: " << tree_size(overall_best_tree) << std::endl;
                  std::cout << "Formula: " << tree_to_string(overall_best_tree) << std::endl;
                  std::cout.flush(); // Ensure Formula: line is captured
                  std::cout << "Predictions vs Targets:" << std::endl;
                  std::cout << std::fixed << std::setprecision(4);
                  if (overall_best_tree && !x_values.empty()) {
                      for (size_t j = 0; j < x_values.size(); ++j) {
                          double val = evaluate_tree(overall_best_tree, x_values[j]);
                          double target_val = (j < targets.size()) ? targets[j] : std::nan("");
                          double diff = (!std::isnan(val) && !std::isnan(target_val)) ? std::fabs(val - target_val) : std::nan("");
                          std::cout << "  x=(";
                          for(size_t v=0; v<x_values[j].size(); ++v) std::cout << (v>0?",":"") << x_values[j][v];
                          std::cout << "): Pred=" << std::setw(12) << val
                                    << ", Target=" << std::setw(12) << target_val
                                    << ", Diff=" << std::setw(12) << diff << std::endl;
                      }
                  } else { std::cout << "  (No data or no valid tree to show predictions)" << std::endl; }
                  std::cout << "========================================" << std::endl;
                  last_overall_best_fitness = overall_best_fitness;
                  generation_last_improvement = gen;
              }
        } else {
             if (overall_best_fitness < INF && (gen - generation_last_improvement) >= GLOBAL_STAGNATION_LIMIT) {
                  std::cout << "\n========================================" << std::endl;
                  std::cout << "TERMINATION: Global best fitness hasn't improved for " << GLOBAL_STAGNATION_LIMIT << " generations." << std::endl;
                  std::cout << "Stopping at Generation " << gen + 1 << "." << std::endl;
                  std::cout << "========================================" << std::endl;
                  break;
             }
        }

        if ((gen + 1) % MIGRATION_INTERVAL == 0 && num_islands > 1) {
             migrate();
             // Re-evaluate after migration using global batch
             evaluate_all_islands();
        }

        if (overall_best_fitness < EXACT_SOLUTION_THRESHOLD) {
            std::cout << "\n========================================" << std::endl;
            std::cout << "Solution found meeting criteria at Generation " << gen + 1 << "!" << std::endl;
            std::cout << "Final Fitness: " << std::fixed << std::setprecision(8) << overall_best_fitness << std::endl;
            if(overall_best_tree) {
                 std::cout << "Final Formula Size: " << tree_size(overall_best_tree) << std::endl;
                 std::cout << "Final Formula: " << tree_to_string(overall_best_tree) << std::endl;
                 std::cout.flush(); // Ensure Final Formula: line is captured
            }
            std::cout << "========================================" << std::endl;
            std::cout.flush(); // Ensure flush
            break;
        }

        if ((gen + 1) % PROGRESS_REPORT_INTERVAL == 0 || gen == generations - 1) {
             auto current_time = std::chrono::high_resolution_clock::now();
             std::chrono::duration<double> elapsed = current_time - start_time;
             std::cout << "\n--- Generation " << gen + 1 << "/" << generations
                       << " (Elapsed: " << std::fixed << std::setprecision(2) << elapsed.count() << "s) ---" << std::endl;
             std::cout << "Overall Best Fitness: " << std::scientific << overall_best_fitness << std::fixed << std::endl;
              if(overall_best_tree) { std::cout << "Best Formula Size: " << tree_size(overall_best_tree) << std::endl; }
              else { std::cout << "Best Formula Size: N/A" << std::endl; }
              std::cout << "(Last improvement at gen: " << generation_last_improvement + 1 << ")" << std::endl;
        }
    }

    auto end_time = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> total_elapsed = end_time - start_time;
    std::cout << "\n========================================" << std::endl;
    std::cout << "Evolution Finished!" << std::endl;
    std::cout << "Total Time: " << std::fixed << std::setprecision(2) << total_elapsed.count() << " seconds" << std::endl;
    std::cout << "Final Best Fitness: " << std::fixed << std::setprecision(8) << overall_best_fitness << std::endl;
     if (overall_best_tree) {
         std::cout << "Final Best Formula Size: " << tree_size(overall_best_tree) << std::endl;
         std::cout << "Final Formula: " << tree_to_string(overall_best_tree) << std::endl;
         std::cout.flush(); // Ensure Final Formula: line is captured
          std::cout << "--- Final Verification ---" << std::endl;
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
          double final_check_fitness = evaluate_fitness(overall_best_tree, targets, x_values, d_targets, d_x_values);
#else
          double final_check_fitness = evaluate_fitness(overall_best_tree, targets, x_values);
#endif
          std::cout << "Recalculated Fitness: " << std::fixed << std::setprecision(8) << final_check_fitness << std::endl;
          std::cout << std::fixed << std::setprecision(4);
          for (size_t j = 0; j < x_values.size(); ++j) {
                double val = evaluate_tree(overall_best_tree, x_values[j]);
                 double target_val = (j < targets.size()) ? targets[j] : std::nan("");
                 double diff = (!std::isnan(val) && !std::isnan(target_val)) ? std::fabs(val - target_val) : std::nan("");
                 std::cout << "  x=(";
                 for(size_t v=0; v<x_values[j].size(); ++v) std::cout << (v>0?",":"") << x_values[j][v];
                 std::cout << "): Pred=" << std::setw(12) << val
                          << ", Target=" << std::setw(12) << target_val
                          << ", Diff=" << std::setw(12) << diff << std::endl;
           }
     } else { std::cout << "No valid solution found." << std::endl; }
      std::cout << "========================================" << std::endl;
    return overall_best_tree;
}


In [None]:
%%writefile Code/src/GeneticAlgorithm.h
// ============================================================
// Archivo: src/GeneticAlgorithm.h
// ============================================================
#ifndef GENETICALGORITHM_H
#define GENETICALGORITHM_H

#include "ExpressionTree.h"
#include "GeneticOperators.h"
#include "AdvancedFeatures.h"
#include "Globals.h" // Incluir Globals.h para INF, NUM_ISLANDS, etc.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
#include "FitnessGPU.cuh" // For GlobalGpuBuffers definition
#endif
#include <vector>
#include <string>
#include <memory> // Para std::unique_ptr

class GeneticAlgorithm {
    // Estructura interna para representar una isla
    struct Island {
        std::vector<Individual> population; // Población de la isla
        EvolutionParameters params;         // Parámetros evolutivos propios de la isla
        PatternMemory pattern_memory;       // Memoria de patrones de la isla
        ParetoOptimizer pareto_optimizer;   // Optimizador Pareto de la isla
        int stagnation_counter = 0;         // Contador de estancamiento local de la isla
        double best_fitness = INF;          // Mejor fitness histórico de la isla
        std::vector<double> fitness_history;// Historial de fitness (opcional)
        int id;                             // Identificador de la isla

        // Constructor de la isla
        explicit Island(int island_id, int pop_size) : id(island_id) {
             population = create_initial_population(pop_size); // Crear población inicial
             params = EvolutionParameters::create_default();   // Usar parámetros por defecto
        }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
        // Persistent GPU buffers
        void* d_nodes = nullptr;
        void* d_offsets = nullptr;
        void* d_sizes = nullptr;
        void* d_results = nullptr;
        size_t d_nodes_capacity = 0;
        size_t d_pop_capacity = 0;

        ~Island() {
            // We cannot easily call cudaFree here because this header might be included
            // where cuda_runtime.h is not. However, we can trust the OS/driver to clean up
            // or we should add a cleanup function.
            // For now, we will rely on GeneticAlgorithm destructor or explict cleanup if possible.
            // But since Island is unique_ptr, we can't easily add a destructor that calls cudaFree 
            // without including cuda_runtime.
            // Optimization: Let's rely on OS cleanup at exit, OR add a cleanup method called by GA.
        }
#endif
    };

    // Miembros principales de la clase GeneticAlgorithm
    std::vector<std::unique_ptr<Island>> islands; // Vector de punteros únicos a las islas
    const std::vector<double>& targets;           // Referencia a los datos objetivo
    const std::vector<std::vector<double>>& x_values;          // Referencia a los valores de x [samples][features]
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    double* d_targets = nullptr;                  // Puntero a los datos objetivo en la GPU
    double* d_x_values = nullptr;                 // Puntero a los valores de x en la GPU
    GlobalGpuBuffers global_gpu_buffers;          // Global buffers for batch evaluation of ALL islands
    DoubleBufferedGpu double_buffer_gpu;          // Double-buffered GPU for async overlap
#endif
    int total_population_size;                    // Tamaño total de la población
    int generations;                              // Número máximo de generaciones
    int num_islands;                              // Número de islas

    // Seguimiento del mejor global
    NodePtr overall_best_tree = nullptr;          // Mejor árbol encontrado globalmente
    double overall_best_fitness = INF;            // Mejor fitness encontrado globalmente

    // --- NUEVO: Seguimiento de Estancamiento Global ---
    int generation_last_improvement = 0;          // Generación en la que mejoró el overall_best_fitness
    double last_overall_best_fitness = INF;       // Valor del overall_best_fitness en la última mejora
    // -------------------------------------------------

    int pop_per_island;                           // Población calculada por isla

public:
    // Constructor
    GeneticAlgorithm(const std::vector<double>& targets_ref,
                       const std::vector<std::vector<double>>& x_values_ref,
                       int total_pop,
                       int gens,
                       const std::vector<std::string>& seeds = {}, // Optional: Initial population seeds
                       int n_islands = NUM_ISLANDS); // Usar valor de Globals.h por defecto
    ~GeneticAlgorithm(); // Destructor para liberar memoria de la GPU

    // Ejecuta el algoritmo genético
    NodePtr run();

private:
    // Funciones auxiliares internas
    void evaluate_population(Island& island); // Evalúa fitness de una isla (legacy)
    void evaluate_all_islands(); // Evalúa ALL islands in ONE GPU batch call (optimized)
    void evolve_island(Island& island, int current_generation); // Evoluciona una isla por una generación
    void migrate(); // Realiza la migración entre islas
    void update_overall_best(const Island& island); // Actualiza el mejor global
};


#endif // GENETICALGORITHM_H


In [None]:
%%writefile Code/src/GeneticOperators.cpp
#include "GeneticOperators.h"
#include "Globals.h"
#include "Fitness.h"
#include "AdvancedFeatures.h"
#include <vector>
#include <cmath>
#include <algorithm>
#include <numeric>  // For std::iota
#include <map>
#include <stdexcept>
#include <set>
#include <iostream> // Para mensajes de error/info

// Genera un árbol aleatorio (CON TODOS LOS OPERADORES, EXPONENTES SIN RESTRICCIÓN)
NodePtr generate_random_tree(int max_depth, int current_depth) {
    std::uniform_real_distribution<double> prob_dist(0.0, 1.0);
    auto& rng = get_rng();
    double terminal_prob = 0.2 + 0.8 * (static_cast<double>(current_depth) / max_depth);

    if (current_depth >= max_depth || prob_dist(rng) < terminal_prob) {
        // Crear terminal
        if (prob_dist(rng) < TERMINAL_VS_VARIABLE_PROB) { 
             auto var_node = std::make_shared<Node>(NodeType::Variable);
             std::uniform_int_distribution<int> var_dist(0, NUM_VARIABLES - 1);
             var_node->var_index = var_dist(rng);
             return var_node;
        }
        else {
            auto node = std::make_shared<Node>(NodeType::Constant);
            if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cd(CONSTANT_INT_MIN_VALUE, CONSTANT_INT_MAX_VALUE); node->value = static_cast<double>(cd(rng)); }
            else { std::uniform_real_distribution<double> cd(CONSTANT_MIN_VALUE, CONSTANT_MAX_VALUE); node->value = cd(rng); }
            if (std::fabs(node->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) node->value = 0.0;
            return node;
        }
    } else {

        // Crear operador
        auto node = std::make_shared<Node>(NodeType::Operator);
        // Match the weights in Globals.h: +, -, *, /, ^, %, s, c, l, e, !, _, g
        const std::vector<char> ops = {'+', '-', '*', '/', '^', '%', 's', 'c', 'l', 'e', '!', '_', 'g', 'S', 'C', 'T'};
        std::discrete_distribution<int> op_dist(OPERATOR_WEIGHTS.begin(), OPERATOR_WEIGHTS.end());
        node->op = ops[op_dist(rng)];

        bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g' || node->op == 'S' || node->op == 'C' || node->op == 'T');

        // Generar hijos recursivamente
        node->left = generate_random_tree(max_depth, current_depth + 1);
        if (!is_unary) {
            node->right = generate_random_tree(max_depth, current_depth + 1);
        } else {
            node->right = nullptr;
        }

        // Fallback para hijos nulos
        auto generate_random_terminal = [&]() -> NodePtr {
            if (prob_dist(rng) < TERMINAL_VS_VARIABLE_PROB) { 
                auto var_node = std::make_shared<Node>(NodeType::Variable);
                std::uniform_int_distribution<int> var_dist(0, NUM_VARIABLES - 1);
                var_node->var_index = var_dist(rng);
                return var_node;
            }
            else {
                auto const_node = std::make_shared<Node>(NodeType::Constant);
                if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cd(CONSTANT_INT_MIN_VALUE, CONSTANT_INT_MAX_VALUE); const_node->value = static_cast<double>(cd(rng)); }
                else { std::uniform_real_distribution<double> cd(CONSTANT_MIN_VALUE, CONSTANT_MAX_VALUE); const_node->value = cd(rng); }
                if (std::fabs(const_node->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) const_node->value = 0.0;
                return const_node;
            }
        };

        if (!node->left) node->left = generate_random_terminal();
        if (!is_unary && !node->right) node->right = generate_random_terminal();

        // --- Manejo especial para el operador de potencia '^' ---
        if (node->op == '^') {
            // Regla 1: Evitar 0^0 o 0^negativo
            if (node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                if (node->right->type == NodeType::Constant && node->right->value <= SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                    const std::vector<char> safe_ops = {'+', '-', '*'};
                    std::uniform_int_distribution<int> safe_op_dist(0, safe_ops.size() - 1);
                    node->op = safe_ops[safe_op_dist(rng)];
                }
            }
            // Regla 2: Evitar base negativa con exponente no entero
            else if (node->left->type == NodeType::Constant && node->left->value < 0.0) {
                if (node->right->type == NodeType::Constant && std::fabs(node->right->value - std::round(node->right->value)) > SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                     // Change exponent to int
                     std::uniform_int_distribution<int> int_exp_dist(-3, 3);
                     node->right = std::make_shared<Node>(NodeType::Constant);
                     node->right->value = static_cast<double>(int_exp_dist(rng));
                }
            }
        }
        return node;
    }
}

// --- Crea la población inicial (MODIFICADO para inyectar fórmula) ---
// === OPTIMIZACIÓN: Paralelizado con OpenMP ===
std::vector<Individual> create_initial_population(int population_size) {
    std::vector<Individual> population;
    population.resize(population_size); // Pre-allocate all slots for parallel access

    // --- NUEVO: Inyección de Fórmula Inicial ---
    int start_index = 0;
    if (USE_INITIAL_FORMULA && !INITIAL_FORMULA_STRING.empty() && population_size > 0) {
        try {
            NodePtr initial_tree = parse_formula_string(INITIAL_FORMULA_STRING);
            if (initial_tree) {
                population[0] = Individual(std::move(initial_tree));
                start_index = 1;
                std::cout << "[INFO] Injected initial formula: " << INITIAL_FORMULA_STRING << std::endl;
            } else {
                 std::cerr << "[WARNING] Parsing initial formula returned null. Skipping injection." << std::endl;
            }
        } catch (const std::exception& e) {
            std::cerr << "[ERROR] Failed to parse initial formula '" << INITIAL_FORMULA_STRING
                      << "': " << e.what() << ". Skipping injection." << std::endl;
        }
    }
    // -----------------------------------------

    // === OPTIMIZACIÓN: Loop paralelo para generar árboles ===
    #pragma omp parallel for schedule(dynamic, 100)
    for (int i = start_index; i < population_size; ++i) {
        // Cada hilo tiene su propio RNG (thread_local en get_rng)
        auto& rng = get_rng();
        std::uniform_int_distribution<int> depth_dist(3, MAX_TREE_DEPTH_INITIAL);
        
        NodePtr random_tree = nullptr;
        int attempts = 0;
        const int max_attempts = 10;
        while (!random_tree && attempts < max_attempts) {
            random_tree = generate_random_tree(depth_dist(rng));
            attempts++;
        }
        if (random_tree) {
            population[i] = Individual(std::move(random_tree));
        } else {
            // Fallback: crear constante simple
            auto fallback_node = std::make_shared<Node>(NodeType::Constant);
            fallback_node->value = 0.0;
            population[i] = Individual(std::move(fallback_node));
        }
    }
    return population;
}

// --- Selección por torneo con parsimonia ---
Individual tournament_selection(const std::vector<Individual>& population, int tournament_size) {
    if (population.empty()) throw std::runtime_error("Cannot perform tournament selection on empty population.");
    if (tournament_size <= 0) tournament_size = 1;
    tournament_size = std::min(tournament_size, (int)population.size());

    std::uniform_int_distribution<int> dist(0, population.size() - 1);
    auto& rng = get_rng();
    const Individual* best_in_tournament = nullptr;

    int attempts = 0; const int max_attempts = std::min((int)population.size() * 2, 100);
    do {
        best_in_tournament = &population[dist(rng)];
        attempts++;
    } while ((!best_in_tournament || !best_in_tournament->tree || !best_in_tournament->fitness_valid) && attempts < max_attempts);

    if (!best_in_tournament || !best_in_tournament->tree || !best_in_tournament->fitness_valid) {
         if (!population.empty()) return population[0];
         else throw std::runtime_error("Tournament selection couldn't find any valid individual in a non-empty population.");
    }

    for (int i = 1; i < tournament_size; ++i) {
        const Individual& contender = population[dist(rng)];
        if (!contender.tree || !contender.fitness_valid) continue;

        if (contender.fitness < best_in_tournament->fitness) {
            best_in_tournament = &contender;
        }
        else if (std::fabs(contender.fitness - best_in_tournament->fitness) < FITNESS_EQUALITY_TOLERANCE) {
            int contender_size = tree_size(contender.tree);
            int best_size = tree_size(best_in_tournament->tree);
            if (contender_size < best_size) best_in_tournament = &contender;
        }
    }
    return *best_in_tournament;
}

// --- Epsilon-Lexicase Selection Implementation ---
// Calculates residuals on demand if not present (Lazy Eval)
void ensure_errors_computed(Individual& ind, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values) {
    if (!ind.errors.empty()) return; // Already computed
    if (!ind.tree) return;
    
    ind.errors.reserve(targets.size());
    for (size_t i = 0; i < targets.size(); ++i) {
        double val = evaluate_tree(ind.tree, x_values[i]);
        if (std::isnan(val) || std::isinf(val)) {
            ind.errors.push_back(INF);
        } else {
            // Use ABSOLUTE error for lexicase
            ind.errors.push_back(std::fabs(val - targets[i]));
        }
    }
}

Individual lexicase_selection(std::vector<Individual>& population, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values) {
    auto& rng = get_rng();
    
    // 1. Initial Candidates: Random subset (Tournament Size * 2) or Full Population?
    // Efficiency: Using a subset is "Tournament Lexicase". Using Full is "Standard Lexicase".
    // For 50k pop, full lexicase is slow. Let's use a large pool (e.g. 50-100).
    int pool_size = 100; 
    std::vector<Individual*> candidates;
    candidates.reserve(pool_size);
    
    std::uniform_int_distribution<int> dist(0, population.size() - 1);
    for(int i=0; i<pool_size; ++i) {
        Individual& ind = population[dist(rng)];
        if(ind.tree && ind.fitness_valid) candidates.push_back(&ind);
    }
    
    if (candidates.empty()) return population[0]; // Should not happen
    
    // 2. Shuffle test cases
    std::vector<int> cases(targets.size());
    std::iota(cases.begin(), cases.end(), 0);
    std::shuffle(cases.begin(), cases.end(), rng);
    
    // 3. Filter loop
    for (int case_idx : cases) {
        // Compute errors for this case for all candidates (Lazy)
        double min_error = INF;
        for (Individual* cand : candidates) {
            ensure_errors_computed(*cand, targets, x_values);
            if (case_idx < cand->errors.size()) {
                 if (cand->errors[case_idx] < min_error) min_error = cand->errors[case_idx];
            }
        }
        
        // Define epsilon (MAD or simple threshold)
        // Here we use a simple dynamic epsilon based on min_error
        double epsilon = std::max(min_error * 0.1, 1e-5); 
        // Or if min_error is 0, epsilon is 1e-5. Epsilon-Lexicase implies "close enough".
        
        // Filter
        std::vector<Individual*> next_candidates;
        next_candidates.reserve(candidates.size());
        for (Individual* cand : candidates) {
             if (case_idx < cand->errors.size()) {
                 if (cand->errors[case_idx] <= min_error + epsilon) {
                     next_candidates.push_back(cand);
                 }
             }
        }
        
        candidates = std::move(next_candidates);
        if (candidates.empty()) break; // Should not happen given min_error logic
        if (candidates.size() == 1) return *candidates[0];
    }
    
    // If multiple remain, pick random
    if (candidates.empty()) return population[dist(rng)];
    std::uniform_int_distribution<int> pick(0, candidates.size() - 1);
    return *candidates[pick(rng)];
}

// Mutata un árbol (EXPONENTES SIN RESTRICCIÓN en OperatorChange)
NodePtr mutate_tree(const NodePtr& tree, double mutation_rate, int max_depth) {
    auto& rng = get_rng();
    std::uniform_real_distribution<double> prob(0.0, 1.0);
    auto new_tree = clone_tree(tree); // Siempre clonar primero
    if (!new_tree) return nullptr; // Si el árbol original era nulo, el clon también

    if (prob(rng) >= mutation_rate) return new_tree; // No mutar

    std::vector<NodePtr*> nodes; collect_node_ptrs(new_tree, nodes);
    if (nodes.empty()) return new_tree; // No hay nodos para mutar (árbol vacío?)

    std::uniform_int_distribution<int> node_dist(0, nodes.size() - 1);
    int node_idx = node_dist(rng);
    NodePtr* node_to_mutate_ptr = nodes[node_idx];
    if (!node_to_mutate_ptr || !(*node_to_mutate_ptr)) return new_tree; // Puntero o nodo nulo inesperado

    const std::vector<MutationType> mutation_types = {
        MutationType::ConstantChange, MutationType::OperatorChange,
        MutationType::SubtreeReplace, MutationType::NodeInsertion,
        MutationType::NodeDeletion
    };
    std::uniform_int_distribution<int> type_dist(0, mutation_types.size() - 1);
    MutationType mut_type = mutation_types[type_dist(rng)];

    NodePtr& current_node_ptr_ref = *node_to_mutate_ptr;
    Node& current_node = *current_node_ptr_ref;

    // Generar reemplazo aleatorio (usado en varios casos)
    auto generate_replacement = [&](int depth) -> NodePtr {
        NodePtr replacement = generate_random_tree(depth);
        if (!replacement) { // Fallback si la generación falla
            replacement = std::make_shared<Node>(NodeType::Constant);
            replacement->value = 1.0; // Usar 1.0 como fallback simple
        }
        return replacement;
    };


    switch (mut_type) {
        case MutationType::ConstantChange:
             if (current_node.type == NodeType::Constant) {
                 // Cambiar valor de la constante
                 double change_factor = std::uniform_real_distribution<double>(0.8, 1.2)(rng);
                 double add_factor = std::uniform_real_distribution<double>(-1.0, 1.0)(rng);
                 current_node.value = current_node.value * change_factor + add_factor;
                 if (FORCE_INTEGER_CONSTANTS) current_node.value = std::round(current_node.value);
                 if (std::fabs(current_node.value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) current_node.value = 0.0;
             } else {
                 // Si no es constante, reemplazar por un subárbol aleatorio pequeño
                 *node_to_mutate_ptr = generate_replacement(1); // Profundidad 1 (terminal)
             }
            break;
        case MutationType::OperatorChange:
             if (current_node.type == NodeType::Operator) {
                 // Usar la distribución ponderada global para elegir el nuevo operador.
                 // Esto asegura que si 'g' tiene peso alto, sea elegido frecuentemente.
                 // Asumimos que OPERATOR_WEIGHTS tiene 0.0 para operadores deshabilitados.
                 
                 const std::vector<char> all_ops = {'+', '-', '*', '/', '^', '%', 's', 'c', 'l', 'e', '!', '_', 'g', 'S', 'C', 'T'};
                 std::discrete_distribution<int> op_dist(OPERATOR_WEIGHTS.begin(), OPERATOR_WEIGHTS.end());
                 
                 // Verificar si hay al menos 2 operadores habilitados para evitar bucle infinito
                 int enabled_count = 0;
                 for (double w : OPERATOR_WEIGHTS) if (w > 0.0) enabled_count++;
                 
                 if (enabled_count > 1) {
                     int attempts = 0;
                     int new_op_idx = -1;
                     do {
                         new_op_idx = op_dist(rng);
                         attempts++;
                     } while (all_ops[new_op_idx] == current_node.op && attempts < 20); // Intentar cambiar
                     
                     if (attempts < 20) { // Si logramos encontrar uno diferente
                         char old_op = current_node.op;
                         char new_op = all_ops[new_op_idx];
                         
                         bool was_unary = (old_op == 's' || old_op == 'c' || old_op == 'l' || old_op == 'e' || old_op == '!' || old_op == '_' || old_op == 'g');
                         bool is_unary = (new_op == 's' || new_op == 'c' || new_op == 'l' || new_op == 'e' || new_op == '!' || new_op == '_' || new_op == 'g' || new_op == 'S' || new_op == 'C' || new_op == 'T');

                         if (was_unary && !is_unary) {
                             current_node.right = generate_replacement(1);
                         } else if (!was_unary && is_unary) {
                             current_node.right = nullptr;
                         }
                         current_node.op = new_op;
                     }
                 }
             } else {
                 // Si no es operador, reemplazar por un subárbol aleatorio
                  *node_to_mutate_ptr = generate_replacement(max_depth);
             }
            break;
        case MutationType::SubtreeReplace:
            *node_to_mutate_ptr = generate_replacement(max_depth);
            break;
        case MutationType::NodeInsertion:
            {
                auto new_op_node = std::make_shared<Node>(NodeType::Operator);
                // Usar distribución ponderada
                const std::vector<char> all_ops = {'+', '-', '*', '/', '^', '%', 's', 'c', 'l', 'e', '!', '_', 'g', 'S', 'C', 'T'};
                std::discrete_distribution<int> op_dist(OPERATOR_WEIGHTS.begin(), OPERATOR_WEIGHTS.end());
                
                int new_op_idx = op_dist(rng); // Siempre elegirá uno habilitado
                new_op_node->op = all_ops[new_op_idx];

                bool is_unary = (new_op_node->op == 's' || new_op_node->op == 'c' || new_op_node->op == 'l' || new_op_node->op == 'e' || new_op_node->op == '!' || new_op_node->op == '_' || new_op_node->op == 'g' || new_op_node->op == 'S' || new_op_node->op == 'C' || new_op_node->op == 'T');

                new_op_node->left = current_node_ptr_ref;

                if (!is_unary) {
                     if (prob(rng) < MUTATE_INSERT_CONST_PROB) {
                         auto right_child = std::make_shared<Node>(NodeType::Constant);
                         if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cv(MUTATE_INSERT_CONST_INT_MIN, MUTATE_INSERT_CONST_INT_MAX); right_child->value = static_cast<double>(cv(rng)); }
                         else { std::uniform_real_distribution<double> cv(MUTATE_INSERT_CONST_FLOAT_MIN, MUTATE_INSERT_CONST_FLOAT_MAX); right_child->value = cv(rng); }
                         if (std::fabs(right_child->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) right_child->value = 0.0;
                         new_op_node->right = right_child;
                     } else {
                         auto var_node = std::make_shared<Node>(NodeType::Variable);
                         std::uniform_int_distribution<int> var_dist(0, NUM_VARIABLES - 1);
                         var_node->var_index = var_dist(rng);
                         new_op_node->right = var_node;
                     }
                } else {
                    new_op_node->right = nullptr;
                }

                *node_to_mutate_ptr = new_op_node;
            }
            break;
        case MutationType::NodeDeletion:
            {
                // No eliminar la raíz directamente si es la única opción
                if (node_to_mutate_ptr == &new_tree && nodes.size() == 1) return new_tree;

                if (current_node.type == NodeType::Operator) {
                    // Si es operador, reemplazarlo por uno de sus hijos (aleatorio)
                    NodePtr replacement = nullptr;
                    bool has_left = (current_node.left != nullptr);
                    bool has_right = (current_node.right != nullptr);

                    if (has_left && has_right) {
                        replacement = (prob(rng) < 0.5) ? current_node.left : current_node.right;
                    } else if (has_left) {
                        replacement = current_node.left;
                    } else if (has_right) {
                        replacement = current_node.right;
                    }
                    // Si no tiene hijos válidos (¿cómo?), reemplazar por terminal
                    if (!replacement) replacement = generate_replacement(0); // Profundidad 0 (terminal)

                    *node_to_mutate_ptr = replacement;
                } else {
                    // Si es terminal, reemplazar por otro terminal aleatorio
                    // (Evitar eliminar si es la raíz y no hay más nodos)
                     if (node_to_mutate_ptr != &new_tree || nodes.size() > 1) {
                          *node_to_mutate_ptr = generate_replacement(0); // Profundidad 0 (terminal)
                     }
                }
            }
            break;
         default: // Caso inesperado, reemplazar por seguridad
             *node_to_mutate_ptr = generate_replacement(max_depth);
            break;
    }
    if (USE_HARD_DEPTH_LIMIT) trim_tree(new_tree, MAX_TREE_DEPTH_HARD_LIMIT); // Enforce hard limit after mutation
    return new_tree;
}

// Implementación de crossover
Individual crossover(const Individual& parent1, const Individual& parent2) {
    NodePtr tree1_clone = clone_tree(parent1.tree);
    NodePtr tree2_clone = clone_tree(parent2.tree);
    crossover_trees(tree1_clone, tree2_clone);
    if (USE_HARD_DEPTH_LIMIT) trim_tree(tree1_clone, MAX_TREE_DEPTH_HARD_LIMIT); // Enforce hard limit
    return Individual(tree1_clone); // Devolver uno de los hijos, el otro se descarta
}

// Implementación de Semantic Crossover
Individual semantic_crossover(const Individual& p1, const Individual& p2, const std::vector<std::vector<double>>& x_values, int attempts) {
    if (x_values.empty()) return crossover(p1, p2); // Fallback

    auto& rng = get_rng();
    
    // 1. Select Semantic Sample (subset of data)
    int sample_size = std::min((int)x_values.size(), 10);
    std::vector<int> indices(sample_size);
    std::uniform_int_distribution<int> idx_dist(0, x_values.size() - 1);
    for(int i=0; i<sample_size; ++i) indices[i] = idx_dist(rng);
    
    // 2. Compute Parent Semantics
    std::vector<double> sem_p1(sample_size);
    std::vector<double> sem_p2(sample_size);
    bool p1_valid = true, p2_valid = true;
    
    for(int i=0; i<sample_size; ++i) {
        sem_p1[i] = evaluate_tree(p1.tree, x_values[indices[i]]);
        sem_p2[i] = evaluate_tree(p2.tree, x_values[indices[i]]);
        if (std::isnan(sem_p1[i]) || std::isinf(sem_p1[i])) p1_valid = false;
        if (std::isnan(sem_p2[i]) || std::isinf(sem_p2[i])) p2_valid = false;
    }
    
    Individual best_child;
    double max_diversity = -1.0;
    
    for(int k=0; k<attempts; ++k) {
        // Generate child
        Individual child = crossover(p1, p2);
        
        // Compute Child Semantics
        std::vector<double> sem_c(sample_size);
        bool c_valid = true;
        
        double diff_p1 = 0.0;
        double diff_p2 = 0.0;
        
        for(int i=0; i<sample_size; ++i) {
            sem_c[i] = evaluate_tree(child.tree, x_values[indices[i]]);
            if (std::isnan(sem_c[i]) || std::isinf(sem_c[i])) {
                c_valid = false;
                break;
            }
            if (p1_valid) diff_p1 += std::abs(sem_c[i] - sem_p1[i]);
            if (p2_valid) diff_p2 += std::abs(sem_c[i] - sem_p2[i]);
        }
        
        if (!c_valid) continue; // Skip bad children
        
        // Check Semantic Difference (Sensitivity)
        double sensitivity = 1e-4;
        if (diff_p1 > sensitivity && diff_p2 > sensitivity) {
            return child; // Found a semantically unique child!
        }
        
        // Keep the "most different" valid child found so far as fallback
        double diversity = std::min(diff_p1, diff_p2);
        if (diversity > max_diversity) {
            max_diversity = diversity;
            best_child = std::move(child);
        }
    }
    
    // Return best found or just a random crossover if all failed semantic check
    if (best_child.tree) return best_child;
    return crossover(p1, p2);
}

// Implementación de mutate
void mutate(Individual& individual, double mutation_rate) {
    individual.tree = mutate_tree(individual.tree, mutation_rate, MAX_TREE_DEPTH_MUTATION);
    individual.fitness_valid = false; // El fitness se invalida al mutar el árbol
}

// Cruce
void crossover_trees(NodePtr& tree1, NodePtr& tree2) {
    if (!tree1 || !tree2) return; // No cruzar si alguno es nulo

    std::vector<NodePtr*> nodes1, nodes2;
    collect_node_ptrs(tree1, nodes1);
    collect_node_ptrs(tree2, nodes2);

    // No cruzar si alguno no tiene nodos (árbol vacío o solo raíz nula?)
    if (nodes1.empty() || nodes2.empty()) return;

    auto& rng = get_rng();
    std::uniform_int_distribution<int> d1(0, nodes1.size()-1);
    std::uniform_int_distribution<int> d2(0, nodes2.size()-1);

    // Seleccionar puntos de cruce
    NodePtr* crossover_point1 = nodes1[d1(rng)];
    NodePtr* crossover_point2 = nodes2[d2(rng)];

    // Intercambiar los subárboles (los NodePtr)
    std::swap(*crossover_point1, *crossover_point2);
}

// Implementación de simplify_tree
void simplify_tree(NodePtr& tree) {
    if (USE_SIMPLIFICATION) {
        tree = DomainConstraints::fix_or_simplify(tree);
    }
}


In [None]:
%%writefile Code/src/GeneticOperators.h
// ============================================================
// Archivo: src/GeneticOperators.h
// ============================================================
#ifndef GENETICOPERATORS_H
#define GENETICOPERATORS_H

#include "ExpressionTree.h"
#include "Globals.h" // Incluir Globals.h para INF
#include <vector>
#include <memory> // Para std::move

// Estructura para representar un individuo en la población.
// Contiene el árbol de expresión y su fitness cacheado.
struct Individual {
    NodePtr tree; // Puntero inteligente al árbol de expresión
    double fitness = INF; // Fitness cacheado (menor es mejor), inicializado a infinito
    std::vector<double> errors; // Cache of per-case errors for Lexicase Selection
    bool fitness_valid = false; // Indica si el fitness cacheado es válido

    // Constructor por defecto
    Individual() = default;
    // Constructor a partir de un árbol (mueve el puntero)
    explicit Individual(NodePtr t) : tree(std::move(t)) {}

    // Operador de comparación para ordenar individuos (menor fitness primero)
    bool operator<(const Individual& other) const {
        // Manejar casos donde uno o ambos fitness no son válidos
        if (!fitness_valid && !other.fitness_valid) return false; // Iguales si ambos inválidos
        if (!fitness_valid) return false; // Inválido es "peor" que válido (va después)
        if (!other.fitness_valid) return true; // Válido es "mejor" que inválido (va antes)
        // Comparar por fitness si ambos son válidos
        return fitness < other.fitness;
    }
};


// --- Funciones de Operadores Genéticos ---

// Genera un árbol de expresión aleatorio hasta una profundidad máxima.
NodePtr generate_random_tree(int max_depth, int current_depth = 0);

// Crea la población inicial de individuos.
std::vector<Individual> create_initial_population(int population_size);

// Selecciona un individuo usando selección por torneo con presión de parsimonia.
Individual tournament_selection(const std::vector<Individual>& population, int tournament_size);

// Selecciona un individuo usando Epsilon-Lexicase Selection (más inteligente)
Individual lexicase_selection(std::vector<Individual>& population, const std::vector<double>& targets, const std::vector<std::vector<double>>& x_values);

// Realiza el cruce (crossover) entre dos individuos y devuelve un nuevo individuo.
// Realiza el cruce (crossover) entre dos individuos y devuelve un nuevo individuo.
Individual crossover(const Individual& parent1, const Individual& parent2);

// Realiza cruce semántico: intenta generar un hijo que sea funcionalmente distinto a los padres.
Individual semantic_crossover(const Individual& p1, const Individual& p2, const std::vector<std::vector<double>>& x_values, int attempts = 5);

// Mutata un individuo in-place.
void mutate(Individual& individual, double mutation_rate);

// Simplifica un árbol in-place.
void simplify_tree(NodePtr& tree);

// Tipos de mutación posibles.
enum class MutationType {
    ConstantChange,
    OperatorChange,
    SubtreeReplace,
    NodeInsertion,
    NodeDeletion // <-- AÑADIDO: Tipo para eliminar un nodo
    // Simplification (manejado por DomainConstraints)
};

// Mutata un árbol aplicando uno de los tipos de mutación con cierta probabilidad.
// Devuelve un nuevo árbol (clonado y potencialmente mutado).
NodePtr mutate_tree(const NodePtr& tree, double mutation_rate, int max_depth);

// Realiza el cruce (crossover) entre dos árboles padres, modificándolos in-place.
void crossover_trees(NodePtr& tree1, NodePtr& tree2);


#endif // GENETICOPERATORS_H


In [None]:
%%writefile Code/src/main.cpp
#include "Globals.h" // Necesario para las constantes globales
#include "GeneticAlgorithm.h"
#include "Fitness.h" // Para evaluate_fitness
#include "ExpressionTree.h" // Para tree_to_string si se necesita aquí
#include <iostream>
#include <vector>
#include <memory> // Para shared_ptr
#include <iomanip> // Para std::setprecision
#include <omp.h>   // Para configuración de OpenMP

#include <fstream> // Para leer archivo
#include <string>
#include <sstream>

// Definición de variable global externa
int NUM_VARIABLES = 1;

int main(int argc, char* argv[]) {
    // === OPTIMIZACIÓN: Configuración explícita de hilos OpenMP ===
    int num_threads = omp_get_max_threads();
    omp_set_num_threads(num_threads);
    std::cout << "[OpenMP] Using " << num_threads << " threads" << std::endl;
    
    // Configurar precisión de salida para números flotantes
    // Force immediate flush for each output (important for subprocess capture)
    std::cout << std::unitbuf << std::fixed << std::setprecision(6);
    
    std::vector<std::string> seed_formulas;
    std::string seed_file_path = "";
    std::string data_file_path = "";
    
    // Parse arguments
    for (int i = 1; i < argc; ++i) {
        std::string arg = argv[i];
        if ((arg == "--seed" || arg == "-s") && i + 1 < argc) {
             seed_file_path = argv[i + 1];
             i++; // Skip next arg
        } else if ((arg == "--data" || arg == "-d") && i + 1 < argc) {
             data_file_path = argv[i + 1];
             i++;
        }
    }
    
    if (!seed_file_path.empty()) {
        std::cout << "Loading seeds from: " << seed_file_path << std::endl;
        std::ifstream file(seed_file_path);
        if (file.is_open()) {
            std::string line;
            while (std::getline(file, line)) {
                if (!line.empty()) {
                    seed_formulas.push_back(line);
                }
            }
            file.close();
            std::cout << "Loaded " << seed_formulas.size() << " formulas." << std::endl;
        } else {
            std::cerr << "[Error] Could not open seed file: " << seed_file_path << std::endl;
        }
    }

    std::vector<double> targets;
    // MODIFIED: final_x_values is now vector<vector<double>>
    std::vector<std::vector<double>> final_x_values;

    if (!data_file_path.empty()) {
         std::cout << "Loading data from: " << data_file_path << std::endl;
         std::ifstream dfile(data_file_path);
         if (dfile.is_open()) {
             // Format:
             // Line 1: x1 x2 x3 ... (Assumed univariable if using this legacy format)
             // Line 2: y1 y2 y3 ...
             // Values separated by space or comma
             
             // Helper lambda to parse line
             auto parse_line = [](const std::string& line) {
                 std::vector<double> vals;
                 std::stringstream ss(line);
                 double val;
                 while (ss >> val) {
                     vals.push_back(val);
                     if (ss.peek() == ',' || ss.peek() == ' ') ss.ignore();
                 }
                 return vals;
             };
             
             std::string line;
             std::vector<std::vector<double>> all_lines;
             while (std::getline(dfile, line)) {
                 if(!line.empty()) {
                     all_lines.push_back(parse_line(line));
                 }
             }
             dfile.close();
             
             if (all_lines.size() < 2) {
                 std::cerr << "[Error] Insufficient data in file (Need at least 1 feature line and 1 target line)." << std::endl;
                 return 1;
             }
             
             targets = all_lines.back();
             all_lines.pop_back(); // Now all_lines contains only features as rows
             
             size_t n_samples = targets.size();
             size_t n_vars = all_lines.size();
             NUM_VARIABLES = (int)n_vars;
             
             // Transpose: from [n_vars][n_samples] to [n_samples][n_vars]
             final_x_values.clear();
             final_x_values.reserve(n_samples);
             
             for (size_t s = 0; s < n_samples; ++s) {
                 std::vector<double> sample_vars;
                 sample_vars.reserve(n_vars);
                 for (size_t v = 0; v < n_vars; ++v) {
                     if (s < all_lines[v].size()) {
                         sample_vars.push_back(all_lines[v][s]);
                     } else {
                         sample_vars.push_back(0.0); // Fallback for mismatched lines
                     }
                 }
                 final_x_values.push_back(sample_vars);
             }

             std::cout << "Loaded " << final_x_values.size() << " data points with " << NUM_VARIABLES << " variables from file." << std::endl;
         } else {
             std::cerr << "[Error] Could not open data file: " << data_file_path << std::endl;
             return 1;
         }
    } else {
        // Fallback to Globals.h
        if (USE_LOG_TRANSFORMATION) {
             std::cout << "Info: Log Transformation is ON (Target = ln(Q(N)))." << std::endl;
             for (size_t i = 0; i < RAW_TARGETS.size(); ++i) {
                 if (RAW_TARGETS[i] > 0) {
                     targets.push_back(std::log(RAW_TARGETS[i]));
                     final_x_values.push_back(X_VALUES[i]);
                 }
             }
        } else {
             std::cout << "Info: Log Transformation is OFF." << std::endl;
             targets = RAW_TARGETS;
             final_x_values = X_VALUES;
        }
        
        // Update NUM_VARIABLES based on data
        if (!final_x_values.empty()) {
            NUM_VARIABLES = final_x_values[0].size();
        } else {
            NUM_VARIABLES = 1; // Default
        }
    }

    std::cout << "Target Function Points (Effective):" << std::endl;
    std::cout << "NUM_VARIABLES set to: " << NUM_VARIABLES << std::endl;
    // Imprimir los puntos objetivo
    for (size_t i = 0; i < targets.size(); ++i) {
        std::cout << "  f(";
        for(size_t v=0; v<final_x_values[i].size(); ++v) {
            std::cout << final_x_values[i][v];
            if(v < final_x_values[i].size()-1) std::cout << ", ";
        }
        std::cout << ") = " << targets[i] << std::endl;
    }
    std::cout << "----------------------------------------" << std::endl;
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    std::cout << "Info: Running with GPU acceleration." << std::endl;
#else
    std::cout << "Info: Running with CPU acceleration." << std::endl;
#endif
    std::cout << "----------------------------------------" << std::endl;
    std::cout << "Parameters:" << std::endl;
    // Imprimir los parámetros globales definidos en Globals.h
    std::cout << "  Total Population: " << TOTAL_POPULATION_SIZE << std::endl;
    std::cout << "  Generations: " << GENERATIONS << std::endl;
    std::cout << "  Islands: " << NUM_ISLANDS << std::endl;
    std::cout << "  Migration Interval: " << MIGRATION_INTERVAL << std::endl;
    std::cout << "  Migration Size: " << MIGRATION_SIZE << std::endl;
    // --- NOMBRES CORREGIDOS ---
    std::cout << "  Mutation Rate (Initial): " << BASE_MUTATION_RATE << std::endl; // <-- Nombre corregido
    std::cout << "  Elite Percentage (Initial): " << BASE_ELITE_PERCENTAGE << std::endl; // <-- Nombre corregido
    // --------------------------
    std::cout << "----------------------------------------" << std::endl;


    try {
        // Crear la instancia del Algoritmo Genético
        // Pasa las referencias a los vectores de datos y los parámetros principales
        GeneticAlgorithm ga(targets, final_x_values, TOTAL_POPULATION_SIZE, GENERATIONS, seed_formulas);

        // Ejecutar el algoritmo
        // La función run() contiene el bucle principal de generaciones y devuelve el mejor árbol encontrado
        NodePtr best_solution_tree = ga.run();

        // La función run() ya imprime el resumen final y la verificación.
        // Comprobar si se encontró alguna solución válida al final
        if (!best_solution_tree) {
            std::cerr << "\nFailed to find any valid solution." << std::endl;
            return 1; // Salir con código de error si no se encontró solución
        }
    } catch (const std::exception& e) {
        std::cerr << "[CRITICAL ERROR] Exception caught in main: " << e.what() << std::endl;
        return 2;
    } catch (...) {
        std::cerr << "[CRITICAL ERROR] Unknown exception caught in main." << std::endl;
        return 3;
    }

    return 0; // Salir con éxito
}


In [None]:
%%writefile Code/src/TestOperators.cpp

#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>
#include <string>
#include <cassert>
#include "ExpressionTree.h"
#include "AdvancedFeatures.h"
#include "Fitness.h"
#include "Globals.h"

// Define global variable for tests
int NUM_VARIABLES = 1;

// --- Helper Macros for Testing ---
#define ASSERT_NEAR(val1, val2, tol) \
    if (std::fabs((val1) - (val2)) > (tol)) { \
        std::cerr << "[FAIL] Line " << __LINE__ << ": Expected " << (val2) << ", got " << (val1) << " (diff: " << std::fabs((val1)-(val2)) << ")" << std::endl; \
        return false; \
    }

#define ASSERT_TRUE(cond) \
    if (!(cond)) { \
        std::cerr << "[FAIL] Line " << __LINE__ << ": Condition failed: " #cond << std::endl; \
        return false; \
    }

#define ASSERT_INF(val) \
    if ((val) != INF && !std::isinf(val)) { \
        std::cerr << "[FAIL] Line " << __LINE__ << ": Expected INF, got " << (val) << std::endl; \
        return false; \
    }

// =============================
// BINARY OPERATORS
// =============================
bool test_binary_operators() {
    std::cout << "Testing Binary Operators..." << std::endl;
    
    NodePtr root; double val;

    // Addition
    root = parse_formula_string("2+3");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 5.0, 1e-9);

    // Subtraction
    root = parse_formula_string("10-4");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 6.0, 1e-9);

    // Multiplication
    root = parse_formula_string("3*4");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 12.0, 1e-9);

    // Division
    root = parse_formula_string("10/2");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 5.0, 1e-9);

    // Power
    root = parse_formula_string("2^3");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 8.0, 1e-9);

    // Modulo
    root = parse_formula_string("10%3");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.0, 1e-9);

    // Negative numbers
    root = parse_formula_string("-5+3");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, -2.0, 1e-9);

    // Operator precedence: 2+3*4 = 14
    root = parse_formula_string("2+3*4");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 14.0, 1e-9);

    // Parentheses override: (2+3)*4 = 20
    root = parse_formula_string("(2+3)*4");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 20.0, 1e-9);

    std::cout << "  -> Binary Operators Passed" << std::endl;
    return true;
}

// =============================
// UNARY OPERATORS
// =============================
bool test_unary_operators() {
    std::cout << "Testing Unary Operators..." << std::endl;
    NodePtr root; double val;

    // sin
    root = parse_formula_string("sin(0)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 0.0, 1e-9);
    
    root = parse_formula_string("sin(1.5708)"); // pi/2
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.0, 1e-4);

    // cos
    root = parse_formula_string("cos(0)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.0, 1e-9);

    // log (protected: log(|x|))
    root = parse_formula_string("log(2.7182818)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.0, 1e-5);

    // exp
    root = parse_formula_string("exp(1)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 2.7182818, 1e-5);

    // floor
    root = parse_formula_string("floor(2.9)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 2.0, 1e-9);
    
    root = parse_formula_string("floor(-2.1)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, -3.0, 1e-9);

    // lgamma: Implementation is lgamma(|x|+1) => ln(|x|!)
    // lgamma(3) -> lgamma(4) = ln(3!) = ln(6) = 1.791759
    root = parse_formula_string("lgamma(3)"); 
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.791759, 1e-4);

    // g(x) alias
    root = parse_formula_string("g(3)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.791759, 1e-4);

    // Factorial (!): Implementation is tgamma(|x|+1) = |x|!
    // gamma(4) = 3! = 6
    root = parse_formula_string("gamma(3)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, 6.0, 1e-4);

    std::cout << "  -> Unary Operators Passed" << std::endl;
    return true;
}

// =============================
// EDGE CASES (Protection)
// =============================
bool test_edge_cases() {
    std::cout << "Testing Edge Cases..." << std::endl;
    NodePtr root; double val;

    // Division by zero -> INF
    root = parse_formula_string("1/0");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    // Modulo by zero -> INF
    root = parse_formula_string("5%0");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    // log(0) -> INF (protected)
    root = parse_formula_string("log(0)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    // exp(800) -> INF (overflow)
    root = parse_formula_string("exp(800)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    // 0^(-1) -> INF
    root = parse_formula_string("0^(-1)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    // Factorial of large number -> INF
    root = parse_formula_string("gamma(200)");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    // Negative base with non-integer exp -> INF (complex result)
    root = parse_formula_string("(-2)^0.5");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_INF(val);

    std::cout << "  -> Edge Cases Passed" << std::endl;
    return true;
}

// =============================
// SIMPLIFICATION RULES
// =============================
bool test_simplification() {
    std::cout << "Testing Simplification..." << std::endl;
    NodePtr root, simplified; double val; std::string str;

    // x - x -> 0
    root = parse_formula_string("x-x");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, std::vector<double>{5.0});
    ASSERT_NEAR(val, 0.0, 1e-9);

    // x / x -> 1
    root = parse_formula_string("x/x");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, std::vector<double>{5.0});
    ASSERT_NEAR(val, 1.0, 1e-9);

    // Constant Folding: 2+3 -> 5
    root = parse_formula_string("2+3");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, 0.0);
    ASSERT_NEAR(val, 5.0, 1e-9);
    ASSERT_TRUE(tree_size(simplified) == 1);

    // x + 0 -> x
    root = parse_formula_string("x+0");
    simplified = DomainConstraints::fix_or_simplify(root);
    str = tree_to_string(simplified);
    ASSERT_TRUE(str == "x0");

    // x * 1 -> x
    root = parse_formula_string("x*1");
    simplified = DomainConstraints::fix_or_simplify(root);
    str = tree_to_string(simplified);
    ASSERT_TRUE(str == "x0");

    // x * 0 -> 0
    root = parse_formula_string("x*0");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, std::vector<double>{100.0});
    ASSERT_NEAR(val, 0.0, 1e-9);

    // x^1 -> x
    root = parse_formula_string("x^1");
    simplified = DomainConstraints::fix_or_simplify(root);
    str = tree_to_string(simplified);
    ASSERT_TRUE(str == "x0");

    // x^0 -> 1
    root = parse_formula_string("x^0");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, std::vector<double>{100.0});
    ASSERT_NEAR(val, 1.0, 1e-9);

    // Unary Operator Preservation: lgamma(x) should NOT simplify to x
    root = parse_formula_string("lgamma(x)");
    simplified = DomainConstraints::fix_or_simplify(root);
    str = tree_to_string(simplified);
    ASSERT_TRUE(str.find("lgamma") != std::string::npos || str.find("g(") != std::string::npos);

    // Unary Constant Folding: lgamma(3) -> constant
    root = parse_formula_string("lgamma(3)");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, std::vector<double>{0.0});
    ASSERT_NEAR(val, 1.791759, 1e-4);
    ASSERT_TRUE(tree_size(simplified) == 1);

    // sin(0) -> 0 (constant folding)
    root = parse_formula_string("sin(0)");
    simplified = DomainConstraints::fix_or_simplify(root);
    val = evaluate_tree(simplified, std::vector<double>{0.0});
    ASSERT_NEAR(val, 0.0, 1e-9);
    ASSERT_TRUE(tree_size(simplified) == 1);

    std::cout << "  -> Simplification Passed" << std::endl;
    return true;
}

// =============================
// COMPLEX PARSING
// =============================
bool test_complex_parsing() {
    std::cout << "Testing Complex Parsing..." << std::endl;
    NodePtr root; double val;

    // Nested functions: sin(cos(0)) = sin(1) ≈ 0.8415
    root = parse_formula_string("sin(cos(0))");
    val = evaluate_tree(root, std::vector<double>{0.0});
    ASSERT_NEAR(val, std::sin(1.0), 1e-4);

    // Mixed: lgamma(x+1) at x=3 -> lgamma(4) = lgamma(5) = ln(4!) = ln(24) ≈ 3.178
    root = parse_formula_string("lgamma(x+1)");
    val = evaluate_tree(root, std::vector<double>{3.0});
    ASSERT_NEAR(val, std::lgamma(5.0), 1e-4);

    // Formula from project: (g(x)-((x*909613)/1000000))+0.24423 at x=4
    root = parse_formula_string("(g(x)-((x*909613)/1000000))+0.24423");
    val = evaluate_tree(root, std::vector<double>{4.0});
    double expected = std::lgamma(5.0) - (4.0 * 909613.0 / 1000000.0) + 0.24423;
    ASSERT_NEAR(val, expected, 1e-4);

    // Implicit multiplication: 2x at x=3 -> 6
    root = parse_formula_string("2x");
    val = evaluate_tree(root, std::vector<double>{3.0});
    ASSERT_NEAR(val, 6.0, 1e-9);

    // Deep nesting: exp(log(x)) at x=5 -> 5
    root = parse_formula_string("exp(log(x))");
    val = evaluate_tree(root, std::vector<double>{5.0});
    ASSERT_NEAR(val, 5.0, 1e-4);

    // Chained operations: x^2+2*x+1 at x=3 -> 16
    root = parse_formula_string("x^2+2*x+1");
    val = evaluate_tree(root, std::vector<double>{3.0});
    ASSERT_NEAR(val, 16.0, 1e-9);

    std::cout << "  -> Complex Parsing Passed" << std::endl;
    return true;
}

// =============================
// FITNESS CALCULATION
// =============================
bool test_fitness_calc() {
    std::cout << "Testing Fitness Calculation..." << std::endl;

    std::vector<double> targets = {1.0, 2.0, 3.0};
    std::vector<std::vector<double>> x_values = {{1.0}, {2.0}, {3.0}};

    // Perfect solution: x
    NodePtr solution = parse_formula_string("x0");
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    double fitness = evaluate_fitness(solution, targets, x_values, (double*)nullptr, (double*)nullptr);
#else
    double fitness = evaluate_fitness(solution, targets, x_values);
#endif
    ASSERT_NEAR(fitness, 0.0, 1e-9);

    // Imperfect solution: x+1 (errors: 1, 1, 1)
    NodePtr imperfect = parse_formula_string("x+1");
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    fitness = evaluate_fitness(imperfect, targets, x_values, (double*)nullptr, (double*)nullptr);
#else
    fitness = evaluate_fitness(imperfect, targets, x_values);
#endif
    ASSERT_TRUE(fitness > 0.001);

    // Very bad solution: constant 100
    NodePtr bad = parse_formula_string("100");
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    double bad_fitness = evaluate_fitness(bad, targets, x_values, (double*)nullptr, (double*)nullptr);
#else
    double bad_fitness = evaluate_fitness(bad, targets, x_values);
#endif
    ASSERT_TRUE(bad_fitness > fitness); // Worse than x+1

    std::cout << "  -> Fitness Calculation Passed" << std::endl;
    return true;
}

// =============================
// TREE UTILITIES
// =============================
bool test_tree_utilities() {
    std::cout << "Testing Tree Utilities..." << std::endl;
    
    // tree_size
    NodePtr root = parse_formula_string("x+1");
    ASSERT_TRUE(tree_size(root) == 3); // +, x, 1

    root = parse_formula_string("lgamma(x)");
    ASSERT_TRUE(tree_size(root) == 2); // lgamma, x

    // tree_to_string roundtrip
    root = parse_formula_string("(x+1)*2");
    std::string str = tree_to_string(root);
    ASSERT_TRUE(str.find("x") != std::string::npos);
    ASSERT_TRUE(str.find("1") != std::string::npos);
    ASSERT_TRUE(str.find("2") != std::string::npos);

    // clone_tree
    NodePtr cloned = clone_tree(root);
    ASSERT_TRUE(tree_to_string(cloned) == tree_to_string(root));
    ASSERT_TRUE(cloned.get() != root.get()); // Different pointers

    std::cout << "  -> Tree Utilities Passed" << std::endl;
    return true;
}

// =============================
// MAIN
// =============================
int main() {
    std::cout << "=======================================" << std::endl;
    std::cout << " Running Comprehensive Operator Tests " << std::endl;
    std::cout << "=======================================" << std::endl;

    bool all_passed = true;
    all_passed &= test_binary_operators();
    all_passed &= test_unary_operators();
    all_passed &= test_edge_cases();
    all_passed &= test_simplification();
    all_passed &= test_complex_parsing();
    all_passed &= test_fitness_calc();
    all_passed &= test_tree_utilities();

    std::cout << "=======================================" << std::endl;
    if (all_passed) {
        std::cout << "ALL TESTS PASSED (" << 7 << " test suites)" << std::endl;
        return 0;
    } else {
        std::cerr << "SOME TESTS FAILED" << std::endl;
        return 1;
    }
}


In [None]:
%%writefile Code/src/Globals.h
#ifndef GLOBALS_H
#define GLOBALS_H

#include <vector>
#include <random>
#include <string>
#include <limits>
#include <cmath>

// ============================================================
//                  PARÁMETROS GLOBALES
// ============================================================

// ----------------------------------------
// Datos del Problema (Regresión Simbólica)
// ----------------------------------------
// MODIFICADO: Usamos log(TARGETS) para aplanar el crecimiento exponencial.
// X representa N. TARGETS_LOG es ln(Q(N)).
// Se han filtrado valores N<4 donde Q(N) es 0 o pequeño irrelevante.

// MODIFICADO: RAW_TARGETS contiene los datos crudos. TARGETS se generará en runtime.
const std::vector<double> RAW_TARGETS = {2, 10, 4, 40, 92, 352, 724, 2680, 14200, 73712, 365596, 2279184, 14772512, 95815104, 666090624, 4968057848, 39029188884};
// MODIFICADO: X_VALUES ahora es vector<vector<double>> para soporte multivariable.
// Inicializador por defecto para problema univariable.
const std::vector<std::vector<double>> X_VALUES = {
    {1, 1, 1},   // 1
    {2, 2, 0},   // 2
    {3, 3, 1},   // 3
    {4, 4, 0},   // 4
    {5, 5, 1},   // 5
    {6, 0, 0},   // 6
    {7, 1, 1},   // 7
    {8, 2, 0},   // 8
    {9, 3, 1},   // 9
    {10, 4, 0},  // 10
    {11, 5, 1},  // 11
    {12, 0, 0},  // 12
    {13, 1, 1},  // 13
    {14, 2, 0},  // 14
    {15, 3, 1},  // 15
    {16, 4, 0},  // 16
    {17, 5, 1},  // 17
    {18, 0, 0},  // 18
    {19, 1, 1},  // 19
    {20, 2, 0},  // 20
    {21, 3, 1},  // 21
    {22, 4, 0},  // 22
    {23, 5, 1},  // 23
    {24, 0, 0},  // 24
    {25, 1, 1},  // 25
    {26, 2, 0}   // 26
};extern int NUM_VARIABLES; // Definido en Globals.cpp o main.cpp

// Flag para activar la transformación logarítmica automática
const bool USE_LOG_TRANSFORMATION = false;

// ----------------------------------------
// Configuración General del Algoritmo Genético
// ----------------------------------------
// Controla si se utiliza la aceleración por GPU.
// FORCE_CPU_MODE: Si es true, usa CPU aunque CUDA esté disponible (útil para comparar rendimiento)
const bool FORCE_CPU_MODE = true;  // Cambiar a 'true' para forzar modo CPU

// USE_GPU_ACCELERATION se define automáticamente por CMake si CUDA está disponible
// Pero si FORCE_CPU_MODE es true, se ignora y usa CPU
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
const bool USE_GPU_ACCELERATION = !FORCE_CPU_MODE;
#else
const bool USE_GPU_ACCELERATION = false;
#endif
// Aumentamos el tamaño de la población y el número de generaciones para maximizar la utilización de la GPU,
// ya que la GPU puede procesar un gran número de individuos en paralelo.
// Ajustamos el tamaño de la población para una GPU con 4GB de VRAM (RTX 3050),
// buscando un equilibrio entre el aprovechamiento de la GPU y el uso de memoria.
// Para hacer un uso aún más intensivo de la GPU y acelerar el algoritmo,
// OPTIMIZADO para Hybrid Search: Población más pequeña = convergencia más rápida en timeouts cortos
const int TOTAL_POPULATION_SIZE = 5000; // Reducido de 50000 para convergencia rápida
const int GENERATIONS = 50000;           // Reducido (timeout domina de todas formas)
const int NUM_ISLANDS = 5;               // Menos islas = más foco por isla
const int MIN_POP_PER_ISLAND = 10;        

// --- Fórmula Inicial ---
const bool USE_INITIAL_FORMULA = false; // Poner en 'true' para inyectar la fórmula
const std::string INITIAL_FORMULA_STRING = "log((x1+exp((((((1.28237193+((x0+2.59195138)+8.54688985))*x0)+(log((((x2/-0.99681346)-(x0-8.00219939))/(0.35461932-x2)))+(x0+(88.95319019/((x0+x0)+x0)))))-x1)/((exp(exp(((exp(x2)*(1.39925709/x0))^exp(x0))))+0.76703064)*6.05423753)))))";

// ----------------------------------------
// Parámetros del Modelo de Islas
// ----------------------------------------
// Aumentamos el intervalo y tamaño de migración para permitir que las islas realicen más trabajo en paralelo
// antes de intercambiar individuos, reduciendo la sobrecarga de comunicación y maximizando el procesamiento GPU.
const int MIGRATION_INTERVAL = 100; // Incrementado para permitir más trabajo por isla entre migraciones
const int MIGRATION_SIZE = 50;      // Incrementado para una migración más sustancial

// ----------------------------------------
// Parámetros de Generación Inicial de Árboles
// ----------------------------------------
const int MAX_TREE_DEPTH_INITIAL = 8; // Reducido para fórmulas iniciales más simples y rápidas
const double TERMINAL_VS_VARIABLE_PROB = 0.75;
const double CONSTANT_MIN_VALUE = -10.0;
const double CONSTANT_MAX_VALUE = 10.0;
const int CONSTANT_INT_MIN_VALUE = -10;
const int CONSTANT_INT_MAX_VALUE = 10;
const bool USE_HARD_DEPTH_LIMIT = true; // Toggle for hard depth limit
const int MAX_TREE_DEPTH_HARD_LIMIT = 12; // Hard limit to prevent bloat
// Order: +, -, *, /, ^, %, s, c, l, e, !, _, g
// ----------------------------------------
// Parámetros de Operadores Genéticos (Configuración de Operadores)
// ----------------------------------------
const bool USE_OP_PLUS     = true; // +
const bool USE_OP_MINUS    = true; // -
const bool USE_OP_MULT     = true; // *
const bool USE_OP_DIV      = true; // /
const bool USE_OP_POW      = true; // ^
const bool USE_OP_MOD      = false; // % (DISABLED)
const bool USE_OP_SIN      = false; // s (DISABLED)
const bool USE_OP_COS      = false; // c (DISABLED)
const bool USE_OP_LOG      = true; // l
const bool USE_OP_EXP      = true; // e
const bool USE_OP_FACT     = false; // ! (DISABLED - using lgamma instead)
const bool USE_OP_FLOOR    = false; // _ (DISABLED)
const bool USE_OP_GAMMA    = true; // g
const bool USE_OP_ASIN     = false; // S (DISABLED)
const bool USE_OP_ACOS     = false; // C (DISABLED)
const bool USE_OP_ATAN     = false; // T (DISABLED)

// Order: +, -, *, /, ^, %, s, c, l, e, !, _, g
// Los pesos se multiplican por el flag (0 o 1) para habilitar/deshabilitar.
const std::vector<double> OPERATOR_WEIGHTS = {
    0.20 * (USE_OP_PLUS  ? 1.0 : 0.0), // +
    0.20 * (USE_OP_MINUS ? 1.0 : 0.0), // -
    0.20 * (USE_OP_MULT  ? 1.0 : 0.0), // *
    0.15 * (USE_OP_DIV   ? 1.0 : 0.0), // /
    0.10 * (USE_OP_POW   ? 1.0 : 0.0), // ^
    0.02 * (USE_OP_MOD   ? 1.0 : 0.0), // %
    0.10 * (USE_OP_SIN   ? 1.0 : 0.0), // s
    0.10 * (USE_OP_COS   ? 1.0 : 0.0), // c
    0.05 * (USE_OP_LOG   ? 1.0 : 0.0), // l
    0.05 * (USE_OP_EXP   ? 1.0 : 0.0), // e
    0.01 * (USE_OP_FACT  ? 1.0 : 0.0), // !
    0.01 * (USE_OP_FLOOR ? 1.0 : 0.0), // _
    0.01 * (USE_OP_GAMMA ? 1.0 : 0.0), // g
    0.01 * (USE_OP_ASIN  ? 1.0 : 0.0), // S
    0.01 * (USE_OP_ACOS  ? 1.0 : 0.0), // C
    0.01 * (USE_OP_ATAN  ? 1.0 : 0.0)  // T
};

// ----------------------------------------
// Parámetros de Operadores Genéticos (Mutación, Cruce, Selección)
// ----------------------------------------
const double BASE_MUTATION_RATE = 0.30;
const double BASE_ELITE_PERCENTAGE = 0.15;
const double DEFAULT_CROSSOVER_RATE = 0.85;
const int DEFAULT_TOURNAMENT_SIZE = 30;
const int MAX_TREE_DEPTH_MUTATION = 8; // Slight increase to allow complexity
const double MUTATE_INSERT_CONST_PROB = 0.6;
const int MUTATE_INSERT_CONST_INT_MIN = 1;
const int MUTATE_INSERT_CONST_INT_MAX = 5;
const double MUTATE_INSERT_CONST_FLOAT_MIN = 0.5;
const double MUTATE_INSERT_CONST_FLOAT_MAX = 5.0;

// ----------------------------------------
// Parámetros de Fitness y Evaluación
// ----------------------------------------
// Reducimos ligeramente la penalización por complejidad para permitir que fórmulas más complejas
// (y computacionalmente más intensivas para la GPU) sean favorecidas por el algoritmo.
// MODIFICADO: Ajustado para ser menos agresivo y permitir multivariable.
const double COMPLEXITY_PENALTY_FACTOR = 0.01; // Was 0.05. Reduced to 0.01.
const bool USE_RMSE_FITNESS = true;
const double FITNESS_ORIGINAL_POWER = 1.3;
const double FITNESS_PRECISION_THRESHOLD = 0.001;
const double FITNESS_PRECISION_BONUS = 0.0001;
const double FITNESS_EQUALITY_TOLERANCE = 1e-9;
const double EXACT_SOLUTION_THRESHOLD = 1e-8;

// ----------------------------------------
// Fitness Ponderado (Weighted Fitness)
// ----------------------------------------
// Activa el fitness ponderado para penalizar fuertemente errores en valores altos de N.
// Esto destruye a las parábolas que fallan en N=20 pero dan buen promedio general.
const bool USE_WEIGHTED_FITNESS = false;
// Tipo de peso: "quadratic" usa i*i, "exponential" usa exp(i*WEIGHTED_FITNESS_EXPONENT)
// Exponente para peso exponencial (más agresivo). Usar 0.2-0.3 para datasets pequeños.
const double WEIGHTED_FITNESS_EXPONENT = 0.25;

// ----------------------------------------
// Parámetros de Características Avanzadas
// ----------------------------------------
const int STAGNATION_LIMIT_ISLAND = 50;
// Lowered from 5000 to allow faster early termination in Hybrid Search mode.
// If best fitness doesn't improve for N generations, terminate early.
const int GLOBAL_STAGNATION_LIMIT = 100; // Reducido para terminar más rápido si no mejora
const double STAGNATION_RANDOM_INJECT_PERCENT = 0.1;
const int PARAM_MUTATE_INTERVAL = 50;
const double PATTERN_RECORD_FITNESS_THRESHOLD = 10.0;
const int PATTERN_MEM_MIN_USES = 3;
const int PATTERN_INJECT_INTERVAL = 10;
const double PATTERN_INJECT_PERCENT = 0.05;
const size_t PARETO_MAX_FRONT_SIZE = 50;
const double SIMPLIFY_NEAR_ZERO_TOLERANCE = 1e-9;
const double SIMPLIFY_NEAR_ONE_TOLERANCE = 1e-9;
const int LOCAL_SEARCH_ATTEMPTS = 30;
// Simplification Toggle
const bool USE_SIMPLIFICATION = true;
// Anti-Stagnation: Island Cataclysm (Hard Reset)
const bool USE_ISLAND_CATACLYSM = true;
// Selection Strategy: Epsilon-Lexicase Selection (Replaces Tournament)
const bool USE_LEXICASE_SELECTION = true;

// ----------------------------------------
// Otros Parámetros
// ----------------------------------------
const int PROGRESS_REPORT_INTERVAL = 100;
// Optimizaciones adicionales:
// Deshabilitamos las constantes enteras forzadas para permitir una mayor flexibilidad
// en las constantes generadas y mutadas, lo que podría conducir a mejores soluciones
// y mantener la GPU ocupada con un rango más amplio de valores.
const bool FORCE_INTEGER_CONSTANTS = false; // Mantenemos false para mayor flexibilidad

// ----------------------------------------
// Control de Duplicados
// ----------------------------------------
const bool PREVENT_DUPLICATES = true; // Activa la verificación de unicidad
const int DUPLICATE_RETRIES = 10;     // Intentos para generar un individuo único antes de rendirse


// ============================================================
//                  UTILIDADES GLOBALES
// ============================================================
std::mt19937& get_rng();
const double INF = std::numeric_limits<double>::infinity();

#endif // GLOBALS_H


In [None]:
%%writefile Code/CMakeLists.txt
cmake_minimum_required(VERSION 3.18)

# Initially only enable CXX. We will enable CUDA if found.
project(SymbolicRegressionGP LANGUAGES CXX)

# Fix for "PTX was compiled with an unsupported toolchain" on Colab T4
# This forces generation of SASS (binary) for the current GPU, avoiding JIT issues.
set(CMAKE_CUDA_ARCHITECTURES native)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Find OpenMP (Required for CPU parallelism)
find_package(OpenMP REQUIRED COMPONENTS CXX)

# --- CUDA Detection ---
include(CheckLanguage)
check_language(CUDA)

set(SOURCES
    src/main.cpp
    src/ExpressionTree.cpp
    src/Fitness.cpp
    src/GeneticOperators.cpp
    src/AdvancedFeatures.cpp
    src/GeneticAlgorithm.cpp
    src/GradientOptimizer.cpp
)

if(CMAKE_CUDA_COMPILER)
    message(STATUS "CUDA compiler found: ${CMAKE_CUDA_COMPILER}")
    enable_language(CUDA)
    
    # Suppress warning about FindCUDA being deprecated (CMP0146)
    cmake_policy(SET CMP0146 OLD)
    find_package(CUDA REQUIRED)

    # Add .cu file to sources only if CUDA is present
    list(APPEND SOURCES src/FitnessGPU.cu src/GradientOptimizerGPU.cu)

    # Create Executable
    add_executable(SymbolicRegressionGP ${SOURCES})

    # Define the macro to enable GPU code paths in C++
    target_compile_definitions(SymbolicRegressionGP PUBLIC "USE_GPU_ACCELERATION_DEFINED_BY_CMAKE")

    # Include CUDA dirs
    target_include_directories(SymbolicRegressionGP PUBLIC ${CUDA_INCLUDE_DIRS})

    # Link CUDA libraries
    # Try modern target first, fallback to variables
    if(TARGET CUDA::cudart)
        target_link_libraries(SymbolicRegressionGP PUBLIC CUDA::cudart)
    else()
        target_link_libraries(SymbolicRegressionGP PUBLIC ${CUDA_LIBRARIES})
    endif()

    set(CMAKE_CUDA_PROPAGATE_HOST_FLAGS ON)
    
    message(STATUS "Build type: GPU Accelerated")
else()
    message(STATUS "CUDA compiler NOT found. Building for CPU only.")
    
    # Create Executable (without .cu file)
    add_executable(SymbolicRegressionGP ${SOURCES})
    
    message(STATUS "Build type: CPU Only")
endif()

# Common settings
target_include_directories(SymbolicRegressionGP PUBLIC src)

# Enlazar las librerías necesarias.
# Se usa la variable legacy ${CUDA_LIBRARIES} como método de compatibilidad
# ya que el target moderno CUDA::cudart no se está encontrando en este sistema.
target_link_libraries(SymbolicRegressionGP PUBLIC
    # Enlazar con OpenMP para C++
    OpenMP::OpenMP_CXX

    # Enlazar con las librerías de CUDA (método de compatibilidad)
    ${CUDA_LIBRARIES}
)

# Asegura que los flags del compilador de C++ (como /std:c++17) se pasen a nvcc.
set(CMAKE_CUDA_PROPAGATE_HOST_FLAGS ON)


# === OPTIMIZACIÓN: Flags de compilación agresivos para GCC/Clang (Colab) ===
# NOTA: Para MSVC+CUDA, dejamos que CMake maneje la optimización automáticamente
#       ya que los flags manuales causan conflictos con nvcc
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
    target_compile_options(SymbolicRegressionGP PRIVATE
        $<$<COMPILE_LANGUAGE:CXX>:-O3>              # Máxima optimización
        $<$<COMPILE_LANGUAGE:CXX>:-march=native>    # Optimizar para CPU actual
        $<$<COMPILE_LANGUAGE:CXX>:-ffast-math>      # Matemáticas rápidas
        $<$<COMPILE_LANGUAGE:CXX>:-funroll-loops>   # Desenrollar loops
        $<$<COMPILE_LANGUAGE:CXX>:-ftree-vectorize> # Forzar vectorización
    )
    # Link-time optimization (LTO) para Release
    if(CMAKE_BUILD_TYPE STREQUAL "Release")
        set_property(TARGET SymbolicRegressionGP PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
    endif()
endif()
# --- Test Suite Executable (Only build if TestOperators.cpp exists) ---
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/src/TestOperators.cpp")
    add_executable(TestOperators 
        src/TestOperators.cpp
        src/ExpressionTree.cpp
        src/Fitness.cpp
        src/GeneticOperators.cpp
        src/AdvancedFeatures.cpp
        src/GeneticAlgorithm.cpp
        src/GradientOptimizer.cpp
    )

    if(CMAKE_CUDA_COMPILER)
        target_compile_definitions(TestOperators PUBLIC "USE_GPU_ACCELERATION_DEFINED_BY_CMAKE")
        target_sources(TestOperators PRIVATE src/FitnessGPU.cu src/GradientOptimizerGPU.cu)
        target_include_directories(TestOperators PUBLIC ${CUDA_INCLUDE_DIRS})
        if(TARGET CUDA::cudart)
            target_link_libraries(TestOperators PUBLIC CUDA::cudart)
        else()
            target_link_libraries(TestOperators PUBLIC ${CUDA_LIBRARIES})
        endif()
    endif()

    target_include_directories(TestOperators PUBLIC src)
    target_link_libraries(TestOperators PUBLIC OpenMP::OpenMP_CXX)
    message(STATUS "TestOperators target: ENABLED")
else()
    message(STATUS "TestOperators target: SKIPPED (src/TestOperators.cpp not found)")
endif()


In [None]:
# Compile C++ Engine
%cd Code
!cmake -B build -S . -DCMAKE_BUILD_TYPE=Release
!cmake --build build -j $(nproc)
import os
if not os.path.exists('build/SymbolicRegressionGP') and not os.path.exists('build/Release/SymbolicRegressionGP'):
    print('BUILD FAILURE? Binary not found in expected locations. Listing build dir:')
    !ls -R build
%cd ..

In [None]:
%%writefile AlphaSymbolic/core/grammar.py
import numpy as np
from scipy.special import gamma as scipy_gamma, gammaln
import math

# Supported operators and their arity (number of arguments)
# Organized by curriculum stage for progressive unlocking
OPERATORS = {
    # === STAGE 0: Pure Arithmetic ===
    '+': 2,
    '-': 2,
    '*': 2,
    '/': 2,
    
    # === STAGE 1: Powers ===
    'pow': 2,
    'sqrt': 1,
    
    # === STAGE 2: Trigonometry ===
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'asin': 1,
    'acos': 1,
    'atan': 1,
    
    # === STAGE 3: Transcendental ===
    'exp': 1,
    'log': 1,
    
    # === STAGE 4: Advanced ===
    'abs': 1,
    'neg': 1,
    'sign': 1,
    'floor': 1,
    'ceil': 1,
    'mod': 2,
    '%': 2,     # Alias for mod
    'gamma': 1,
    'lgamma': 1,
    
    # === C++ / GPU Specific Aliases ===
    'e': 1,     # Alias for exp
    '!': 1,     # Alias for gamma/factorial
    'g': 1,     # Alias for lgamma
    '_': 1,     # Alias for floor
    'S': 1,     # Alias for asin
    'C': 1,     # Alias for acos
    'T': 1,     # Alias for atan
}

# Operator groups for curriculum control
OPERATOR_STAGES = {
    0: ['+', '-', '*', '/'],
    1: ['+', '-', '*', '/', 'pow', 'sqrt'],
    2: ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan'],
    3: ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log'],
    4: list(OPERATORS.keys()),  # All operators
}

# Terminal tokens
# Terminal tokens
VARIABLES = ['x' + str(i) for i in range(10)] # x0, x1, ..., x9
# 'C' is a placeholder for learnable constants
CONSTANTS = ['C', '0', '1', '2', '3', '5', '10', 'pi', 'e']

# Full Vocabulary
VOCABULARY = list(OPERATORS.keys()) + VARIABLES + CONSTANTS
TOKEN_TO_ID = {token: i for i, token in enumerate(VOCABULARY)}
ID_TO_TOKEN = {i: token for token, i in TOKEN_TO_ID.items()}

# Special token for start of sequence
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'

class Node:
    def __init__(self, value, children=None):
        self.value = value
        self.children = children if children else []

    def __repr__(self):
        if not self.children:
            return str(self.value)
        return f"({self.value} " + " ".join([str(c) for c in self.children]) + ")"
    
    def to_infix(self):
        if not self.children:
            return str(self.value)
        
        op = self.value
        if len(self.children) == 1:
            return f"{op}({self.children[0].to_infix()})"
        elif len(self.children) == 2:
            if op == 'pow':
                return f"({self.children[0].to_infix()} ^ {self.children[1].to_infix()})"
            elif op == 'mod' or op == '%':
                return f"({self.children[0].to_infix()} % {self.children[1].to_infix()})"
            return f"({self.children[0].to_infix()} {op} {self.children[1].to_infix()})"
        
        # Mapping for short tokens to readable infix
        if op == '!': return f"gamma({self.children[0].to_infix()})" # Use gamma for !
        if op == '_': return f"floor({self.children[0].to_infix()})"
        if op == 'g': return f"lgamma({self.children[0].to_infix()})"
        if op == 'S': return f"asin({self.children[0].to_infix()})"
        if op == 'C': return f"acos({self.children[0].to_infix()})"
        if op == 'T': return f"atan({self.children[0].to_infix()})"
        if op == 'e': return f"exp({self.children[0].to_infix()})"
        
        return f"{op}({self.children[0].to_infix()})"
    
    def count_constants(self):
        """Count the number of 'C' placeholders in the tree."""
        count = 1 if self.value == 'C' else 0
        for child in self.children:
            count += child.count_constants()
        return count
    
    def get_constant_positions(self, path=None):
        """Returns a list of paths to all 'C' nodes for optimization."""
        if path is None:
            path = []
        positions = []
        if self.value == 'C':
            positions.append(path.copy())
        for i, child in enumerate(self.children):
            positions.extend(child.get_constant_positions(path + [i]))
        return positions


import ast

class ExpressionTree:
    def __init__(self, token_list):
        """
        Parses a list of tokens in Pre-order traversal (Prefix notation)
        Example: ['+', 'x', 'sin', 'x'] -> x + sin(x)
        """
        self.tokens = token_list
        try:
            self.root, remaining = self._build_tree(token_list)
            if remaining:
                raise ValueError("Tokens remained after building tree")
            self.is_valid = True
        except Exception:
            self.root = None
            self.is_valid = False

    @classmethod
    def from_infix(cls, infix_str):
        """
        Creates an ExpressionTree from a standard infix string (e.g. "sin(x) + x^2").
        Uses Python's ast to parse.
        """
        # Replacements to make it valid python for AST
        # 1. Handle postfix factorial '!' which C++ outputs as '(... )!'
        # We convert '(... )!' to 'gamma(...)'
        # Iterate until no '!' left
        processed_str = infix_str
        while '!' in processed_str:
            idx = processed_str.find('!')
            # Helper to find matching paren backwards
            if idx > 0 and processed_str[idx-1] == ')':
                paren_count = 1
                start = idx - 2
                while start >= 0 and paren_count > 0:
                    if processed_str[start] == ')':
                        paren_count += 1
                    elif processed_str[start] == '(':
                        paren_count -= 1
                    start -= 1
                # start is now 1 char before the matching '('
                start += 1 
                # Reconstruct: ... + gamma( + ... + ) + ...
                # Content includes the parens: ( ... )
                content = processed_str[start:idx] 
                processed_str = processed_str[:start] + "gamma" + content + processed_str[idx+1:]
            elif idx < len(processed_str) - 1 and processed_str[idx+1] == '(':
                # Prefix usage !(...) -> gamma(...)
                processed_str = processed_str[:idx] + "gamma" + processed_str[idx+1:]
            else:
                # Fallback: Replace ! with gamma if it's explicitly used as a function-like token
                processed_str = processed_str.replace('!', 'gamma', 1)

        # 1b. Handle 'pow' and 'mod' keywords if they leak in
        processed_str = processed_str.replace(' pow ', ' ^ ') # Be careful not to replace 'power' variable names if any, though we don't have them.
        processed_str = processed_str.replace(' mod ', ' % ')

        # 2. C++ uses ^ for power, Python uses **. AST parses ^ as BitXor.
        try:
            tree = ast.parse(processed_str, mode='eval')
            tokens = cls._ast_to_prefix(tree.body)
            return cls(tokens)
        except Exception as e:
            print(f"Error parsing infix: {e} | Original: {infix_str} | Processed: {processed_str}")
            return cls([]) # Invalid

    @staticmethod
    def _ast_to_prefix(node):
        if isinstance(node, ast.BinOp):
            # Map operators
            op_map = {
                ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/',
                ast.BitXor: 'pow', ast.Pow: 'pow', ast.Mod: 'mod'
            }
            op_type = type(node.op)
            if op_type in op_map:
                return [op_map[op_type]] + ExpressionTree._ast_to_prefix(node.left) + ExpressionTree._ast_to_prefix(node.right)
        
        elif isinstance(node, ast.UnaryOp):
            op_map = {ast.USub: 'neg', ast.UAdd: None} # Ignore unary +
            op_type = type(node.op)
            if op_type == ast.USub:
                # Check directly if it's a number to collapse "-5"
                if isinstance(node.operand, ast.Constant) and isinstance(node.operand.value, (int, float)):
                    return [str(-node.operand.value)]
                return ['neg'] + ExpressionTree._ast_to_prefix(node.operand)
            elif op_type == ast.UAdd:
                 return ExpressionTree._ast_to_prefix(node.operand)

        elif isinstance(node, ast.Call):
            # Functions like sin(x)
            func_id = node.func.id
            # Allow both standard names and GPU short tokens
            if func_id in ['sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'sqrt', 'abs', 'floor', 'ceil', 'gamma', 'lgamma', 'sign', 'neg',
                           'S', 'C', 'T', 'e', 'g', '_']: 
                
                # Map back to short tokens if used by engine
                # We assume engine uses short tokens for S, C, T, e, !, _, g
                token = func_id
                if func_id == 'asin': token = 'S'
                if func_id == 'acos': token = 'C'
                if func_id == 'atan': token = 'T'
                if func_id == 'exp': token = 'e'
                if func_id == 'gamma': token = '!'
                if func_id == 'floor': token = '_'
                if func_id == 'lgamma': token = 'g'
                
                tokens = [token]
                for arg in node.args:
                    tokens.extend(ExpressionTree._ast_to_prefix(arg))
                return tokens
        
        elif isinstance(node, ast.Name):
            # Map 'x' to 'x0' if preferred, or keep as is if using x0 in string
            if node.id == 'x':
                return ['x0']
            return [node.id]
        
        elif isinstance(node, ast.Constant): # Python 3.8+
            return [str(node.value)]
        elif isinstance(node, ast.Num): # Older python
            return [str(node.n)]

        raise ValueError(f"Unsupported AST node: {node}")


    def _build_tree(self, tokens):
        if not tokens:
            raise ValueError("Empty token list")
        
        token = tokens[0]
        remaining = tokens[1:]
        
        if token in OPERATORS:
            arity = OPERATORS[token]
            children = []
            for _ in range(arity):
                child, remaining = self._build_tree(remaining)
                children.append(child)
            return Node(token, children), remaining
        elif token in VARIABLES or token in CONSTANTS:
            return Node(token), remaining
        else:
            # Try to parse as float literal
            try:
                float(token)
                return Node(token), remaining
            except:
                raise ValueError(f"Unknown token: {token}")

    def evaluate(self, x_values, constants=None):
        """
        Evaluates the expression tree for a given input.
        x_values: 
            - numpy array of shape (N,) for single variable (x0)
            - numpy array of shape (features, N) or (N, features) ?? 
              Let's standardize on (features, samples) for easy indexing x[i], 
              OR a dictionary {'x0': array, 'x1': array}.
        constants: optional dict mapping path tuples to constant values
        Returns a numpy array of results.
        """
        if isinstance(x_values, dict):
             # Extract arrays: expected keys 'x0', 'x1', ...
             # We pass the dict directly.
             pass
        elif isinstance(x_values, np.ndarray):
            if x_values.ndim == 1:
                # Single variable x -> x0
                x_values = {'x0': x_values}
            elif x_values.ndim == 2:
                # Shape issue: is it (N, M) or (M, N)?
                # Usually standard ML is (samples, features).
                # But for our eval logic `x[0]` returning feature 0 is easier.
                # So if shape is (samples, features), we transpose or wrap.
                # Let's assume standard (N_samples, M_features).
                # Then x_values[:, 0] is x0.
                inputs = {}
                n_features = x_values.shape[1]
                for i in range(n_features):
                    inputs[f'x{i}'] = x_values[:, i]
                x_values = inputs
            else:
                raise ValueError(f"Unsupported input shape: {x_values.shape}")
        else:
             x_values = {'x0': np.array(x_values, dtype=np.float64)}
        
        # Determine sample size from first key
        n_samples = len(next(iter(x_values.values())))
        
        if not self.is_valid:
            return np.full(n_samples, np.nan, dtype=np.float64)
        return self._eval_node(self.root, x_values, constants, path=[])

    def _eval_node(self, node, x, constants=None, path=None):
        val = node.value
        
        # Check for variable
        if val in x:
            return x[val].astype(np.float64)
        if val == 'x': # Backward compatibility
             if 'x0' in x: return x['x0'].astype(np.float64)
             # Fallback if x was passed as key 'x'
             if 'x' in x: return x['x'].astype(np.float64)
             raise ValueError("Variable 'x' not found in input.")
        # Get sample size from a variable
        n_samples = len(next(iter(x.values())))
        
        if val == 'pi':
            return np.full(n_samples, np.pi, dtype=np.float64)
        if val == 'e' and not node.children:
            return np.full(n_samples, np.e, dtype=np.float64)
        if val == 'C':
            # Check if we have an optimized constant for this position
            if constants is not None and tuple(path) in constants:
                return np.full(n_samples, constants[tuple(path)], dtype=np.float64)
            return np.full(n_samples, 1.0, dtype=np.float64)  # Default constant = 1
        
        # Check for numeric constants
        try:
            return np.full(n_samples, float(val), dtype=np.float64)
        except:
            pass
            
        # Recursive evaluation
        args = []
        for i, c in enumerate(node.children):
            args.append(self._eval_node(c, x, constants, path + [i] if path is not None else None))
        
        # Operators
        with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
            if val == '+': return args[0] + args[1]
            if val == '-': return args[0] - args[1]
            if val == '*': return args[0] * args[1]
            if val == '/': 
                return np.divide(args[0], args[1], out=np.zeros(n_samples, dtype=np.float64), where=args[1]!=0)
            if val == 'pow':
                # Safe power
                return np.power(np.abs(args[0]) + 1e-10, np.clip(args[1], -10, 10))
            if val == 'mod':
                return np.mod(args[0], args[1] + 1e-10)
            if val == 'sin': return np.sin(args[0])
            if val == 'cos': return np.cos(args[0])
            if val == 'tan': return np.tan(args[0])
            
            if val == 'asin' or val == 'S': 
                # Protected asin: asin(clip(x, -1, 1))
                return np.arcsin(np.clip(args[0], -1 + 1e-7, 1 - 1e-7))
            if val == 'acos' or val == 'C': 
                # Protected acos: acos(clip(x, -1, 1))
                return np.arccos(np.clip(args[0], -1 + 1e-7, 1 - 1e-7))
            if val == 'atan' or val == 'T': return np.arctan(args[0])
            
            if val == 'exp' or val == 'e': 
                return np.exp(np.clip(args[0], -100, 100))
            if val == 'log': 
                return np.log(np.abs(args[0]) + 1e-10)
            if val == 'sqrt':
                return np.sqrt(np.abs(args[0]))
            if val == 'abs':
                return np.abs(args[0])
            if val == 'floor' or val == '_':
                return np.floor(args[0])
            if val == 'ceil':
                return np.ceil(args[0])
            
            if val == 'gamma' or val == '!':
                # Match C++ Protected Gamma/Factorial: tgamma(|x| + 1)
                # This ensures consistent evaluation for formulas from C++ engine (which uses !)
                arg = np.abs(args[0]) + 1.0
                clipped = np.clip(arg, 0.1, 50) # Clip upper bound to avoid overflow
                return scipy_gamma(clipped)
            if val == 'lgamma' or val == 'g':
                # Protected lgamma: lgamma(|x| + 1)
                arg = np.abs(args[0]) + 1.0
                # gammaln is safe for large positive numbers, so less aggressive clipping needed for overflow,
                # but we clip for consistency and to avoid extremely large outputs if followed by exp
                clipped = np.clip(arg, 0.1, 1000) 
                return gammaln(clipped)
            if val == 'neg':
                return -args[0]
            if val == 'sign':
                return np.sign(args[0])
                
        return np.zeros(n_samples, dtype=np.float64)

    def get_infix(self):
        if not self.is_valid:
            return "Invalid"
        return self.root.to_infix()
    
    

    def count_constants(self):
        if not self.is_valid:
            return 0
        return self.root.count_constants()

    @staticmethod
    def generate_random(max_depth=4, num_variables=1, p_terminal=0.3):
        """
        Generates a random valid ExpressionTree.
        """
        import random
        
        valid_vars = ['x' + str(i) for i in range(num_variables)]
        # Use module-level OPERATORS
        ops = list(OPERATORS.keys())

        # Weighted operators (matching C++ OPERATOR_WEIGHTS approximately)
        # +, -, *, / : High weight (0.2, 0.2, 0.2, 0.15)
        # ^ : Medium (0.1)
        # sin, cos: Medium (0.1)
        # log, exp: Low (0.05)
        # others: Very Low (0.01)
        
        weighted_ops = []
        op_weights = []
        for op in ops:
            weighted_ops.append(op)
            if op in ['+', '-', '*']: w = 20
            elif op == '/': w = 15
            elif op in ['^', 'sin', 'cos']: w = 10
            elif op in ['log', 'exp']: w = 5
            else: w = 1 # !, _, g, S, C, T
            op_weights.append(w)
            
        def _gen(depth):
            if depth >= max_depth or (depth > 1 and random.random() < p_terminal):
                # Terminal
                if random.random() < 0.75: # Variable (0.75 matches C++ config)
                     return Node(random.choice(valid_vars))
                else: # Constant
                     # C++ uses range -10 to 10
                     # We can generate a random float string or use 'C' for optimization
                     # Using 'C' allows optimize_constants to work later.
                     # But initially, some random numbers are good.
                     if random.random() < 0.5:
                         val = random.uniform(-5.0, 5.0)
                         return Node(f"{val:.4f}")
                     else:
                         return Node('C')
            else:
                # Operator (Weighted)
                op = random.choices(weighted_ops, weights=op_weights, k=1)[0]
                arity = OPERATORS[op]
                if arity == 1:
                    return Node(op, [_gen(depth+1)])
                else:
                    return Node(op, [_gen(depth+1), _gen(depth+1)])
        
        root = _gen(0)
        # Verify valid infix generation
        try:
            return ExpressionTree.from_infix(root.to_infix())
        except:
            # Fallback for very unlucky generation
            return ExpressionTree.from_infix("x0")


import sympy

def simplify_formula(formula_str):
    """
    Simplifies a mathematical formula using SymPy.
    """
    try:
        # 1. Clean up C++ notation that sympy might not like directly
        # e.g., 'pi' is fine. 'neg(x)' -> '-x'.
        # But our infix is usually standard. 
        # C++ 'pow(x,2)' might need conversion to 'x**2' or sympy handles it?
        # Sympy uses 'Pow'. 
        
        # Replace common mismatches
        s_str = formula_str.replace("pow(", "Pow(")
        # s_str = s_str.replace("abs(", "Abs(") # Sympy handles abs
        
        # Parse
        expr = sympy.sympify(s_str)
        
        # Simplify
        simplified = sympy.simplify(expr)
        
        # Convert back to string
        # We need to ensure it uses our function names (e.g. sin, cos)
        # Sympy standard printer is usually good.
        # But 'Power' is '**'. We used 'hat' or 'pow' in some places?
        # Our tokenizer supports standard operators. 'x**2' is not standard infix for our parser?
        # Our Parser supports 'x^2' or 'pow(x,2)'? 
        # AST parser handles '**' -> 'pow'.
        
        final_str = str(simplified)
        return final_str
        
    except Exception as e:
        # Fallback if simplification fails (e.g. unknown functions)
        return formula_str


In [None]:
%%writefile AlphaSymbolic/core/model.py
import torch
import torch.nn as nn
import numpy as np

class AlphaSymbolicModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_encoder_layers=2, num_decoder_layers=2, max_seq_len=256, input_dim=2):
        super(AlphaSymbolicModel, self).__init__()
        
        self.d_model = d_model
        self.input_dim = input_dim
        
        # 1. Point Encoder: Processes pairs/tuples of (x..., y)
        self.point_embedding = nn.Linear(input_dim, d_model)
        
        # We use a standard Transformer Encoder for the "Problem Embedding"
        # Since points are a set, we don't necessarily need positional encoding, 
        # but the Transformer will process them as a sequence.
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.problem_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # 2. Formula Decoder: Generates tokens
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.formula_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        # 3. Heads
        self.policy_head = nn.Linear(d_model, vocab_size)
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 3) # Quantiles: 0.25, 0.50, 0.75
        )
        
    def forward(self, x_values, y_values, formula_input, formula_mask=None):
        """
        x_values: [batch, num_points]
        y_values: [batch, num_points]
        formula_input: [batch, seq_len] (Token IDs)
        formula_mask: Optional mask for the decoder (causal mask)
        """
        # -- Problem Encoding --
        # 1. Ensure dimensions
        if x_values.dim() == 2:
            x_values = x_values.unsqueeze(-1) # [batch, num_points, 1]
        
        if y_values.dim() == 2:
            y_values = y_values.unsqueeze(-1) # [batch, num_points, 1]
            
        # 2. Stack x and y: [batch, num_points, n_vars + 1]
        points = torch.cat([x_values, y_values], dim=-1)
        
        # 3. Pad to match input_dim (e.g., 11)
        curr_dim = points.shape[-1]
        if curr_dim < self.input_dim:
            pad_size = self.input_dim - curr_dim
            padding = torch.zeros(points.shape[0], points.shape[1], pad_size, device=points.device)
            points = torch.cat([points, padding], dim=-1)
        
        # Project to d_model
        points_emb = self.point_embedding(points) # [batch, num_points, d_model]
        
        # Encode problem (memory for decoder)
        memory = self.problem_encoder(points_emb)
        
        # -- Formula Decoding --
        # Embed tokens
        tgt = self.token_embedding(formula_input) # [batch, seq_len, d_model]
        tgt = self.pos_encoder(tgt)
        
        # Decode
        # memory is [batch, num_points, d_model]
        # tgt is [batch, seq_len, d_model]
        if formula_mask is None:
             # Create causal mask
            seq_len = formula_input.size(1)
            formula_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(formula_input.device)

        output = self.formula_decoder(tgt, memory, tgt_mask=formula_mask)
        
        # -- Heads --
        # Policy: distribution over vocab for each token position
        logits = self.policy_head(output) # [batch, seq_len, vocab_size]
        
        # Value: estimate value from the LAST token's state
        # (Assuming the last token summarizes the current state)
        last_token_output = output[:, -1, :] # [batch, d_model]
        value = self.value_head(last_token_output) # [batch, 1]
        
        return logits, value

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x

if __name__ == "__main__":
    # Smoke Test
    vocab_size = 20
    model = AlphaSymbolicModel(vocab_size=vocab_size, d_model=32)
    
    # Dummy data
    bs = 2
    points = 10
    x = torch.randn(bs, points)
    y = torch.randn(bs, points)
    
    # Formula input (start token + some tokens)
    seq = torch.randint(0, vocab_size, (bs, 5))
    
    logits, value = model(x, y, seq)
    
    print("Logits shape:", logits.shape) # Should be [2, 5, 20]
    print("Value shape:", value.shape)   # Should be [2, 1]
    print("Smoke test passed.")


In [None]:
%%writefile AlphaSymbolic/core/environment.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from core.grammar import VOCABULARY, OPERATORS, TOKEN_TO_ID, ExpressionTree
from data.synthetic_data import DataGenerator

class SymbolicEnv(gym.Env):
    def __init__(self, max_length=50):
        super(SymbolicEnv, self).__init__()
        
        self.vocab_size = len(VOCABULARY)
        self.max_length = max_length
        self.vocab = VOCABULARY
        
        # Action space: Choose a token from the vocabulary
        self.action_space = spaces.Discrete(self.vocab_size)
        
        # Observation space: 
        # 1. Current token sequence (padded)
        # 2. X values (fixed size for simplicity)
        # 3. Y values
        # For this prototype we will expose a dictionary observation
        self.observation_space = spaces.Dict({
            "sequence": spaces.Box(low=0, high=self.vocab_size, shape=(max_length,), dtype=np.int32),
            "x": spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32),
            "y": spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32)
        })
        
        self.data_gen = DataGenerator(max_depth=4)
        self.current_problem = None
        self.current_sequence = []
        self.open_branches = 0
        
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Generate a new problem (X, Y)
        # In a real scenario, this could be sampled from a fixed dataset
        batch = self.data_gen.generate_batch(1, point_count=10)
        self.current_problem = batch[0]
        
        self.current_sequence = []
        self.open_branches = 1 # Start expecting a root node
        
        return self._get_obs(), {}

    def step(self, action_id):
        token = self.vocab[action_id]
        self.current_sequence.append(token)
        
        # Update open branches
        if token in OPERATORS:
            arity = OPERATORS[token]
            self.open_branches += (arity - 1)
        else:
            self.open_branches -= 1
            
        term = False
        trunc = False
        reward = 0.0
        
        # Check completion
        if self.open_branches == 0:
            term = True
            # Tree is complete, evaluate
            reward = self._calculate_reward()
        elif self.open_branches < 0:
            # Should not happen if we mask actions, but for safety
            term = True
            reward = -100.0 # Syntax error penalty
        elif len(self.current_sequence) >= self.max_length:
            trunc = True
            reward = -10.0 # Incomplete penalty
            
        return self._get_obs(), reward, term, trunc, {}

    def _get_obs(self):
        # Convert sequence to IDs and pad
        seq_ids = [TOKEN_TO_ID[t] for t in self.current_sequence]
        padded_seq = np.zeros(self.max_length, dtype=np.int32)
        padded_seq[:len(seq_ids)] = seq_ids
        
        return {
            "sequence": padded_seq,
            "x": self.current_problem['x'].astype(np.float32),
            "y": self.current_problem['y'].astype(np.float32)
        }

    def _calculate_reward(self):
        try:
            tree = ExpressionTree(self.current_sequence)
            if not tree.is_valid:
                return -100.0
            
            y_pred = tree.evaluate(self.current_problem['x'])
            
            # Root Mean Squared Error (RMSE)
            mse = np.mean((y_pred - self.current_problem['y'])**2)
            rmse = np.sqrt(mse)
            
            if np.isnan(rmse) or np.isinf(rmse):
                return -1000.0
                
            # Reward is negative RMSE
            # We want to maximize reward -> minimize RMSE
            # Normalize or scale? simpler is just -RMSE
            return -rmse
            
        except Exception:
            return -100.0

if __name__ == "__main__":
    env = SymbolicEnv()
    obs, _ = env.reset()
    print("Initial Observation Keys:", obs.keys())
    
    # Simulate a few steps for x + x
    # Prefix: + x x
    actions = ['+', 'x', 'x']
    tot_reward = 0
    for tok in actions:
        aid = TOKEN_TO_ID[tok]
        obs, reward, term, trunc, _ = env.step(aid)
        print(f"Action: {tok}, Reward: {reward}, Term: {term}, Branches: {env.open_branches}")
        tot_reward += reward
        if term: break
    
    print(f"Total Reward: {tot_reward}")


In [None]:
%%writefile AlphaSymbolic/core/loss.py

import torch
import torch.nn as nn

class QuantileLoss(nn.Module):
    """
    Quantile Loss (Pinball Loss) for multiple quantiles.
    
    Args:
        quantiles (list): List of quantiles to estimate (e.g. [0.25, 0.5, 0.75])
    """
    def __init__(self, quantiles=[0.25, 0.5, 0.75]):
        super().__init__()
        self.quantiles = quantiles
        
    def forward(self, preds, target):
        """
        preds: [batch, num_quantiles] - Predicted values for each quantile
        target: [batch, 1] - True scalar target
        """
        # Ensure target matches batch dim
        # target shape might be [batch] or [batch, 1]
        if target.dim() == 1:
            target = target.unsqueeze(1)
            
        loss = 0
        for i, q in enumerate(self.quantiles):
            error = target - preds[:, i:i+1]
            # Pinball loss: max(q * error, (q - 1) * error)
            # Equivalent to: error * (q - I(error < 0))
            loss += torch.max(q * error, (q - 1) * error).mean()
            
        return loss


In [None]:
%%writefile AlphaSymbolic/core/gp_bridge.py
import os
import subprocess
import tempfile
import re
import time
import sys
from typing import List, Optional
import numpy as np

# Windows: Disable crash dialog boxes for child processes
if sys.platform == 'win32':
    try:
        import ctypes
        # SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX
        SEM_NOGPFAULTERRORBOX = 0x0002
        SEM_FAILCRITICALERRORS = 0x0001
        SEM_NOOPENFILEERRORBOX = 0x8000
        ctypes.windll.kernel32.SetErrorMode(
            SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX
        )
    except Exception:
        pass  # Silently ignore if ctypes fails

class GPEngine:
    def __init__(self, binary_path=None):
        if binary_path is None:
            # Default location: Code/build/Release/SymbolicRegressionGP.exe
            # Assuming we are in AlphaSymbolic/.. root or similar.
            # Adjust path relative to this file: alphasybolic/core/gp_bridge.py
            # So binary is at ../../Code/build/Release/SymbolicRegressionGP.exe
            # Improved Project Root Detection
            # We look for the "Code" directory by walking up from this file.
            current_dir = os.path.dirname(os.path.abspath(__file__))
            project_root = None
            
            # Walk up up to 5 levels
            d = current_dir
            for _ in range(5):
                if os.path.exists(os.path.join(d, "Code")):
                    project_root = d
                    break
                parent = os.path.dirname(d)
                if parent == d:
                    break
                d = parent
            
            if project_root:
                base_dir = project_root
            else:
                # Fallback
                base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

            # Define candidates based on OS
            is_windows = os.name == 'nt'
            search_paths = []
            
            if is_windows:
                 search_paths = [
                    os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP.exe"),
                    os.path.join(base_dir, "Code", "build", "SymbolicRegressionGP.exe"),
                    # Fallbacks
                     os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP"),
                 ]
            else:
                 # Linux/Mac (Colab) - Prioritize no extension
                 search_paths = [
                    os.path.join(base_dir, "Code", "build", "SymbolicRegressionGP"),
                    os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP"),
                    # Fallbacks
                    os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP.exe"),
                    os.path.join(base_dir, "Code", "build", "SymbolicRegressionGP.exe"),
                 ]

            self.binary_path = None
            for p in search_paths:
                if os.path.exists(p):
                    self.binary_path = p
                    break
            
            if self.binary_path is None:
                print(f"[Warning] GP Binary not found. Checked locations:")
                for p in search_paths:
                    print(f" - {p}")
                # Fallback to the most likely one for the current OS
                self.binary_path = search_paths[0]
        else:
            self.binary_path = binary_path

    def run(self, x_values, y_values: List[float], seeds: List[str] = [], timeout_sec: int = 10) -> Optional[str]:
        """
        Runs the C++ GP Engine with the given data and seeds.
        Returns the best formula found as a string, or None if failed.
        """
        if not os.path.exists(self.binary_path):
            print(f"[Error] GP Binary not found at: {self.binary_path}")
            return None

        # Create temporary files
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as seed_file, \
             tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as data_file:
            
            # Write Seeds
            for seed in seeds:
                seed_file.write(seed + "\n")
            seed_file_path = seed_file.name
            
            # Write Data
            # Format:
            # Line 1: x0_1 x0_2 ...
            # Line 2: x1_1 x1_2 ...
            # ...
            # Line M: y1 y2 ...
            
            # Handle x_values input structure
            # Case 1: x_values is a list of lists (matrix) or numpy 2D array [features, samples]
            # Case 2: x_values is dict {'x0': ..., 'x1': ...}
            # Case 3: x_values is list (single feature) [samples]
            
            x_matrix = []
            if isinstance(x_values, dict):
                 # Sort keys to ensure order x0, x1, x2...
                 sorted_keys = sorted(x_values.keys(), key=lambda k: int(k[1:]) if k[1:].isdigit() else 0)
                 for k in sorted_keys:
                     x_matrix.append(x_values[k])
            elif isinstance(x_values, np.ndarray):
                if x_values.ndim == 1:
                    x_matrix.append(x_values)
                else:
                     # Check shape. Assume (features, samples) if passed from app loop.
                     # But verify: if shape is (N, F) and F is small, we probably want to transpose.
                     # Let's assume input matches logic in grammar.py (features, samples)
                     # Actually, standard sklearn is (samples, features).
                     # Let's support both but prioritize features being rows for the file.
                     # If (samples, features), we transpose.
                     # Heuristic: if shape[0] > shape[1] and shape[1] < 20, assume (samples, features).
                     if x_values.shape[0] > x_values.shape[1] and x_values.shape[1] < 50:
                          # (Samples, Features) -> Transpose to (Features, Samples)
                          for i in range(x_values.shape[1]):
                              x_matrix.append(x_values[:, i])
                     else:
                          # (Features, Samples)
                          for i in range(x_values.shape[0]):
                              x_matrix.append(x_values[i])
            elif isinstance(x_values, list):
                 # Check if element is list (matrix)
                 if len(x_values) > 0 and isinstance(x_values[0], list):
                      # Heuristic check for (Samples, Features) vs (Features, Samples)
                      # If we have many rows (samples) and few columns (features), transpose.
                      rows = len(x_values)
                      cols = len(x_values[0])
                      
                      if rows > cols and cols < 50:
                           # Transpose list of lists
                           x_matrix = list(map(list, zip(*x_values)))
                      else:
                           # Assumed (Features, Samples) already
                           x_matrix = x_values
                 else:
                      # Single feature
                      x_matrix.append(x_values)

            # Write X lines
            for feature_vals in x_matrix:
                data_file.write(" ".join(map(str, feature_vals)) + "\n")
            
            # Write Y line
            data_file.write(" ".join(map(str, y_values)) + "\n")
            data_file_path = data_file.name

        try:
            # Run Command
            cmd = [self.binary_path, "--seed", seed_file_path, "--data", data_file_path]
            # print(f"Running GP Engine: {' '.join(cmd)}")
            
            # Windows-specific: Hide console window and suppress error dialogs
            startupinfo = None
            creationflags = 0
            if os.name == 'nt':  # Windows
                startupinfo = subprocess.STARTUPINFO()
                startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
                startupinfo.wShowWindow = subprocess.SW_HIDE
                # CREATE_NO_WINDOW + Don't show error dialogs
                creationflags = subprocess.CREATE_NO_WINDOW | 0x08000000  # CREATE_NO_WINDOW | SEM_NOGPFAULTERRORBOX
            
            start_time = time.time()
            result = subprocess.run(
                cmd, 
                capture_output=True, 
                text=True, 
                timeout=timeout_sec,
                startupinfo=startupinfo,
                creationflags=creationflags
            )
            
            output = result.stdout
            
            # Parse Output
            # We look for the LAST occurrence of "Formula: ..."
            # Standard formats:
            # "Formula: ((x * x) + 2)"
            # "Final Formula: ..."
            
            best_formula = None
            # Look for formula lines (case-insensitive)
            # Priority: "Final Formula:" > "Formula:" > "Initial best formula:"
            for line in output.splitlines():
                line_lower = line.lower()
                if "formula:" in line_lower:
                    # Extract the part after "formula:" (case-insensitive split)
                    idx = line_lower.find("formula:")
                    if idx != -1:
                        formula_part = line[idx + len("formula:"):].strip()
                        if formula_part:
                            best_formula = formula_part
                            # Keep looking for better matches (Final Formula is best)
                            if "final formula:" in line_lower:
                                break  # Final Formula is the best, stop looking
                        
            # print(f"GP Engine finished in {time.time() - start_time:.2f}s")
            
            if best_formula is None:
                pass # print(f"[DEBUG] GP Engine Output (Stdout):\n{output}")
                pass # print(f"[DEBUG] GP Engine Output (Stderr):\n{result.stderr}")
            
            return best_formula

        except subprocess.TimeoutExpired as e:
            # print(f"GP Engine timed out after {timeout_sec}s.")
            # Recover output captured so far. 
            # Note: TimeoutExpired.stdout/stderr might be bytes even if text=True was passed to subprocess.run.
            output = e.stdout if e.stdout else ""
            if isinstance(output, bytes):
                output = output.decode('utf-8', errors='ignore')
            
            error_output = e.stderr if e.stderr else ""
            if isinstance(error_output, bytes):
                error_output = error_output.decode('utf-8', errors='ignore')

            best_formula = None
            if output:
                for line in output.splitlines():
                    line_lower = line.lower()
                    if "formula:" in line_lower:
                        idx = line_lower.find("formula:")
                        if idx != -1:
                            formula_part = line[idx + len("formula:"):].strip()
                            if formula_part:
                                best_formula = formula_part
                                if "final formula:" in line_lower:
                                    break
            
            if best_formula:
                # print(f"Recovered best formula from timeout: {best_formula}")
                return best_formula
            
            if error_output:
                 pass # print(f"GP Engine Timeout Stderr: {error_output}")
            return None

        except Exception as e:
            # print(f"GP Engine failed: {e}")
            if hasattr(e, 'stderr') and e.stderr:
                pass # print(f"Stderr: {e.stderr}")
            return None
        finally:
            # Cleanup
            if os.path.exists(seed_file_path):
                os.unlink(seed_file_path)
            if os.path.exists(data_file_path):
                os.unlink(data_file_path)

if __name__ == "__main__":
    # Test
    engine = GPEngine()
    x = [1, 2, 3, 4]
    y = [1+2, 4+2, 9+2, 16+2] # x^2 + 2
    seeds = ["(x * x)", "(x + 2)"]
    
    print("Testing GPEngine...")
    res = engine.run(x, y, seeds)
    print(f"Result: {res}")


In [None]:
%%writefile AlphaSymbolic/core/gpu_engine.py

# DEPRECATED: Use core.gpu instead
from core.gpu import TensorGeneticEngine

# We map every token to an integer ID.
# 0 is padding/null.
PAD_ID = 0
# 1..N are operators and terminals.

class GPUGrammar:
    def __init__(self, num_variables=1):
        self.token_to_id = {'<PAD>': PAD_ID}
        self.id_to_token = {PAD_ID: '<PAD>'}
        self.next_id = 1
        
        # Terminals (Variables + Constants)
        # Only include variables compliant with num_variables
        self.active_variables = ['x0'] # Always support x0
        if num_variables > 1:
            self.active_variables = [f'x{i}' for i in range(num_variables)]
        elif num_variables == 1:
            self.active_variables = ['x', 'x0'] # Support both for 1D

        self.terminals = self.active_variables + ['C', '1', '2', '3', '5', 'pi', 'e']
        for t in self.terminals:
            self.token_to_id[t] = self.next_id
            self.id_to_token[self.next_id] = t
            self.next_id += 1
            
        # Operators
        # Map operator string to ID
        self.operators = list(OPERATORS.keys())
        for op in self.operators:
            self.token_to_id[op] = self.next_id
            self.id_to_token[self.next_id] = op
            self.next_id += 1
            
        self.vocab_size = self.next_id
        
        # Precompute arithmetic mappings for faster lookup in eval loop
        # We need to know which ID corresponds to which operation type
        self.op_ids = {op: self.token_to_id[op] for op in self.operators}
        self.arity = {self.token_to_id[op]: OPERATORS[op] for op in self.operators}

class TensorGeneticEngine:
    def __init__(self, device: torch.device = None, pop_size=10000, max_len=30, num_variables=1, max_constants=5, n_islands=5):
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.grammar = GPUGrammar(num_variables)
        
        # Adjust pop_size to be divisible by n_islands
        self.n_islands = n_islands
        if pop_size % n_islands != 0:
            pop_size = (pop_size // n_islands) * n_islands
            
        self.pop_size = pop_size
        self.island_size = pop_size // n_islands
        self.max_len = max_len
        self.num_variables = num_variables
        self.max_constants = max_constants
        
        # Pre-allocate memory for random generation
        self.terminal_ids = torch.tensor([self.grammar.token_to_id[t] for t in self.grammar.terminals], device=self.device)
        self.operator_ids = torch.tensor([self.grammar.token_to_id[op] for op in self.grammar.operators], device=self.device)
        
        # --- Pre-compute Arity Masks for Safe Mutation ---
        self.token_arity = torch.zeros(self.grammar.vocab_size + 1, dtype=torch.long, device=self.device)
        self.arity_0_ids = []
        self.arity_1_ids = []
        self.arity_2_ids = []
        
        # Terminals (0)
        for t in self.grammar.terminals:
            tid = self.grammar.token_to_id[t]
            self.token_arity[tid] = 0
            self.arity_0_ids.append(tid)
            
        # Operators (1 or 2)
        for op in self.grammar.operators:
            tid = self.grammar.token_to_id[op]
            arity = OPERATORS[op]
            self.token_arity[tid] = arity
            if arity == 1: self.arity_1_ids.append(tid)
            elif arity == 2: self.arity_2_ids.append(tid)
            
        self.arity_0_ids = torch.tensor(self.arity_0_ids, device=self.device)
        self.arity_1_ids = torch.tensor(self.arity_1_ids, device=self.device)
        self.arity_2_ids = torch.tensor(self.arity_2_ids, device=self.device)

    def optimize_constants(self, population: torch.Tensor, constants: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor, steps=10, lr=0.1):
        """
        Refine constants using Gradient Descent.
        population: [K, L]
        constants: [K, MaxConstants]
        """
        # Clone constants to leaf tensor with grad
        optimized_consts = constants.clone().detach().requires_grad_(True)
        optimizer = torch.optim.Adam([optimized_consts], lr=lr)
        
        best_mse = torch.full((population.shape[0],), float('inf'), device=self.device)
        best_consts = constants.clone().detach() # Fallback
        
        for _ in range(steps):
            optimizer.zero_grad()
            
            # Forward (Differentiable)
            mse, _ = self.evaluate_differentiable(population, optimized_consts, x, y_target)
            
            # Mask NaNs (invalid formulas don't train)
            valid_mask = ~torch.isnan(mse)
            if not valid_mask.any(): break
            
            # Keep best known constants per individual
            improved = (mse < best_mse) & valid_mask
            if improved.any():
                best_mse[improved] = mse[improved].detach()
                best_consts[improved] = optimized_consts[improved].detach()
            
            # Loss = Sum(MSE_valid)
            loss = mse[valid_mask].sum()
            
            if not loss.requires_grad:
                # This happens if no individual uses 'C' (graph disconnected)
                break
                
            loss.backward()
            optimizer.step()
            
        return best_consts, best_mse

    def infix_to_rpn(self, formulas: List[str]) -> torch.Tensor:
        """
        Converts a list of infix strings to a padded RPN tensor [B, L].
        """
        batch_rpn = []
        for f in formulas:
            try:
                # Use shared ExpressionTree to parse infix -> tree -> postfix(ish)
                # But ExpressionTree is prefix. We need Postfix for stack eval.
                # Let's do a simple recursive implementation here or leverage parsed tree.
                tree = ExpressionTree.from_infix(f)
                if not tree.is_valid:
                    batch_rpn.append([PAD_ID]*self.max_len)
                    continue
                
                # Conversion: Tree -> Postfix
                rpn_tokens = []
                def traverse(node):
                    if not node: return
                    for child in node.children:
                        traverse(child)
                    rpn_tokens.append(node.value)
                
                traverse(tree.root)
                
                # Convert to IDs
                ids = [self.grammar.token_to_id.get(t, PAD_ID) for t in rpn_tokens]
                # Pad/Truncate
                if len(ids) > self.max_len:
                    ids = ids[:self.max_len]
                else:
                    ids = ids + [PAD_ID] * (self.max_len - len(ids))
                batch_rpn.append(ids)
            except:
                batch_rpn.append([PAD_ID]*self.max_len)
                
        if not batch_rpn:
             return torch.empty((0, self.max_len), device=self.device, dtype=torch.long)
        return torch.tensor(batch_rpn, device=self.device, dtype=torch.long)

    def rpn_to_infix(self, rpn_tensor: torch.Tensor) -> str:
        """
        Decodes a single RPN tensor row back to infix string.
        """
        ids = rpn_tensor.squeeze().cpu().numpy()
        stack = []
        
        for id in ids:
            if id == PAD_ID: continue
            token = self.grammar.id_to_token.get(id, '?')
            
            if token in OPERATORS:
                arity = OPERATORS[token]
            if token in OPERATORS:
                arity = OPERATORS[token]
                if len(stack) < arity: 
                    # Skip invalid op, just like GPU engine does
                    continue
                
                args = [stack.pop() for _ in range(arity)]
                args.reverse()
                
                # Infix string construction
                if arity == 2:
                    if token == 'pow': elem = f"pow({args[0]}, {args[1]})"
                    else: elem = f"({args[0]} {token} {args[1]})"
                else:
                    elem = f"{token}({args[0]})"
                stack.append(elem)
            else:
                stack.append(token)
                
        if len(stack) >= 1:
            return stack[-1]
        # print(f"DEBUG: RPN Decode Failed. IDs: {ids} Stack: {stack}")
        return "Invalid"

    def evaluate_batch(self, population: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor) -> torch.Tensor:
        """
        Evaluates the RPN population on the GPU.
        population: [PopSize, MaxLen] (Integers)
        x: [DataSize] (Floats)
        y_target: [DataSize] (Floats)
        
        Returns: RMSE per individual [PopSize]
        """
        B, L = population.shape
        D = x.shape[0]
        MAX_STACK = 10
        
        # Stack: [B, D, MAX_STACK]
        # We need D because each data point evaluates differently.
        # But wait, the structure is the same. Just the values differ.
        # We can treat B*D as the batch dimension for the stack operations to simplify?
        # PopSize=10k, Data=20 -> 200k items. Easy for GPU.
        
        # Reshape inputs for "Batch of Data Points"
        # Effective Batch Size = B * D
        eff_B = B * D
        
        # Expand population to match data: [B, 1, L] -> [B, D, L] -> [B*D, L]
        pop_expanded = population.unsqueeze(1).expand(-1, D, -1).reshape(eff_B, L)
        
        # Expand x to match population: [D] -> [1, D] -> [B, D] -> [B*D]
        x_expanded = x.unsqueeze(0).expand(B, -1).reshape(eff_B)
        
        # Stack tensor: [EffectiveBatch, StackDepth]
        stack = torch.zeros(eff_B, MAX_STACK, device=self.device, dtype=torch.float32)
        sp = torch.zeros(eff_B, device=self.device, dtype=torch.long) # Stack pointer (next empty slot)
        
        # DEBUG
        # print(f"DEBUG: Eval Batch B={B} L={L} EffB={eff_B}")
        
        # Constants lookup (naive)
        pi_val = torch.tensor(np.pi, device=self.device)
        e_val = torch.tensor(np.e, device=self.device)
        
        # Precompute IDs for speed
        id_x = self.grammar.token_to_id.get('x', -100)
        id_x0 = self.grammar.token_to_id.get('x0', -100)
        id_C = self.grammar.token_to_id.get('C', -100)
        id_pi = self.grammar.token_to_id.get('pi', -100)
        id_e = self.grammar.token_to_id.get('e', -100)
        
        # Binary Ops
        op_add = self.grammar.token_to_id.get('+', -100)
        op_sub = self.grammar.token_to_id.get('-', -100)
        op_mul = self.grammar.token_to_id.get('*', -100)
        op_div = self.grammar.token_to_id.get('/', -100)
        op_pow = self.grammar.token_to_id.get('pow', -100)
        
        # Unary Ops
        # Unary Ops
        op_sin = self.grammar.token_to_id.get('sin', -100)
        op_cos = self.grammar.token_to_id.get('cos', -100)
        op_tan = self.grammar.token_to_id.get('tan', -100)
        
        op_asin = self.grammar.token_to_id.get('asin', -100)
        op_acos = self.grammar.token_to_id.get('acos', -100)
        op_atan = self.grammar.token_to_id.get('atan', -100)
        
        op_exp = self.grammar.token_to_id.get('exp', -100)
        op_log = self.grammar.token_to_id.get('log', -100)
        op_sqrt = self.grammar.token_to_id.get('sqrt', -100)
        op_abs = self.grammar.token_to_id.get('abs', -100)
        op_neg = self.grammar.token_to_id.get('neg', -100)

        # Loop over RPN tokens
        for i in range(L):
            token = pop_expanded[:, i] # [EffectiveBatch]
            
            # Mask: Is this row active? (Not PAD)
            # PAD=0. If PAD, we do nothing (stack remains same)
            active_mask = (token != PAD_ID)
            if not active_mask.any(): continue
            
            # 1. Handle Operands (Push)
            # -------------------------
            # We calculate "value to push" for everyone, then apply.
            push_vals = torch.zeros(eff_B, device=self.device)
            is_operand = torch.zeros(eff_B, dtype=torch.bool, device=self.device)
            
            # x
            mask = (token == id_x) | (token == id_x0)
            if mask.any():
                push_vals[mask] = x_expanded[mask]
                is_operand = is_operand | mask
                
            # Constants
            mask = (token == id_pi)
            if mask.any():
                push_vals[mask] = pi_val
                is_operand = is_operand | mask
            
            mask = (token == id_e)
            if mask.any():
                push_vals[mask] = e_val
                is_operand = is_operand | mask
                
            mask = (token == id_C)
            if mask.any():
                push_vals[mask] = 1.0 # Default C=1.0 for GPU Search (optimization is hard here)
                is_operand = is_operand | mask
                
            # Numeric Literals (1..5)
            # (Assuming ids mapped sequentially or we map individually)
            # Simpler: Check range if mapped sequentially, or just discrete checks
            for val_str in ['1', '2', '3', '5']:
                vid = self.grammar.token_to_id.get(val_str, -999)
                mask = (token == vid)
                if mask.any():
                    push_vals[mask] = float(val_str)
                    is_operand = is_operand | mask

            # Apply Push
            if is_operand.any():
                # stack[b, sp[b]] = val
                # Safe scatter
                safe_sp = torch.clamp(sp, 0, MAX_STACK-1)
                stack.scatter_(1, safe_sp.unsqueeze(1), push_vals.unsqueeze(1))
                # Increment SP
                sp = sp + is_operand.long()


            # 2. Handle Binary Ops (Pop 2, Push 1)
            # ------------------------------------
            is_binary = (token == op_add) | (token == op_sub) | (token == op_mul) | (token == op_div) | (token == op_pow)
            
            if is_binary.any():
                # We need at least 2 items. If sp < 2, it's invalid.
                valid_op = is_binary & (sp >= 2)
                
                if valid_op.any():
                    # Calculate indices safely (clamp to valid range [0, 9] even if invalid row)
                    # We will mask out the result later, so garbage input is fine, but SEGV isn't.
                    safe_sp_minus_1 = torch.clamp(sp - 1, 0, MAX_STACK - 1)
                    safe_sp_minus_2 = torch.clamp(sp - 2, 0, MAX_STACK - 1)
                    
                    # Pop B (Top)
                    idx_b = safe_sp_minus_1.unsqueeze(1)
                    val_b = stack.gather(1, idx_b).squeeze(1)
                    
                    # Pop A (Second)
                    idx_a = safe_sp_minus_2.unsqueeze(1)
                    val_a = stack.gather(1, idx_a).squeeze(1)
                    
                    res = torch.zeros_like(val_a)
                    
                    # Compute
                    mask = (token == op_add) & valid_op
                    if mask.any(): res[mask] = val_a[mask] + val_b[mask]
                    
                    mask = (token == op_sub) & valid_op
                    if mask.any(): res[mask] = val_a[mask] - val_b[mask]
                    
                    mask = (token == op_mul) & valid_op
                    if mask.any(): res[mask] = val_a[mask] * val_b[mask]
                    
                    mask = (token == op_div) & valid_op
                    if mask.any(): 
                        # Protected Div
                        denom = val_b[mask]
                        denom = torch.where(denom.abs() < 1e-6, torch.tensor(1.0, device=self.device), denom)
                        res[mask] = val_a[mask] / denom
                        
                    mask = (token == op_pow) & valid_op
                    if mask.any():
                        # Protected Pow
                        base = val_a[mask].abs() + 1e-6
                        expon = torch.clamp(val_b[mask], -10, 10)
                        res[mask] = torch.pow(base, expon)
                    
                    # Push Result (at pos sp-2)
                    write_val = res
                    # Write pos must be valid too
                    write_pos = torch.clamp(sp - 2, 0, MAX_STACK-1)
                    
                    # Blend: Only update if valid_op
                    current_at_pos = stack.gather(1, write_pos.unsqueeze(1)).squeeze(1)
                    final_write_val = torch.where(valid_op, write_val, current_at_pos)
                    
                    stack.scatter_(1, write_pos.unsqueeze(1), final_write_val.unsqueeze(1))
                    
                    # Decrement SP by 1 (Pop 2, Push 1 = Net -1)
                    sp = sp - valid_op.long()


            # 3. Handle Unary Ops (Pop 1, Push 1)
            # -----------------------------------
            is_unary = (token == op_sin) | (token == op_cos) | (token == op_tan) | \
                       (token == op_asin) | (token == op_acos) | (token == op_atan) | \
                       (token == op_exp) | (token == op_log) | \
                       (token == op_sqrt) | (token == op_abs) | (token == op_neg)
                       
            if is_unary.any():
                valid_op = is_unary & (sp >= 1)
                
                if valid_op.any():
                    # Index safety
                    safe_sp_minus_1 = torch.clamp(sp - 1, 0, MAX_STACK - 1)
                    
                    # Peek Top (at sp-1)
                    idx_a = safe_sp_minus_1.unsqueeze(1)
                    val_a = stack.gather(1, idx_a).squeeze(1)
                    
                    res = torch.zeros_like(val_a)
                    
                    mask = (token == op_sin) & valid_op
                    if mask.any(): res[mask] = torch.sin(val_a[mask])
                    
                    mask = (token == op_cos) & valid_op
                    if mask.any(): res[mask] = torch.cos(val_a[mask])
                    
                    mask = (token == op_tan) & valid_op
                    if mask.any(): res[mask] = torch.tan(val_a[mask])
                    
                    mask = (token == op_asin) & valid_op
                    if mask.any():
                        # Clamp for safety
                        clamped = torch.clamp(val_a[mask], -0.999, 0.999) 
                        res[mask] = torch.asin(clamped)
                        
                    mask = (token == op_acos) & valid_op
                    if mask.any():
                        clamped = torch.clamp(val_a[mask], -0.999, 0.999)
                        res[mask] = torch.acos(clamped)
                        
                    mask = (token == op_atan) & valid_op
                    if mask.any(): res[mask] = torch.atan(val_a[mask])
                    
                    mask = (token == op_exp) & valid_op
                    if mask.any(): res[mask] = torch.exp(torch.clamp(val_a[mask], -20, 20))
                    
                    mask = (token == op_log) & valid_op
                    if mask.any(): res[mask] = torch.log(val_a[mask].abs() + 1e-6)
                    
                    mask = (token == op_sqrt) & valid_op
                    if mask.any(): res[mask] = torch.sqrt(val_a[mask].abs())
                    
                    mask = (token == op_abs) & valid_op
                    if mask.any(): res[mask] = torch.abs(val_a[mask])
                    
                    mask = (token == op_neg) & valid_op
                    if mask.any(): res[mask] = -val_a[mask]
                    
                    # Overwrite Top
                    write_pos = safe_sp_minus_1
                    current_at_pos = stack.gather(1, write_pos.unsqueeze(1)).squeeze(1)
                    final_write_val = torch.where(valid_op, res, current_at_pos)
                    
                    stack.scatter_(1, write_pos.unsqueeze(1), final_write_val.unsqueeze(1))
                    
                    # SP stays same

        
        # End of Loop
        # Result is at stack[0] (if valid)
        # Check validity: sp should be 1
        
        is_valid = (sp == 1)
        
        # Extract result
        # final_preds: [EffectiveBatch]
        final_preds = stack[:, 0]
        
        # For invalid, set to NaN or huge error
        final_preds = torch.where(is_valid, final_preds, torch.tensor(float('nan'), device=self.device))
        
        # Reshape back to [B, D]
        # [eff_B] -> [B, D]
        preds_matrix = final_preds.view(B, D)
        
        # Compute RMSE
        # y_target: [D] -> [1, D] -> [B, D]
        target_matrix = y_target.unsqueeze(0).expand(B, -1)
        
        # MSE: mean over dim 1 (Data points)
        diff = preds_matrix - target_matrix
        mse = torch.mean(diff**2, dim=1) # [B]
        
        # Handle NaNs (invalid formulas)
        mse = torch.where(torch.isnan(mse), torch.tensor(1e9, device=self.device), mse)
        rmse = torch.sqrt(mse)
        
        return rmse

    def evaluate_differentiable(self, population: torch.Tensor, constants: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Autograd-compatible evaluation for constant optimization.
        Only run this on a small subset (e.g. Top-K) due to memory cost of tracing.
        
        population: [B, L] (Long)
        constants: [B, MaxConstants] (Float, Requires Grad)
        x: [D]
        y_target: [D]
        
        Returns: (RMSE [B], Predictions [B, D])
        """
        B, L = population.shape
        D = x.shape[0]
        MAX_STACK = 10
        eff_B = B * D
        
        # Reshape inputs
        pop_expanded = population.unsqueeze(1).expand(-1, D, -1).reshape(eff_B, L)
        x_expanded = x.unsqueeze(0).expand(B, -1).reshape(eff_B)
        
        # Expand constants: [B, K] -> [B, D, K] -> [B*D, K]
        constants_expanded = constants.unsqueeze(1).expand(-1, D, -1).reshape(eff_B, -1)
        
        # Initial State (Functional, no in-place)
        stack = torch.zeros(eff_B, MAX_STACK, device=self.device, dtype=torch.float32)
        sp = torch.zeros(eff_B, device=self.device, dtype=torch.long)
        c_ptr = torch.zeros(eff_B, device=self.device, dtype=torch.long) # Pointer to which constant to use next
        
        # Constants lookup
        pi_val = torch.tensor(np.pi, device=self.device)
        e_val = torch.tensor(np.e, device=self.device)
        
        # IDs
        id_x = self.grammar.token_to_id.get('x', -100)
        id_x0 = self.grammar.token_to_id.get('x0', -100)
        id_C = self.grammar.token_to_id.get('C', -100)
        id_pi = self.grammar.token_to_id.get('pi', -100)
        id_e = self.grammar.token_to_id.get('e', -100)
        
        # Binary Ops
        op_add = self.grammar.token_to_id.get('+', -100)
        op_sub = self.grammar.token_to_id.get('-', -100)
        op_mul = self.grammar.token_to_id.get('*', -100)
        op_div = self.grammar.token_to_id.get('/', -100)
        op_pow = self.grammar.token_to_id.get('pow', -100)
        
        # Unary Ops
        op_sin = self.grammar.token_to_id.get('sin', -100)
        op_cos = self.grammar.token_to_id.get('cos', -100)
        op_tan = self.grammar.token_to_id.get('tan', -100)
        op_asin = self.grammar.token_to_id.get('asin', -100)
        op_acos = self.grammar.token_to_id.get('acos', -100)
        op_atan = self.grammar.token_to_id.get('atan', -100)
        op_exp = self.grammar.token_to_id.get('exp', -100)
        op_log = self.grammar.token_to_id.get('log', -100)
        op_sqrt = self.grammar.token_to_id.get('sqrt', -100)
        op_abs = self.grammar.token_to_id.get('abs', -100)
        op_neg = self.grammar.token_to_id.get('neg', -100)
        
        import torch.nn.functional as F

        for i in range(L):
            token = pop_expanded[:, i]
            active_mask = (token != PAD_ID)
            if not active_mask.any(): continue
            
            # --- 1. Push Operands ---
            push_vals = torch.zeros(eff_B, device=self.device)
            is_operand = torch.zeros(eff_B, dtype=torch.bool, device=self.device)
            
            # x
            mask = (token == id_x) | (token == id_x0)
            if mask.any():
                push_vals = torch.where(mask, x_expanded, push_vals)
                is_operand = is_operand | mask
                
            # Learnable Constants 'C'
            mask = (token == id_C)
            if mask.any():
                # Gather from constants buffer using c_ptr
                # safe_ptr = c_ptr.clamp(0, K-1)
                safe_ptr = torch.clamp(c_ptr, 0, constants_expanded.shape[1]-1)
                
                # Gather: constants[batch, ptr]
                # Gather requires [B, 1] index
                val_c = torch.gather(constants_expanded, 1, safe_ptr.unsqueeze(1)).squeeze(1)
                
                push_vals = torch.where(mask, val_c, push_vals)
                is_operand = is_operand | mask
                
                # Update pointer only for those who used C
                c_ptr = c_ptr + mask.long()
                
            # Fixed Constants
            mask = (token == id_pi)
            if mask.any():
                push_vals = torch.where(mask, pi_val, push_vals)
                is_operand = is_operand | mask
                
            mask = (token == id_e)
            if mask.any():
                push_vals = torch.where(mask, e_val, push_vals)
                is_operand = is_operand | mask
                
            # Literals
            for val_str in ['1', '2', '3', '5']:
                vid = self.grammar.token_to_id.get(val_str, -999)
                mask = (token == vid)
                if mask.any():
                    push_vals = torch.where(mask, torch.tensor(float(val_str), device=self.device), push_vals)
                    is_operand = is_operand | mask
            
            # Update Stack (Functional)
            if is_operand.any():
                # One-hot encoding of SP position
                # safe_sp = sp.clamp(0, MAX_STACK - 1)
                # target_mask: [B, MAX_STACK]
                target_mask = F.one_hot(torch.clamp(sp, 0, MAX_STACK-1), num_classes=MAX_STACK).bool()
                
                # Logic: if is_operand, replace stack[sp] with push_val
                # stack_new = stack * (~(is_operand & target_mask)) + push_vals * (is_operand & target_mask)
                # But is_operand is [B], target_mask is [B, 10].
                
                update_mask = target_mask & is_operand.unsqueeze(1) # [B, 10]
                
                # Expand push_vals to [B, 10]
                vals_expanded = push_vals.unsqueeze(1).expand(-1, MAX_STACK)
                
                stack = torch.where(update_mask, vals_expanded, stack)
                sp = sp + is_operand.long()
                
            # --- 2. Binary Ops ---
            is_binary = (token == op_add) | (token == op_sub) | (token == op_mul) | (token == op_div) | (token == op_pow)
            valid_op = is_binary & (sp >= 2)
            
            if valid_op.any():
                sp_1 = torch.clamp(sp - 1, 0, MAX_STACK-1)
                sp_2 = torch.clamp(sp - 2, 0, MAX_STACK-1)
                
                # Gather operands
                idx_b = F.one_hot(sp_1, MAX_STACK).bool()
                val_b = (stack * idx_b).sum(dim=1) # Differentiable gather
                
                idx_a = F.one_hot(sp_2, MAX_STACK).bool()
                val_a = (stack * idx_a).sum(dim=1)
                
                res = torch.zeros_like(val_a)
                
                mask = (token == op_add) & valid_op
                if mask.any(): res = torch.where(mask, val_a + val_b, res)
                
                mask = (token == op_sub) & valid_op
                if mask.any(): res = torch.where(mask, val_a - val_b, res)
                
                mask = (token == op_mul) & valid_op
                if mask.any(): res = torch.where(mask, val_a * val_b, res)
                
                mask = (token == op_div) & valid_op
                if mask.any():
                    denom = torch.where(val_b.abs() < 1e-6, torch.tensor(1.0, device=self.device), val_b)
                    res = torch.where(mask, val_a / denom, res)
                    
                mask = (token == op_pow) & valid_op
                if mask.any():
                    base = val_a.abs() + 1e-6
                    expon = torch.clamp(val_b, -10, 10)
                    res = torch.where(mask, torch.pow(base, expon), res)
                
                # Write back to sp-2
                write_pos = sp_2
                target_mask = F.one_hot(write_pos, MAX_STACK).bool()
                update_mask = target_mask & valid_op.unsqueeze(1)
                vals_expanded = res.unsqueeze(1).expand(-1, MAX_STACK)
                
                stack = torch.where(update_mask, vals_expanded, stack)
                sp = sp - valid_op.long()
                
            # --- 3. Unary Ops ---
            is_unary = (token == op_sin) | (token == op_cos) | (token == op_tan) | \
                       (token == op_asin) | (token == op_acos) | (token == op_atan) | \
                       (token == op_exp) | (token == op_log) | \
                       (token == op_sqrt) | (token == op_abs) | (token == op_neg)
            valid_op = is_unary & (sp >= 1)
            
            if valid_op.any():
                sp_1 = torch.clamp(sp - 1, 0, MAX_STACK-1)
                idx_a = F.one_hot(sp_1, MAX_STACK).bool()
                val_a = (stack * idx_a).sum(dim=1)
                
                res = torch.zeros_like(val_a)
                
                mask = (token == op_sin) & valid_op
                if mask.any(): res = torch.where(mask, torch.sin(val_a), res)
                
                mask = (token == op_cos) & valid_op
                if mask.any(): res = torch.where(mask, torch.cos(val_a), res)
                
                mask = (token == op_tan) & valid_op
                if mask.any(): res = torch.where(mask, torch.tan(val_a), res)
                
                mask = (token == op_asin) & valid_op
                if mask.any():
                    clamped = torch.clamp(val_a, -0.999, 0.999)
                    res = torch.where(mask, torch.asin(clamped), res)
                    
                mask = (token == op_acos) & valid_op
                if mask.any():
                    clamped = torch.clamp(val_a, -0.999, 0.999)
                    res = torch.where(mask, torch.acos(clamped), res)
                    
                mask = (token == op_atan) & valid_op
                if mask.any(): res = torch.where(mask, torch.atan(val_a), res)
                
                mask = (token == op_exp) & valid_op
                if mask.any(): res = torch.where(mask, torch.exp(torch.clamp(val_a, -20, 20)), res)
                
                mask = (token == op_log) & valid_op
                if mask.any(): res = torch.where(mask, torch.log(val_a.abs() + 1e-6), res)
                
                mask = (token == op_sqrt) & valid_op
                if mask.any(): res = torch.where(mask, torch.sqrt(val_a.abs()), res)
                
                mask = (token == op_abs) & valid_op
                if mask.any(): res = torch.where(mask, torch.abs(val_a), res)
                
                mask = (token == op_neg) & valid_op
                if mask.any(): res = torch.where(mask, -val_a, res)
                
                # Write back
                target_mask = F.one_hot(sp_1, MAX_STACK).bool()
                update_mask = target_mask & valid_op.unsqueeze(1)
                vals_expanded = res.unsqueeze(1).expand(-1, MAX_STACK)
                
                stack = torch.where(update_mask, vals_expanded, stack)

        # Final Extract
        is_valid = (sp == 1)
        final_preds = stack[:, 0]
        final_preds = torch.where(is_valid, final_preds, torch.tensor(float('nan'), device=self.device))
        
        preds_matrix = final_preds.view(B, D)
        
        # Loss
        target_matrix = y_target.unsqueeze(0).expand(B, -1)
        mse = torch.mean((preds_matrix - target_matrix)**2, dim=1) # [B]
        
        # Handling NaNs for gradient? 
        # If NaN, we can't backprop. Mask them out.
        # But we mostly optimize valid formulas.
        
        return mse, preds_matrix

    def run(self, x_data: List[float], y_data: List[float], seeds: List[str], timeout_sec=10) -> Optional[str]:
        """
        Main entry point.
        """
    def rpn_to_infix(self, rpn_tensor: torch.Tensor, constants: torch.Tensor = None) -> str:
        """
        Decodes RPN tensor to Infix string (CPU-style formatting).
        """
        vocab = self.grammar.id_to_token
        stack = []
        const_idx = 0
        
        def format_const(val):
            # Match C++ format_constant
            if abs(val - round(val)) < 1e-9:
                return str(int(round(val)))
            if abs(val) >= 1e6 or abs(val) <= 1e-6:
                return f"{val:.8e}"
            s = f"{val:.8f}"
            s = s.rstrip('0').rstrip('.')
            return s if s else "0"

        for token_id in rpn_tensor:
            token_id = token_id.item()
            if token_id == PAD_ID: break
            
            token = vocab.get(token_id, "")
            
            if token in self.grammar.OPERATORS:
                arity = self.grammar.token_arity.get(token, 2)
                if arity == 1:
                    if not stack: return "Invalid"
                    a = stack.pop()
                    if token == 's': stack.append(f"sin({a})")
                    elif token == 'c': stack.append(f"cos({a})")
                    elif token == 'l': stack.append(f"log({a})")
                    elif token == 'e': stack.append(f"exp({a})")
                    elif token == 'q': stack.append(f"sqrt({a})")
                    elif token == 'a': stack.append(f"abs({a})")
                    elif token == 'n': stack.append(f"sign({a})")
                    elif token == '_': stack.append(f"floor({a})")
                    elif token == '!': stack.append(f"({a})!")
                    else: stack.append(f"{token}({a})")
                else: # Binary
                    if len(stack) < 2: return "Invalid"
                    b = stack.pop()
                    a = stack.pop()
                    
                    # Handle A + (-B) -> (A - B)
                    # Handle 0 - B -> (-B)
                    if token == '+' and b.startswith("-") and not b.startswith("(-"):
                         stack.append(f"({a} - {b[1:]})")
                    elif token == '-' and a == "0":
                         stack.append(f"(-{b})")
                    else:
                         stack.append(f"({a} {token} {b})")
            elif token == 'C':
                val = 1.0
                if constants is not None and const_idx < len(constants):
                    val = constants[const_idx].item()
                    const_idx += 1
                stack.append(format_const(val))
            elif token.startswith('x'):
                # Handle x0, x1
                # If token is just 'x', assume x0
                if token == 'x': stack.append("x0")
                else: stack.append(token)
            else:
                stack.append(str(token))
                
        if len(stack) == 1:
            return stack[0]
        return "Invalid"


    def run(self, x_values: List[float], y_targets: List[float], seeds: List[str], timeout_sec=10, callback=None) -> Optional[str]:
        """
        Main evolutionary loop on GPU.
        callback: function(gen, best_mse, best_rpn, best_consts, is_new_best) -> None
        """
        start_time = time.time()
        
        # 1. Setup Data
        x_t = torch.tensor(x_values, device=self.device, dtype=torch.float32)
        y_t = torch.tensor(y_targets, device=self.device, dtype=torch.float32)
        
        if x_t.ndim > 1: x_t = x_t.flatten() 
        if y_t.ndim > 1: y_t = y_t.flatten()

        # print(f"[GPU Worker] Initializing Tensor Population ({self.pop_size})...")
        
        # --- 0. Target Pattern Detection ("The Sniper") ---
        # Swiftly check for trivial Linear or Geometric patterns
        if x_t.shape[0] > 2:
            try:
                # Prepare X matrix [N, 2] for (slope, intercept)
                X_mat = torch.stack([x_t, torch.ones_like(x_t)], dim=1)
                
                # A. Linear Check (y = mx + c)
                # Solve: X * [m, c] = y
                try:
                    solution = torch.linalg.lstsq(X_mat, y_t).solution
                    m, c = solution[0].item(), solution[1].item()
                    y_pred = m * x_t + c
                    
                    # Check residuals (Relative Error or R2?)
                    # Use Normalized RMSE
                    res_std = torch.std(y_t - y_pred)
                    y_std = torch.std(y_t)
                    if y_std > 1e-9 and (res_std / y_std) < 1e-4:
                        # Found Linear!
                        # print(f"[GPU Sniper] Detected Linear Pattern: {m:.4f}*x + {c:.4f}")
                        if abs(c) < 1e-5: return f"({m:.4f} * x)"
                        return f"(({m:.4f} * x) + {c:.4f})"
                except: pass

                # B. Geometric Check (y = A * e^(Bx) -> log(y) = log(A) + Bx)
                if torch.all(y_t > 0):
                    log_y = torch.log(y_t)
                    solution_g = torch.linalg.lstsq(X_mat, log_y).solution
                    B, log_A = solution_g[0].item(), solution_g[1].item()
                    
                    y_pred_log = B * x_t + log_A
                    res_std_log = torch.std(log_y - y_pred_log)
                    
                    if res_std_log < 1e-4:
                        # Found Geometric!
                        # Formula: exp(log_A + Bx)
                        # print(f"[GPU Sniper] Detected Geometric Pattern.")
                        return f"exp(({B:.4f} * x) + {log_A:.4f})"
            except Exception as e:
                pass

        
        # 2. Initialize Population & Constants
        seed_tensor = self.infix_to_rpn(seeds)
        num_seeds = seed_tensor.shape[0]
        
        population = torch.zeros(self.pop_size, self.max_len, device=self.device, dtype=torch.long)
        pop_constants = torch.randn(self.pop_size, self.max_constants, device=self.device) # Learnable Constants
        
        # Fill seeds
        population[:num_seeds] = seed_tensor
        
        # Fill rest
        if num_seeds > 0:
            remaining = self.pop_size - num_seeds
            src_indices = torch.randint(0, num_seeds, (remaining,), device=self.device)
            population[num_seeds:] = seed_tensor[src_indices]
            
            mutation_mask = torch.rand(remaining, self.max_len, device=self.device) < 0.3
            random_tokens = torch.randint(1, self.grammar.vocab_size, (remaining, self.max_len), device=self.device)
            population[num_seeds:] = torch.where(mutation_mask, random_tokens, population[num_seeds:])
        else:
             print("[GPU Worker] No seeds provided.")
             return None

        best_formula_str = None
        best_rmse = float('inf')
        
        generations = 0
        COMPLEXITY_PENALTY = 0.01
        
        # --- Dynamic Adaptation ("The Thermostat") ---
        stagnation_counter = 0
        current_mutation_rate = 0.15  # Base rate
        current_chaos_rate = 0.01     # Base chaos
        last_improvement_gen = 0

        
        while time.time() - start_time < timeout_sec:
            generations += 1
            
            # A. Evaluate (Standard)
            fitness_rmse = self.evaluate_batch(population, x_t, y_t)
            
            # Calculate Complexity (Length)
            # Penalize longer formulas to encourage simplicity (Occam's Razor)
            lengths = (population != PAD_ID).sum(dim=1).float()
            # fitness = rmse * (1 + penalty * length) + length * epsilon (for 0-rmse ties)
            fitness_penalized = fitness_rmse * (1.0 + COMPLEXITY_PENALTY * lengths) + lengths * 1e-6
            
            # B. Constant Optimization (Elitism)
            # Pick Top 50 candidates based on PENALIZED fitness to refine
            k_opt = 50
            top_vals, top_indices = torch.topk(fitness_penalized, k_opt, largest=False)
            
            # Run Gradient Descent on these constants
            refined_consts, refined_mse = self.optimize_constants(
                population[top_indices], 
                pop_constants[top_indices], 
                x_t, y_t, steps=10, lr=0.1
            )
            
            # Write back results
            pop_constants[top_indices] = refined_consts.detach()
            fitness_rmse[top_indices] = refined_mse.detach()
            
            # Re-calculate penalized fitness for optimized ones
            refined_lengths = lengths[top_indices]
            fitness_penalized[top_indices] = refined_mse.detach() * (1.0 + COMPLEXITY_PENALTY * refined_lengths) + refined_lengths * 1e-6
            
            # --- Algebraic Simplification (The Cleaner) ---
            # Every 5 generations, simplify the top elites to remove clutter (x*1 -> x)
            if generations % 5 == 0:
                try:
                    import sympy
                    # Simplify the optimization candidates (which are already best)
                    # We operate in-place on the population
                    for idx_in_top in range(len(top_indices)):
                        pop_idx = top_indices[idx_in_top]
                        
                        # 1. Decode to string (with optimized constants)
                        rpn = population[pop_idx].unsqueeze(0)
                        consts = pop_constants[pop_idx]
                        expr_str = self.rpn_to_infix(rpn, consts)
                        
                        if expr_str == "Invalid": continue
                        
                        # 2. Simplify with SymPy
                        try:
                            # Parse and simplify
                            sym_expr = sympy.sympify(expr_str)
                            simplified_sym = sympy.simplify(sym_expr)
                            
                            # 3. Re-encode to RPN + Constants
                            new_rpn_ids, new_consts_vals = self.sympy_to_rpn(simplified_sym)
                            
                            # Update if valid and fits
                            if len(new_rpn_ids) <= self.max_len:
                                # Overwrite population
                                population[pop_idx] = torch.tensor(new_rpn_ids + [PAD_ID]*(self.max_len - len(new_rpn_ids)), device=self.device)
                                
                                # Overwrite constants
                                new_c_tensor = torch.zeros(self.max_constants, device=self.device)
                                num_c = min(len(new_consts_vals), self.max_constants)
                                if num_c > 0:
                                    new_c_tensor[:num_c] = torch.tensor(new_consts_vals[:num_c], device=self.device)
                                pop_constants[pop_idx] = new_c_tensor
                                
                                # Note: Fitness needs update? 
                                # Simplification should preserve semantics, so RMSE is same. 
                                # But length might decrease, so fitness improves!
                                # Let's re-eval next generation or now?
                                # For safety, we leave it. It will be re-evaluated next gen or by selection if we updated lengths.
                                lengths[pop_idx] = len(new_rpn_ids) # Approximate update
                                
                        except Exception as e:
                            # print(f"Simplification failed for {expr_str}: {e}")
                            pass
                except ImportError:
                    pass

            # Check Best (based on Raw RMSE, but maybe Length matters for user? stick to RMSE)
            min_rmse, min_idx = torch.min(fitness_rmse, dim=0)
            if min_rmse.item() < best_rmse:
                best_rmse = min_rmse.item()
                best_rpn = population[min_idx].unsqueeze(0)
                best_consts_vec = pop_constants[min_idx]
                best_formula_str = self.rpn_to_infix(best_rpn, best_consts_vec)
                # print(f"[GPU Worker] New Best: {best_formula_str} (RMSE: {best_rmse:.5f})")
                
                if callback:
                    callback(generations, best_rmse, best_rpn, best_consts_vec, True)
                
                # Reset Stagnation
                stagnation_counter = 0
                current_mutation_rate = 0.15
                current_chaos_rate = 0.01
                last_improvement_gen = generations
            else:
                stagnation_counter += 1
                
            if callback and (generations % 100 == 0 or generations == 1) and best_rpn is not None:
                 # Pass current global best
                 callback(generations, best_rmse, best_rpn, best_consts_vec, False) # False = not new best, just update
                 
            # Adaptation Logic


                
            # Adaptation Logic
            if stagnation_counter > 20:
                # Boost Mutation/Chaos incrementally
                current_mutation_rate = min(0.40, current_mutation_rate + 0.02)
                current_chaos_rate = min(0.05, current_chaos_rate + 0.005)
                
            # --- Island Cataclysm (Nuclear Reset) ---
            if stagnation_counter >= 50:
                 # print(f"[GPU Worker] CATACLYSM! Global Stagnation {stagnation_counter}. Resetting population.")
                 # Keep Top 1 (min_idx)
                 # We need to construct a new population where index 0 is best, rest random.
                 
                 # 1. Save Best
                 saved_best_rpn = population[min_idx].clone()
                 saved_best_c = pop_constants[min_idx].clone()
                 
                 # 2. Randomize All
                 population = torch.randint(1, self.grammar.vocab_size, (self.pop_size, self.max_len), device=self.device)
                 pop_constants = torch.randn(self.pop_size, self.max_constants, device=self.device)
                 
                 # 3. Restore Best at 0
                 population[0] = saved_best_rpn
                 pop_constants[0] = saved_best_c
                 
                 # 4. Reset Stats
                 stagnation_counter = 0
                 current_mutation_rate = 0.15
                 current_chaos_rate = 0.01
                 
                 # Force re-eval? Next loop will evaluate.


            
            # C. Island Selection & Tournament (Vectorized)
            # 1. Reshape to [NumIslands, IslandSize]
            view_fit = fitness_penalized.view(self.n_islands, self.island_size)
            view_pop = population.view(self.n_islands, self.island_size, self.max_len)
            view_const = pop_constants.view(self.n_islands, self.island_size, self.max_constants)

            # 2. Elitism per Island
            k_elite_island = max(1, int(self.island_size * 0.1))
            # topk returns indices relative to the island
            elite_vals, elite_local_idx = torch.topk(view_fit, k_elite_island, dim=1, largest=False)

            # Gather Elites
            # Expansion for gather: [Islands, K, L]
            gather_idx_pop = elite_local_idx.unsqueeze(-1).expand(-1, -1, self.max_len)
            elites_pop = torch.gather(view_pop, 1, gather_idx_pop)
            
            gather_idx_c = elite_local_idx.unsqueeze(-1).expand(-1, -1, self.max_constants)
            elites_c = torch.gather(view_const, 1, gather_idx_c)

            # 3. Tournament for Offspring
            num_offspring = self.island_size - k_elite_island
            
            # Generate random pairs of indices [Islands, NumOffspring]
            p1_idx = torch.randint(0, self.island_size, (self.n_islands, num_offspring), device=self.device)
            p2_idx = torch.randint(0, self.island_size, (self.n_islands, num_offspring), device=self.device)
            
            # Compare fitness
            f1 = torch.gather(view_fit, 1, p1_idx)
            f2 = torch.gather(view_fit, 1, p2_idx)
            
            winner_idx = torch.where(f1 < f2, p1_idx, p2_idx)
            
            # Gather Winners
            gather_idx_win_pop = winner_idx.unsqueeze(-1).expand(-1, -1, self.max_len)
            winners_pop = torch.gather(view_pop, 1, gather_idx_win_pop)
            
            gather_idx_win_c = winner_idx.unsqueeze(-1).expand(-1, -1, self.max_constants)
            winners_c = torch.gather(view_const, 1, gather_idx_win_c)
            
            # 4. Migration (Every 10 gens, Ring Topology)
            # Inject neighbor's elites into the worst slots of current offspring
            if generations % 10 == 0 and self.n_islands > 1:
                # Rotate elites: Island i gets elites from i-1 (or i+1 if we roll pos)
                migrants_pop = torch.roll(elites_pop, shifts=1, dims=0)
                migrants_c = torch.roll(elites_c, shifts=1, dims=0)
                
                # Replace last k_elite spots in WINNERS (weakest offspring? actually tournament winners are random quality)
                # But acceptable to just replace.
                if num_offspring >= k_elite_island:
                    winners_pop[:, -k_elite_island:] = migrants_pop
                    winners_c[:, -k_elite_island:] = migrants_c
            
            # D. Mutation (On Offspring Only)
            
            # 1. Safe Arity-Preserving Mutation (Dynamic Rate)
            mask = torch.rand(winners_pop.shape, device=self.device) < current_mutation_rate
            current_arities = self.token_arity[winners_pop]
            
            # Arity 0 -> Arity 0
            if len(self.arity_0_ids) > 0:
                noise_0 = self.arity_0_ids[torch.randint(0, len(self.arity_0_ids), winners_pop.shape, device=self.device)]
                winners_pop = torch.where(mask & (current_arities == 0), noise_0, winners_pop)
                
            # Arity 1 -> Arity 1
            if len(self.arity_1_ids) > 0:
                noise_1 = self.arity_1_ids[torch.randint(0, len(self.arity_1_ids), winners_pop.shape, device=self.device)]
                winners_pop = torch.where(mask & (current_arities == 1), noise_1, winners_pop)
                
            # Arity 2 -> Arity 2
            if len(self.arity_2_ids) > 0:
                noise_2 = self.arity_2_ids[torch.randint(0, len(self.arity_2_ids), winners_pop.shape, device=self.device)]
                winners_pop = torch.where(mask & (current_arities == 2), noise_2, winners_pop)
                
            # 2. Chaos Mutation (Structure changing, Low Rate: 1%)
            chaos_mask = torch.rand(winners_pop.shape, device=self.device) < current_chaos_rate
            chaos_noise = torch.randint(1, self.grammar.vocab_size, winners_pop.shape, device=self.device)
            winners_pop = torch.where(chaos_mask, chaos_noise, winners_pop)
            
            # Constant Mutation
            c_noise = torch.randn_like(winners_c) * 0.1
            winners_c = winners_c + c_noise
            
            # 5. Reconstruct Population
            # Concat [Elites, Offspring] -> [Islands, Size, L]
            next_pop_view = torch.cat([elites_pop, winners_pop], dim=1)
            next_c_view = torch.cat([elites_c, winners_c], dim=1)
            
            # Flatten to [PopSize, L]
            population = next_pop_view.view(self.pop_size, self.max_len)
            pop_constants = next_c_view.view(self.pop_size, self.max_constants)
            
        print(f"[GPU Worker] Finished. Gens: {generations}. Best RMSE: {best_rmse:.5f}")
        return best_formula_str

    def sympy_to_rpn(self, sym_expr) -> Tuple[List[int], List[float]]:
        """
        Converts a SymPy expression to RPN token IDs and a list of constants.
        """
        import sympy
        
        rpn_ids = []
        constants = []
        
        def visit(node):
            if node.is_Number:
                val = float(node)
                # Check for simple constants
                if node == sympy.pi:
                    rpn_ids.append(self.grammar.token_to_id['pi'])
                elif node == sympy.E:
                    rpn_ids.append(self.grammar.token_to_id['e'])
                elif val == 1.0 and '1' in self.grammar.token_to_id:
                     rpn_ids.append(self.grammar.token_to_id['1'])
                elif val == 2.0 and '2' in self.grammar.token_to_id:
                     rpn_ids.append(self.grammar.token_to_id['2'])
                elif val == 3.0 and '3' in self.grammar.token_to_id:
                     rpn_ids.append(self.grammar.token_to_id['3'])
                elif val == 5.0 and '5' in self.grammar.token_to_id:
                     rpn_ids.append(self.grammar.token_to_id['5'])
                else:
                    # Generic Constant -> C
                    rpn_ids.append(self.grammar.token_to_id['C'])
                    constants.append(val)
            elif node.is_Symbol:
                name = str(node)
                if name in self.grammar.token_to_id:
                    rpn_ids.append(self.grammar.token_to_id[name])
                else:
                    # Variable mismatch or unknown
                    rpn_ids.append(self.grammar.token_to_id.get('x', 0)) # Fallback
            elif isinstance(node, sympy.Add):
                # Sympy Add is n-ary. Convert to chain of binary adds.
                # A + B + C -> A B + C +
                args = node.args
                visit(args[0])
                for i in range(1, len(args)):
                    visit(args[i])
                    rpn_ids.append(self.grammar.token_to_id['+'])
            elif isinstance(node, sympy.Mul):
                # A * B * C -> A B * C *
                args = node.args
                visit(args[0])
                for i in range(1, len(args)):
                    visit(args[i])
                    rpn_ids.append(self.grammar.token_to_id['*'])
            elif isinstance(node, sympy.Pow):
                visit(node.base)
                visit(node.exp)
                rpn_ids.append(self.grammar.token_to_id['pow'])
            elif isinstance(node, sympy.sin):
                visit(node.args[0])
                rpn_ids.append(self.grammar.token_to_id['sin'])
            elif isinstance(node, sympy.cos):
                visit(node.args[0])
                rpn_ids.append(self.grammar.token_to_id['cos'])
            elif isinstance(node, sympy.tan):
                visit(node.args[0])
                rpn_ids.append(self.grammar.token_to_id['tan'])
            elif isinstance(node, sympy.exp):
                visit(node.args[0])
                rpn_ids.append(self.grammar.token_to_id['exp'])
            elif isinstance(node, sympy.log):
                visit(node.args[0])
                rpn_ids.append(self.grammar.token_to_id['log'])
            # Add other functions as needed (asin, acos, etc.)
            else:
                 # Fallback for unknown
                 # Check if it is a known function by string
                 func_name = str(node.func)
                 if func_name in self.grammar.token_to_id:
                      visit(node.args[0]) # Assumes unary
                      rpn_ids.append(self.grammar.token_to_id[func_name])
                 else:
                     # raise ValueError(f"Unknown node: {node}")
                     # Fallback to ignore? Or try to approximate?
                     # If we raise, the simplification block catches it and aborts.
                     raise ValueError(f"Unknown node: {node}")
        
        visit(sym_expr)
        return rpn_ids, constants


In [None]:
%%writefile AlphaSymbolic/core/__init__.py


In [None]:
%%writefile AlphaSymbolic/core/gpu/__init__.py
from .engine import TensorGeneticEngine


In [None]:
%%writefile AlphaSymbolic/core/gpu/benchmark.py
"""
Performance Benchmarking Suite for GPU GP Engine.

Standard benchmark problems for evaluating symbolic regression performance.
"""
import torch
import numpy as np
import time
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass


@dataclass
class BenchmarkResult:
    """Result from a single benchmark run."""
    problem_name: str
    target_formula: str
    found_formula: Optional[str]
    rmse: float
    exact_match: bool
    time_seconds: float
    generations: int


class BenchmarkSuite:
    """
    Standard symbolic regression benchmark problems.
    
    Includes problems from:
    - Nguyen benchmark suite
    - Keijzer benchmarks
    - Custom problems
    """
    
    # Standard benchmark problems: (name, formula, x_range, n_points)
    PROBLEMS = {
        # Nguyen benchmarks
        'nguyen-1': ('x^3 + x^2 + x', (-1, 1), 20),
        'nguyen-2': ('x^4 + x^3 + x^2 + x', (-1, 1), 20),
        'nguyen-3': ('x^5 + x^4 + x^3 + x^2 + x', (-1, 1), 20),
        'nguyen-4': ('x^6 + x^5 + x^4 + x^3 + x^2 + x', (-1, 1), 20),
        'nguyen-5': ('sin(x^2)*cos(x) - 1', (-1, 1), 20),
        'nguyen-6': ('sin(x) + sin(x + x^2)', (-1, 1), 20),
        'nguyen-7': ('log(x+1) + log(x^2+1)', (0, 2), 20),
        'nguyen-8': ('sqrt(x)', (0, 4), 20),
        
        # Keijzer benchmarks (simpler)
        'keijzer-1': ('x^3/5 + x^2/2 - x', (-3, 3), 20),
        'keijzer-4': ('x^3 * exp(-x) * cos(x) * sin(x)', (0, 10), 20),
        
        # Simple polynomials
        'poly-1': ('x^2', (-5, 5), 20),
        'poly-2': ('x^3 - 2*x', (-3, 3), 20),
        'poly-3': ('2*x^2 + 3*x + 1', (-5, 5), 20),
        
        # Trigonometric
        'trig-1': ('sin(x)', (-3.14, 3.14), 20),
        'trig-2': ('cos(x)*sin(x)', (-3.14, 3.14), 20),
        
        # Mixed
        'mixed-1': ('x*sin(x)', (-5, 5), 20),
        'mixed-2': ('sqrt(x)*log(x+1)', (0.1, 10), 20),
    }
    
    def __init__(self, engine_factory):
        """
        Args:
            engine_factory: Function that creates a TensorGeneticEngine instance
        """
        self.engine_factory = engine_factory
        self.results: List[BenchmarkResult] = []
    
    def generate_data(self, formula: str, x_range: Tuple[float, float], n_points: int) -> Tuple[List[float], List[float]]:
        """Generate x,y data from a formula string."""
        import math
        
        x_vals = np.linspace(x_range[0], x_range[1], n_points).tolist()
        y_vals = []
        
        for x in x_vals:
            try:
                # Safe eval with math functions
                y = eval(formula.replace('^', '**'), {"x": x, "sin": math.sin, "cos": math.cos, 
                                                        "tan": math.tan, "exp": math.exp, 
                                                        "log": math.log, "sqrt": math.sqrt,
                                                        "pi": math.pi, "e": math.e})
                y_vals.append(float(y))
            except:
                y_vals.append(0.0)
        
        return x_vals, y_vals
    
    def run_benchmark(self, problem_name: str, timeout_sec: float = 10) -> BenchmarkResult:
        """Run a single benchmark problem."""
        if problem_name not in self.PROBLEMS:
            raise ValueError(f"Unknown problem: {problem_name}")
        
        formula, x_range, n_points = self.PROBLEMS[problem_name]
        x_vals, y_vals = self.generate_data(formula, x_range, n_points)
        
        engine = self.engine_factory()
        
        start_time = time.time()
        result = engine.run(x_vals, y_vals, [], timeout_sec=timeout_sec)
        elapsed = time.time() - start_time
        
        # Calculate RMSE of found solution
        rmse = float('inf')
        if result:
            try:
                # Evaluate found formula
                import math
                found_y = []
                for x in x_vals:
                    try:
                        y = eval(result.replace('^', '**'), 
                                {"x": x, "x0": x, "sin": math.sin, "cos": math.cos,
                                 "tan": math.tan, "exp": math.exp, "log": math.log, 
                                 "sqrt": math.sqrt, "abs": abs, "pi": math.pi, "e": math.e})
                        found_y.append(float(y))
                    except:
                        found_y.append(float('inf'))
                
                mse = sum((a-b)**2 for a,b in zip(y_vals, found_y)) / len(y_vals)
                rmse = mse ** 0.5
            except:
                pass
        
        # Check exact match (simplified comparison)
        exact_match = rmse < 1e-6
        
        bench_result = BenchmarkResult(
            problem_name=problem_name,
            target_formula=formula,
            found_formula=result,
            rmse=rmse,
            exact_match=exact_match,
            time_seconds=elapsed,
            generations=0  # Would need to track in engine
        )
        
        self.results.append(bench_result)
        return bench_result
    
    def run_suite(self, problem_names: List[str] = None, timeout_sec: float = 10, 
                  callback=None) -> Dict[str, BenchmarkResult]:
        """
        Run a suite of benchmark problems.
        
        Args:
            problem_names: List of problems to run (default: all)
            timeout_sec: Timeout per problem
            callback: Optional progress callback
            
        Returns:
            Dict mapping problem name to result
        """
        if problem_names is None:
            problem_names = list(self.PROBLEMS.keys())
        
        results = {}
        for i, name in enumerate(problem_names):
            if callback:
                callback(f"Running {name} ({i+1}/{len(problem_names)})")
            
            results[name] = self.run_benchmark(name, timeout_sec)
        
        return results
    
    def get_summary(self) -> Dict:
        """Get summary statistics of benchmark results."""
        if not self.results:
            return {}
        
        n_exact = sum(1 for r in self.results if r.exact_match)
        avg_rmse = np.mean([r.rmse for r in self.results if r.rmse < float('inf')])
        avg_time = np.mean([r.time_seconds for r in self.results])
        
        return {
            'n_problems': len(self.results),
            'n_exact_matches': n_exact,
            'success_rate': n_exact / len(self.results) * 100,
            'avg_rmse': avg_rmse,
            'avg_time_seconds': avg_time,
        }
    
    def print_report(self):
        """Print a formatted benchmark report."""
        print("\n" + "="*70)
        print("GPU GP ENGINE BENCHMARK REPORT")
        print("="*70)
        
        for result in self.results:
            status = "✓" if result.exact_match else "✗"
            print(f"\n{status} {result.problem_name}")
            print(f"  Target: {result.target_formula}")
            print(f"  Found:  {result.found_formula or 'None'}")
            print(f"  RMSE:   {result.rmse:.6e}")
            print(f"  Time:   {result.time_seconds:.2f}s")
        
        summary = self.get_summary()
        print("\n" + "-"*70)
        print(f"SUMMARY: {summary.get('n_exact_matches', 0)}/{summary.get('n_problems', 0)} exact matches")
        print(f"Success Rate: {summary.get('success_rate', 0):.1f}%")
        print(f"Average RMSE: {summary.get('avg_rmse', 0):.6e}")
        print(f"Average Time: {summary.get('avg_time_seconds', 0):.2f}s")
        print("="*70)


def create_benchmark_suite(device=None, pop_size=1000):
    """Factory function to create a benchmark suite."""
    from . import TensorGeneticEngine
    
    def factory():
        return TensorGeneticEngine(device=device, pop_size=pop_size, n_islands=4)
    
    return BenchmarkSuite(factory)


In [None]:
%%writefile AlphaSymbolic/core/gpu/config.py
import math

class GpuGlobals:
    # ============================================================
    #                  PARÁMETROS GLOBALES
    # ============================================================

    # ----------------------------------------
    # Datos del Problema (Regresión Simbólica)
    # ----------------------------------------
    USE_LOG_TRANSFORMATION = True

    # ----------------------------------------
    # Configuración General del Algoritmo Genético
    # ----------------------------------------
    FORCE_CPU_MODE = False # Si es True, usa CPU aunque CUDA esté disponible
    
    # Tamaño de población - MÁXIMO para RTX 3050 (4GB VRAM)
    POP_SIZE = 5000       # Agresivo - usa ~3GB VRAM
    GENERATIONS = 100000  # Más generaciones
    NUM_ISLANDS = 8       # Máxima diversidad
    MIN_POP_PER_ISLAND = 10

    # --- Fórmula Inicial ---
    USE_INITIAL_FORMULA = False
    INITIAL_FORMULA_STRING = "log((x1+exp((((((1.28237193+((x0+2.59195138)+8.54688985))*x0)+(log((((x2/-0.99681346)-(x0-8.00219939))/(0.35461932-x2)))+(x0+(88.95319019/((x0+x0)+x0)))))-x1)/((exp(exp(((exp(x2)*(1.39925709/x0))^exp(x0))))+0.76703064)*6.05423753)))))"

    # ----------------------------------------
    # Parámetros del Modelo de Islas
    # ----------------------------------------
    MIGRATION_INTERVAL = 100
    MIGRATION_SIZE = 50

    # ----------------------------------------
    # Parámetros de Generación Inicial de Árboles
    # ----------------------------------------
    MAX_TREE_DEPTH_INITIAL = 8
    TERMINAL_VS_VARIABLE_PROB = 0.75
    CONSTANT_MIN_VALUE = -10.0
    CONSTANT_MAX_VALUE = 10.0
    CONSTANT_INT_MIN_VALUE = -10
    CONSTANT_INT_MAX_VALUE = 10
    USE_HARD_DEPTH_LIMIT = True
    MAX_TREE_DEPTH_HARD_LIMIT = 30  # MÁXIMO - expresiones muy complejas

    # ----------------------------------------
    # Parámetros de Operadores Genéticos (Configuración de Operadores)
    # ----------------------------------------
    USE_OP_PLUS     = True
    USE_OP_MINUS    = True
    USE_OP_MULT     = True
    USE_OP_DIV      = True
    USE_OP_POW      = True
    USE_OP_MOD      = False
    USE_OP_SIN      = False
    USE_OP_COS      = False
    USE_OP_LOG      = True
    USE_OP_EXP      = True
    USE_OP_FACT     = False
    USE_OP_FLOOR    = False
    USE_OP_GAMMA    = True
    USE_OP_ASIN     = False
    USE_OP_ACOS     = False
    USE_OP_ATAN     = False

    # Pesos de Operadores (Order: +, -, *, /, ^, %, s, c, l, e, !, _, g, S, C, T)
    OPERATOR_WEIGHTS = [
        0.20 * (1.0 if USE_OP_PLUS else 0.0),
        0.20 * (1.0 if USE_OP_MINUS else 0.0),
        0.20 * (1.0 if USE_OP_MULT else 0.0),
        0.15 * (1.0 if USE_OP_DIV else 0.0),
        0.10 * (1.0 if USE_OP_POW else 0.0),
        0.02 * (1.0 if USE_OP_MOD else 0.0),
        0.10 * (1.0 if USE_OP_SIN else 0.0),
        0.10 * (1.0 if USE_OP_COS else 0.0),
        0.05 * (1.0 if USE_OP_LOG else 0.0),
        0.05 * (1.0 if USE_OP_EXP else 0.0),
        0.01 * (1.0 if USE_OP_FACT else 0.0),
        0.01 * (1.0 if USE_OP_FLOOR else 0.0),
        0.01 * (1.0 if USE_OP_GAMMA else 0.0),
        0.01 * (1.0 if USE_OP_ASIN else 0.0),
        0.01 * (1.0 if USE_OP_ACOS else 0.0),
        0.01 * (1.0 if USE_OP_ATAN else 0.0)
    ]

    # ----------------------------------------
    # Parámetros de Operadores Genéticos (Mutación, Cruce, Selección)
    # ----------------------------------------
    BASE_MUTATION_RATE = 0.30
    BASE_ELITE_PERCENTAGE = 0.15
    DEFAULT_CROSSOVER_RATE = 0.60
    DEFAULT_TOURNAMENT_SIZE = 4
    MAX_TREE_DEPTH_MUTATION = 8
    MUTATE_INSERT_CONST_PROB = 0.6
    MUTATE_INSERT_CONST_INT_MIN = 1
    MUTATE_INSERT_CONST_INT_MAX = 5
    MUTATE_INSERT_CONST_FLOAT_MIN = 0.5
    MUTATE_INSERT_CONST_FLOAT_MAX = 5.0

    # ----------------------------------------
    # Parámetros de Fitness y Evaluación
    # ----------------------------------------
    COMPLEXITY_PENALTY = 0.01
    USE_RMSE_FITNESS = True
    FITNESS_ORIGINAL_POWER = 1.3
    FITNESS_PRECISION_THRESHOLD = 0.001
    FITNESS_PRECISION_BONUS = 0.0001
    FITNESS_EQUALITY_TOLERANCE = 1e-9
    EXACT_SOLUTION_THRESHOLD = 1e-8

    # ----------------------------------------
    # Fitness Ponderado (Weighted Fitness)
    # ----------------------------------------
    USE_WEIGHTED_FITNESS = False
    WEIGHTED_FITNESS_EXPONENT = 0.25

    # ----------------------------------------
    # Parámetros de Características Avanzadas
    # ----------------------------------------
    STAGNATION_LIMIT = 50
    GLOBAL_STAGNATION_LIMIT = 100
    STAGNATION_RANDOM_INJECT_PERCENT = 0.1
    PARAM_MUTATE_INTERVAL = 50
    PATTERN_RECORD_FITNESS_THRESHOLD = 10.0
    PATTERN_MEM_MIN_USES = 3
    PATTERN_INJECT_INTERVAL = 10
    PATTERN_INJECT_PERCENT = 0.05
    PARETO_MAX_FRONT_SIZE = 50
    
    SIMPLIFY_NEAR_ZERO_TOLERANCE = 1e-9
    SIMPLIFY_NEAR_ONE_TOLERANCE = 1e-9
    LOCAL_SEARCH_ATTEMPTS = 30
    
    USE_SIMPLIFICATION = True
    USE_ISLAND_CATACLYSM = True
    USE_LEXICASE_SELECTION = True
    USE_PARETO_SELECTION = True  # NSGA-II multi-objective (error vs complexity)
    USE_WEIGHTED_FITNESS = False  # Enable to weight fitness cases (e.g., by difficulty)

    # ----------------------------------------
    # Otros Parámetros
    # ----------------------------------------
    PROGRESS_REPORT_INTERVAL = 100
    FORCE_INTEGER_CONSTANTS = False
    
    # Control de Duplicados
    PREVENT_DUPLICATES = True
    DUPLICATE_RETRIES = 10
    INF = float('inf')


In [None]:
%%writefile AlphaSymbolic/core/gpu/engine.py

import torch
import numpy as np
import time
from typing import List, Tuple, Optional
from core.grammar import OPERATORS, VARIABLES, CONSTANTS, ExpressionTree
from .formatting import format_const
from .sniper import Sniper
from .config import GpuGlobals
from .pareto import ParetoOptimizer
from .pattern_memory import PatternMemory

# SymPy for simplification
try:
    import sympy
    from sympy import symbols, sympify, simplify, nsimplify, Float
    from sympy.parsing.sympy_parser import parse_expr
    SYMPY_AVAILABLE = True
except ImportError:
    SYMPY_AVAILABLE = False

# --- GPU GRAMMAR ENCODING (RPN / Postfix) ---
PAD_ID = 0

class GPUGrammar:
    def __init__(self, num_variables=1):
        self.token_to_id = {'<PAD>': PAD_ID}
        self.id_to_token = {PAD_ID: '<PAD>'}
        self.next_id = 1
        
        # Terminals (Variables + Constants)
        self.active_variables = ['x0'] # Always support x0
        if num_variables > 1:
            self.active_variables = [f'x{i}' for i in range(num_variables)]
        elif num_variables == 1:
            self.active_variables = ['x', 'x0'] 

        self.terminals = self.active_variables + ['C', '1', '2', '3', '5'] # Removed pi, e to avoid collision
        for t in self.terminals:
            self.token_to_id[t] = self.next_id
            self.id_to_token[self.next_id] = t
            self.next_id += 1
            
        # Operators
        self.operators = []
        if GpuGlobals.USE_OP_PLUS:  self.operators.append('+')
        if GpuGlobals.USE_OP_MINUS: self.operators.append('-')
        if GpuGlobals.USE_OP_MULT:  self.operators.append('*')
        if GpuGlobals.USE_OP_DIV:   self.operators.append('/')
        if GpuGlobals.USE_OP_POW:   self.operators.append('pow')
        if GpuGlobals.USE_OP_MOD:   self.operators.append('%')
        if GpuGlobals.USE_OP_SIN:   self.operators.append('sin')
        if GpuGlobals.USE_OP_COS:   self.operators.append('cos')
        if GpuGlobals.USE_OP_LOG:   self.operators.append('log')
        if GpuGlobals.USE_OP_EXP:   self.operators.append('e')
        if GpuGlobals.USE_OP_FACT:  self.operators.append('!') # tgamma
        # if GpuGlobals.USE_OP_FLOOR: self.operators.append('_') # Not mapped in default?
        if GpuGlobals.USE_OP_GAMMA: self.operators.append('g')
        if GpuGlobals.USE_OP_ASIN:  self.operators.append('S')
        if GpuGlobals.USE_OP_ACOS:  self.operators.append('C')
        if GpuGlobals.USE_OP_ATAN:  self.operators.append('T')
        
        # Always active standard ops? Or add globals for them?
        # Assuming these are always available or tracked by globals?
        # Globals.h doesn't seem to have toggles for sqrt/abs/neg explicitly in the list I saw?
        # Wait, I saw USE_OP_SIN, etc.
        # I'll add them unconditionally for now or check globals?
        # Globals.h doesn't list sqrt/abs/neg toggles. So they are likely always on or implicit.
        self.operators.append('sqrt')
        self.operators.append('abs')
        self.operators.append('neg')
        self.operators.append('_') # Floor, adding it back since I saw it in C++ kernel logic!

        for op in self.operators:
            self.token_to_id[op] = self.next_id
            self.id_to_token[self.next_id] = op
            self.next_id += 1
            
        self.vocab_size = self.next_id
        
        self.op_ids = {op: self.token_to_id[op] for op in self.operators}
        self.token_arity = {}
        for op in self.operators:
            tid = self.token_to_id[op]
            self.token_arity[op] = OPERATORS[op] 
            
    def get_subtree_span(self, rpn_ids: List[int], root_idx: int) -> Tuple[int, int]:
        """
        Finds the span (start_idx, end_idx) of the subtree rooted at root_idx in RPN.
        Scanning backwards from root_idx.
        Returns indices inclusive [start, end].
        """
        if root_idx < 0 or root_idx >= len(rpn_ids): return (-1, -1)
        
        # Get Arity of root
        root_id = rpn_ids[root_idx]
        if root_id == PAD_ID: return (root_idx, root_idx)
        
        token = self.id_to_token.get(root_id, "")
        required_args = self.token_arity.get(token, 0)
        
        current_idx = root_idx - 1
        for _ in range(required_args):
            start, _ = self.get_subtree_span(rpn_ids, current_idx)
            if start == -1: return (-1, -1) # Error
            current_idx = start - 1
            
        return (current_idx + 1, root_idx)

class TensorGeneticEngine:
    def __init__(self, device=None, pop_size=None, max_len=30, num_variables=1, max_constants=5, n_islands=None):
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Defaults from Globals
        if pop_size is None: pop_size = GpuGlobals.POP_SIZE
        if n_islands is None: n_islands = GpuGlobals.NUM_ISLANDS
        
        self.grammar = GPUGrammar(num_variables)
        
        self.n_islands = n_islands
        if pop_size % n_islands != 0:
            pop_size = (pop_size // n_islands) * n_islands
            
        self.pop_size = pop_size
        self.island_size = pop_size // n_islands
        self.max_len = max_len
        self.num_variables = num_variables
        self.max_constants = max_constants
        
        # Pre-allocate memory for random generation
        self.terminal_ids = torch.tensor([self.grammar.token_to_id[t] for t in self.grammar.terminals], device=self.device)
        self.operator_ids = torch.tensor([self.grammar.token_to_id[op] for op in self.grammar.operators], device=self.device)
        
        # --- Pre-compute Arity Masks for Safe Mutation ---
        self.token_arity = torch.zeros(self.grammar.vocab_size + 1, dtype=torch.long, device=self.device)
        self.arity_0_ids = []
        self.arity_1_ids = []
        self.arity_2_ids = []
        
        # Terminals (0)
        # Note: self.grammar.terminals is fixed in grammar class currently.
        # Ideally Grammar should also take GpuGlobals into account.
        for t in self.grammar.terminals:
            tid = self.grammar.token_to_id[t]
            self.token_arity[tid] = 0
            self.arity_0_ids.append(tid)
            
        # Operators (1 or 2)
        for op in self.grammar.operators:
            tid = self.grammar.token_to_id[op]
            arity = OPERATORS[op]
            self.token_arity[tid] = arity
            if arity == 1: self.arity_1_ids.append(tid)
            elif arity == 2: self.arity_2_ids.append(tid)
            
        self.arity_0_ids = torch.tensor(self.arity_0_ids, device=self.device)
        self.arity_1_ids = torch.tensor(self.arity_1_ids, device=self.device)
        self.arity_2_ids = torch.tensor(self.arity_2_ids, device=self.device)
        
        # The Sniper
        self.sniper = Sniper(self.device)
        
        # Pareto Optimizer (NSGA-II)
        self.pareto = ParetoOptimizer(self.device, GpuGlobals.PARETO_MAX_FRONT_SIZE)
        
        # Pattern Memory
        self.pattern_memory = PatternMemory(
            self.device, 
            max_patterns=100,
            fitness_threshold=GpuGlobals.PATTERN_RECORD_FITNESS_THRESHOLD,
            min_uses=GpuGlobals.PATTERN_MEM_MIN_USES
        )


    def optimize_constants(self, population: torch.Tensor, constants: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor, steps=10, lr=0.1):
        """
        Refine constants using Gradient Descent (Adam).
        Returns: (best_constants, best_mse)
        """
        # Optimize a COPY of constants
        optimized_consts = constants.clone().detach().requires_grad_(True)
        optimizer = torch.optim.Adam([optimized_consts], lr=lr)
        
        # Track best found during steps (in case it diverges)
        best_mse = torch.full((population.shape[0],), float('inf'), device=self.device, dtype=torch.float64)
        best_consts = constants.clone().detach() 
        
        for _ in range(steps):
            optimizer.zero_grad()
            
            # Forward pass (differentiable if we implemented soft operations, 
            # but standard ops are differentiable in PyTorch!)
            # Evaluator returns RMSE, but we want MSE for gradients usually, or just minimize RMSE.
            # evaluate_batch returns RMSE [PopSize].
            # Problem: evaluate_batch uses scatter_ (in-place) which might break gradients if not careful.
            # However, for simple constant optimization, we might need a "soft" stack or ignore in-place issues if PyTorch handles them.
            # Let's try standard evaluate_batch. If scatter breaks, we might need a rewriting.
            # ACTUALLY: scatter_ IS differentiable for values, but not indices (indices are fixed by RPN).
            # So this SHOULD work.
            
            rmse = self.evaluate_batch(population, x, y_target, optimized_consts)
            
            # Loss = Sum of RMSEs (to optimize all in parallel)
            # We filter NaNs
            valid_mask = ~torch.isnan(rmse)
            if not valid_mask.any(): break
            
            # Update bests
            current_mse = rmse**2 # Approximation since we returned RMSE
            improved = (current_mse < best_mse) & valid_mask
            if improved.any():
                best_mse[improved] = current_mse[improved].detach()
                best_consts[improved] = optimized_consts[improved].detach()
            
            loss = rmse[valid_mask].sum()
            
            if not loss.requires_grad: 
                # This happens if formula has no 'C' or operations detach graph
                break
                
            loss.backward()
            optimizer.step()
            
        return best_consts, torch.sqrt(best_mse)

    def local_search(self, population: torch.Tensor, constants: torch.Tensor, 
                     x: torch.Tensor, y: torch.Tensor, 
                     top_k: int = 10, attempts: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Hill climbing: try single-token mutations on top individuals, keep improvements.
        
        Args:
            population: [PopSize, L] RPN tensors
            constants: [PopSize, MaxC] constants
            x: Input data
            y: Target data  
            top_k: Number of top individuals to apply local search
            attempts: Number of mutation attempts per individual (default: LOCAL_SEARCH_ATTEMPTS)
        
        Returns:
            (improved_population, improved_constants)
        """
        if attempts is None:
            attempts = GpuGlobals.LOCAL_SEARCH_ATTEMPTS
        
        pop_out = population.clone()
        const_out = constants.clone()
        
        # Get top K individuals by fitness
        fitness = self.evaluate_batch(population, x, y, constants)
        _, top_idx = torch.topk(fitness, top_k, largest=False)
        
        for idx in top_idx:
            idx = idx.item()
            current_rpn = population[idx:idx+1]
            current_const = constants[idx:idx+1]
            current_fit = fitness[idx].item()
            
            best_rpn = current_rpn.clone()
            best_const = current_const.clone()
            best_fit = current_fit
            
            # Try random single-token mutations
            for _ in range(attempts):
                # Mutate with high rate (1 token expected change)
                mutant = self.mutate_population(current_rpn, mutation_rate=0.15)
                
                # Evaluate mutant
                mutant_fit = self.evaluate_batch(mutant, x, y, current_const)[0].item()
                
                if mutant_fit < best_fit:
                    best_rpn = mutant.clone()
                    best_fit = mutant_fit
            
            # Update if improved
            if best_fit < current_fit:
                pop_out[idx] = best_rpn[0]
                # Also optimize constants for the improved individual
                opt_const, _ = self.optimize_constants(best_rpn, best_const, x, y, steps=5)
                const_out[idx] = opt_const[0]
        
        return pop_out, const_out

    def simplify_expression(self, rpn_tensor: torch.Tensor, constants: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
        """
        Simplify an RPN expression using SymPy.
        
        Args:
            rpn_tensor: [L] tensor of token IDs (single individual)
            constants: [MaxC] tensor of constant values
        
        Returns:
            (simplified_rpn, new_constants, success)
        """
        if not SYMPY_AVAILABLE or not GpuGlobals.USE_SIMPLIFICATION:
            return rpn_tensor, constants, False
        
        try:
            # 1. Convert RPN to infix string
            infix = self.rpn_to_infix(rpn_tensor, constants)
            if infix == "Invalid" or not infix:
                return rpn_tensor, constants, False
            
            # 2. Prepare SymPy symbols
            sym_vars = {f'x{i}': symbols(f'x{i}') for i in range(self.num_variables)}
            sym_vars['x'] = sym_vars.get('x0', symbols('x0'))  # Alias
            
            # 3. Parse to SymPy (handle operator conversions)
            expr_str = infix
            expr_str = expr_str.replace('^', '**')  # Power
            expr_str = expr_str.replace('lgamma', 'loggamma')
            
            # Try parsing
            try:
                expr = parse_expr(expr_str, local_dict=sym_vars)
            except:
                # Fallback to sympify
                expr = sympify(expr_str, locals=sym_vars)
            
            # 4. Simplify
            simplified = simplify(expr)
            
            # 5. Rationalize constants (e.g., 0.5 -> 1/2)
            simplified = nsimplify(simplified, tolerance=1e-6, rational=True)
            
            # 6. Convert back to infix string
            simplified_str = str(simplified)
            
            # 7. If simplification made it longer, abort
            if len(simplified_str) > len(infix) * 1.5:
                return rpn_tensor, constants, False
            
            # 8. Convert to our format (** -> ^, etc.)
            simplified_str = simplified_str.replace('**', ' ^ ')
            simplified_str = simplified_str.replace('loggamma', 'lgamma')
            
            # 9. Convert back to RPN
            new_rpn = self.infix_to_rpn([simplified_str])
            if new_rpn.shape[0] == 0 or (new_rpn[0] == PAD_ID).all():
                return rpn_tensor, constants, False
            
            # 10. Extract new constants from simplified expression
            # For now, we initialize with zeros (optimizer will refine)
            new_consts = torch.zeros(self.max_constants, device=self.device, dtype=torch.float64)
            
            # Count 'C' tokens in new RPN
            id_C = self.grammar.token_to_id.get('C', -1)
            n_consts = (new_rpn[0] == id_C).sum().item()
            
            # Try to extract numeric constants from simplified_str
            import re
            numbers = re.findall(r'[-+]?\d*\.?\d+', simplified_str)
            for i, num in enumerate(numbers[:min(n_consts, self.max_constants)]):
                try:
                    new_consts[i] = float(num)
                except:
                    pass
            
            return new_rpn[0], new_consts, True
            
        except Exception as e:
            # Simplification failed, return original
            return rpn_tensor, constants, False

    def simplify_population(self, population: torch.Tensor, constants: torch.Tensor, top_k: int = None) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Simplify top K individuals in the population.
        
        Args:
            population: [PopSize, L] RPN tensors
            constants: [PopSize, MaxC] constant tensors
            top_k: Number of individuals to simplify (default: 10% of population)
        
        Returns:
            (new_population, new_constants, n_simplified)
        """
        if not SYMPY_AVAILABLE or not GpuGlobals.USE_SIMPLIFICATION:
            return population, constants, 0
        
        if top_k is None:
            top_k = max(1, int(population.shape[0] * 0.1))
        
        n_simplified = 0
        pop_out = population.clone()
        const_out = constants.clone()
        
        for i in range(min(top_k, population.shape[0])):
            new_rpn, new_consts, success = self.simplify_expression(population[i], constants[i])
            if success:
                pop_out[i] = new_rpn
                const_out[i] = new_consts
                n_simplified += 1
        
        return pop_out, const_out, n_simplified

    def migrate_islands(self, population: torch.Tensor, constants: torch.Tensor, fitness: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Perform ring migration between islands.
        
        Top MIGRATION_SIZE individuals from each island migrate to the next island (ring topology),
        replacing the worst individuals in the destination.
        
        Args:
            population: [PopSize, L] RPN tensors
            constants: [PopSize, MaxC] constant tensors
            fitness: [PopSize] fitness scores (lower is better)
        
        Returns:
            (new_population, new_constants)
        """
        if self.n_islands <= 1:
            return population, constants
        
        pop_out = population.clone()
        const_out = constants.clone()
        
        island_size = self.island_size
        mig_size = min(GpuGlobals.MIGRATION_SIZE, island_size // 2)  # Don't migrate more than half
        
        for island in range(self.n_islands):
            # Source island
            src_start = island * island_size
            src_end = src_start + island_size
            
            # Destination island (ring: island+1 mod n_islands)
            dst_island = (island + 1) % self.n_islands
            dst_start = dst_island * island_size
            dst_end = dst_start + island_size
            
            # Get fitness for source island
            src_fitness = fitness[src_start:src_end]
            
            # Get indices of best individuals in source (lowest fitness)
            _, best_idx_local = torch.topk(src_fitness, mig_size, largest=False)
            best_idx_global = best_idx_local + src_start
            
            # Get fitness for destination island
            dst_fitness = fitness[dst_start:dst_end]
            
            # Get indices of worst individuals in destination (highest fitness)
            _, worst_idx_local = torch.topk(dst_fitness, mig_size, largest=True)
            worst_idx_global = worst_idx_local + dst_start
            
            # Migrate: replace worst in destination with best from source
            pop_out[worst_idx_global] = population[best_idx_global]
            const_out[worst_idx_global] = constants[best_idx_global]
        
        return pop_out, const_out

    def deduplicate_population(self, population: torch.Tensor, constants: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Remove duplicate individuals from the population.
        
        Duplicates are identified by hashing their RPN token sequence.
        Duplicates are replaced with **FRESH RANDOM INDIVIDUALS** (not mutated clones).
        
        Args:
            population: [PopSize, L] RPN tensors
            constants: [PopSize, MaxC] constant tensors
        
        Returns:
            (new_population, new_constants, n_replaced)
        """
        if not GpuGlobals.PREVENT_DUPLICATES:
            return population, constants, 0
        
        pop_size = population.shape[0]
        pop_cpu = population.cpu().numpy()
        
        # Hash each individual
        seen_hashes = {}
        duplicate_indices = []
        
        for i in range(pop_size):
            # Create hash from non-padding tokens
            tokens = pop_cpu[i]
            non_pad = tokens[tokens != PAD_ID]
            hash_key = tuple(non_pad.tolist())
            
            if hash_key in seen_hashes:
                duplicate_indices.append(i)
            else:
                seen_hashes[hash_key] = i
        
        n_dups = len(duplicate_indices)
        if n_dups == 0:
            return population, constants, 0
        
        # Replace duplicates with fresh random individuals
        pop_out = population.clone()
        const_out = constants.clone()
        
        # Generate N fresh random trees
        fresh_pop = self._generate_random_population(n_dups)
        fresh_consts = torch.randn(n_dups, constants.shape[1], device=self.device, dtype=torch.float64)
        
        # Assign to duplicate slots
        # duplicate_indices is a list, convert to tensor?
        # Actually indexing with list works in PyTorch if converted to tensor or list.
        # But `fresh_pop` is [n_dups, L].
        pop_out[duplicate_indices] = fresh_pop
        const_out[duplicate_indices] = fresh_consts
        
        return pop_out, const_out, n_dups

    def tarpeian_control(self, population: torch.Tensor, fitness: torch.Tensor) -> torch.Tensor:
        """
        Tarpeian bloat control: randomly penalize oversized individuals.
        
        Individuals longer than 1.5x average length have 50% chance of 
        receiving very bad fitness, pushing them out of selection.
        
        Args:
            population: [PopSize, L] RPN tensors
            fitness: [PopSize] current fitness values
            
        Returns:
            Modified fitness tensor
        """
        lengths = (population != PAD_ID).sum(dim=1).float()
        avg_len = lengths.mean()
        
        # Find oversized individuals (> 1.5x average)
        oversized = lengths > avg_len * 1.5
        
        # Randomly penalize 50% of oversized
        random_mask = torch.rand(population.shape[0], device=self.device) < 0.5
        penalize_mask = oversized & random_mask
        
        # Apply penalty
        fitness_out = fitness.clone()
        fitness_out[penalize_mask] = 1e30  # Large but within float32 range
        
        return fitness_out

    def shrink_mutation(self, individual: torch.Tensor) -> torch.Tensor:
        """
        Apply shrinking mutation - removes a subtree and replaces with terminal.
        """
        ind_cpu = individual.cpu().numpy()
        non_pad = ind_cpu[ind_cpu != PAD_ID]
        
        if len(non_pad) < 3:
            return individual
        
        # Pick a random operator position
        operator_positions = []
        for i, token_id in enumerate(non_pad):
            token = self.grammar.id_to_token.get(token_id, "")
            if token in self.grammar.operators:
                operator_positions.append(i)
        
        if not operator_positions:
            return individual
        
        target_pos = np.random.choice(operator_positions)
        span = self.grammar.get_subtree_span(non_pad.tolist(), target_pos)
        
        if span[0] == -1:
            return individual
        
        # Replace subtree with random terminal
        terminal_id = self.terminal_ids[torch.randint(len(self.terminal_ids), (1,))].item()
        new_tokens = list(non_pad[:span[0]]) + [terminal_id] + list(non_pad[target_pos+1:])
        new_tokens = new_tokens[:self.max_len]
        new_tokens = new_tokens + [PAD_ID] * (self.max_len - len(new_tokens))
        
        return torch.tensor(new_tokens, device=self.device, dtype=individual.dtype)


    def mutate_population(self, population: torch.Tensor, mutation_rate: float) -> torch.Tensor:
        """
        Performs arity-safe mutation on the population.
        """
        # Create mutation mask
        mask = torch.rand_like(population, dtype=torch.float32) < mutation_rate
        # Don't mutate padding
        mask = mask & (population != PAD_ID)
        
        # We need to know arity of current tokens to replace them with same arity
        # self.token_arity has shape [VocabSize+1]
        # Gather arity for each token in population
        current_arities = self.token_arity[population]
        
        # New Reference:
        # Arity 0 -> Sample from arity_0_ids
        # Arity 1 -> Sample from arity_1_ids
        # Arity 2 -> Sample from arity_2_ids
        
        # We can prepare 3 tensors of random replacements, one for each arity type, same shape as pop
        # Ideally only generate for needed spots, but fully generating is easier for vectorized code.
        
        # Random replacements for Arity 0
        if len(self.arity_0_ids) > 0:
            rand_idx_0 = torch.randint(0, len(self.arity_0_ids), population.shape, device=self.device)
            replacements_0 = self.arity_0_ids[rand_idx_0]
        else:
            replacements_0 = population
            
        # Random replacements for Arity 1
        if len(self.arity_1_ids) > 0:
             rand_idx_1 = torch.randint(0, len(self.arity_1_ids), population.shape, device=self.device)
             replacements_1 = self.arity_1_ids[rand_idx_1]
        else:
             replacements_1 = population

        # Random replacements for Arity 2
        if len(self.arity_2_ids) > 0:
             rand_idx_2 = torch.randint(0, len(self.arity_2_ids), population.shape, device=self.device)
             replacements_2 = self.arity_2_ids[rand_idx_2]
        else:
             replacements_2 = population
             
        # Apply Logic
        mutated_pop = population.clone()
        
        # Mask for Arity 0 mutations
        mask_0 = mask & (current_arities == 0)
        mutated_pop = torch.where(mask_0, replacements_0, mutated_pop)
        
        # Mask for Arity 1 mutations
        mask_1 = mask & (current_arities == 1)
        mutated_pop = torch.where(mask_1, replacements_1, mutated_pop)
        
        # Mask for Arity 2 mutations
        mask_2 = mask & (current_arities == 2)
        mutated_pop = torch.where(mask_2, replacements_2, mutated_pop)
        
        return mutated_pop

    def _get_subtree_ranges(self, population: torch.Tensor) -> torch.Tensor:
        """
        Calculates the start index of the subtree ending at each position.
        Returns tensor [B, L] where value is start_index, or -1 if invalid/padding.
        Optimized for RPN logic on GPU.
        """
        B, L = population.shape
        subtree_starts = torch.full((B, L), -1, device=self.device, dtype=torch.long)
        
        # 1. Map tokens to Arity Change
        # Variables/Consts: +1
        # Binary: -1 (Pop 2 Push 1 -> Net -1)
        # Unary: 0 (Pop 1 Push 1 -> Net 0)
        
        # We need a fast lookup.
        # Construct Arity Table
        # Default 1 (Operand)
        arities = torch.ones_like(population, dtype=torch.long) 
        
        # Binary (-1)
        op_add = self.grammar.token_to_id.get('+', -100); op_sub = self.grammar.token_to_id.get('-', -100)
        op_mul = self.grammar.token_to_id.get('*', -100); op_div = self.grammar.token_to_id.get('/', -100)
        op_pow = self.grammar.token_to_id.get('pow', -100); op_mod = self.grammar.token_to_id.get('%', -100)
        
        mask_bin = (population == op_add) | (population == op_sub) | (population == op_mul) | \
                   (population == op_div) | (population == op_pow) | (population == op_mod)
        arities[mask_bin] = -1
        
        # Unary (0)
        op_sin = self.grammar.token_to_id.get('sin', -100); op_cos = self.grammar.token_to_id.get('cos', -100)
        # ... (add all unary)
        # Simplified: If it's not binary and it's an operator, it's unary?
        # Better: Assume all Ops are <= some ID? No.
        # Explicit list is safer.
        unary_tokens = ['sin','cos','tan','S','C','T','e','log','sqrt','abs','neg','!','_','g']
        unary_ids = [self.grammar.token_to_id.get(t, -999) for t in unary_tokens]
        # Make tensor of unary ids
        # Ideally this table is precomputed. For now, on the fly is ok.
        
        mask_unary = torch.zeros_like(population, dtype=torch.bool)
        for uid in unary_ids:
             mask_unary = mask_unary | (population == uid)
        arities[mask_unary] = 0
        
        # Padding: -999 (Invalid)
        arities[population == PAD_ID] = -999
        
        # 2. To find start of subtree ending at 'end', we scan backwards until cumulative sum is +1.
        # Since scanning backwards is hard in vector, we can scan loops?
        # Max depth isn't huge.
        
        # Alternatively: "Stack Depth at step i".
        # Subtree at 'end' corresponds to the interval [start, end] where
        # Depth(start-1) = D
        # Depth(end) = D + 1
        # And min_depth(start...end) >= D
        
        # Let's compute cumulative sum (Stack Depth Profile)
        # cum_arity[i] is depth AFTER processing token i.
        # arities for this: PAD=0? No, PAD breaks it.
        # Let's mask PAD for cumsum.
        
        safe_arities = arities.clone()
        safe_arities[population == PAD_ID] = 0
        depths = torch.cumsum(safe_arities, dim=1)
        
        # The scan logic is still tricky O(L^2) across batch.
        # For L=30, iterating i from 0 to L is fast.
        
        for i in range(L):
            # If position i is PAD, skip
            is_pad = (population[:, i] == PAD_ID)
            
            # We want to find 'start' such that sum(arities[start...i]) == 1
            # Which means depths[i] - depths[start-1] == 1
            # implies depths[start-1] = depths[i] - 1.
            # And for all k in start...i, depths[k] >= depths[start-1] (validity).
            
            target_depth = depths[:, i] - 1
            
            # Search backwards from i
            # We can vectorize this search over B by iterating j downwards
            current_start = torch.full((B,), -1, device=self.device, dtype=torch.long)
            found = torch.zeros(B, dtype=torch.bool, device=self.device)
            
            # Optimization: Pre-calculate validity masks?
            # Brute force backwards for L=30 is fine.
            for j in range(i, -1, -1):
                # Check condition for batch
                # d[j-1] == target?
                # Actually, depth[j-1] is depth BEFORE processing j.
                # If j=0, depth[-1] = 0.
                
                prev_depth = depths[:, j-1] if j > 0 else torch.zeros(B, device=self.device)
                
                # Match condition: prev_depth == target_depth
                match = (prev_depth == target_depth)
                
                # Check if we violated lower bound in between?
                # Implicitly, if we hit match first time going backwards, it's the minimal subtree.
                # We update 'found' mask.
                
                new_found = match & (~found)
                current_start[new_found] = j
                found = found | new_found
            
            # Store valid starts for this end position 'i'
            # Only valid if not PAD and found
            valid_i = (~is_pad) & found
            subtree_starts[valid_i, i] = current_start[valid_i]
            
        return subtree_starts

    def crossover_population(self, parents: torch.Tensor, crossover_rate: float) -> torch.Tensor:
        """
        Performs subtree crossover on the population (Two-Child Crossover).
        parents: [PopSize, L] (assumed valid RPNs)
        Uses GPU to find subtrees, CPU to splice (faster than irregular GPU scatter).
        """
        pop_size, length = parents.shape
        # We process pairs. Shuffle.
        indices = torch.randperm(pop_size, device=self.device)
        
        # Number of crossover operations
        n_pairs = int(pop_size * 0.5 * crossover_rate)
        if n_pairs == 0:
            return parents.clone()
        
        # Get ranges for ALL (Parallel GPU Scan)
        subtree_starts = self._get_subtree_ranges(parents)
        
        # Move relevant data to CPU for splicing
        parents_cpu = parents.detach().cpu().numpy()
        starts_cpu = subtree_starts.detach().cpu().numpy()
        indices_cpu = indices.detach().cpu().numpy()
        
        offspring_cpu = parents_cpu.copy()
        
        # Create batches of crossover
        MAX_LEN = self.max_len
        PAD = PAD_ID
        
        for k in range(n_pairs):
            idx_a = indices_cpu[2*k]
            idx_b = indices_cpu[2*k+1]
            
            pA = parents_cpu[idx_a]
            pB = parents_cpu[idx_b]
            starts_A = starts_cpu[idx_a]
            starts_B = starts_cpu[idx_b]
            
            # Valid root points are where starts_X != -1
            cand_A = np.where(starts_A != -1)[0]
            cand_B = np.where(starts_B != -1)[0]
            
            if len(cand_A) == 0 or len(cand_B) == 0: continue
            
            # Pick random crossover points
            end_A = np.random.choice(cand_A)
            end_B = np.random.choice(cand_B)
            
            start_A = starts_A[end_A]
            start_B = starts_B[end_B]
            
            # --- Child 1: A takes B ---
            # New A = A[:startA] + B[startB:endB+1] + A[endA+1:]
            
            part1 = pA[:start_A]
            part2 = pB[start_B : end_B+1]
            part3 = pA[end_A+1:]
            
            new_gene_a = np.concatenate([part1, part2, part3])
            
            # Validate Length A
            valid_len_a = len(new_gene_a)
            if valid_len_a > MAX_LEN:
                # Check real length (last non-pad index)
                non_pad = np.where(new_gene_a != PAD)[0]
                if len(non_pad) == 0: real_len = 0
                else: real_len = non_pad[-1] + 1
                
                if real_len <= MAX_LEN:
                    # Fits if truncated
                    truncated = np.full(MAX_LEN, PAD, dtype=pA.dtype)
                    truncated[:real_len] = new_gene_a[:real_len]
                    offspring_cpu[idx_a] = truncated
            else:
                # Fits, need to pad
                padded = np.full(MAX_LEN, PAD, dtype=pA.dtype)
                padded[:valid_len_a] = new_gene_a
                offspring_cpu[idx_a] = padded

            # --- Child 2: B takes A ---
            # New B = B[:startB] + A[startA:endA+1] + B[endB+1:]
            
            part1_b = pB[:start_B]
            part2_b = pA[start_A : end_A+1]
            part3_b = pB[end_B+1:]
            
            new_gene_b = np.concatenate([part1_b, part2_b, part3_b])
            
            # Validate Length B
            valid_len_b = len(new_gene_b)
            if valid_len_b > MAX_LEN:
                # Check real length
                non_pad = np.where(new_gene_b != PAD)[0]
                if len(non_pad) == 0: real_len = 0
                else: real_len = non_pad[-1] + 1
                
                if real_len <= MAX_LEN:
                    truncated = np.full(MAX_LEN, PAD, dtype=pB.dtype)
                    truncated[:real_len] = new_gene_b[:real_len]
                    offspring_cpu[idx_b] = truncated
            else:
                padded = np.full(MAX_LEN, PAD, dtype=pB.dtype)
                padded[:valid_len_b] = new_gene_b
                offspring_cpu[idx_b] = padded
                
        return torch.tensor(offspring_cpu, device=self.device, dtype=torch.long)




    def infix_to_rpn(self, formulas: List[str]) -> torch.Tensor:
        """
        Converts a list of infix strings to a padded RPN tensor [B, L].
        """
        batch_rpn = []
        for f in formulas:
            try:
                tree = ExpressionTree.from_infix(f)
                if not tree.is_valid:
                    batch_rpn.append([PAD_ID]*self.max_len)
                    continue
                
                rpn_tokens = []
                def traverse(node):
                    if not node: return
                    for child in node.children:
                        traverse(child)
                    rpn_tokens.append(node.value)
                
                traverse(tree.root)
                ids = [self.grammar.token_to_id.get(t, PAD_ID) for t in rpn_tokens]
                if len(ids) > self.max_len:
                    ids = ids[:self.max_len]
                else:
                    ids = ids + [PAD_ID] * (self.max_len - len(ids))
                batch_rpn.append(ids)
            except:
                batch_rpn.append([PAD_ID]*self.max_len)
                
        if not batch_rpn:
             return torch.empty((0, self.max_len), device=self.device, dtype=torch.long)
        return torch.tensor(batch_rpn, device=self.device, dtype=torch.long)


    def rpn_to_infix(self, rpn_tensor: torch.Tensor, constants: torch.Tensor = None) -> str:
        """
        Decodes RPN tensor to Infix string (CPU-style formatting).
        """
        if rpn_tensor.ndim > 1:
            rpn_tensor = rpn_tensor.view(-1)
            
        vocab = self.grammar.id_to_token
        stack = []
        const_idx = 0
        
        for token_id in rpn_tensor:

            token_id = token_id.item()
            if token_id == PAD_ID: continue
            
            token = vocab.get(token_id, "")
            
            if token in self.grammar.operators:
                arity = self.grammar.token_arity.get(token, 2)
                if arity == 1:
                    if not stack: return "Invalid"
                    a = stack.pop()
                    if token == 's': stack.append(f"sin({a})")
                    elif token == 'c': stack.append(f"cos({a})")
                    elif token == 'l': stack.append(f"log({a})")
                    elif token == 'e' or token == 'exp': stack.append(f"exp({a})")
                    elif token == 'q' or token == 'sqrt': stack.append(f"sqrt({a})")
                    elif token == 'a' or token == 'abs': stack.append(f"abs({a})")
                    elif token == 'n' or token == 'sign': stack.append(f"sign({a})")
                    elif token == 'neg': stack.append(f"neg({a})")
                    elif token == '_' or token == 'floor': stack.append(f"floor({a})")
                    elif token == '!' or token == 'gamma': stack.append(f"gamma({a})")
                    elif token == 'g' or token == 'lgamma': stack.append(f"lgamma({a})")
                    elif token == 'S' or token == 'asin': stack.append(f"asin({a})")
                    elif token == 'C' or token == 'acos': stack.append(f"acos({a})")
                    elif token == 'T' or token == 'atan': stack.append(f"atan({a})")
                    else: stack.append(f"{token}({a})")
                else: # Binary
                    if len(stack) < 2: return "Invalid"
                    b = stack.pop()
                    a = stack.pop()
                    
                    if token == '+' and b.startswith("-") and not b.startswith("(-"):
                         stack.append(f"({a} - {b[1:]})")
                    elif token == '-' and a == "0":
                         stack.append(f"(-{b})")
                    elif token == 'pow':
                         stack.append(f"({a} ^ {b})")
                    elif token == 'mod':
                         stack.append(f"({a} % {b})")
                    else:
                         stack.append(f"({a} {token} {b})")
            elif token == 'C':
                val = 1.0
                if constants is not None and const_idx < len(constants):
                    val = constants[const_idx].item()
                    const_idx += 1
                stack.append(format_const(val))
            elif token.startswith('x'):
                if token == 'x': stack.append("x0")
                else: stack.append(token)
            else:
                stack.append(str(token))
                
        if len(stack) == 1:
            return stack[0]
        return "Invalid"
    
    def get_tree_size(self, rpn_tensor: torch.Tensor) -> int:
        """
        Returns number of non-pad nodes.
        """
        return (rpn_tensor != PAD_ID).sum().item()
    


    def _run_vm(self, population: torch.Tensor, x: torch.Tensor, constants: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Internal VM interpreter to evaluate RPN population on the GPU.
        Returns: (final_predictions, stack_pointer)
        """
        B, L = population.shape
        D = x.shape[0]
        MAX_STACK = 10
        eff_B = B * D
        
        pop_expanded = population.unsqueeze(1).expand(-1, D, -1).reshape(eff_B, L)
        const_expanded = None
        if constants is not None:
             const_expanded = constants.unsqueeze(1).expand(-1, D, -1).reshape(eff_B, -1)
             
        if x.ndim == 1:
            x_expanded = x.unsqueeze(0).expand(B, -1).reshape(eff_B, 1)
        else:
            x_expanded = x.unsqueeze(0).expand(B, -1, -1).reshape(eff_B, x.shape[1])
            
        stack = torch.zeros(eff_B, MAX_STACK, device=self.device, dtype=torch.float64)
        sp = torch.zeros(eff_B, device=self.device, dtype=torch.long)
        const_counters = torch.zeros(eff_B, device=self.device, dtype=torch.long)
        
        # NEW: Error tracking
        has_error = torch.zeros(eff_B, dtype=torch.bool, device=self.device)
        
        pi_val = torch.tensor(np.pi, device=self.device, dtype=torch.float64)
        e_val = torch.tensor(np.e, device=self.device, dtype=torch.float64)

        # IDs
        id_C = self.grammar.token_to_id.get('C', -100)
        id_pi = self.grammar.token_to_id.get('pi', -100)
        id_e = self.grammar.token_to_id.get('e', -100)
        
        op_add = self.grammar.token_to_id.get('+', -100); op_sub = self.grammar.token_to_id.get('-', -100)
        op_mul = self.grammar.token_to_id.get('*', -100); op_div = self.grammar.token_to_id.get('/', -100)
        op_pow = self.grammar.token_to_id.get('pow', -100); op_mod = self.grammar.token_to_id.get('%', -100)
        op_sin = self.grammar.token_to_id.get('sin', -100); op_cos = self.grammar.token_to_id.get('cos', -100)
        op_tan = self.grammar.token_to_id.get('tan', -100)
        op_asin = self.grammar.token_to_id.get('S', -100); op_acos = self.grammar.token_to_id.get('C', -100); op_atan = self.grammar.token_to_id.get('T', -100)
        op_exp = self.grammar.token_to_id.get('e', -100); op_log = self.grammar.token_to_id.get('log', -100)
        op_sqrt = self.grammar.token_to_id.get('sqrt', -100); op_abs = self.grammar.token_to_id.get('abs', -100); op_neg = self.grammar.token_to_id.get('neg', -100)
        op_fact = self.grammar.token_to_id.get('!', -100); op_floor = self.grammar.token_to_id.get('_', -100); op_gamma = self.grammar.token_to_id.get('g', -100)
        
        var_ids = [self.grammar.token_to_id.get(v, -100) for v in self.grammar.active_variables]
        id_x_legacy = self.grammar.token_to_id.get('x', -100)

        for i in range(L):
            token = pop_expanded[:, i]
            active_mask = (token != PAD_ID)
            if not active_mask.any(): continue
            
            push_vals = torch.zeros(eff_B, device=self.device, dtype=torch.float64)
            is_operand = torch.zeros(eff_B, dtype=torch.bool, device=self.device)
            
            # Variables
            mask = (token == id_x_legacy)
            if mask.any():
                push_vals[mask] = x_expanded[mask, 0]
                is_operand = is_operand | mask
                
            for v_idx, vid in enumerate(var_ids):
                mask = (token == vid)
                if mask.any():
                    v_col = v_idx if v_idx < x_expanded.shape[1] else 0
                    push_vals[mask] = x_expanded[mask, v_col]
                    is_operand = is_operand | mask
            
            mask = (token == id_pi)
            if mask.any(): push_vals[mask] = pi_val; is_operand = is_operand | mask
            mask = (token == id_e)
            if mask.any(): push_vals[mask] = e_val; is_operand = is_operand | mask
                
            mask = (token == id_C)
            if mask.any():
                if const_expanded is not None:
                     safe_idx = torch.clamp(const_counters, 0, const_expanded.shape[1]-1)
                     c_vals = const_expanded.gather(1, safe_idx.unsqueeze(1)).squeeze(1)
                     push_vals[mask] = c_vals[mask]
                     const_counters[mask] += 1
                else:
                     push_vals[mask] = 1.0 
                is_operand = is_operand | mask
            
            for val_str in ['1', '2', '3', '5']:
                vid = self.grammar.token_to_id.get(val_str, -999)
                mask = (token == vid)
                if mask.any():
                    push_vals[mask] = float(val_str)
                    is_operand = is_operand | mask
                    
            if is_operand.any():
                safe_sp = torch.clamp(sp, 0, MAX_STACK-1)
                stack = stack.scatter(1, safe_sp.unsqueeze(1), push_vals.unsqueeze(1))
                sp = sp + is_operand.long()
                
            # Binary
            is_binary = (token == op_add) | (token == op_sub) | (token == op_mul) | (token == op_div) | (token == op_pow) | (token == op_mod)
            
            enough_stack = (sp >= 2)
            valid_op = is_binary & enough_stack
            
            has_error = has_error | (is_binary & ~enough_stack)
            
            if valid_op.any():
                idx_b = torch.clamp(sp - 1, 0, MAX_STACK - 1).unsqueeze(1); val_b = stack.gather(1, idx_b).squeeze(1)
                idx_a = torch.clamp(sp - 2, 0, MAX_STACK - 1).unsqueeze(1); val_a = stack.gather(1, idx_a).squeeze(1)
                res = torch.zeros_like(val_a)
                
                m = (token == op_add) & valid_op; res[m] = val_a[m] + val_b[m]
                m = (token == op_sub) & valid_op; res[m] = val_a[m] - val_b[m]
                m = (token == op_mul) & valid_op; res[m] = val_a[m] * val_b[m]
                m = (token == op_div) & valid_op
                if m.any(): 
                    d = val_b[m]; bad = d.abs() < 1e-9; sd = torch.where(bad, torch.tensor(1.0, device=self.device, dtype=torch.float64), d)
                    out = val_a[m] / sd; out[bad] = 1e150; res[m] = out
                m = (token == op_mod) & valid_op
                if m.any():
                    d = val_b[m]; bad = d.abs() < 1e-9; sd = torch.where(bad, torch.tensor(1.0, device=self.device, dtype=torch.float64), d)
                    out = torch.fmod(val_a[m], sd); out[bad] = 1e150; res[m] = out
                m = (token == op_pow) & valid_op; 
                if m.any(): res[m] = torch.pow(val_a[m], val_b[m])
                
                wp = torch.clamp(sp - 2, 0, MAX_STACK-1)
                curr = stack.gather(1, wp.unsqueeze(1)).squeeze(1)
                fw = torch.where(valid_op, res, curr)
                stack = stack.scatter(1, wp.unsqueeze(1), fw.unsqueeze(1)); sp = sp - valid_op.long()
                
            # Unary
            is_unary = (token == op_sin) | (token == op_cos) | (token == op_tan) | \
                       (token == op_asin) | (token == op_acos) | (token == op_atan) | \
                       (token == op_exp) | (token == op_log) | \
                       (token == op_sqrt) | (token == op_abs) | (token == op_neg) | \
                       (token == op_fact) | (token == op_floor) | (token == op_gamma)
            
            enough_stack = (sp >= 1)
            valid_op = is_unary & enough_stack
            
            has_error = has_error | (is_unary & ~enough_stack)
            
            if valid_op.any():
                idx_a = torch.clamp(sp - 1, 0, MAX_STACK - 1).unsqueeze(1); val_a = stack.gather(1, idx_a).squeeze(1); res = torch.zeros_like(val_a)
                m = (token == op_sin) & valid_op; res[m] = torch.sin(val_a[m])
                m = (token == op_cos) & valid_op; res[m] = torch.cos(val_a[m])
                m = (token == op_tan) & valid_op; res[m] = torch.tan(val_a[m])
                m = (token == op_log) & valid_op
                if m.any(): 
                    inv = val_a[m]; s = inv > 1e-9; out = torch.full_like(inv, 1e150); out[s] = torch.log(inv[s]); res[m] = out
                m = (token == op_exp) & valid_op
                if m.any(): 
                    inv = val_a[m]; s = inv <= 700.0; out = torch.full_like(inv, 1e150); out[s] = torch.exp(inv[s]); res[m] = out
                m = (token == op_sqrt) & valid_op; res[m] = torch.sqrt(val_a[m].abs())
                m = (token == op_abs) & valid_op; res[m] = torch.abs(val_a[m])
                m = (token == op_neg) & valid_op; res[m] = -val_a[m]
                m = (token == op_asin) & valid_op; res[m] = torch.asin(torch.clamp(val_a[m], -1.0, 1.0))
                m = (token == op_acos) & valid_op; res[m] = torch.acos(torch.clamp(val_a[m], -1.0, 1.0))
                m = (token == op_atan) & valid_op; res[m] = torch.atan(val_a[m])
                m = (token == op_floor) & valid_op; res[m] = torch.floor(val_a[m])
                m = (token == op_fact) & valid_op
                if m.any():
                    inv = val_a[m]; u = (inv < 0) | (inv > 170.0); out = torch.full_like(inv, 1e150)
                    si = inv.clone(); si[u] = 1.0; vc = torch.special.gamma(si + 1.0); out[~u] = vc[~u]; res[m] = out
                m = (token == op_gamma) & valid_op
                if m.any():
                    inv = val_a[m]; u = (inv <= -1.0); out = torch.full_like(inv, 1e150)
                    si = inv.clone(); si[u] = 1.0; vc = torch.special.gammaln(si + 1.0); out[~u] = vc[~u]; res[m] = out

                wp = torch.clamp(sp - 1, 0, MAX_STACK-1); curr = stack.gather(1, wp.unsqueeze(1)).squeeze(1)
                fw = torch.where(valid_op, res, curr); stack = stack.scatter(1, wp.unsqueeze(1), fw.unsqueeze(1))
        return stack[:, 0], sp, has_error

    def evaluate_batch(self, population: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor, constants: torch.Tensor = None) -> torch.Tensor:
        """
        Evaluates the RPN population on the GPU.
        Returns: RMSE per individual [PopSize]
        """
        B, L = population.shape
        D = x.shape[0]
        
        final_preds, sp, has_error = self._run_vm(population, x, constants)
        
        is_valid = (sp == 1) & (~has_error)
        # Use parity with C++: if not valid or nan/inf, penalty 1e300
        final_preds = torch.where(is_valid & ~torch.isnan(final_preds) & ~torch.isinf(final_preds), 
                                  final_preds, 
                                  torch.tensor(1e300, device=self.device, dtype=torch.float64))
                                  
        preds_matrix = final_preds.view(B, D)
        target_matrix = y_target.unsqueeze(0).expand(B, -1)
        mse = torch.mean((preds_matrix - target_matrix)**2, dim=1)
        
        # Guard against MSE itself being Inf/NaN after mean
        rmse = torch.sqrt(torch.where(torch.isnan(mse) | torch.isinf(mse), 
                                      torch.tensor(1e150, device=self.device, dtype=torch.float64), 
                                      mse))
        return rmse

        
        # Precompute IDs
        id_C = self.grammar.token_to_id.get('C', -100)
        id_pi = self.grammar.token_to_id.get('pi', -100)
        id_e = self.grammar.token_to_id.get('e', -100)
        
        # Op IDs
        op_add = self.grammar.token_to_id.get('+', -100)
        op_sub = self.grammar.token_to_id.get('-', -100)
        op_mul = self.grammar.token_to_id.get('*', -100)
        op_div = self.grammar.token_to_id.get('/', -100)
        op_pow = self.grammar.token_to_id.get('pow', -100)
        op_mod = self.grammar.token_to_id.get('%', -100)
        
        op_sin = self.grammar.token_to_id.get('sin', -100)
        op_cos = self.grammar.token_to_id.get('cos', -100)
        op_tan = self.grammar.token_to_id.get('tan', -100)
        op_asin = self.grammar.token_to_id.get('S', -100)
        op_acos = self.grammar.token_to_id.get('C', -100)
        op_atan = self.grammar.token_to_id.get('T', -100)
        op_exp = self.grammar.token_to_id.get('e', -100) # 'e' is the operator token
        op_log = self.grammar.token_to_id.get('log', -100)
        op_sqrt = self.grammar.token_to_id.get('sqrt', -100)
        op_abs = self.grammar.token_to_id.get('abs', -100)
        op_neg = self.grammar.token_to_id.get('neg', -100)
        
        op_fact = self.grammar.token_to_id.get('!', -100)
        op_floor = self.grammar.token_to_id.get('_', -100)
        op_gamma = self.grammar.token_to_id.get('g', -100)
        
        # Cache Variable IDs
        # We know self.grammar.active_variables list
        var_ids = [self.grammar.token_to_id.get(v, -100) for v in self.grammar.active_variables]
        # x0 -> index 0, x1 -> index 1...
        # Also 'x' usually maps to x0
        id_x_legacy = self.grammar.token_to_id.get('x', -100)

        for i in range(L):
            token = pop_expanded[:, i]
            active_mask = (token != PAD_ID)
            if not active_mask.any(): continue
            
            # --- 1. Push ---
            push_vals = torch.zeros(eff_B, device=self.device, dtype=torch.float64)
            is_operand = torch.zeros(eff_B, dtype=torch.bool, device=self.device)
            
            # Variables
            # Check legacy 'x'
            mask = (token == id_x_legacy)
            if mask.any():
                push_vals[mask] = x_expanded[mask, 0]
                is_operand = is_operand | mask
                
            # Check x0, x1, x2...
            for v_idx, vid in enumerate(var_ids):
                mask = (token == vid)
                if mask.any():
                    # If inputs have enough columns, use them. If not, fallback to 0 or error?
                    # We assume x_expanded shape matches grammar requirements.
                    if v_idx < x_expanded.shape[1]:
                        push_vals[mask] = x_expanded[mask, v_idx]
                        is_operand = is_operand | mask
            
            mask = (token == id_pi)
            if mask.any():
                push_vals[mask] = pi_val
                is_operand = is_operand | mask
            

            mask = (token == id_e)
            if mask.any():
                push_vals[mask] = e_val
                is_operand = is_operand | mask
                
            mask = (token == id_C)
            if mask.any():
                if const_expanded is not None:
                     # Gather constants based on current counter
                     # const_expanded: [eff_B, MaxC]
                     # const_counters: [eff_B]
                     # We need to clamp counter to MaxC-1 to avoid error, though valid RPN shouldn't exceed
                     safe_idx = torch.clamp(const_counters, 0, const_expanded.shape[1]-1)
                     
                     c_vals = const_expanded.gather(1, safe_idx.unsqueeze(1)).squeeze(1)
                     push_vals[mask] = c_vals[mask]
                     
                     # Increment counter where C was used
                     const_counters[mask] += 1
                else:
                     push_vals[mask] = 1.0 
                     
                is_operand = is_operand | mask
            
            # Literals
            for val_str in ['1', '2', '3', '5']:
                vid = self.grammar.token_to_id.get(val_str, -999)
                mask = (token == vid)
                if mask.any():
                    push_vals[mask] = float(val_str)
                    is_operand = is_operand | mask
                    

            if is_operand.any():
                safe_sp = torch.clamp(sp, 0, MAX_STACK-1)
                # Out-of-place scatter for autograd safety
                stack = stack.scatter(1, safe_sp.unsqueeze(1), push_vals.unsqueeze(1))
                sp = sp + is_operand.long()
                
            # --- 2. Binary ---
            is_binary = (token == op_add) | (token == op_sub) | (token == op_mul) | (token == op_div) | (token == op_pow)
            valid_op = is_binary & (sp >= 2)
            
            if valid_op.any():
                idx_b = torch.clamp(sp - 1, 0, MAX_STACK - 1).unsqueeze(1)
                val_b = stack.gather(1, idx_b).squeeze(1)
                
                idx_a = torch.clamp(sp - 2, 0, MAX_STACK - 1).unsqueeze(1)
                val_a = stack.gather(1, idx_a).squeeze(1)
                
                res = torch.zeros_like(val_a)
                
                mask = (token == op_add) & valid_op
                if mask.any(): res[mask] = val_a[mask] + val_b[mask]
                
                mask = (token == op_sub) & valid_op
                if mask.any(): res[mask] = val_a[mask] - val_b[mask]
                
                mask = (token == op_mul) & valid_op
                if mask.any(): res[mask] = val_a[mask] * val_b[mask]
                
                mask = (token == op_div) & valid_op
                if mask.any(): 
                    denom = val_b[mask]
                    # C++: if (fabs(right) < 1e-9) { result = GPU_MAX_DOUBLE; }
                    # We implement parity:
                    bad_denom = denom.abs() < 1e-9
                    
                    # We compute safe division where possible
                    safe_denom = torch.where(bad_denom, torch.tensor(1.0, device=self.device, dtype=torch.float64), denom)
                    out_div = val_a[mask] / safe_denom
                    
                    # Apply penalty for bad denom
                    # We use 1e300 as GPU_MAX_DOUBLE proxy (or just 1e15 to avoid inf issues in float64?)
                    # C++ uses DBL_MAX typically which is ~1e308. 
                    out_div[bad_denom] = 1e300
                    res[mask] = out_div
                    
                mask = (token == op_mod) & valid_op
                if mask.any():
                    denom = val_b[mask]
                    bad_denom = denom.abs() < 1e-9
                    safe_denom = torch.where(bad_denom, torch.tensor(1.0, device=self.device, dtype=torch.float64), denom)
                    # C++: fmod
                    out_mod = torch.fmod(val_a[mask], safe_denom)
                    out_mod[bad_denom] = 1e300
                    res[mask] = out_mod
                    
                mask = (token == op_pow) & valid_op
                if mask.any():
                    # No artificial clamping for float64 unless extremely huge to avoid NaN propagation immediately
                    # C++ just does pow(l, r)
                    # But we can protect against complex numbers (negative base ^ float exp) -> NaN
                    base = val_a[mask]
                    expon = val_b[mask]
                    # If base < 0 and exponent is not integer loop, result is NaN. 
                    # We can protect base like C++ protected ops sometimes do, or just let it be NaN (yielding INF fitness)
                    res[mask] = torch.pow(base, expon)
                
                write_pos = torch.clamp(sp - 2, 0, MAX_STACK-1)
                current_at_pos = stack.gather(1, write_pos.unsqueeze(1)).squeeze(1)
                final_write_val = torch.where(valid_op, res, current_at_pos)
                
                # Out-of-place scatter
                stack = stack.scatter(1, write_pos.unsqueeze(1), final_write_val.unsqueeze(1))
                sp = sp - valid_op.long()
                
            # --- 3. Unary ---
            is_unary = (token == op_sin) | (token == op_cos) | (token == op_tan) | \
                       (token == op_asin) | (token == op_acos) | (token == op_atan) | \
                       (token == op_exp) | (token == op_log) | \
                       (token == op_sqrt) | (token == op_abs) | (token == op_neg) | \
                       (token == op_fact) | (token == op_floor) | (token == op_gamma)
            valid_op = is_unary & (sp >= 1)
            
            if valid_op.any():
                idx_a = torch.clamp(sp - 1, 0, MAX_STACK - 1).unsqueeze(1)
                val_a = stack.gather(1, idx_a).squeeze(1)
                res = torch.zeros_like(val_a)
                
                mask = (token == op_sin) & valid_op
                if mask.any(): res[mask] = torch.sin(val_a[mask])
                
                mask = (token == op_cos) & valid_op
                if mask.any(): res[mask] = torch.cos(val_a[mask])
                
                mask = (token == op_log) & valid_op
                if mask.any(): 
                    # C++: (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val)
                    # We use a large value for error
                    inp = val_a[mask]
                    safe_mask = inp > 1e-9
                    # Where unsafe, we put a huge value. But we must set res.
                    # We compute log everywhere but replace bad ones? Or select?
                    out = torch.full_like(inp, 1e300) # GPU_MAX_DOUBLE proxy
                    out[safe_mask] = torch.log(inp[safe_mask])
                    res[mask] = out
                
                mask = (token == op_exp) & valid_op
                if mask.any(): 
                    # C++: (val > 700.0) ? GPU_MAX_DOUBLE : exp(val)
                    inp = val_a[mask]
                    safe_mask = inp <= 700.0
                    out = torch.full_like(inp, 1e300)
                    out[safe_mask] = torch.exp(inp[safe_mask])
                    res[mask] = out
                
                mask = (token == op_sqrt) & valid_op
                if mask.any(): res[mask] = torch.sqrt(val_a[mask].abs())
                
                mask = (token == op_abs) & valid_op
                if mask.any(): res[mask] = torch.abs(val_a[mask])
                
                mask = (token == op_neg) & valid_op
                if mask.any(): res[mask] = -val_a[mask]
                
                mask = (token == op_tan) & valid_op
                if mask.any(): res[mask] = torch.tan(val_a[mask])
                
                mask = (token == op_asin) & valid_op
                if mask.any(): 
                    # C++ protected: asin(clamp(x, -1, 1)) (actually S op code)
                    # But if we want standard behavior or protected?
                    # The C++ code for 'asin' (op 'S' in tree string, but 'asin' in kernel?) 
                    # Kernel uses standard asin but our engine usually protects domain.
                    # Let's use protection [-1, 1]
                    res[mask] = torch.asin(torch.clamp(val_a[mask], -1.0, 1.0))
                
                mask = (token == op_acos) & valid_op
                if mask.any(): res[mask] = torch.acos(torch.clamp(val_a[mask], -1.0, 1.0))
                
                mask = (token == op_atan) & valid_op
                if mask.any(): res[mask] = torch.atan(val_a[mask])

                mask = (token == op_floor) & valid_op
                if mask.any(): res[mask] = torch.floor(val_a[mask])
                
                mask = (token == op_fact) & valid_op
                if mask.any():
                     # C++: (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0);
                     inp = val_a[mask]
                     # tgamma(n+1) = n!
                     # We use lgamma and exp to be safe? or torch.special.gamma?
                     # torch.special.gamma is 'tgamma' equivalent.
                     # Protection
                     unsafe = (inp < 0) | (inp > 170.0)
                     out = torch.full_like(inp, 1e300)
                     
                     # Only compute safe to avoid NaN/Inf in gradients or runtime
                     safe_inp = inp.clone()
                     safe_inp[unsafe] = 1.0 # dummy
                     
                     val_computed = torch.special.gamma(safe_inp + 1.0)
                     out[~unsafe] = val_computed[~unsafe]
                     res[mask] = out

                mask = (token == op_gamma) & valid_op
                if mask.any():
                     # C++: (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); 
                     # Wait, snippet said lgamma(val+1). Usually 'gamma' op is just gamma function?
                     # C++ snippet: case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); 
                     # This seems to be Log-Gamma of (x+1)? Or is it Gamma? 
                     # 'lgamma' function usually computes log(|gamma(x)|).
                     # The snippet explicitly says lgamma. So GPU op 'g' is log-gamma.
                     inp = val_a[mask]
                     unsafe = (inp <= -1.0)
                     out = torch.full_like(inp, 1e300)
                     
                     safe_inp = inp.clone()
                     safe_inp[unsafe] = 1.0
                     
                     val_computed = torch.special.gammaln(safe_inp + 1.0) # lgamma matches gammaln in torch
                     out[~unsafe] = val_computed[~unsafe]
                     res[mask] = out

                write_pos = torch.clamp(sp - 1, 0, MAX_STACK-1)
                current_at_pos = stack.gather(1, write_pos.unsqueeze(1)).squeeze(1)
                final_write_val = torch.where(valid_op, res, current_at_pos)
                
                # Out-of-place scatter
                stack = stack.scatter(1, write_pos.unsqueeze(1), final_write_val.unsqueeze(1))

        
        is_valid = (sp == 1)
        final_preds = stack[:, 0]
        final_preds = torch.where(is_valid, final_preds, torch.tensor(float('nan'), device=self.device, dtype=torch.float64))
        preds_matrix = final_preds.view(B, D)
        target_matrix = y_target.unsqueeze(0).expand(B, -1)
        mse = torch.mean((preds_matrix - target_matrix)**2, dim=1)
        rmse = torch.sqrt(torch.where(torch.isnan(mse), torch.tensor(1e300, device=self.device, dtype=torch.float64), mse))
        return rmse

    def evaluate_differentiable(self, population: torch.Tensor, constants: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor):
        import torch.nn.functional as F
        return self.evaluate_batch(population, x, y_target), torch.zeros_like(x).expand(population.shape[0], -1) 

    def evaluate_batch_full(self, population: torch.Tensor, x: torch.Tensor, y_target: torch.Tensor, constants: torch.Tensor = None) -> torch.Tensor:
        """
        Returns full error matrix [Pop, D]. 
        Used for Lexicase Selection.
        """
        B, L = population.shape
        D = x.shape[0]
        
        final_preds, sp = self._run_vm(population, x, constants)
        is_valid = (sp == 1)
        
        # Penalize invalid/nan/inf with 1e300
        final_preds = torch.where(is_valid & ~torch.isnan(final_preds) & ~torch.isinf(final_preds), 
                                  final_preds, 
                                  torch.tensor(1e300, device=self.device, dtype=torch.float64))
        
        preds_matrix = final_preds.view(B, D)
        target_matrix = y_target.unsqueeze(0).expand(B, -1)
        abs_err = torch.abs(preds_matrix - target_matrix)
        
        # Guard against Inf in abs_err (e.g. pred - target where one is huge)
        abs_err = torch.where(torch.isnan(abs_err) | torch.isinf(abs_err), 
                              torch.tensor(1e300, device=self.device, dtype=torch.float64), 
                              abs_err)
        return abs_err

    def compute_case_weights(self, errors: torch.Tensor) -> torch.Tensor:
        """
        Compute case weights based on difficulty (variance of errors across population).
        
        Cases with higher variance are considered harder and get higher weights.
        
        Args:
            errors: [PopSize, n_cases] error matrix
            
        Returns:
            [n_cases] weights normalized to sum to 1
        """
        # Variance across population per case
        case_variance = torch.var(errors, dim=0)
        
        # Normalize to weights (higher variance -> higher weight)
        weights = case_variance / (case_variance.sum() + 1e-9)
        
        return weights

    def weighted_rmse(self, errors: torch.Tensor, weights: torch.Tensor = None) -> torch.Tensor:
        """
        Compute weighted RMSE across cases.
        
        Args:
            errors: [PopSize, n_cases] absolute error matrix
            weights: [n_cases] optional case weights (default: uniform)
            
        Returns:
            [PopSize] weighted RMSE per individual
        """
        if weights is None or not GpuGlobals.USE_WEIGHTED_FITNESS:
            # Standard RMSE
            return torch.sqrt((errors ** 2).mean(dim=1))
        
        # Weighted mean squared error
        weighted_mse = (errors ** 2 * weights.unsqueeze(0)).sum(dim=1)
        return torch.sqrt(weighted_mse)

    def lexicase_selection(self, population: torch.Tensor, errors: torch.Tensor, n_select: int) -> torch.Tensor:
        """
        Selects n_select parents using Tournament Lexicase Selection.
        errors: [PopSize, n_cases] (Absolute Error)
        """
        # Lexicase is slow if running on full population for every selection.
        # "Tournament Lexicase": Pick random subset, run lexicase on it to find 1 winner. Repeat.
        
        # Optimized implementation:
        # We need n_select winners.
        # For each winner:
        # 1. Pick pool (size ~50?)
        # 2. Shuffle cases
        # 3. Filter loop
        
        # Since we cannot easily loop inside tensor ops, we might need a custom kernel or CPU loop.
        # Lexicase is inherently sequential on cases.
        # CPU Loop over n_select is feasible if n_select is not huge (e.g. 1000).
        
        pop_size, n_cases = errors.shape
        pool_size = 50
        
        selected_indices = []
        
        # Errors to CPU for logic
        errors_cpu = errors.detach().cpu().numpy()
        
        for _ in range(n_select):
            # 1. Pool
            candidates = np.random.randint(0, pop_size, pool_size)
            
            # 2. Shuffle cases
            cases = np.random.permutation(n_cases)
            
            active_cands = candidates
            
            for case_idx in cases:
                # Get errors for active candidates at this case
                # errors_cpu[active_cands, case_idx]
                case_errs = errors_cpu[active_cands, case_idx]
                min_err = np.min(case_errs)
                
                # Epsilon (MAD or simple)
                epsilon = max(min_err * 0.1, 1e-9)
                
                # Filter
                survivors_mask = case_errs <= (min_err + epsilon)
                active_cands = active_cands[survivors_mask]
                
                if len(active_cands) == 1:
                    break
            
            # Pick random survivor
            winner = np.random.choice(active_cands)
            selected_indices.append(winner)
            
        return population[selected_indices]

        

    def _generate_random_population(self, size: int) -> torch.Tensor:
        """
        Helper to generate random RPN population of given size.
        """
        formulas = []
        # Generate full population (slower but ensures diversity)
        for _ in range(size):
            try:
                # Generate random valid tree
                tree = ExpressionTree.generate_random(max_depth=GpuGlobals.MAX_TREE_DEPTH_INITIAL, num_variables=self.num_variables)
                formulas.append(tree.get_infix())
            except:
                formulas.append("x0")
        
        # Convert to RPN
        return self.infix_to_rpn(formulas)

    def initialize_population(self) -> torch.Tensor:
        """
        Generates a population of VALID random formulas.
        """
        return self._generate_random_population(self.pop_size)

    def cataclysm_population(self, population: torch.Tensor, constants: torch.Tensor, fitness_rmse: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Hard Reset: Keep Top 10% Elites, replace the rest with new random individuals.
        Called when diversity collapses (too many duplicates).
        """
        B = population.shape[0]
        n_elites = int(B * 0.10)
        n_random = B - n_elites
        
        # Sort by fitness (RMSE ascending is better)
        sorted_indices = torch.argsort(fitness_rmse)
        elite_indices = sorted_indices[:n_elites]
        
        # Keep Elites
        elites = population[elite_indices]
        elite_consts = constants[elite_indices]
        
        # Generate fresh randoms
        new_pop = self._generate_random_population(n_random)
        new_consts = torch.zeros((n_random, constants.shape[1]), device=self.device, dtype=torch.float64)
        
        # Combine
        final_pop = torch.cat([elites, new_pop], dim=0)
        final_consts = torch.cat([elite_consts, new_consts], dim=0)
        
        return final_pop, final_consts

    def detect_patterns(self, targets: List[float]) -> List[str]:
        """
        Detects simple patterns (Arithmetic, Geometric) in 1D targets.
        Returns a list of seed formulas.
        """
        if len(targets) < 3: return []
        
        seeds = []
        
        # 1. Arithmetic: y = a + d*x  (assuming x=0, 1, 2...)
        # We need to know X to be sure, but let's assume standard index for simple detection
        # or just check diffs.
        
        diffs = np.diff(targets)
        if np.allclose(diffs, diffs[0], atol=1e-5):
            d = diffs[0]
            a = targets[0] # assuming x0=0 start? 
            # If x starts at 1, then y = a + d*(x-1) = (a-d) + d*x
            # We construct a generic candidate 'C + C*x0'
            # We can just let the optimizer find constants if we give the structure.
            seeds.append("(C + (C * x0))") 
            
        # 2. Geometric: y = a * r^x
        # Check ratios
        if not np.any(np.abs(targets) < 1e-9):
            ratios = targets[1:] / targets[:-1]
            if np.allclose(ratios, ratios[0], atol=1e-5):
                seeds.append("(C * (C ^ x0))")
                
        # 3. Constant
        if np.allclose(targets, targets[0], atol=1e-5):
             seeds.append("C")
             
        # 4. Fibonacci-ish? (Last 2 sum)
        # 5. Sinusoidal?
        
        return seeds

    def run(self, x_values: List[float], y_targets: List[float], seeds: List[str], timeout_sec=10, callback=None) -> Optional[str]:
        """
        Main evolutionary loop.
        """
        start_time = time.time()
        
        # 1. Setup Data
        if GpuGlobals.USE_LOG_TRANSFORMATION:
            print("Info: Log Transformation is ON (Target = ln(Y)).")
            y_np = np.array(y_targets)
            x_np = np.array(x_values)
            mask = y_np > 1e-9 # Parity with C++ log protection
            if not mask.all():
                print(f"Warning: Filtering out {(~mask).sum()} zero or negative data points for log transformation.")
                y_np = y_np[mask]
                x_np = x_np[mask]
            y_targets = np.log(y_np).tolist()
            x_values = x_np.tolist()

        x_t = torch.tensor(x_values, device=self.device, dtype=torch.float64)
        y_t = torch.tensor(y_targets, device=self.device, dtype=torch.float64)
        
        # Flatten only if strictly 1 variable and input is weirdly shaped?
        # If num_variables > 1, we expect x_t to be [N, Vars]
        if self.num_variables == 1:
             if x_t.ndim > 1: x_t = x_t.flatten()
        
        if y_t.ndim > 1: y_t = y_t.flatten()

        # The Sniper
        sniper_res = self.sniper.run(x_values, y_targets)
        if sniper_res: return sniper_res
        
        print("[GPU Worker] Initializing Tensor Population...")
        
        # 2. Init Population
        population = self.initialize_population()
        # population[:, 0] = torch.randint(...) # No longer needed, as we have valid RPNs

 
        
        # Seeds
        if seeds:
            seed_tensors = self.infix_to_rpn(seeds)
            k_seeds = seed_tensors.shape[0]
            if k_seeds > 0:
                population[:k_seeds] = seed_tensors
        
        # --- Pattern Detection ---
        # Detect standard sequences (Arithmetic, Geometric)
        pattern_seeds = self.detect_patterns(y_targets)
        if pattern_seeds:
            print(f"[GPU Worker] Detected patterns: {pattern_seeds}")
            pat_tensors = self.infix_to_rpn(pattern_seeds)
            k_pats = pat_tensors.shape[0]
            if k_pats > 0:
                 # Insert after user seeds
                 offset = len(seeds) if seeds else 0
                 population[offset:offset+k_pats] = pat_tensors

        pop_constants = torch.randn(self.pop_size, self.max_constants, device=self.device, dtype=torch.float64)
        
        # Stats
        best_rmse = float('inf')
        best_rpn = None
        best_consts_vec = None
        
        stagnation_counter = 0
        current_mutation_rate = GpuGlobals.BASE_MUTATION_RATE
        
        generations = 0
        COMPLEXITY_PENALTY = GpuGlobals.COMPLEXITY_PENALTY
        max_generations = GpuGlobals.GENERATIONS

        # Loop until: fitness ~0, OR max generations, OR timeout
        while generations < max_generations:
            # Check timeout (optional, set timeout_sec=None to disable)
            if timeout_sec and (time.time() - start_time) >= timeout_sec:
                print(f"[GPU] Timeout after {generations} generations")
                break
                
            generations += 1
            

            # Eval (Fast Scan)
            fitness_rmse = self.evaluate_batch(population, x_t, y_t, pop_constants)
            
            # --- Constant Optimization (Top K) ---
            # Optimize top 200 candidates to refine their constants
            k_opt = min(self.pop_size, 200)
            
            # Find candidates (using penalized fitness or raw rmse?)
            # Raw RMSE is better for optimization target
            _, top_idx = torch.topk(fitness_rmse, k_opt, largest=False)
            
            # Extract subset
            opt_pop = population[top_idx]
            opt_consts = pop_constants[top_idx]
            
            # Optimize (Gradient Descent)
            # Use fewer steps to keep speed up? 10 is fine.
            refined_consts, refined_mse = self.optimize_constants(
                opt_pop, opt_consts, x_t, y_t, steps=10, lr=0.1
            )
            
            # Update population constants
            pop_constants[top_idx] = refined_consts
            
            # Update fitness for optimized individuals (optional but good for accurate tracking)
            # refined_mse is actually RMSE from the function
            fitness_rmse[top_idx] = refined_mse
            
            # Re-evaluate penalties? Length doesn't change.
            # But we can just leave it for next gen or update fitness_penalized here.
            # Let's update Penalized so Elitism picks the improved versions instantly.
            # We need lengths for these
            # lengths is [PopSize], so we pick top_idx
            # fitness_penalized[top_idx] = refined_mse * (1.0 + lengths[top_idx] * COMPLEXITY_PENALTY)
            
            # Run Adam
            # refined_consts, refined_rmse = self.optimize_constants(opt_pop, opt_consts, x_t, y_t, steps=15)
            
            # Update Population
            # We must update the original tensors.
            # 1. Update Constants
            # pop_constants[top_idx] = refined_consts
            # 2. Update Fitness Scores (Evaluation was implicit in optimize)
            # But wait, fitness_rmse is [PopSize].
            # We update the scores.
            # fitness_rmse[top_idx] = refined_rmse
            
            # --- Selection ---
            lengths = (population != PAD_ID).sum(dim=1).float()
            fitness_penalized = fitness_rmse * (1.0 + COMPLEXITY_PENALTY * lengths) + lengths * 1e-6
            
            # --- Tarpeian Bloat Control ---
            fitness_penalized = self.tarpeian_control(population, fitness_penalized)
            
            # Select Best
            min_rmse, min_idx = torch.min(fitness_rmse, dim=0)
            if min_rmse.item() < best_rmse:
                best_rmse = min_rmse.item()
                best_rpn = population[min_idx].clone()
                best_consts_vec = pop_constants[min_idx].clone()
                best_island_idx = (min_idx.item() // self.island_size) 
                
                if callback:
                    callback(generations, best_rmse, best_rpn, best_consts_vec, True, best_island_idx)
                
                stagnation_counter = 0
                current_mutation_rate = GpuGlobals.BASE_MUTATION_RATE
            else:
                stagnation_counter += 1
            
            if callback and (generations % GpuGlobals.PROGRESS_REPORT_INTERVAL == 0 or generations == 1) and best_rpn is not None:
                 callback(generations, best_rmse, best_rpn, best_consts_vec, False, -1)

            # --- Island Migration ---
            if self.n_islands > 1 and generations % GpuGlobals.MIGRATION_INTERVAL == 0:
                population, pop_constants = self.migrate_islands(population, pop_constants, fitness_rmse)

            # Cataclysm
            if stagnation_counter >= GpuGlobals.STAGNATION_LIMIT:
                 saved_best_rpn = best_rpn.clone()
                 saved_best_c = best_consts_vec.clone()
                 population = self.initialize_population()
                 pop_constants = torch.randn(self.pop_size, self.max_constants, device=self.device, dtype=torch.float64)
                 population[0] = saved_best_rpn
                 pop_constants[0] = saved_best_c
                 stagnation_counter = 0
                 current_mutation_rate = GpuGlobals.BASE_MUTATION_RATE
                 continue
            
            # --- Dynamic Mutation Rate ---
            if stagnation_counter > 10:
                current_mutation_rate = min(0.4, GpuGlobals.BASE_MUTATION_RATE + (stagnation_counter - 10) * 0.01)
            else:
                current_mutation_rate = GpuGlobals.BASE_MUTATION_RATE


            # --- NEW ADVANCED EVOLUTION STEP ---
            # 1. Elitism
            # 2. Crossover (Lexicase Sel)
            # 3. Mutation (Tournament Sel)
            # 4. Uniqueness Check
            
            next_pop_list = []
            next_const_list = []
            
            # 1. Elitism (Top 5% using Pareto or Fitness)
            k_elite = max(1, int(self.pop_size * 0.05))
            
            if GpuGlobals.USE_PARETO_SELECTION:
                # Use NSGA-II to select elite individuals balancing error vs complexity
                complexity = lengths  # Tree size as complexity
                elite_idx = self.pareto.select(population, fitness_rmse, complexity, k_elite)
            else:
                # Standard fitness-based elitism
                _, elite_idx = torch.topk(fitness_penalized, k_elite, largest=False)
            
            elites = population[elite_idx]
            elites_c = pop_constants[elite_idx]
            next_pop_list.append(elites)
            next_const_list.append(elites_c)
            
            remaining_slots = self.pop_size - k_elite
            
            # 2. Crossover (Using Lexicase if costly or standard if fast?)
            # Lexicase is costly. Let's compute FULL errors only if using Lexicase.
            # Use Lexicase for Crossover Parents (Standard GP practice)
            
            # 2. Crossover Parents (GPU Tournament for Speed)
            # 2. Crossover Parents (GPU Tournament for Speed)
            n_crossover = int(remaining_slots * GpuGlobals.DEFAULT_CROSSOVER_RATE)
            n_mutation = remaining_slots - n_crossover
            
            if n_crossover > 0:
                idx_cross = torch.randint(0, self.pop_size, (n_crossover, GpuGlobals.DEFAULT_TOURNAMENT_SIZE), device=self.device)
                best_in_tourn = torch.argmin(fitness_penalized[idx_cross], dim=1)
                global_idx_cross = idx_cross.gather(1, best_in_tourn.unsqueeze(1)).squeeze(1)
                
                # We need PAIRS of parents. This logic selects N individuals.
                # crossover_population internally shuffles and pairs them.
                parents_cross = population[global_idx_cross]
                consts_cross = pop_constants[global_idx_cross]
                
                # Perform crossover (Vectorized)
                off_cross = self.crossover_population(parents_cross, crossover_rate=1.0) # Rate 1.0 because we already selected size
                next_pop_list.append(off_cross)
                next_const_list.append(consts_cross)
            
            # 3. Mutation Parents (Tournament)
            if n_mutation > 0:
                idx_mut = torch.randint(0, self.pop_size, (n_mutation, GpuGlobals.DEFAULT_TOURNAMENT_SIZE), device=self.device)
                best_in_tourn = torch.argmin(fitness_penalized[idx_mut], dim=1)
                global_idx_mut = idx_mut.gather(1, best_in_tourn.unsqueeze(1)).squeeze(1)
                
                parents_mut = population[global_idx_mut]
                consts_mut = pop_constants[global_idx_mut]
                
                off_mut = self.mutate_population(parents_mut, current_mutation_rate)
                next_pop_list.append(off_mut)
                next_const_list.append(consts_mut)
            
            
            # Concatenate
            next_pop = torch.cat(next_pop_list, dim=0)
            next_c = torch.cat(next_const_list, dim=0)
            
            population = next_pop[:self.pop_size]
            pop_constants = next_c[:self.pop_size]

            # --- Deduplication (Aggressive: Every generation to force diversity) ---
            if GpuGlobals.PREVENT_DUPLICATES and generations % 1 == 0:
                population, pop_constants, n_dups = self.deduplicate_population(population, pop_constants)
                # Silent - only log if many duplicates
                if n_dups > self.pop_size * 0.1:
                    print(f"[GPU] Removed {n_dups} duplicates (Fresh Randoms Injected)")
            
            # Debug: Report Valid Count
            if generations % 5 == 0:
                 valid_cnt = (fitness_rmse < 1e9).sum().item()
                 print(f"[GPU] Gen {generations}: Valid Individuals = {valid_cnt}/{self.pop_size}")
                    
                # TRIGGER CATACLYSM if > 90% are duplicates - REMOVED (Redundant with Fresh Random Injection)
                # if n_dups > self.pop_size * 0.9 and generations > 20: 
                #     print(f"!!! CATACLYSM TRIGGERED (Duplicates: {n_dups}/{self.pop_size}) !!!")
                #     print("!!! Resetting 90% of population with fresh DNA !!!")
                #     population, pop_constants = self.cataclysm_population(population, pop_constants, fitness_rmse)

            # --- Simplification (Reduced Frequency: Every 500 generations) ---
            if GpuGlobals.USE_SIMPLIFICATION and generations % 500 == 0:
                # Simplify top 50 individuals
                population, pop_constants, n_simp = self.simplify_population(population, pop_constants, top_k=50)
                if n_simp > 0 and callback:
                    print(f"[GPU] Simplified {n_simp} expressions")

            # --- Pattern Memory (DISABLED FOR SPEED TEST) ---
            # Record successful subtrees from current population
            # self.pattern_memory.record_subtrees(population, fitness_rmse, self.grammar)
            
            # Inject patterns periodically
            # if generations % GpuGlobals.PATTERN_INJECT_INTERVAL == 0:
            #     population, pop_constants, n_inj = self.pattern_memory.inject_into_population(
            #         population, pop_constants, self.grammar, 
            #         percent=GpuGlobals.PATTERN_INJECT_PERCENT
            #     )

            # --- Local Search (DISABLED FOR SPEED TEST) ---
            # if generations % 100 == 0:
            #     population, pop_constants = self.local_search(
            #         population, pop_constants, x_t, y_t, 
            #         top_k=10, attempts=GpuGlobals.LOCAL_SEARCH_ATTEMPTS
            #     )
            
            if best_rmse < 1e-7:
                 return self.rpn_to_infix(best_rpn, best_consts_vec)
                 
        if best_rpn is not None:
             return self.rpn_to_infix(best_rpn, best_consts_vec)
        return None


In [None]:
%%writefile AlphaSymbolic/core/gpu/ensemble.py
"""
Ensemble / Coevolution Support for GPU GP Engine.

Provides utilities to:
- Run multiple GP engines in parallel
- Combine results from multiple runs
- Share best individuals between runs (coevolution)
"""
import torch
import numpy as np
from typing import List, Tuple, Optional, Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import time


class EnsembleRunner:
    """
    Runs multiple GP engines and combines their results.
    
    Supports:
    - Parallel execution of multiple runs
    - Hall of Fame aggregation across runs
    - Best solution selection with Pareto consideration
    """
    
    def __init__(self, engine_factory: Callable, n_runs: int = 5, 
                 share_best: bool = True, share_interval: int = 100):
        """
        Args:
            engine_factory: Function that creates a TensorGeneticEngine instance
            n_runs: Number of parallel runs
            share_best: Whether to share best solutions between runs
            share_interval: Generations between sharing best solutions
        """
        self.engine_factory = engine_factory
        self.n_runs = n_runs
        self.share_best = share_best
        self.share_interval = share_interval
        
        # Hall of Fame: list of (formula_str, rmse, complexity)
        self.hall_of_fame: List[Tuple[str, float, int]] = []
        self.max_hof_size = 20
        
    def run_single(self, engine, x_values, y_targets, seeds, timeout_sec, run_id) -> Tuple[Optional[str], float, int]:
        """
        Run a single GP engine.
        
        Returns:
            (best_formula, best_rmse, run_id)
        """
        try:
            result = engine.run(x_values, y_targets, seeds, timeout_sec=timeout_sec)
            
            # Get fitness from last evaluation
            if hasattr(engine, 'best_rmse'):
                rmse = engine.best_rmse
            else:
                rmse = float('inf')
            
            return (result, rmse, run_id)
        except Exception as e:
            print(f"[Ensemble] Run {run_id} failed: {e}")
            return (None, float('inf'), run_id)
    
    def run_ensemble(self, x_values: List[float], y_targets: List[float], 
                     seeds: List[str] = None, timeout_sec: float = 10,
                     callback: Callable = None) -> str:
        """
        Run ensemble of GP engines and return best result.
        
        Args:
            x_values: Input data
            y_targets: Target data
            seeds: Optional seed formulas
            timeout_sec: Timeout per run
            callback: Optional progress callback
            
        Returns:
            Best formula found across all runs
        """
        if seeds is None:
            seeds = []
        
        # Create engines
        engines = [self.engine_factory() for _ in range(self.n_runs)]
        
        results = []
        best_formula = None
        best_rmse = float('inf')
        
        # Run sequentially (parallel would require careful GPU memory management)
        for i, engine in enumerate(engines):
            if callback:
                callback(f"Running engine {i+1}/{self.n_runs}")
            
            # Use shared seeds from Hall of Fame
            shared_seeds = seeds.copy()
            if self.share_best and self.hall_of_fame:
                top_hof = [f for f, _, _ in self.hall_of_fame[:5]]
                shared_seeds.extend(top_hof)
            
            result, rmse, _ = self.run_single(engine, x_values, y_targets, 
                                              shared_seeds, timeout_sec, i)
            
            if result:
                results.append((result, rmse))
                
                # Update Hall of Fame
                complexity = len(result) if result else 0
                self._add_to_hof(result, rmse, complexity)
                
                if rmse < best_rmse:
                    best_rmse = rmse
                    best_formula = result
        
        if callback:
            callback(f"Ensemble complete. Best RMSE: {best_rmse:.6f}")
        
        return best_formula
    
    def _add_to_hof(self, formula: str, rmse: float, complexity: int):
        """Add formula to Hall of Fame if it's good enough."""
        if formula is None:
            return
        
        # Check if already in HoF
        for existing_formula, _, _ in self.hall_of_fame:
            if existing_formula == formula:
                return
        
        # Add
        self.hall_of_fame.append((formula, rmse, complexity))
        
        # Sort by (rmse, complexity) - lexicographic
        self.hall_of_fame.sort(key=lambda x: (x[1], x[2]))
        
        # Trim to max size
        if len(self.hall_of_fame) > self.max_hof_size:
            self.hall_of_fame = self.hall_of_fame[:self.max_hof_size]
    
    def get_pareto_front(self) -> List[Tuple[str, float, int]]:
        """
        Get Pareto-optimal solutions from Hall of Fame.
        
        Returns:
            List of (formula, rmse, complexity) tuples on the Pareto front
        """
        if not self.hall_of_fame:
            return []
        
        pareto = []
        for formula, rmse, complexity in self.hall_of_fame:
            is_dominated = False
            for other_formula, other_rmse, other_complexity in self.hall_of_fame:
                if other_formula == formula:
                    continue
                # Check if other dominates this
                if other_rmse <= rmse and other_complexity <= complexity:
                    if other_rmse < rmse or other_complexity < complexity:
                        is_dominated = True
                        break
            
            if not is_dominated:
                pareto.append((formula, rmse, complexity))
        
        return pareto
    
    def get_best(self) -> Optional[str]:
        """Get best formula from Hall of Fame."""
        if not self.hall_of_fame:
            return None
        return self.hall_of_fame[0][0]


def create_ensemble_runner(device=None, pop_size=1000, n_runs=5):
    """
    Factory function to create an EnsembleRunner.
    
    Args:
        device: Torch device
        pop_size: Population size per run
        n_runs: Number of runs
        
    Returns:
        EnsembleRunner instance
    """
    from . import TensorGeneticEngine
    
    def factory():
        return TensorGeneticEngine(device=device, pop_size=pop_size, n_islands=4)
    
    return EnsembleRunner(factory, n_runs=n_runs)


In [None]:
%%writefile AlphaSymbolic/core/gpu/formatting.py

def format_const(val: float) -> str:
    """
    Format a constant float to string matching CPU engine rules:
    - Integer-like values -> "3"
    - Extreme values (>=1e6 or <=1e-6) -> Scientific "1.23456789e+09"
    - Normal values -> Fixed "1.23456789", trimmed trailing zeros and dot.
    """
    if abs(val - round(val)) < 1e-9:
        return str(int(round(val)))
    if abs(val) >= 1e6 or abs(val) <= 1e-6:
        return f"{val:.8e}"
    s = f"{val:.8f}"
    s = s.rstrip('0').rstrip('.')
    return s if s else "0"


In [None]:
%%writefile AlphaSymbolic/core/gpu/pareto.py
"""
Pareto Optimization (NSGA-II) for GPU GP Engine.

Implements multi-objective optimization balancing:
- Objective 1: Error (RMSE) - minimize
- Objective 2: Complexity (tree size) - minimize
"""
import torch
import numpy as np
from typing import List, Tuple


class ParetoOptimizer:
    """
    NSGA-II style Pareto optimizer for symbolic regression.
    
    Objectives:
        - fitness: RMSE (lower is better)
        - complexity: number of tokens (lower is better)
    """
    
    def __init__(self, device: torch.device, max_front_size: int = 50):
        self.device = device
        self.max_front_size = max_front_size
    
    def dominates(self, obj_a: Tuple[float, float], obj_b: Tuple[float, float]) -> bool:
        """
        Check if solution A dominates solution B (both objectives <= and at least one <).
        """
        a_fit, a_comp = obj_a
        b_fit, b_comp = obj_b
        
        # A dominates B if A is <= B in all objectives and < in at least one
        at_least_one_better = (a_fit < b_fit) or (a_comp < b_comp)
        not_worse = (a_fit <= b_fit) and (a_comp <= b_comp)
        
        return not_worse and at_least_one_better
    
    def non_dominated_sort(self, fitness: torch.Tensor, complexity: torch.Tensor) -> List[List[int]]:
        """
        Perform non-dominated sorting on the population.
        
        Args:
            fitness: [PopSize] RMSE values
            complexity: [PopSize] tree sizes
            
        Returns:
            List of fronts, where each front is a list of indices
        """
        n = fitness.shape[0]
        fitness_cpu = fitness.cpu().numpy()
        complexity_cpu = complexity.cpu().numpy()
        
        # For each individual, count how many dominate it
        domination_count = np.zeros(n, dtype=np.int32)
        dominated_by = [[] for _ in range(n)]  # Who each individual dominates
        
        for i in range(n):
            for j in range(i + 1, n):
                obj_i = (fitness_cpu[i], complexity_cpu[i])
                obj_j = (fitness_cpu[j], complexity_cpu[j])
                
                if self.dominates(obj_i, obj_j):
                    dominated_by[i].append(j)
                    domination_count[j] += 1
                elif self.dominates(obj_j, obj_i):
                    dominated_by[j].append(i)
                    domination_count[i] += 1
        
        # Build fronts
        fronts = []
        current_front = []
        
        # First front: individuals with domination_count = 0
        for i in range(n):
            if domination_count[i] == 0:
                current_front.append(i)
        
        while current_front:
            fronts.append(current_front)
            next_front = []
            
            for i in current_front:
                for j in dominated_by[i]:
                    domination_count[j] -= 1
                    if domination_count[j] == 0:
                        next_front.append(j)
            
            current_front = next_front
        
        return fronts
    
    def crowding_distance(self, front: List[int], fitness: torch.Tensor, complexity: torch.Tensor) -> torch.Tensor:
        """
        Calculate crowding distance for individuals in a front.
        
        Args:
            front: List of indices in this front
            fitness: [PopSize] RMSE values
            complexity: [PopSize] tree sizes
            
        Returns:
            [len(front)] crowding distances
        """
        n = len(front)
        if n <= 2:
            return torch.full((n,), float('inf'), device=self.device)
        
        distances = torch.zeros(n, device=self.device, dtype=torch.float64)
        
        # For each objective
        for obj_vals in [fitness, complexity]:
            # Get values for this front
            front_vals = obj_vals[front].cpu().numpy()
            
            # Sort by objective
            sorted_idx = np.argsort(front_vals)
            
            # Boundary points get infinite distance
            distances[sorted_idx[0]] = float('inf')
            distances[sorted_idx[-1]] = float('inf')
            
            # Normalize by range
            obj_range = front_vals[sorted_idx[-1]] - front_vals[sorted_idx[0]]
            if obj_range < 1e-9:
                continue
            
            # Calculate crowding distance for interior points
            for i in range(1, n - 1):
                distances[sorted_idx[i]] += (front_vals[sorted_idx[i + 1]] - front_vals[sorted_idx[i - 1]]) / obj_range
        
        return distances
    
    def select(self, population: torch.Tensor, fitness: torch.Tensor, complexity: torch.Tensor, n_select: int) -> torch.Tensor:
        """
        Select n_select individuals using NSGA-II selection.
        
        Args:
            population: [PopSize, L] RPN tensors
            fitness: [PopSize] RMSE values
            complexity: [PopSize] tree sizes
            n_select: Number of individuals to select
            
        Returns:
            [n_select] tensor of selected indices
        """
        # Non-dominated sorting
        fronts = self.non_dominated_sort(fitness, complexity)
        
        selected = []
        
        for front in fronts:
            if len(selected) + len(front) <= n_select:
                # Add entire front
                selected.extend(front)
            else:
                # Need to select subset using crowding distance
                remaining = n_select - len(selected)
                
                # Calculate crowding distance
                distances = self.crowding_distance(front, fitness, complexity)
                
                # Select by highest crowding distance
                _, sorted_idx = torch.sort(distances, descending=True)
                for i in range(remaining):
                    selected.append(front[sorted_idx[i].item()])
                
                break
        
        return torch.tensor(selected, device=self.device, dtype=torch.long)
    
    def get_pareto_front(self, fitness: torch.Tensor, complexity: torch.Tensor) -> List[int]:
        """
        Get indices of individuals in the Pareto front.
        
        Args:
            fitness: [PopSize] RMSE values
            complexity: [PopSize] tree sizes
            
        Returns:
            List of indices in the Pareto front
        """
        fronts = self.non_dominated_sort(fitness, complexity)
        
        if not fronts:
            return []
        
        front = fronts[0]
        
        # Limit size
        if len(front) > self.max_front_size:
            distances = self.crowding_distance(front, fitness, complexity)
            _, sorted_idx = torch.sort(distances, descending=True)
            front = [front[sorted_idx[i].item()] for i in range(self.max_front_size)]
        
        return front


In [None]:
%%writefile AlphaSymbolic/core/gpu/pattern_memory.py
"""
Pattern Memory System for GPU GP Engine.

Stores successful subtrees/patterns and injects them into the population
to accelerate convergence by reusing proven building blocks.
"""
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import defaultdict


class PatternMemory:
    """
    Memory system that stores successful formula patterns (subtrees).
    
    Patterns with good fitness scores are recorded and can be injected
    into the population to share successful building blocks.
    """
    
    def __init__(self, device: torch.device, max_patterns: int = 100, 
                 fitness_threshold: float = 10.0, min_uses: int = 3):
        """
        Args:
            device: Torch device
            max_patterns: Maximum number of patterns to store
            fitness_threshold: Only record patterns from individuals with fitness below this
            min_uses: Minimum uses before a pattern is considered "useful"
        """
        self.device = device
        self.max_patterns = max_patterns
        self.fitness_threshold = fitness_threshold
        self.min_uses = min_uses
        
        # Pattern storage: hash -> (pattern_rpn, count, best_fitness)
        self.patterns: Dict[tuple, Tuple[List[int], int, float]] = {}
        
        # Usage stats
        self.total_recorded = 0
        self.total_injected = 0
    
    def record_subtrees(self, population: torch.Tensor, fitness: torch.Tensor, 
                        grammar, min_size: int = 3, max_size: int = 10):
        """
        Extract and record successful subtrees from the population.
        
        Args:
            population: [PopSize, L] RPN tensors
            fitness: [PopSize] RMSE values
            grammar: GPUGrammar for subtree extraction
            min_size: Minimum subtree size to record
            max_size: Maximum subtree size to record
        """
        pop_cpu = population.cpu().numpy()
        fit_cpu = fitness.cpu().numpy()
        
        # Only look at individuals with good fitness
        good_mask = fit_cpu < self.fitness_threshold
        good_indices = np.where(good_mask)[0]
        
        for idx in good_indices[:50]:  # Limit to prevent slowdown
            rpn = pop_cpu[idx]
            fit = fit_cpu[idx]
            
            # Find all subtrees
            subtrees = self._extract_subtrees(rpn, grammar, min_size, max_size)
            
            for subtree in subtrees:
                self._record_pattern(subtree, fit)
    
    def _extract_subtrees(self, rpn: np.ndarray, grammar, min_size: int, max_size: int) -> List[List[int]]:
        """
        Extract all valid subtrees from an RPN expression.
        """
        subtrees = []
        
        # Find non-pad length
        non_pad = rpn[rpn != 0]
        if len(non_pad) < min_size:
            return subtrees
        
        # Try each position as potential subtree root
        for root_idx in range(len(non_pad)):
            span = grammar.get_subtree_span(non_pad.tolist(), root_idx)
            if span[0] == -1:
                continue
            
            start, end = span
            size = end - start + 1
            
            if min_size <= size <= max_size:
                subtree = non_pad[start:end+1].tolist()
                subtrees.append(subtree)
        
        return subtrees
    
    def _record_pattern(self, pattern: List[int], fitness: float):
        """
        Record a pattern in memory.
        """
        key = tuple(pattern)
        
        if key in self.patterns:
            rpn, count, best_fit = self.patterns[key]
            self.patterns[key] = (rpn, count + 1, min(best_fit, fitness))
        else:
            if len(self.patterns) >= self.max_patterns:
                # Evict least used pattern
                self._evict_least_useful()
            
            self.patterns[key] = (pattern, 1, fitness)
            self.total_recorded += 1
    
    def _evict_least_useful(self):
        """
        Remove the least useful pattern (lowest count, highest fitness).
        """
        if not self.patterns:
            return
        
        # Score: higher is worse (low count, high fitness)
        def score(item):
            key, (rpn, count, best_fit) = item
            return -count + best_fit / 100.0
        
        worst_key = max(self.patterns.items(), key=score)[0]
        del self.patterns[worst_key]
    
    def get_useful_patterns(self, n: int = 10) -> List[List[int]]:
        """
        Get the top N most useful patterns.
        
        Args:
            n: Number of patterns to return
            
        Returns:
            List of RPN patterns (as lists of token IDs)
        """
        # Filter by min_uses
        useful = [(k, v) for k, v in self.patterns.items() if v[1] >= self.min_uses]
        
        # Sort by usefulness: high count, low fitness
        useful.sort(key=lambda x: (-x[1][1], x[1][2]))
        
        return [list(k) for k, v in useful[:n]]
    
    def inject_into_population(self, population: torch.Tensor, constants: torch.Tensor,
                                grammar, percent: float = 0.05) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Inject useful patterns into the population by replacing some individuals.
        
        Args:
            population: [PopSize, L] RPN tensors
            constants: [PopSize, MaxC] constants
            grammar: GPUGrammar for token lookup
            percent: Fraction of population to replace
            
        Returns:
            (new_population, new_constants, n_injected)
        """
        patterns = self.get_useful_patterns(20)
        if not patterns:
            return population, constants, 0
        
        pop_size, max_len = population.shape
        n_inject = max(1, int(pop_size * percent))
        n_inject = min(n_inject, len(patterns) * 2)  # Don't inject more than we have variety
        
        pop_out = population.clone()
        const_out = constants.clone()
        
        # Inject at random positions (avoid elites at front)
        inject_start = int(pop_size * 0.1)  # Skip first 10% (elites)
        inject_positions = torch.randint(inject_start, pop_size, (n_inject,))
        
        for i, pos in enumerate(inject_positions):
            pattern = patterns[i % len(patterns)]
            
            # Pad pattern to max_len
            padded = pattern + [0] * (max_len - len(pattern))
            padded = padded[:max_len]
            
            pop_out[pos] = torch.tensor(padded, device=self.device, dtype=population.dtype)
            
            # Random constants for the pattern
            const_out[pos] = torch.randn_like(const_out[pos]) * 0.5
        
        self.total_injected += n_inject
        return pop_out, const_out, n_inject
    
    def get_stats(self) -> Dict:
        """
        Get pattern memory statistics.
        """
        return {
            'n_patterns': len(self.patterns),
            'total_recorded': self.total_recorded,
            'total_injected': self.total_injected,
            'useful_count': sum(1 for v in self.patterns.values() if v[1] >= self.min_uses)
        }


In [None]:
%%writefile AlphaSymbolic/core/gpu/sniper.py

import torch
import numpy as np
from .formatting import format_const

class Sniper:
    """
    'The Sniper': Special Pattern Detection Unit.
    Quickly identifies simple patterns (Linear, Geometric) before evolution begins.
    """
    def __init__(self, device):
        self.device = device
        
    def check_linear(self, x_t, y_t):
        """
        Check for y = m*x + c
        Returns formula string if found, else None
        """
        try:
            # Solve [x, 1] * [m, c]^T = y
            # A: [N, 2]
            ones = torch.ones_like(x_t)
            A = torch.stack([x_t, ones], dim=1)
            
            # Least squares
            # solution = (A^T A)^-1 A^T y
            solution = torch.linalg.lstsq(A, y_t).solution
            m = solution[0].item()
            c = solution[1].item()
            
            # Predict
            y_pred = m * x_t + c
            mse = torch.mean((y_pred - y_t)**2)
            
            # Threshold (1e-6)
            if mse < 1e-6:
                m_str = format_const(m)
                c_str = format_const(c)
                # Formats: 
                # (m * x) + c if c > 0
                # (m * x) - |c| if c < 0
                term = f"({m_str} * x)"
                if c >= 0:
                    return f"({term} + {c_str})"
                else:
                    return f"({term} - {format_const(abs(c))})"
        except:
            pass
        return None

    def check_geometric(self, x_t, y_t):
        """
        Check for y = A * exp(B*x) -> log(y) = log(A) + B*x
        Returns formula string if found, else None
        """
        try:
            if (y_t <= 0).any(): return None
            
            log_y = torch.log(y_t)
            
            # Solve [x, 1] * [B, log_A]^T = log_y
            ones = torch.ones_like(x_t)
            A_mat = torch.stack([x_t, ones], dim=1)
            
            solution = torch.linalg.lstsq(A_mat, log_y).solution
            B = solution[0].item()
            log_A = solution[1].item()
            A_val = np.exp(log_A)
            
            # Predict
            y_pred = A_val * torch.exp(B * x_t)
            
            # Check relative error for geometric? Or log-MSE?
            # Let's check MSE of original
            mse = torch.mean((y_pred - y_t)**2)
            
            if mse < 1e-4: # Slightly looser for exponential
                # Formula: exp(B*x + log_A) if we want pure exp form? 
                # Or A * exp(B*x)?
                # Our grammar supports exp inside.
                # Let's use exp(B*x + log_A) which is exp(log(y)) = y.
                # A * exp(Bx) = exp(lnA + Bx).
                # GPU grammar usually prefers: exp( (B*x) + lnA )
                
                b_str = format_const(B)
                ln_a_str = format_const(log_A)
                
                inner = f"({b_str} * x)"
                if log_A >= 0:
                    inner = f"({inner} + {ln_a_str})"
                else:
                     inner = f"({inner} - {format_const(abs(log_A))})"
                     
                return f"exp({inner})"
        except:
            pass
        return None

    def run(self, x_data, y_data):
        """
        Run all checks.
        x_data, y_data: CPU lists or arrays.
        """
        try:
            x_t = torch.tensor(x_data, device=self.device, dtype=torch.float32).flatten()
            y_t = torch.tensor(y_data, device=self.device, dtype=torch.float32).flatten()
            
            res = self.check_linear(x_t, y_t)
            if res: 
                print(f"[The Sniper] Detected Linear Pattern: {res}")
                return res
            
            res = self.check_geometric(x_t, y_t)
            if res:
                print(f"[The Sniper] Detected Geometric Pattern: {res}")
                return res
            
        except Exception as e:
            # print(f"[The Sniper] Failed: {e}")
            pass
        return None


In [None]:
%%writefile AlphaSymbolic/data/synthetic_data.py
import numpy as np
import random
from core.grammar import VOCABULARY, OPERATORS, VARIABLES, CONSTANTS, ExpressionTree
from data.augmentation import augment_formula_tokens

class DataGenerator:
    def __init__(self, max_depth=5, population_size=1000, allowed_operators=None, num_variables=1):
        self.max_depth = max_depth
        self.population_size = population_size
        self.num_variables = num_variables
        self.vocab = VOCABULARY
        # Use subset of variables based on num_variables
        self.active_variables = VARIABLES[:num_variables]
        
        # Pre-compute terminal vs operator lists
        self.terminals = self.active_variables + CONSTANTS
        if allowed_operators:
            self.operators = [op for op in allowed_operators if op in OPERATORS]
        else:
            self.operators = list(OPERATORS.keys())

    def generate_random_tree(self, max_depth, current_depth=0):
        if current_depth >= max_depth:
            # Balanced Terminal Selection: 50% variable, 50% constant
            if random.random() < 0.5:
                return [random.choice(self.active_variables)]
            else:
                return [random.choice(CONSTANTS)]
        
        # Decide if terminal or operator
        # Higher probability of operator at shallow depths
        if random.random() < 0.7: 
            op = random.choice(self.operators)
            arity = OPERATORS[op]
            tokens = [op]
            for _ in range(arity):
                tokens.extend(self.generate_random_tree(max_depth, current_depth + 1))
            return tokens
        else:
            # Balanced Terminal Selection: 40% var, 30% C, 30% numbers
            r = random.random()
            if r < 0.4:
                return [random.choice(self.active_variables)]
            elif r < 0.7:
                return ['C']
            else:
                return [random.choice([c for c in CONSTANTS if c != 'C'])]

    def generate_batch(self, batch_size, point_count=10, x_range=(-10, 10)):
        """
        Generates a batch of (X, Y) pairs and their generating formulas.
        """
        data = []
        
        while len(data) < batch_size:
            # Generate random formula
            tokens = self.generate_random_tree(self.max_depth)
            tree = ExpressionTree(tokens)
            
            if not tree.is_valid:
                continue
            
            # Ensure variables are present (90% of the time, check any active variable)
            if not any(v in tokens for v in self.active_variables) and random.random() < 0.9:
                continue
                
            # Generate random X points
            # If num_variables > 1, shape (point_count, num_variables)
            # If num_variables == 1, shape (point_count,) or (point_count, 1) - but maintain compat
            if self.num_variables > 1:
                x_values = np.random.uniform(x_range[0], x_range[1], (point_count, self.num_variables))
                # Sorting 2D array by first col just for consistent indexing? or keep random?
                # Maybe sort by first column for visualization
                x_values = x_values[x_values[:, 0].argsort()]
            else:
                x_values = np.random.uniform(x_range[0], x_range[1], point_count)
                # Sort X for cleaner visualization/learning
                x_values.sort()
            
            # Randomize 'C' values if present
            c_positions = tree.root.get_constant_positions()
            constant_vals = {}
            for pos in c_positions:
                # Expanded range: -20 to 20. Favor 1.0 occasionally
                val = random.uniform(-20, 20) if random.random() > 0.1 else 1.0
                constant_vals[tuple(pos)] = val
            
            # Calculate Y with randomized constants
            y_values = tree.evaluate(x_values, constants=constant_vals)
            
            # Check for validity (no NaNs, Infs, or extremely large values)
            if np.any(np.isnan(y_values)) or np.any(np.isinf(y_values)):
                continue
            if np.max(np.abs(y_values)) > 1e4: # Reject too large numbers (1e6 causes NaN gradients)
                continue
            if np.std(y_values) < 1e-6: # Reject flat lines (too simple)
                 # Optionally keep some, but mostly we want interesting curves
                 if random.random() > 0.1: continue

            data.append({
                'tokens': tokens,
                'infix': tree.get_infix(),
                'x': x_values,
                'y': y_values
            })
            
        return data

    def generate_structured_tree(self, complexity=1, input_node='x0'):
        """
        Recursively builds a structured, human-like formula.
        Respects self.operators.
        """
        # Base cases
        if complexity <= 0:
            # Randomly choose between active_variables, C and constants
            r = random.random()
            if r < 0.4: return [random.choice(self.active_variables)]
            if r < 0.7: return ['C']
            return [random.choice([c for c in CONSTANTS if c != 'C'])]
            
        # Filter available structures based on allowed operators
        available_structures = []
        
        # Arithmetic needed: +, -, *
        if any(op in self.operators for op in ['+', '-', '*']):
            available_structures.append('arithmetic')
            
        # Poly needed: pow
        if 'pow' in self.operators:
            available_structures.append('poly')
            
        # Trig needed: sin, cos, asin, acos, atan
        if any(op in self.operators for op in ['sin', 'cos', 'asin', 'acos', 'atan']):
            available_structures.append('trig')
            
        # Exp/Log needed
        if 'exp' in self.operators or 'log' in self.operators:
            available_structures.append('exp_log')
            
        # Composition needs enough variety
        if len(self.operators) > 4 and complexity > 1:
             available_structures.append('composition')
        
        # Fallback if nothing allowed matches (shouldn't happen with proper init)
        if not available_structures:
            return input_node if isinstance(input_node, list) else [input_node]

        choice = random.choice(available_structures)
        
        if choice == 'poly':
            # a*x + b or a*x^2 + b
            a = str(random.randint(1, 5))
            b = str(random.randint(-5, 5))
            power = random.choice(['1', '2', '3'])
            if power == '1':
                term = ['*', a] + (input_node if isinstance(input_node, list) else [input_node])
                return ['+', ] + term + [b]
            else:
                base = input_node if isinstance(input_node, list) else [input_node]
                pow_term = ['pow'] + base + [power]
                term = ['*', a] + pow_term
                return ['+', ] + term + [b]
                
        elif choice == 'trig':
            # Filter trig ops that are allowed
            ops = [op for op in ['sin', 'cos', 'asin', 'acos', 'atan'] if op in self.operators]
            if not ops: return input_node # Should be caught by structure check
            func = random.choice(ops)
            val = input_node if isinstance(input_node, list) else [input_node]
            return [func] + val
            
        elif choice == 'exp_log':
            ops = [op for op in ['exp', 'log'] if op in self.operators]
            if not ops: return input_node
            func = random.choice(ops)
            val = input_node if isinstance(input_node, list) else [input_node]
            return [func] + val
            
        elif choice == 'arithmetic':
            # Allow mixing variables in arithmetic nodes
            left_input = input_node if random.random() < 0.6 else random.choice(self.active_variables)
            right_input = input_node if random.random() < 0.6 else random.choice(self.active_variables)
            left = self.generate_structured_tree(complexity - 1, left_input)
            right = self.generate_structured_tree(complexity - 1, right_input)
            ops = [op for op in ['+', '-', '*'] if op in self.operators]
            if not ops: return input_node
            op = random.choice(ops)
            return [op] + left + right
            
        elif choice == 'composition':
            inner = self.generate_structured_tree(complexity - 1, input_node)
            outer = self.generate_structured_tree(1, inner)
            return outer
            
        return [input_node]

    def generate_inverse_batch(self, batch_size, point_count=10, x_range=(-5, 5)):
        """
        Generates complex, structured formulas using the new engine.
        """
        data = []
        attempts = 0
        
        while len(data) < batch_size and attempts < batch_size * 5:
            attempts += 1
            # Random complexity capped by max_depth
            complexity = random.randint(1, max(1, self.max_depth - 1))
            
            try:
                # Use random variable as starting seed if needed, but structured tree handles selection at leaves
                tokens = self.generate_structured_tree(complexity, random.choice(self.active_variables))
                
                # Convert numeric strings to 'C' placeholders if needed
                # But here we want the GROUND TRUTH tokens with numbers for checking?
                # The model predicts tokens. 'C' is for optimization.
                # If we train "End-to-End" (predict 3*x), we keep numbers.
                # If we train "Symbolic" (predict C*x), we swap.
                # The original code swapped numbers to 'C'. Let's check VOCABULARY.
                # '1','2','3' are in VOCABULARY. So we can keep small integers.
                # Large integers -> 'C'.
                
                final_tokens = []
                for t in tokens:
                    if t in self.vocab:
                        final_tokens.append(t)
                    else:
                        # If it's a number not in vocab, map to C?
                        # Or just nearest constant?
                        # For now, simplistic mapping:
                        try:
                            val = float(t)
                            if abs(val - round(val)) < 0.01 and str(int(round(val))) in self.vocab:
                                final_tokens.append(str(int(round(val))))
                            else:
                                final_tokens.append('C')
                        except:
                            final_tokens.append('C')

                # --- DATA AUGMENTATION ---
                if random.random() < 0.3:
                    final_tokens = augment_formula_tokens(final_tokens)
                # -------------------------
                
                tree = ExpressionTree(final_tokens)
                if not tree.is_valid:
                    continue
                
                # Ensure variables are present (90% of the time)
                if not any(v in final_tokens for v in self.active_variables) and random.random() < 0.9:
                    continue
                    
                # Check constraints (depth, length)
                if len(final_tokens) > 30: # Limit length
                    continue

                # Generate X points
                # Use safer range for complex funcs
                # Exp/Pow grow very fast, so we constrain X to avoid float overflow
                range_limit = x_range
                if 'exp' in final_tokens or 'pow' in final_tokens:
                    range_limit = (-2, 2)
                elif 'log' in final_tokens or 'sqrt' in final_tokens:
                    range_limit = (0.1, 5)

                if self.num_variables > 1:
                    x_safe = np.linspace(range_limit[0], range_limit[1], point_count)
                    # For multivar, linspace per column or random?
                    # Let's use random uniform for coverage in multivar space
                    x_safe = np.random.uniform(range_limit[0], range_limit[1], (point_count, self.num_variables))
                else:
                    x_safe = np.linspace(range_limit[0], range_limit[1], point_count)
                
                # Randomize 'C' values if present
                c_positions = tree.root.get_constant_positions()
                constant_vals = {}
                for pos in c_positions:
                    # Expanded range: -20 to 20
                    val = random.uniform(-20, 20) if random.random() > 0.1 else 1.0
                    constant_vals[tuple(pos)] = val
                
                y_values = tree.evaluate(x_safe, constants=constant_vals)
                
                # Quality Control
                if np.any(np.isnan(y_values)) or np.any(np.isinf(y_values)):
                    continue
                if np.max(np.abs(y_values)) > 1e4: # Relaxed limit
                    continue
                if np.std(y_values) < 0.01: # Too flat
                    continue
                
                data.append({
                    'tokens': final_tokens,
                    'infix': tree.get_infix(),
                    'x': x_safe,
                    'y': y_values
                })
            except Exception:
                continue
                
        return data

# Quick test if run directly
if __name__ == "__main__":
    gen = DataGenerator(max_depth=4)
    batch = gen.generate_batch(5)
    for item in batch:
        print(f"Formula: {item['infix']}")
        print(f"Tokens: {item['tokens']}")
        print(f"Y sample: {item['y'][:3]}...")
        print("-" * 20)


In [None]:
%%writefile AlphaSymbolic/data/benchmark_data.py
import numpy as np

# Standard Benchmark Problems
# Levels: 1 (Easy), 2 (Medium), 3 (Hard)

BENCHMARK_SUITE = [
    # --- Level 1: Polynomials & Basic Arithmetic ---
    {
        'id': 'p1',
        'name': 'Lineal',
        'formula_str': '2.5 * x + 1.0',
        'lambda': lambda x: 2.5 * x + 1.0,
        'domain': (-10, 10),
        'points': 20,
        'level': 1
    },
    {
        'id': 'p2',
        'name': 'Cuadratica Simple',
        'formula_str': 'x * x',
        'lambda': lambda x: x**2,
        'domain': (-5, 5),
        'points': 20,
        'level': 1
    },
    {
        'id': 'p3',
        'name': 'Polinomio Cubico',
        'formula_str': 'x**3 + x**2',
        'lambda': lambda x: x**3 + x**2,
        'domain': (-3, 3),
        'points': 20,
        'level': 1
    },
    
    # --- Level 2: Trigonometric & Transcendental ---
    {
        'id': 'p4',
        'name': 'Seno Basico',
        'formula_str': 'sin(x)',
        'lambda': lambda x: np.sin(x),
        'domain': (-np.pi, np.pi),
        'points': 30,
        'level': 2
    },
    {
        'id': 'p5',
        'name': 'Coseno Desplazado',
        'formula_str': 'cos(x) + 1',
        'lambda': lambda x: np.cos(x) + 1,
        'domain': (-np.pi, np.pi),
        'points': 30,
        'level': 2
    },
    {
        'id': 'p6',
        'name': 'Exponencial Simple',
        'formula_str': 'exp(x)',
        'lambda': lambda x: np.exp(x),
        'domain': (-2, 2), # Small domain to avoid explosion
        'points': 20,
        'level': 2
    },
    
    # --- Level 3: Physics / Complex ---
    {
        'id': 'p7',
        'name': 'Damped Oscillation',
        'formula_str': 'exp(-x) * sin(2*x)',
        'lambda': lambda x: np.exp(-x) * np.sin(2*x),
        'domain': (0, 4),
        'points': 40,
        'level': 3
    },
    {
        'id': 'p8',
        'name': 'Gaussian',
        'formula_str': 'exp(-x**2)',
        'lambda': lambda x: np.exp(-x**2),
        'domain': (-3, 3),
        'points': 30,
        'level': 3
    },
    {
        'id': 'p9',
        'name': 'Nguyen-3 (x^3 + x^2 + x)',
        'formula_str': 'x**3 + x**2 + x',
        'lambda': lambda x: x**3 + x**2 + x,
        'domain': (-2, 2),
        'points': 20,
        'level': 3
    },
    {
        'id': 'p10',
        'name': 'Rational Function',
        'formula_str': 'x / (1 + x**2)',
        'lambda': lambda x: x / (1 + x**2),
        'domain': (-4, 4),
        'points': 30,
        'level': 3
    }
]

def get_benchmark_data(problem_id):
    """Returns (x, y) for a specific problem ID."""
    for p in BENCHMARK_SUITE:
        if p['id'] == problem_id:
            x = np.linspace(p['domain'][0], p['domain'][1], p['points'])
            y = p['lambda'](x)
            return x, y, p
    return None, None, None


In [None]:
%%writefile AlphaSymbolic/data/augmentation.py

import random
from core.grammar import OPERATORS

def augment_formula_tokens(tokens):
    """
    Applies mathematical invariants to generate an equivalent formula structure.
    Acts as 'Data Augmentation' for symbolic regression.
    
    Supported Transformations:
    1. Commutativity: (+) and (*)
       e.g. [+ a b] -> [+ b a]
    2. Identity:
       e.g. x -> [+ x 0], x -> [* x 1] (Rarely used to avoid bloat, but useful for robustness)
    3. Inverse operations (Conceptually):
       Not implemented directly on tokens without tree parsing, 
       so we focus on purely structural swaps that don't change value.
    
    Args:
        tokens (list): List of tokens in Prefix notation.
    
    Returns:
        list: A new list of tokens representing an equivalent formula.
    """
    if not tokens:
        return []

    # Helper to parse prefix expression into a tree-like structure (recursive)
    def parse_prefix(token_list):
        if not token_list:
            return None, []
        
        root = token_list[0]
        remaining = token_list[1:]
        
        if root in OPERATORS:
            try:
                arity = OPERATORS[root]
                children = []
                for _ in range(arity):
                    child, remaining = parse_prefix(remaining)
                    children.append(child)
                return {'val': root, 'children': children}, remaining
            except:
                 # Fallback for malformed
                return {'val': root, 'children': []}, remaining
        else:
            # Terminal
            return {'val': root, 'children': []}, remaining

    # Helper to flatten tree back to tokens
    def flatten(node):
        res = [node['val']]
        for child in node['children']:
            res.extend(flatten(child))
        return res

    # 1. Parse
    try:
        tree, _ = parse_prefix(tokens)
    except:
        return list(tokens) # Fail safe

    # 2. Augment Recursive
    def augment_recursive(node):
        # First augment children
        for i in range(len(node['children'])):
            node['children'][i] = augment_recursive(node['children'][i])
            
        val = node['val']
        children = node['children']
        
        # Transformation: Commutativity
        if val in ['+', '*'] and len(children) == 2:
            if random.random() < 0.5:
                # Swap children
                node['children'] = [children[1], children[0]]
        
        # Transformation: (- a b) -> (+ a (- b)) ? Too complex for tokens only without 'neg'
        # Transformation: (+ x x) -> (* x 2) ?
        if val == '+' and len(children) == 2:
            # Check deep equality is hard, but simple check:
            if flatten(children[0]) == flatten(children[1]):
                if random.random() < 0.3:
                    # Convert x + x -> x * 2
                    return {'val': '*', 'children': [children[0], {'val': '2', 'children': []}]}

        return node

    # 3. Apply
    augmented_tree = augment_recursive(tree)
    
    # 4. Flatten
    return flatten(augmented_tree)

if __name__ == "__main__":
    # Test
    # Formula: (+ x y) -> prefix ['+', 'x', 'y']
    t1 = ['+', 'x', 'y']
    print(f"Original: {t1} -> Aug: {augment_formula_tokens(t1)}")
    
    # Formula: (* (+ a b) c)
    t2 = ['*', '+', 'a', 'b', 'c']
    print(f"Original: {t2} -> Aug: {augment_formula_tokens(t2)}")
    
    # Formula: (+ x x)
    t3 = ['+', 'x', 'x']
    print(f"Original: {t3} -> Aug: {augment_formula_tokens(t3)}")


In [None]:
%%writefile AlphaSymbolic/data/expanded_benchmarks.py

import pandas as pd
import numpy as np
import os

def load_expanded_feynman_subset(csv_path="data/benchmarks/FeynmanEquations.csv", limit=50):
    """
    Loads equations from the Feynman dataset and projects them to 1D.
    Strategies for projection:
    - Fix all variables except the first one to 1.0.
    """
    
    if not os.path.exists(csv_path):
        print(f"Warning: {csv_path} not found.")
        return []

    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"Error reading CSV: {e}")
        return []
    
    problems = []
    
    # Filter for reasonable complexity (e.g., # variables <= 3 for now to ensure 1D projection makes sense)
    # We can be bolder, but let's start safe.
    # Actually, let's take everything and just project.
    
    count = 0
    for idx, row in df.iterrows():
        if limit is not None and count >= limit:
            break
            
        try:
            row_id = row['Filename']
            formula_raw = str(row['Formula'])
            num_vars = int(row['# variables'])
            
            # Extract variable names
            var_names = []
            for i in range(1, 11):
                v_col = f'v{i}_name'
                if v_col in row and pd.notna(row[v_col]):
                    var_names.append(row[v_col])
            
            # Projection Logic
            # We treat the first variable as 'x' and the rest as constants = 1.0
            # We need to construct a python-evaluable string where other vars are replaced by 1.0
            
            target_var = var_names[0]
            formula_1d = formula_raw
            
            # Replace other variables with "1.0"
            for other_var in var_names[1:]:
                # Simple replace might be dangerous if variable names are substrings of others
                # But Feynman dataset usually uses distinct names like m, v, theta, sigma
                # Better: use a context dict for eval, but we need a string for the model target?
                # Actually, our model needs a target string that uses 'x'.
                pass
                
            # Create a closure-like logic for evaluation
            # We will store the full formula and the fixed context
            fixed_context = {v: 1.0 for v in var_names[1:]}
            
            problems.append({
                "id": row_id,
                "name": f"Feynman {row_id}",
                "original_formula": formula_raw,
                "target_var": target_var,
                "fixed_context": fixed_context,
                "description": f"Projected 1D (varying {target_var}, others fixed to 1.0)"
            })
            count += 1
            
        except Exception as e:
            continue
            
    return problems

def evaluate_projected_formula(formula, target_var, x_val, fixed_context):
    """
    Evaluates the formula with x_val assigned to target_var, and others fixed.
    """
    # math context
    ctx = {
        'exp': np.exp, 'sin': np.sin, 'cos': np.cos, 'sqrt': np.sqrt, 'log': np.log, 
        'pi': np.pi, 'theta': 1.0, 'sigma': 1.0 # Defaults
    }
    
    # Constants from fixed_context
    ctx.update(fixed_context)
    
    # Target variable
    ctx[target_var] = x_val
    
    try:
        return eval(formula, {}, ctx)
    except Exception as e:
        return np.full_like(x_val, np.nan)


In [None]:
%%writefile AlphaSymbolic/data/__init__.py


In [None]:
%%writefile AlphaSymbolic/search/mcts.py
import math
import numpy as np
import torch
import copy
from core.grammar import VOCABULARY, TOKEN_TO_ID, OPERATORS, ExpressionTree, VARIABLES, OPERATOR_STAGES
from utils.optimize_constants import optimize_constants
from utils.data_utils import normalize_batch

class MCTSNode:
    def __init__(self, tokens, parent=None, prior=0.0):
        self.tokens = tokens
        self.parent = parent
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.prior = prior
        self.is_expanded = False
        
        # for parallel search
        self.virtual_loss = 0.0
        self.virtual_visits = 0

    @property
    def value(self):
        count = self.visit_count + self.virtual_visits
        if count == 0:
            return 0.0
        # Combine real value and virtual loss
        # Virtual loss is SUBTRACTED to discourage visits
        return (self.value_sum - self.virtual_loss) / count

    def ucb_score(self, c_puct=1.0):
        count = self.visit_count + self.virtual_visits
        parent_count = self.parent.visit_count + self.parent.virtual_visits if self.parent else 1
        
        if self.parent is None:
            return 0.0
            
        u = c_puct * self.prior * math.sqrt(parent_count) / (1 + count)
        return self.value + u

    @property
    def complexity(self):
        """Estimate complexity (length of formula)."""
        return len(self.tokens)

class MCTS:
    def __init__(self, model, device, grammar=None, c_puct=1.0, n_simulations=100, max_simulations=None, max_depth=50, complexity_lambda=0.1, max_len=200, batch_size=8, curriculum_stage=None, num_variables=1):
        self.model = model
        self.device = device
        self.grammar = grammar
        self.c_puct = c_puct
        self.num_variables = num_variables
        
        # Handle backwards compatibility for max_simulations
        if max_simulations is not None:
            self.n_simulations = max_simulations
        else:
            self.n_simulations = n_simulations
            
        self.max_depth = max_depth
        self.complexity_lambda = complexity_lambda
        self.max_len = max_len
        self.min_value = -float('inf')
        self.max_value = float('inf')
        self.vocab_size = len(VOCABULARY)
        self.sos_id = self.vocab_size
        self.batch_size = batch_size
        
        # Curriculum stage for operator filtering
        self.curriculum_stage = curriculum_stage
        self._build_allowed_tokens()
        
        # Pareto Front: List of {'tokens':, 'rmse':, 'complexity':, 'formula':}
        self.pareto_front = []
        
        # Virtual loss constant usually 1-3
        self.v_loss_const = 3.0
    
    def _build_allowed_tokens(self):
        """Build set of allowed token indices based on curriculum stage and num_variables."""
        # Terminals are always allowed, but respect num_variables
        allowed = set(['C', '0', '1', '2', '3', '5', '10', 'pi', 'e'])
        
        # Determine allowed variables
        if self.num_variables == 1:
            allowed.update(['x', 'x0'])
        else:
            allowed.update(VARIABLES[:self.num_variables])
        
        # Add operators based on curriculum stage
        if self.curriculum_stage is not None and self.curriculum_stage in OPERATOR_STAGES:
            allowed.update(OPERATOR_STAGES[self.curriculum_stage])
        else:
            # No stage = all operators allowed
            allowed.update(OPERATORS.keys())
        
        # Convert to indices
        self.allowed_token_indices = set()
        for token in allowed:
            if token in TOKEN_TO_ID:
                self.allowed_token_indices.add(TOKEN_TO_ID[token])
        
    def search(self, x_values, y_values, num_simulations=None):
        """
        Run MCTS (Parallel/Batched) to find the best formula.
        """
        self.pareto_front = [] # Reset Pareto Front for new search
        root = MCTSNode(tokens=[])
        
        # Initial expansion (single)
        self._expand_batch([root], x_values, y_values)
        
        best_rmse = float('inf')
        best_formula = None
        best_tokens = None
        
        limit = num_simulations if num_simulations is not None else self.n_simulations
        
        # Loop in batches
        # Ensure we do at least 1 batch
        num_batches = max(1, (limit + self.batch_size - 1) // self.batch_size)
        
        for _ in range(num_batches): 
            leaves = []
            
            # 1. Selection (find N leaves)
            for _ in range(self.batch_size):
                node = root
                depth = 0
                
                # Selection loop
                while node.is_expanded and node.children and depth < self.max_depth:
                    node = max(node.children.values(), key=lambda n: n.ucb_score(self.c_puct))
                    
                    # Apply virtual loss to discourage re-selection in same batch
                    node.virtual_loss += self.v_loss_const
                    node.virtual_visits += 1
                    depth += 1
                
                # Check if valid leaf to expand
                if depth < self.max_depth and not node.is_expanded:
                    # Avoid duplicates in batch (simple check)
                    if node not in leaves:
                        leaves.append(node)
                else:
                    pass
            
            if not leaves:
                # If no leaves found (tree fully explored or locked), standard MCTS usually continues or stops.
                # We can just break or continue backprop of terminals.
                if root.visit_count > limit: break 
                continue
                
            # 2. Batch Expansion & Evaluation
            values = self._expand_batch(leaves, x_values, y_values)
            
            # 3. Backpropagation
            for node, val in zip(leaves, values):
                # Check for best solution found
                if self._is_complete_tree(node.tokens):
                    # For completed formulas, we calculate REAL RMSE
                    try:
                        # Evaluar
                        # Importar aquí para evitar circular imports si es necesario
                        from utils.optimize_constants import optimize_constants
                        
                        # 1. Optimizar constants (Crucial para Accuracy)
                        # Esto es "Phase 1" de TPSR (constantes en las hojas)
                        # Por simplicidad en esta iteración, asumimos que 'evaluate_formula' ya hace algo o usamos el string directo.
                        # Idealmente llamaríamos a BFGS aquí.
                        
                        # Use existing _evaluate_formula to get RMSE and optimized constants
                        tree = ExpressionTree(node.tokens)
                        optimized_constants, real_rmse = optimize_constants(tree, x_values, y_values)
                        
                        # Get y_pred using the optimized constants
                        y_pred = tree.evaluate(x_values, constants=optimized_constants)
                        
                        # Check dimensions
                        if y_pred.shape != y_values.shape:
                            # If shapes don't match, it's an invalid evaluation
                            final_val = 0.0
                        else:
                            # 2. Calcular Reward TPSR (Hybrid Accuracy + Complexity)
                            # R = 1 / (1 + NMSE) + lambda * exp(-len/L)
                            
                            mse = np.mean((y_pred - y_values)**2)
                            var_y = np.var(y_values)
                            if var_y < 1e-9: var_y = 1.0 # Avoid division by zero
                            
                            nmse = mse / var_y
                            
                            # Evitar NMSE gigantes
                            if np.isnan(nmse) or np.isinf(nmse):
                                nmse = 1e9
                            
                            r_acc = 1.0 / (1.0 + nmse)
                            
                            # Penalización por complejidad
                            token_len = len(node.tokens)
                            L = self.max_len # Max length del modelo
                            
                            r_cplx = self.complexity_lambda * np.exp(-token_len / L)
                            
                            # Suma y Normalización (para mantener rango 0-1)
                            # El máximo teórico es (1.0 + lambda). Dividimos por eso.
                            raw_reward = r_acc + r_cplx
                            final_val = raw_reward / (1.0 + self.complexity_lambda)

                        # Update best formula based on RMSE (for reporting, not for MCTS value)
                        if real_rmse < best_rmse:
                            best_rmse = real_rmse
                            best_tokens = node.tokens
                            best_formula = ExpressionTree(node.tokens).get_infix()
                        
                        # Update Pareto Front
                        # Complexity = len(tokens) (or could use count_constants + nodes)
                        complexity = len(node.tokens)
                        self._update_pareto_front(node.tokens, real_rmse, complexity, ExpressionTree(node.tokens).get_infix())

                    except Exception as e:
                        # print(f"Error evaluating formula: {e}")
                        final_val = 0.0 # Invalid formula gets 0 reward
                else:
                    final_val = val
                
                # The following lines were part of the user's instruction but contained syntax errors and undefined variables.
                # They are commented out to maintain a syntactically correct and functional document.
                # If these lines were intended to be added, please provide a complete and correct snippet.
                #
                # # Construir vector de probabilidades
                # probs = np.zeros(self.vocab_size, dtype=np.float32)
                # for token_id, count in counts.items():
                #     probs[token_id] = count / total_visits_count += 1
                
                curr = node
                while curr is not None:
                    curr.visit_count += 1
                    curr.value_sum += final_val
                    
                    # Revert virtual loss for parent and above
                    # Since we added to PARENT's child (which is curr), 
                    # and we traverse Up...
                    # Wait, logic: We selected CHILD. Virtual loss was added TO CHILD (curr).
                    # So we must remove it from curr.
                    if curr.virtual_visits > 0:
                        curr.virtual_loss -= self.v_loss_const
                        curr.virtual_visits -= 1
                            
                    curr = curr.parent
        
        # After search, force cleanup of any residual virtual loss (safety)
        # (Not strictly needed if logic is perfect, but good practice in complex async MCTS)
        
        return {
            'tokens': best_tokens,
            'formula': best_formula,
            'rmse': best_rmse,
            'root': root,
            'pareto_front': self.pareto_front
        }

    def _update_pareto_front(self, tokens, rmse, complexity, formula_str):
        """
        Update the Pareto Front with a new solution.
        Keep solutions that are not dominated by any other solution.
        Solution A dominates B if:
        A.rmse <= B.rmse AND A.complexity <= B.complexity AND (A.rmse < B.rmse OR A.complexity < B.complexity)
        """
        # Create candidate
        candidate = {'tokens': tokens, 'rmse': rmse, 'complexity': complexity, 'formula': formula_str}
        
        # Check if dominated by existing
        is_dominated = False
        to_remove = []
        
        for existing in self.pareto_front:
            # Check if existing dominates candidate
            if (existing['rmse'] <= candidate['rmse'] and 
                existing['complexity'] <= candidate['complexity'] and 
                (existing['rmse'] < candidate['rmse'] or existing['complexity'] < candidate['complexity'])):
                is_dominated = True
                break
                
            # Check if candidate dominates existing
            if (candidate['rmse'] <= existing['rmse'] and 
                candidate['complexity'] <= existing['complexity'] and 
                (candidate['rmse'] < existing['rmse'] or candidate['complexity'] < existing['complexity'])):
                to_remove.append(existing)
        
        if not is_dominated:
            # Remove dominated existing solutions
            for item in to_remove:
                self.pareto_front.remove(item)
            
            # Add candidate
            self.pareto_front.append(candidate)
            # Sort by RMSE for easier viewing
            self.pareto_front.sort(key=lambda x: x['rmse'])

    def _expand_batch(self, nodes, x_values, y_values):
        """
        Batched expansion. Returns list of values.
        """
        if not nodes:
            return []
            
        # Prepare inputs
        # CRITICAL: Normalize data before feeding to model (Model expects [-1, 1] range)
        # We need to wrap single items in list to use normalize_batch, then unpack or use directly if batch logic allows.
        # normalize_batch takes lists of arrays.
        
        # NOTE: normalize_batch expects list of numpy arrays. _expand_batch takes single x_values/y_values which are reused across batch?
        # x_values is passed to search(). 
        # If search() was called with raw data, we must normalize it here for the MODEL only.
        # But optimize_constants expects RAW data. 
        # So we keep x_values raw, and create normalized tensors for the model.
        
        # normalize_batch input: list of arrays. output: list of arrays.
        norm_x_list, norm_y_list = normalize_batch([x_values], [y_values])
        norm_x = norm_x_list[0]
        norm_y = norm_y_list[0]
        
        x_tensor = torch.tensor(norm_x, dtype=torch.float32).unsqueeze(0).to(self.device)
        y_tensor = torch.tensor(norm_y, dtype=torch.float32).unsqueeze(0).to(self.device)
        
        # Repeat X/Y for batch
        batch_size = len(nodes)
        x_batch = x_tensor.repeat(batch_size, 1, 1).squeeze(1) # [batch, points]
        y_batch = y_tensor.repeat(batch_size, 1, 1).squeeze(1) # [batch, points]
        
        # Prepare sequences
        # Find max len
        max_len = 0
        seqs = []
        for n in nodes:
            s = [self.sos_id] + [TOKEN_TO_ID[t] for t in n.tokens]
            seqs.append(s)
            max_len = max(max_len, len(s))
            
        # Pad and stack
        input_tensor = torch.full((batch_size, max_len), self.sos_id, dtype=torch.long).to(self.device)
        for i, s in enumerate(seqs):
            input_tensor[i, :len(s)] = torch.tensor(s, dtype=torch.long)
            
        # Inference
        with torch.no_grad():
            logits, value_preds = self.model(x_batch, y_batch, input_tensor)
            
        # Process results
        values = []
        
        # To CPU numpy for probability processing
        probs_batch = torch.softmax(logits[:, -1, :self.vocab_size], dim=1).cpu().numpy()
        value_preds = value_preds.cpu().numpy() # [batch, 3]
        
        for i, node in enumerate(nodes):
            # 1. Store Value (Median for now)
            # value_preds is [batch, 3] -> (Pessimistic, Median, Optimistic)
            # We use Median (index 1) for standard UCB.
            val_pred = value_preds[i, 1] 
            val = float(np.clip(val_pred, 0.0, 1.0))
            values.append(val)
            
            # 2. Expand children
            node_probs = probs_batch[i]
            valid_next = self._get_valid_next_tokens(node.tokens)
            
            for idx in valid_next:
                token = VOCABULARY[idx]
                prior = node_probs[idx]
                child = MCTSNode(tokens=node.tokens + [token], parent=node, prior=prior)
                node.children[token] = child
            
            node.is_expanded = True
            
        return values

    def _get_valid_next_tokens(self, tokens):
        """Grammar check + curriculum filtering."""
        open_slots = 1
        for t in tokens:
            if t in OPERATORS:
                open_slots += OPERATORS[t] - 1
            else:
                open_slots -= 1
        
        if open_slots <= 0:
            return []
        
        # Filter by curriculum-allowed tokens
        return list(self.allowed_token_indices)

    def _is_complete_tree(self, tokens):
        if not tokens: return False
        try:
            tree = ExpressionTree(tokens)
            # Basic validation
            if len(tokens) > self.max_depth * 2: return False
            return tree.is_valid
        except:
            return False

    def _evaluate_formula(self, tokens, x, y):
        try:
            tree = ExpressionTree(tokens)
            _, rmse = optimize_constants(tree, x, y)
            return rmse
        except:
            return 1e9

    def get_training_examples(self, root):
        """
        Extrae ejemplos de entrenamiento del árbol generado.
        Retorna: lista de (state_tokens, policy_probs, value_target)
        """
        examples = []
        queue = [root]
        
        while queue:
            node = queue.pop(0)
            if node.visit_count < 1: 
                continue
            
            # Policy Target (Pi)
            # Distribución de visitas de los hijos
            counts = {}
            total_visits = 0
            has_children = False
            
            for token_id, child in node.children.items():
                # child key is token STRING or ID?
                # In _expand_batch: node.children[token] = child.
                # token = VOCABULARY[idx] (String).
                # So keys are strings.
                # But we need ID for probabilities array index.
                if token_id in TOKEN_TO_ID:
                    tid = TOKEN_TO_ID[token_id]
                    counts[tid] = child.visit_count
                    total_visits += child.visit_count
                    queue.append(child)
                    has_children = True
            
            if not has_children or total_visits == 0:
                continue
                
            # Construir vector de probabilidades
            probs = np.zeros(self.vocab_size, dtype=np.float32)
            for tid, count in counts.items():
                probs[tid] = count / total_visits
            
            # Value Target (V)
            # Usamos el Q-value (valor esperado) del nodo como target para el Value Head.
            # Q = value_sum / visit_count
            v = node.value_sum / node.visit_count
            
            # State: node.tokens (lista de ids?)
            # node.tokens is list of strings (from VOCABULARY).
            # self_play.py expects tokens as strings in ReplayBuffer.add.
            examples.append((node.tokens, probs, v))
            
        return examples


In [None]:
%%writefile AlphaSymbolic/search/beam_search.py
"""
Beam Search for AlphaSymbolic.
Explores multiple formula candidates in parallel, keeping top-K at each step.
"""
import torch
import numpy as np
from core.grammar import VOCABULARY, OPERATORS, TOKEN_TO_ID, ExpressionTree, OPERATOR_STAGES
from utils.optimize_constants import optimize_constants
from utils.data_utils import normalize_batch

class BeamSearch:
    def __init__(self, model, device, beam_width=10, max_length=30, curriculum_stage=None, num_variables=1):
        self.model = model
        self.device = device
        self.beam_width = beam_width
        self.max_length = max_length
        self.num_variables = num_variables
        self.vocab_size = len(VOCABULARY)
        self.sos_id = self.vocab_size  # SOS token ID
        
        # Build token mask
        # 1. Start with everything allowed (mask = 0)
        # 2. disallow variables outside num_variables range
        mask = torch.zeros(self.vocab_size, device=device)
        
        # Determine allowed variables
        from core.grammar import VARIABLES
        if num_variables == 1:
            allowed_vars = set(['x', 'x0'])
        else:
            allowed_vars = set(VARIABLES[:num_variables])
            
        disallowed_vars = [v for v in VARIABLES if v not in allowed_vars]
        if 'x' not in allowed_vars: disallowed_vars.append('x')

        for v in disallowed_vars:
            if v in TOKEN_TO_ID:
                mask[TOKEN_TO_ID[v]] = float('-inf')

        # 3. Apply curriculum limits if set
        if curriculum_stage is not None:
            allowed_ops = set(OPERATOR_STAGES.get(curriculum_stage, list(OPERATORS.keys())))
            # Disallow operators not in the current stage
            for op in OPERATORS:
                if op not in allowed_ops:
                    mask[TOKEN_TO_ID[op]] = float('-inf')
        
        # Only set self.token_mask if there are actually restricted tokens
        if torch.any(mask != 0):
            self.token_mask = mask
        else:
            self.token_mask = None
        
    def search(self, x_values, y_values, return_partial=False):
        """
        Beam Search to find the best formula structure.
        """
        # Prepare data once
        # Normalize data for model inference
        # normalize_batch input: list of arrays. We wrap x_values/y_values in list.
        norm_x_list, norm_y_list = normalize_batch([x_values], [y_values])
        norm_x = norm_x_list[0]
        norm_y = norm_y_list[0]
        
        x_tensor = torch.tensor(norm_x, dtype=torch.float32).unsqueeze(0).to(self.device) # [1, points, vars] or [1, points] if 1D-normalized
        y_tensor = torch.tensor(norm_y, dtype=torch.float32).unsqueeze(0).to(self.device) # [1, points]
        
        # Each element in beams is just the sequence of tokens (list of strings)
        # We track scores and open branches in parallel lists or a list of dicts
        beams = [{'seq': [], 'log_prob': 0.0, 'open': 1}]
        
        completed = []
        
        for step in range(self.max_length):
            if not beams:
                break
                
            # Filter valid beams just in case
            active_beams = [b for b in beams if b['open'] > 0]
            if not active_beams:
                break
                
            # Prepare batch for model
            # Batch size = number of active beams
            batch_size = len(active_beams)
            
            # Expand X and Y to match batch size [batch, points, vars]
            # Use expand with *shape to handle arbitrary dimensions (1D or Multi-Var)
            # x_tensor shape is [1, points, vars] or [1, points]
            # We want [batch_size, points, vars]
            
            # Use repeat or expand. Expand is strictly view, safer:
            x_batch = x_tensor.expand(batch_size, *x_tensor.shape[1:])
            y_batch = y_tensor.expand(batch_size, *y_tensor.shape[1:])
            
            # Prepare input sequences [batch, current_seq_len]
            # Must prepend SOS token
            seqs = [[self.sos_id] + [TOKEN_TO_ID[t] for t in b['seq']] for b in active_beams]
            input_tensor = torch.tensor(seqs, dtype=torch.long).to(self.device)
            
            # Single model call for all beams
            with torch.no_grad():
                logits, _ = self.model(x_batch, y_batch, input_tensor)
            
            # Logits shape: [batch, seq_len, vocab_size]
            # We want the last token's probabilities
            last_token_logits = logits[:, -1, :self.vocab_size]
            
            # Apply curriculum mask if set
            if self.token_mask is not None:
                last_token_logits = last_token_logits + self.token_mask


            
            log_probs = torch.log_softmax(last_token_logits, dim=-1) # [batch, vocab]
            
            # --- Repetition Penalty (Simple) ---
            # If the same token was generated recently, penalize it slightly.
            # This prevents 10 ////////// loops.
            penalty_factor = 2.0  # Reduce log_prob (which is negative) by absolute amount or multiplier?
            # Log probs are negative (e.g. -0.1). Making them MORE negative penalizes.
            # If we multiply by 1.2, -0.1 becomes -0.12 (lower probability).
            
            for i, beam in enumerate(active_beams):
                if beam['seq']:
                     # Get last token ID
                    last_token = beam['seq'][-1]
                    if last_token in TOKEN_TO_ID:
                        last_id = TOKEN_TO_ID[last_token]
                        # Penalize current step logits for this token
                        # If log_prob is close to 0 (high prob), e.g. -0.01 -> -0.012
                        # If log_prob is -10 (low prob), -> -12
                        # Check bounds to avoid NaN if -inf
                        if log_probs[i, last_id] > -1e9:
                             log_probs[i, last_id] *= 1.5 
            # -----------------------------------
            
            # We need to find the top-K candidates ACROSS current beams?
            # Standard beam search: expand all, then prune to K
            
            all_candidates = []
            
            # Get top-K for EACH beam to avoid explosion (e.g. top 2*width)
            k_per_beam = min(self.beam_width, self.vocab_size)
            beam_topk_scores, beam_topk_indices = torch.topk(log_probs, k_per_beam, dim=-1)
            
            # Move to CPU for processing logic
            beam_topk_scores = beam_topk_scores.cpu().numpy()
            beam_topk_indices = beam_topk_indices.cpu().numpy()
            
            for i, beam in enumerate(active_beams):
                for score, idx in zip(beam_topk_scores[i], beam_topk_indices[i]):
                    token = VOCABULARY[idx]
                    new_seq = beam['seq'] + [token]
                    
                    # Calculate new open branches
                    if token in OPERATORS:
                        new_open = beam['open'] + OPERATORS[token] - 1
                    else:
                        new_open = beam['open'] - 1
                    
                    if new_open < 0:
                        continue
                        
                    all_candidates.append({
                        'seq': new_seq,
                        'log_prob': beam['log_prob'] + score,
                        'open': new_open
                    })
            
            # Global prune: keep top beam_width
            all_candidates.sort(key=lambda x: x['log_prob'], reverse=True)
            beams = all_candidates[:self.beam_width]
            
            # Check for completions
            still_active = []
            for b in beams:
                if b['open'] == 0:
                    completed.append(b)
                else:
                    still_active.append(b)
            
            beams = still_active
            # If we filled up on completions, we might still want to explore? 
            # Usually we keep exploring until all beams are done or max length
            if len(completed) >= self.beam_width:
                 # Optional: early exit if we found enough good candidates
                 pass

        # Evaluate results
        scored_results = []
        for beam in completed:
            tree = ExpressionTree(beam['seq'])
            if tree.is_valid:
                constants, rmse = optimize_constants(tree, x_values, y_values)
                scored_results.append({
                    'tokens': beam['seq'],
                    'log_prob': beam['log_prob'],
                    'rmse': rmse,
                    'constants': constants,
                    'formula': tree.get_infix()
                })
        
        scored_results.sort(key=lambda x: x['rmse'])
        
        # If no results and return_partial is requested, return the best incomplete beam
        if not scored_results and return_partial and beams:
            # Take the beam with highest probability
            best_beam = beams[0] 
            # Construct a partial result
            # We can't optimize constants or get a valid infix easily, but we can show tokens
            scored_results.append({
                'tokens': best_beam['seq'],
                'log_prob': best_beam['log_prob'],
                'rmse': float('inf'),
                'constants': {},
                'formula': "Partial: " + " ".join(best_beam['seq']) + "..."
            })
            
        return scored_results


def beam_solve(target_x, target_y, model, device, beam_width=20, max_length=25, num_variables=1):
    """
    Solve symbolic regression using beam search.
    """
    searcher = BeamSearch(model, device, beam_width=beam_width, max_length=max_length, num_variables=num_variables)
    results = searcher.search(target_x, target_y)
    
    if not results:
        return None
        
    return results  # Return all results for Pareto analysis


if __name__ == "__main__":
    from core.model import AlphaSymbolicModel
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    VOCAB_SIZE = len(VOCABULARY)
    
    model = AlphaSymbolicModel(vocab_size=VOCAB_SIZE + 1, d_model=64).to(DEVICE)
    try:
        model.load_state_dict(torch.load("alpha_symbolic_model.pth", map_location=DEVICE, weights_only=True))
    except:
        print("Model not found, using random weights")
    model.eval()
    
    # Test
    x_test = np.linspace(-5, 5, 20).astype(np.float64)
    y_test = 2 * x_test + 3
    
    print("Running Beam Search...")
    results = beam_solve(x_test, y_test, model, DEVICE, beam_width=10)
    
    print(f"\nFound {len(results)} valid formulas:")
    for i, r in enumerate(results[:5]):
        print(f"  {i+1}. {r['formula']} (RMSE: {r['rmse']:.4f})")


In [None]:
%%writefile AlphaSymbolic/search/hybrid_search.py
import time
import torch
import numpy as np
from typing import List, Dict, Any, Optional
import concurrent.futures
import os

from core.gp_bridge import GPEngine
from search.beam_search import BeamSearch, beam_solve
from core.grammar import ExpressionTree
try:
    from core.gpu import TensorGeneticEngine
except ImportError:
    TensorGeneticEngine = None
    print("Warning: Could not import TensorGeneticEngine (PyTorch/CUDA missing?)")

def _run_gp_worker(args):
    """
    Worker function for Parallel GP execution.
    args: (x_list, y_list, seeds_chunk, gp_timeout, gp_binary_path)
    """
    x_list, y_list, seeds, timeout, binary_path = args
    import numpy as np
    from core.grammar import ExpressionTree
    engine = GPEngine(binary_path=binary_path)
    # Give each worker a slight timeout variance to avoid file lock collisions if using temp files
    # or just to spread load. But GPEngine handles unique tmp files so it should be fine.
    
    # Run GP
    result = engine.run(x_list, y_list, seeds, timeout_sec=timeout)
    
    # Evaluate immediately if result found to return RMSE for comparison
    if result:
        try:
             # Basic RMSE check for the worker's champion
            tree = ExpressionTree.from_infix(result)
            y_pred = tree.evaluate(np.array(x_list))
            mse = np.mean((np.array(y_list) - y_pred)**2)
            rmse = np.sqrt(mse)
            return {'formula': result, 'rmse': rmse, 'status': 'success'}
        except:
            return {'formula': result, 'rmse': 999.0, 'status': 'eval_error'}
    
    return {'formula': None, 'rmse': 1e9, 'status': 'failed'}

def hybrid_solve(
    x_values: np.ndarray,
    y_values: np.ndarray,
    model: torch.nn.Module,
    device: torch.device,
    beam_width: int = 50,
    gp_timeout: int = 10,
    gp_binary_path: Optional[str] = None,
    max_workers: int = 4,
    num_variables: int = 1,
    extra_seeds: Optional[List[str]] = None
) -> Dict[str, Any]:
    """
    Solves Symbolic Regression using a Hybrid Neuro-Evolutionary approach with Parallel GP.
    """
    
    # print(f"--- Starting Alpha-GP Hybrid Search (Parallel Workers={max_workers}, Vars={num_variables}) ---")
    start_time = time.time()
    
    # 1. Neural Beam Search (Phase 1)
    # print(f"[Phase 1] Neural Beam Search (Width={beam_width})...")
    neural_results = beam_solve(x_values, y_values, model, device, beam_width=beam_width, num_variables=num_variables)
    
    seeds = []
    
    # Inject Extra Seeds (Feedback Loop)
    if extra_seeds:
        pass # print(f"[Phase 1] Injecting {len(extra_seeds)} external seeds (Feedback Loop).")
        seeds.extend(extra_seeds)
        
    if neural_results:
        pass # print(f"[Phase 1] Found {len(neural_results)} candidates.")
        seen_formulas = set()
        for res in neural_results:
            f_str = res['formula']
            if f_str.startswith("Partial"): continue
            if f_str not in seen_formulas:
                seeds.append(f_str)
                seen_formulas.add(f_str)
        
        # print(f"[Phase 1] Generated {len(seeds)} unique seeds for GP.")
        if len(seeds) > 0:
            print(f"Top Seed NN: {seeds[0]}")
    else:
        print("[Phase 1] No valid candidates found. Falling back to pure GP.")

    # 2. GP Refinement (Phase 2 - Heterogeneous CPU + GPU)
    # print(f"[Phase 2] Genetic Refinement (Timeout={gp_timeout}s)...")
    
    x_list = x_values.tolist() if hasattr(x_values, 'tolist') else list(x_values)
    y_list = y_values.tolist() if hasattr(y_values, 'tolist') else list(y_values)
    
    # Resources managed dynamically by max_workers

    results = []
    futures = []

    # A. Launch CPU Workers (Background)
    # ---------------------------------
    # ---------------------------------
    # Prepare chunks for ALL workers
    cpu_seeds = list(seeds) # Copy
    cpu_chunks = []
    
    if max_workers > 0:
        if not cpu_seeds:
            cpu_chunks = [[] for _ in range(max_workers)]
        else:
            # Distribute seeds round-robin
            cpu_chunks = [[] for _ in range(max_workers)]
            for i, seed in enumerate(cpu_seeds):
                cpu_chunks[i % max_workers].append(seed)

    # print(f"[Phase 2] Launching {max_workers} Parallel GP Workers (C++ SOTA)...")

    with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, max_workers)) as executor:
        for chunk in cpu_chunks:
            args = (x_list, y_list, chunk, gp_timeout, gp_binary_path)
            futures.append(executor.submit(_run_gp_worker, args))


        # C. Collect CPU Results
        # ----------------------
        for future in concurrent.futures.as_completed(futures):
            try:
                res = future.result()
                if res['status'] == 'success' or res['status'] == 'eval_error':
                    res['worker'] = 'CPU'
                    results.append(res)
            except Exception as e:
                print(f"Worker exception: {e}")

    total_time = time.time() - start_time

    # Find best result across all workers
    best_result = None
    best_rmse = float('inf')
    
    for res in results:
        if res['formula'] and res['rmse'] < best_rmse:
            best_rmse = res['rmse']
            best_result = res['formula']
            
    if best_result:
        # print(f"--- Hybrid Search Completed in {total_time:.2f}s ---")
        # print(f"Best Formula (Parallel): {best_result} (RMSE: {best_rmse:.5f})")
        
        return {
            'formula': best_result,
            'rmse': best_rmse,
            'source': 'Alpha-GP Hybrid',
            'time': total_time,
            'seeds_tried': seeds if seeds else []
        }
    else:
        print(f"--- Hybrid Search Failed (All workers failed) ---")
        return {
            'formula': None,
            'rmse': 1e9,
            'source': 'Alpha-GP Hybrid',
            'time': total_time,
            'seeds_tried': seeds if seeds else []
        }

if __name__ == "__main__":
    # Test
    class MockModel(torch.nn.Module):
        def forward(self, x, y, seq):
            bs, seq_len = seq.shape
            vocab = 20
            return torch.randn(bs, seq_len, vocab), None

    print("Testing Parallel Hybrid Search...")
    x = np.linspace(-5, 5, 20)
    y = x**2 - 5
    try:
        # Important: must protect entry point for multiprocessing on Windows
        res = hybrid_solve(x, y, MockModel(), torch.device("cpu"), beam_width=5, max_workers=2)
        print(res)
    except Exception as e:
        print(f"Test failed: {e}")


In [None]:
%%writefile AlphaSymbolic/search/pareto.py
"""
Pareto Front Manager for AlphaSymbolic.
Maintains a set of non-dominated solutions (accuracy vs complexity).
"""
import numpy as np
from core.grammar import ExpressionTree

class ParetoSolution:
    def __init__(self, tokens, rmse, complexity, formula_str, constants=None):
        self.tokens = tokens
        self.rmse = rmse  # Lower is better
        self.complexity = complexity  # Lower is better (number of nodes)
        self.formula = formula_str
        self.constants = constants or {}
        
    def dominates(self, other):
        """Returns True if self dominates other (better in all objectives)."""
        # Self dominates other if:
        # - Self is at least as good in all objectives
        # - Self is strictly better in at least one objective
        at_least_as_good = (self.rmse <= other.rmse) and (self.complexity <= other.complexity)
        strictly_better = (self.rmse < other.rmse) or (self.complexity < other.complexity)
        return at_least_as_good and strictly_better
    
    def __repr__(self):
        return f"ParetoSolution(rmse={self.rmse:.4f}, complexity={self.complexity}, formula='{self.formula}')"


class ParetoFront:
    def __init__(self, max_size=50):
        self.solutions = []
        self.max_size = max_size
        
    def add(self, solution):
        """
        Attempts to add a solution to the Pareto front.
        Returns True if added, False if dominated.
        """
        # Check if new solution is dominated by any existing
        for existing in self.solutions:
            if existing.dominates(solution):
                return False  # New solution is dominated
        
        # Remove any solutions dominated by the new one
        self.solutions = [s for s in self.solutions if not solution.dominates(s)]
        
        # Add the new solution
        self.solutions.append(solution)
        
        # Enforce max size by removing worst solutions
        if len(self.solutions) > self.max_size:
            # Sort by a combined score and keep top max_size
            self.solutions.sort(key=lambda s: s.rmse + 0.01 * s.complexity)
            self.solutions = self.solutions[:self.max_size]
        
        return True
    
    def add_from_results(self, results_list):
        """
        Add multiple results from beam search or MCTS.
        results_list: list of dicts with 'tokens', 'rmse', 'constants', 'formula'
        """
        added = 0
        for r in results_list:
            tree = ExpressionTree(r['tokens'])
            complexity = len(r['tokens'])  # Simple complexity = token count
            
            sol = ParetoSolution(
                tokens=r['tokens'],
                rmse=r['rmse'],
                complexity=complexity,
                formula_str=r['formula'],
                constants=r.get('constants', {})
            )
            
            if self.add(sol):
                added += 1
        
        return added
    
    def get_best_by_rmse(self):
        """Returns the solution with lowest RMSE."""
        if not self.solutions:
            return None
        return min(self.solutions, key=lambda s: s.rmse)
    
    def get_simplest(self):
        """Returns the solution with lowest complexity."""
        if not self.solutions:
            return None
        return min(self.solutions, key=lambda s: s.complexity)
    
    def get_balanced(self, alpha=0.5):
        """
        Returns a balanced solution.
        alpha: weight for RMSE (1-alpha for complexity)
        """
        if not self.solutions:
            return None
        
        # Normalize scores
        rmse_vals = [s.rmse for s in self.solutions]
        comp_vals = [s.complexity for s in self.solutions]
        
        min_rmse, max_rmse = min(rmse_vals), max(rmse_vals) + 1e-10
        min_comp, max_comp = min(comp_vals), max(comp_vals) + 1e-10
        
        def score(s):
            norm_rmse = (s.rmse - min_rmse) / (max_rmse - min_rmse)
            norm_comp = (s.complexity - min_comp) / (max_comp - min_comp)
            return alpha * norm_rmse + (1 - alpha) * norm_comp
        
        return min(self.solutions, key=score)
    
    def summary(self):
        """Print a summary of the Pareto front."""
        print(f"\n=== Pareto Front ({len(self.solutions)} solutions) ===")
        for i, sol in enumerate(sorted(self.solutions, key=lambda s: s.rmse)[:10]):
            print(f"  {i+1}. RMSE={sol.rmse:.6f}, Nodes={sol.complexity}, Formula: {sol.formula}")


# Quick test
if __name__ == "__main__":
    front = ParetoFront()
    
    # Add some test solutions
    solutions = [
        ParetoSolution(['x'], 10.0, 1, "x"),
        ParetoSolution(['+', 'x', '1'], 5.0, 3, "(x + 1)"),
        ParetoSolution(['*', '2', 'x'], 3.0, 3, "(2 * x)"),
        ParetoSolution(['+', '*', '2', 'x', '3'], 0.5, 5, "((2 * x) + 3)"),
        ParetoSolution(['+', '*', '*', '2', 'x', 'x', '+', 'x', '1'], 0.1, 9, "complicated"),
    ]
    
    for sol in solutions:
        added = front.add(sol)
        print(f"Added {sol.formula}: {added}")
    
    front.summary()
    
    print(f"\nBest by RMSE: {front.get_best_by_rmse()}")
    print(f"Simplest: {front.get_simplest()}")
    print(f"Balanced: {front.get_balanced()}")


In [None]:
%%writefile AlphaSymbolic/search/__init__.py


In [None]:
%%writefile AlphaSymbolic/training/train.py
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from core.model import AlphaSymbolicModel
from data.synthetic_data import DataGenerator
from core.grammar import VOCABULARY, TOKEN_TO_ID, ExpressionTree

def validate(model, val_data, device, vocab_size):
    model.eval()
    total_loss = 0
    total_token_acc = 0
    valid_formulas = 0
    total_samples = len(val_data)
    
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
    
    # Prepare batch
    x_list = [d['x'] for d in val_data]
    y_list = [d['y'] for d in val_data]
    token_ids_list = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in val_data]
    
    max_len = max(len(s) for s in token_ids_list)
    SOS_ID = vocab_size
    
    decoder_input = torch.full((total_samples, max_len + 1), SOS_ID, dtype=torch.long)
    targets = torch.full((total_samples, max_len + 1), -1, dtype=torch.long)
    
    for i, seq in enumerate(token_ids_list):
        l = len(seq)
        decoder_input[i, 1:l+1] = torch.tensor(seq, dtype=torch.long)
        targets[i, :l] = torch.tensor(seq, dtype=torch.long)
        
    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(device)
    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(device)
    decoder_input = decoder_input.to(device)
    targets = targets.to(device)
    
    with torch.no_grad():
        logits, _ = model(x_tensor, y_tensor, decoder_input)
        
        # Loss
        loss = ce_loss_fn(logits.view(-1, vocab_size + 1), targets.view(-1))
        total_loss = loss.item()
        
        # Accuracy
        preds = torch.argmax(logits, dim=-1) # [batch, seq_len]
        
        # Token accuracy (mask padding)
        mask = targets != -1
        correct = (preds == targets) & mask
        total_token_acc = correct.sum().float() / mask.sum().float()
        
        # Formula Validity (reconstruct and check)
        # We need to strip EOS/Padding and stop at first end token if we had one, 
        # but here we just check if the sequence *as predicted* makes sense?
        # Actually, let's just check the ground truth reconstruction for now or 
        # ideally we should run greedy search to generate a formula and check THAT.
        # Checking "teacher forced" predictions for validity is less useful.
        # fast check:
        pass 

    model.train()
    return total_loss, total_token_acc.item()

def train_supervised():
    # Hyperparameters
    BATCH_SIZE = 32
    EPOCHS = 100 
    LR = 1e-4
    VOCAB_SIZE = len(VOCABULARY)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"Using device: {DEVICE}")
    
    # Model
    model = AlphaSymbolicModel(vocab_size=VOCAB_SIZE + 1, d_model=64).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
    
    # Data Generator
    data_gen = DataGenerator(max_depth=4)
    
    # Generate Fixed Validation Set
    print("Generating validation set...")
    val_data = data_gen.generate_batch(100) # 100 validation samples
    
    SOS_ID = VOCAB_SIZE
    model.train()
    
    for epoch in range(EPOCHS):
        # 1. Generate Training Batch
        batch_data = data_gen.generate_batch(BATCH_SIZE)
        if not batch_data: continue
        
        # Prepare inputs
        x_list = [d['x'] for d in batch_data]
        y_list = [d['y'] for d in batch_data]
        token_ids_list = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch_data]
        
        max_len = max(len(s) for s in token_ids_list)
        
        decoder_input = torch.full((BATCH_SIZE, max_len + 1), SOS_ID, dtype=torch.long)
        targets = torch.full((BATCH_SIZE, max_len + 1), -1, dtype=torch.long)
        
        for i, seq in enumerate(token_ids_list):
            l = len(seq)
            decoder_input[i, 1:l+1] = torch.tensor(seq, dtype=torch.long)
            targets[i, :l] = torch.tensor(seq, dtype=torch.long)
            
        x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
        y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
        decoder_input = decoder_input.to(DEVICE)
        targets = targets.to(DEVICE)
        
        # Forward
        logits, value_pred = model(x_tensor, y_tensor, decoder_input)
        loss = ce_loss_fn(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
        
        # Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            val_loss, val_acc = validate(model, val_data, DEVICE, VOCAB_SIZE)
            print(f"Epoch {epoch}: Train Loss = {loss.item():.4f} | Val Loss = {val_loss:.4f} | Val Acc = {val_acc:.2%}")
            
    # Save model
    torch.save(model.state_dict(), "alpha_symbolic_model.pth")
    print("Training complete. Model saved.")

if __name__ == "__main__":
    try:
        train_supervised()
    except Exception as e:
        print(f"Training failed: {e}")
        import traceback
        traceback.print_exc()


In [None]:
%%writefile AlphaSymbolic/training/train_enhanced.py
"""
Enhanced Training Script for AlphaSymbolic.
Includes:
- Curriculum Learning (simple formulas first, then complex operators)
- Value Network Training (not just policy)
- Proper Loss Weighting
- Regularization (dropout, weight decay)
- Learning Rate Scheduling (OneCycleLR)
- Gradient Accumulation
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
import numpy as np
import time

from core.model import AlphaSymbolicModel
from data.synthetic_data import DataGenerator
from core.grammar import VOCABULARY, TOKEN_TO_ID

def train_enhanced(epochs=1000, batch_size=64, curriculum=True, save_interval=100, accum_steps=4):
    """
    Enhanced training with curriculum learning, value head, OneCycleLR, and Gradient Accumulation.
    effective_batch_size = batch_size * accum_steps
    """
    VOCAB_SIZE = len(VOCABULARY)
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print("="*60)
    print("AlphaSymbolic Enhanced Training")
    print("="*60)
    print(f"Device: {DEVICE}")
    print(f"Vocabulary Size: {VOCAB_SIZE}")
    print(f"Epochs: {epochs}")
    print(f"Physical Batch Size: {batch_size}")
    print(f"Accumulation Steps: {accum_steps}")
    print(f"Effective Batch Size: {batch_size * accum_steps}")
    print(f"Curriculum Learning: {curriculum}")
    
    # Model with dropout
    model = AlphaSymbolicModel(
        vocab_size=VOCAB_SIZE + 1, 
        d_model=128,  # Larger model
        nhead=4,
        num_encoder_layers=3,
        num_decoder_layers=3
    ).to(DEVICE)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total Parameters: {total_params:,}")
    
    # Optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) # Higher LR for OneCycle
    
    # Learning rate scheduler
    # Steps per epoch is actually 1 because we generate data on fly? 
    # No, usually OneCycle needs total steps: epochs * steps_per_epoch
    # Here one "epoch" is one batch generation loop?
    # In the original code, `range(epochs)` ran one batch per loop iteration.
    # So total_steps = epochs.
    scheduler = OneCycleLR(
        optimizer, 
        max_lr=1e-3, 
        total_steps=epochs, 
        pct_start=0.3, 
        div_factor=25, 
        final_div_factor=1000
    )
    
    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
    mse_loss_fn = nn.MSELoss()
    
    # SOS token
    SOS_ID = VOCAB_SIZE
    
    # Curriculum Levels
    # Level 0: Basic arithmetic
    # Level 1: Division included
    # Level 2: All (Trig, Exp, Log)
    op_levels = [
        ['+', '-', '*'],
        ['+', '-', '*', '/'],
        None # All
    ]
    
    # Training loop
    model.train()
    start_time = time.time()
    best_loss = float('inf')
    
    optimizer.zero_grad() # Initialize gradients to zero before the loop
    
    for epoch in range(epochs):
        # Determine Curriculum Level
        if curriculum:
            progress = epoch / epochs
            # Depth: 2 -> 6
            current_depth = int(2 + progress * 4)
            # Operators
            if progress < 0.3:
                current_ops = op_levels[0]
            elif progress < 0.6:
                current_ops = op_levels[1]
            else:
                current_ops = op_levels[2]
        else:
            current_depth = 5
            current_ops = None
        
        # Generate batch with current difficulty
        data_gen = DataGenerator(max_depth=current_depth, allowed_operators=current_ops)
        batch_data = data_gen.generate_batch(batch_size)
        
        if len(batch_data) < batch_size // 2:
            continue
        
        actual_batch = len(batch_data)
        
        # Prepare data
        x_list = [d['x'] for d in batch_data]
        y_list = [d['y'] for d in batch_data]
        token_ids_list = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch_data]
        
        # Pad sequences
        max_len = max(len(s) for s in token_ids_list)
        
        # Decoder input: [SOS, tokens...]
        decoder_input = torch.full((actual_batch, max_len + 1), SOS_ID, dtype=torch.long)
        targets = torch.full((actual_batch, max_len + 1), -1, dtype=torch.long)  # -1 = padding
        
        # Target values (negative normalized RMSE, scaled to [-1, 1])
        # For synthetic data, the "perfect" formula exists, so we use length penalty as proxy
        value_targets = torch.zeros(actual_batch, 1)
        
        for i, seq in enumerate(token_ids_list):
            l = len(seq)
            decoder_input[i, 1:l+1] = torch.tensor(seq, dtype=torch.long)
            targets[i, :l] = torch.tensor(seq, dtype=torch.long)
            # Value target: These are ground truth sequences, so they lead to the correct solution.
            # We assign a high value (1.0) to represent that this is a "winning" path.
            # Ideally, we could dampen it slightly for earlier steps (gamma discount), 
            # but 1.0 is standard for "optimal path" in supervised pretraining.
            value_targets[i] = 1.0
        
        x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
        y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
        decoder_input = decoder_input.to(DEVICE)
        targets = targets.to(DEVICE)
        value_targets = value_targets.to(DEVICE)
        
        # Forward pass
        logits, value_pred = model(x_tensor, y_tensor, decoder_input)
        
        # Policy loss (cross-entropy)
        policy_loss = ce_loss_fn(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
        
        # Value loss (MSE)
        # Value loss (ensure dimensions match)
        if value_pred.shape != value_targets.shape:
             value_pred = value_pred.view_as(value_targets)
             
        value_loss = mse_loss_fn(value_pred, value_targets)
        
        # Combined loss
        loss = policy_loss + 0.5 * value_loss
        
        # Gradient Accumulation
        # Normalize loss by accum_steps to keep magnitude same
        loss = loss / accum_steps
        loss.backward()
        
        if (epoch + 1) % accum_steps == 0:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Logging (multiply loss back by accum_steps for display)
        if epoch % 50 == 0:
            elapsed = time.time() - start_time
            lr = scheduler.get_last_lr()[0]
            real_loss = loss.item() * accum_steps
            ops_name = "All" if current_ops is None else str(len(current_ops))
            print(f"Epoch {epoch:4d} | Loss: {real_loss:.4f} (P: {policy_loss.item():.4f}, V: {value_loss.item():.4f}) | LR: {lr:.2e} | Depth: {current_depth} | Ops: {ops_name} | Time: {elapsed:.1f}s")
        
        # Save checkpoint
        if epoch % save_interval == 0 and epoch > 0:
            real_loss = loss.item() * accum_steps
            if real_loss < best_loss:
                best_loss = real_loss
                torch.save(model.state_dict(), "alpha_symbolic_model.pth")
                print(f"  -> Saved checkpoint (best loss: {best_loss:.4f})")
    
    # Final save
    torch.save(model.state_dict(), "alpha_symbolic_model.pth")
    total_time = time.time() - start_time
    print(f"\nTraining complete! Total time: {total_time:.1f}s")
    print(f"Model saved to: alpha_symbolic_model.pth")
    
    return model


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Train AlphaSymbolic")
    parser.add_argument("--epochs", type=int, default=500, help="Number of epochs")
    parser.add_argument("--batch", type=int, default=32, help="Batch size (physical)")
    parser.add_argument("--accum", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument("--no-curriculum", action="store_true", help="Disable curriculum learning")
    args = parser.parse_args()
    
    train_enhanced(
        epochs=args.epochs,
        batch_size=args.batch,
        curriculum=not args.no_curriculum,
        accum_steps=args.accum
    )



In [None]:
%%writefile AlphaSymbolic/training/self_play.py
"""
Self-Play AlphaZero Loop for AlphaSymbolic.
The model improves by learning from its own search results.

Process:
1. Generate problems (synthetic or from memory)
2. Use MCTS to find best formulas
3. Store successful (state, action, value) tuples with priority
4. Train network on this experience using weighted sampling
5. Repeat
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
import os
from collections import deque
import random
import sys
import os

# Add project root to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core.model import AlphaSymbolicModel
from core.grammar import VOCABULARY, TOKEN_TO_ID
from data.synthetic_data import DataGenerator
from search.mcts import MCTS
from data.pattern_memory import PatternMemory
from core.loss import QuantileLoss


class ReplayBuffer:
    """Experience replay buffer for storing (state, policy, value) tuples with priority."""
    
    def __init__(self, capacity=50000):
        self.buffer = deque(maxlen=capacity)
        self.priorities = deque(maxlen=capacity)
    
    def add(self, x_data, y_data, tokens, policy, value):
        """
        Add an experience tuple (state, policy, value).
        """
        self.buffer.append({
            'x': x_data,
            'y': y_data, 
            'tokens': tokens,
            'policy': policy,
            'value': value
        })
        # Priority: could be based on error, but for now uniform/FIFO
        # We'll use value magnitude + small noise for diversity if needed
        # For now, just append 1.0
        self.priorities.append(1.0)
    
    def sample(self, batch_size):
        """Sample a batch."""
        if len(self.buffer) < batch_size:
            return list(self.buffer)
        
        # Uniform sampling for now (simplify)
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        return [self.buffer[i] for i in indices]
    
    def __len__(self):
        return len(self.buffer)


class AlphaZeroLoop:
    def __init__(self, model_path="alpha_symbolic_model.pth", fresh_start=False):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.vocab_size = len(VOCABULARY)
        self.model_path = model_path

        # Handle fresh start
        if fresh_start and os.path.exists(self.model_path):
            os.remove(self.model_path)
            print("Previous model deleted. Starting fresh.")
        
        # Model
        self.model = AlphaSymbolicModel(
            vocab_size=self.vocab_size + 1,
            d_model=128,
            nhead=4,
            num_encoder_layers=3,
            num_decoder_layers=3
        ).to(self.device)
        
        self.model_path = model_path
        self.load_model()
        
        # Optimizer
        self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0.01)
        
        # Replay buffer
        self.replay = ReplayBuffer(capacity=100000)
        
        # Data generator for new problems
        self.data_gen = DataGenerator(max_depth=5)
        
        # Pattern memory
        self.memory = PatternMemory()
        
        # Search (MCTS)
        # TPSR params: complexity_lambda=0.1
        self.searcher = MCTS(self.model, self.device, max_simulations=50, max_depth=25, complexity_lambda=0.1)
        
        # Statistics
        self.stats = {
            'iterations': 0,
            'best_rmse': float('inf'),
            'avg_rmse': deque(maxlen=100)
        }
    
    def load_model(self):
        if os.path.exists(self.model_path):
            try:
                self.model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=True))
                print(f"Loaded model from {self.model_path}")
            except:
                print("Could not load model, using fresh weights")
    
    def save_model(self):
        torch.save(self.model.state_dict(), self.model_path)
    
    def self_play_episode(self, num_problems=10):
        """
        Generate problems, solve them with MCTS, store experiences (AlphaZero style).
        """
        self.model.eval()
        
        experiences = []
        
        # Generate problems
        problems = self.data_gen.generate_batch(num_problems)
        
        for prob in problems:
            x_data = prob['x'].astype(np.float64)
            y_data = prob['y'].astype(np.float64)
            
            # Search for solution via MCTS
            result = self.searcher.search(x_data, y_data)
            
            # Extract training examples from the tree (Polcy, Value)
            # result['root'] is now available
            if 'root' in result:
                examples = self.searcher.get_training_examples(result['root'])
                for (tokens, policy, value) in examples:
                    self.replay.add(x_data, y_data, tokens, policy, value)
            
            if result['tokens']:
                experiences.append(result['rmse'])
                
                # Update pattern memory
                if result['formula'] and result['rmse'] < 1.0: # Only good ones
                     self.memory.record(result['tokens'], result['rmse'], result['formula'])
                
                # Track statistics
                if result['rmse'] < self.stats['best_rmse']:
                    self.stats['best_rmse'] = result['rmse']
                self.stats['avg_rmse'].append(result['rmse'])
        
        return experiences
    
    def train_step(self, batch_size=32):
        """
        Train on experiences from replay buffer.
        """
        if len(self.replay) < batch_size:
            return None
        
        self.model.train()
        
        batch = self.replay.sample(batch_size)
        
        # Prepare batch
        SOS_ID = self.vocab_size
        
        x_list = [exp['x'] for exp in batch]
        y_list = [exp['y'] for exp in batch]
        token_lists = [exp['tokens'] for exp in batch]
        policy_targets = [exp['policy'] for exp in batch]
        value_targets_list = [exp['value'] for exp in batch]
        
        # Pad sequences
        max_len = max(len(t) for t in token_lists)
        
        decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
        # Policy target is distribution [batch, vocab]
        # But wait, MCTS returns policy for the NEXT token from current state.
        # Our model predicts sequence. 
        # Standard AZ: (State) -> (Policy, Value).
        # Wrapper: (Input X, Y, Partial Formula) -> (Next Token Dist, Value).
        # So we align 'decoder_input' = tokens. 
        # Target for Policy head at the LAST step is 'policy_targets'.
        
        # The 'tokens' in replay buffer IS the state (partial sequence).
        # So decoder_input should be exactly that.
        
        for i, tokens in enumerate(token_lists):
            ids = [TOKEN_TO_ID[t] for t in tokens] # Assuming tokens are strings? 
            # Wait, MCTSNode.tokens are IDs or Strings?
            # Grammar says VOCABULARY list. MCTS seems to store tokens as strings or IDs?
            # MCTS code uses TOKEN_TO_ID when calling model. So MCTSNode.tokens likely Strings?
            # Let's verify MCTSNode usage.
            # MCTSNode init: tokens=[]. expand: names from VOCABULARY.
            # So they are strings.
            
            l = len(ids)
            decoder_input[i, 1:l+1] = torch.tensor(ids, dtype=torch.long)
            
        # Targets
        # Policy: [batch, vocab_size]
        policy_target_tensor = torch.tensor(np.array(policy_targets), dtype=torch.float32).to(self.device)
        value_target_tensor = torch.tensor(np.array(value_targets_list), dtype=torch.float32).unsqueeze(1).to(self.device)
        
        # To device
        x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(self.device)
        y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(self.device)
        decoder_input = decoder_input.to(self.device)
        
        # Forward
        logits, value_pred = self.model(x_tensor, y_tensor, decoder_input)
        
        # Losses
        # Policy Loss: KLDiv between predicted distribution (logits of last token) and target distribution
        # logits: [batch, seq_len, vocab]
        # We only care about the LAST token prediction because that matches the MCTS state
        # The 'decoder_input' has length L+1 (SOS + tokens).
        # The prediction for the 'next' token is at the last position.
        
        # Gather last step logits
        # We need to pick the logits corresponding to the end of each sequence?
        # Since we padded, we must be careful.
        # 'decoder_input' is padded with SOS? No, standard padding?
        # I initialized with SOS_ID.
        # The effective length for user i is len(tokens)+1.
        # So we take logits[i, len(tokens), :] ??
        # Or did I pad with something else?
        
        last_logits = []
        for i, tokens in enumerate(token_lists):
            idx = len(tokens) # index of last real token (SOS is at 0)
            # sequences: SOS, T1, T2 ...
            # Input: SOS, T1
            # Output at pos 1: Pred for T2.
            # State was [T1]. 
            # Correct.
            last_logits.append(logits[i, idx, :self.vocab_size])
            
        last_logits = torch.stack(last_logits)
        
        # Log Softmax for KLDiv
        log_probs = torch.log_softmax(last_logits, dim=1)
        
        # KLDivLoss(input, target) expects log_probs as input
        policy_loss = nn.KLDivLoss(reduction='batchmean')(log_probs, policy_target_tensor)
        
        # Value Loss (Quantile Regression)
        # value_pred: [batch, 3], value_target_tensor: [batch, 1]
        value_loss = QuantileLoss()(value_pred, value_target_tensor)
        
        total_loss = policy_loss + value_loss # Weighting? 1.0 each for now
        
        # Backward
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        
        return {
            'total': total_loss.item(),
            'policy': policy_loss.item(),
            'value': value_loss.item()
        }
    
    def run(self, iterations=100, problems_per_iter=20, train_steps_per_iter=10, 
            save_interval=10, verbose=True):
        """
        Main AlphaZero loop.
        """
        if verbose:
            print("="*60)
            print("AlphaZero Self-Play Loop (MCTS Enhanced)")
            print("="*60)
            print(f"Device: {self.device}")
            print(f"Iterations: {iterations}")
            print(f"Problems per iteration: {problems_per_iter}")
        
        start_time = time.time()
        
        for i in range(iterations):
            self.stats['iterations'] = i + 1
            
            # Self-play phase
            rmses = self.self_play_episode(problems_per_iter)
            
            # Training phase
            losses = []
            for _ in range(train_steps_per_iter):
                loss = self.train_step()
                if loss:
                    losses.append(loss)
            
            # Logging
            if verbose and (i + 1) % 5 == 0:
                avg_rmse = np.mean(list(self.stats['avg_rmse'])) if self.stats['avg_rmse'] else 0
                avg_loss = np.mean([l['total'] for l in losses]) if losses else 0
                elapsed = time.time() - start_time
                
                print(f"Iter {i+1:4d} | Buffer: {len(self.replay):5d} | "
                      f"Avg RMSE: {avg_rmse:.4f} | Best: {self.stats['best_rmse']:.4f} | "
                      f"Loss: {avg_loss:.4f} | Time: {elapsed:.1f}s")
            
            # Save model
            if (i + 1) % save_interval == 0:
                self.save_model()
                self.memory.save()
                if verbose:
                    print(f"  -> Checkpoint saved")
        
        # Final save
        self.save_model()
        self.memory.save()
        
        total_time = time.time() - start_time
        if verbose:
            print(f"\nSelf-play complete! Total time: {total_time:.1f}s")
            print(f"Final buffer size: {len(self.replay)}")
            print(f"Best RMSE achieved: {self.stats['best_rmse']:.6f}")


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="AlphaZero Self-Play")
    parser.add_argument("--iterations", type=int, default=50)
    parser.add_argument("--problems", type=int, default=10)
    parser.add_argument("--train-steps", type=int, default=5)
    args = parser.parse_args()
    
    loop = AlphaZeroLoop()
    loop.run(
        iterations=args.iterations,
        problems_per_iter=args.problems,
        train_steps_per_iter=args.train_steps
    )


In [None]:
%%writefile AlphaSymbolic/training/__init__.py


In [None]:
%%writefile AlphaSymbolic/ui/app_core.py
"""
Core state and model management for AlphaSymbolic Gradio App.
"""
import torch
import os
from core.model import AlphaSymbolicModel
from core.grammar import VOCABULARY

from collections import deque
import time

# Global state
MODEL = None
DEVICE = None
TRAINING_STATUS = {"running": False, "epoch": 0, "loss": 0, "message": "Listo"}
STOP_TRAINING = False  # Flag to request training stop

def request_stop_training():
    """Request training to stop gracefully."""
    global STOP_TRAINING
    STOP_TRAINING = True
    return "⏹️ Deteniendo entrenamiento..."

def should_stop_training():
    """Check if training should stop."""
    return STOP_TRAINING

def reset_stop_flag():
    """Reset the stop flag (call at start of training)."""
    global STOP_TRAINING
    STOP_TRAINING = False

# Hall of Shame: Rolling buffer of recent failures
# Format: {'time': str, 'target': str, 'predicted': str, 'loss': float, 'stage': str}
TRAINING_ERRORS = deque(maxlen=20)

def add_training_error(target, predicted, loss, stage):
    """Add an error to the Hall of Shame."""
    TRAINING_ERRORS.append({
        'time': time.strftime("%H:%M:%S"),
        'target': target,
        'predicted': predicted,
        'loss': float(loss),
        'stage': stage
    })

def get_training_errors():
    """Get list of errors for the UI."""
    return list(TRAINING_ERRORS)

MODEL_PRESETS = {
    'lite': {'d_model': 128, 'nhead': 4, 'num_encoder_layers': 3, 'num_decoder_layers': 3},
    'pro': {'d_model': 256, 'nhead': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6}
}
CURRENT_PRESET = 'lite'

def get_device(force_cpu=False):
    """Get the best available device (CUDA > MPS > CPU)."""
    if force_cpu:
        return torch.device("cpu")
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def set_device(use_gpu=True):
    """Set the device (GPU or CPU)."""
    global DEVICE, MODEL
    new_device = get_device(force_cpu=not use_gpu)
    
    if MODEL is not None and DEVICE != new_device:
        MODEL = MODEL.to(new_device)
    
    DEVICE = new_device
    return get_device_info()

def get_device_info():
    """Get device info string."""
    global DEVICE
    if DEVICE is None:
        DEVICE = get_device()
    
    if DEVICE.type == "cuda":
        return f"CUDA ({torch.cuda.get_device_name(0)})"
    elif DEVICE.type == "mps":
        return "MPS (Apple Silicon)"
    else:
        return "CPU"

def load_model(force_reload=False, preset_name=None):
    """Load or reload the model."""
    global MODEL, DEVICE, CURRENT_PRESET
    
    if preset_name:
        CURRENT_PRESET = preset_name
    
    if DEVICE is None:
        DEVICE = get_device()
    
    VOCAB_SIZE = len(VOCABULARY)
    config = MODEL_PRESETS[CURRENT_PRESET]
    
    print(f"Loading Model [{CURRENT_PRESET.upper()}]...")
    MODEL = AlphaSymbolicModel(
        vocab_size=VOCAB_SIZE + 1, 
        d_model=config['d_model'], 
        nhead=config['nhead'],
        num_encoder_layers=config['num_encoder_layers'], 
        num_decoder_layers=config['num_decoder_layers'],
        max_seq_len=256,
        input_dim=11
    ).to(DEVICE)
    
    filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
    status = f"Nuevo modelo ({CURRENT_PRESET})" # Default status
    
    # Check Drive for backup IF on colab and main file doesn't exist or is older?
    # Simple strategy: prioritize local, but if local missing, check Drive.
    drive_path = "/content/drive/MyDrive/AlphaSymbolic_Models"
    drive_filename = os.path.join(drive_path, filename)
    
    local_path = os.path.join("models", filename)
    
    source_file = None
    if os.path.exists(local_path):
        source_file = local_path
    elif os.path.exists(filename): # Legacy location
        source_file = filename
    elif os.path.exists(drive_filename):
        print(f"📦 Local model missing. Loading from Drive: {drive_filename}")
        source_file = drive_filename

    if source_file:
        try:
            state_dict = torch.load(source_file, map_location=DEVICE, weights_only=True)
            
            # Check for NaNs
            has_nans = False
            for k, v in state_dict.items():
                if torch.isnan(v).any() or torch.isinf(v).any():
                    has_nans = True
                    break
            
            if has_nans:
                # print(f"⚠️ Modelo corrupto detectado (NaNs) en {source_file}. Eliminando.")
                # try:
                #    os.remove(source_file)
                #    print("✅ Archivo corrupto eliminado.")
                # except OSError as e:
                #    print(f"Error al eliminar archivo: {e}")
                status = "⚠️ Advertencia: NaNs detectados (pero no borrado)"
            else:
                # Handle resizing of Positional Encoding (e.g. 50 -> 256)
                if 'pos_encoder.pe' in state_dict:
                    saved_pe_shape = state_dict['pos_encoder.pe'].shape
                    model_pe_shape = MODEL.pos_encoder.pe.shape
                    if saved_pe_shape != model_pe_shape:
                        print(f"⚠️ Resizing Positional Encoding from {saved_pe_shape[1]} to {model_pe_shape[1]}. Resetting buffer.")
                        del state_dict['pos_encoder.pe']
                        MODEL.load_state_dict(state_dict, strict=False)
                    else:
                        MODEL.load_state_dict(state_dict)
                else:
                    MODEL.load_state_dict(state_dict)
                    
                MODEL.eval()
                status = f"Modelo cargado ({CURRENT_PRESET})"
                
        except RuntimeError as e:
            print(f"⚠️ Error de compatibilidad ({e}). Iniciando modelo fresco.")
            status = f"Nuevo modelo ({CURRENT_PRESET})"
        except Exception as e:
            print(f"Error cargando: {e}")
            status = "Sin modelo pre-entrenado"
    
    return status, get_device_info()

def get_model():
    """Get the current model, loading if needed."""
    global MODEL, DEVICE
    if MODEL is None:
        load_model()
    return MODEL, DEVICE

def save_model():
    """Save the current model and all associated data (formulas, patterns)."""
    global MODEL, CURRENT_PRESET
    if MODEL is not None:
        filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
        local_path = os.path.join("models", filename)
        torch.save(MODEL.state_dict(), local_path)
        
        # Backup to Google Drive if available
        if os.path.exists("/content/drive"):
            import shutil
            drive_path = "/content/drive/MyDrive/AlphaSymbolic_Models"
            try:
                os.makedirs(drive_path, exist_ok=True)
                
                # 1. Backup Model
                drive_filename = os.path.join(drive_path, filename)
                shutil.copy(local_path, drive_filename)
                
                # 2. Backup Formula Data
                FILES_TO_BACKUP = [
                    ('top_formulas.csv', 'top_formulas.csv'),
                    ('pattern_memory.json', 'pattern_memory.json'),
                    ('results/learned_formulas.csv', 'learned_formulas.csv'),
                    ('top_5_detailed_report.csv', 'top_5_detailed_report.csv')
                ]
                
                for src, name in FILES_TO_BACKUP:
                    if os.path.exists(src):
                        shutil.copy(src, os.path.join(drive_path, name))
                        
                print(f"✅ Data & Model backed up to Drive: {drive_path}")
            except Exception as e:
                print(f"⚠️ Error al respaldar en Drive: {e}")


In [None]:
%%writefile AlphaSymbolic/ui/app_search.py
"""
Search/Solve functions for AlphaSymbolic Gradio App.
Supports both Beam Search and MCTS.
"""
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import time
import gradio as gr

from core.grammar import ExpressionTree
from search.beam_search import BeamSearch
from search.mcts import MCTS
from search.hybrid_search import hybrid_solve
from utils.simplify import simplify_tree
from search.pareto import ParetoFront
from utils.detect_pattern import detect_pattern
from utils.optimize_constants import optimize_constants, substitute_constants
from ui.app_core import get_model


def parse_data(x_str, y_str):
    """
    Parse input strings.
    Supports:
    1. 1D Comma-separated: "1, 2, 3"
    2. 2D Multi-line/Semi-colon: "1,2; 3,4" or "1 2\n3 4"
    """
    try:
        # Pre-process: standardize separators
        x_str = x_str.strip()
        y_str = y_str.strip()
        
        # Check for multi-line or semi-colon (Multi-Variable)
        is_multivar = '\n' in x_str or ';' in x_str
        
        if is_multivar:
            # Split into rows
            rows = [r.strip() for r in x_str.replace(';', '\n').split('\n') if r.strip()]
            # Parse each row
            x_data = []
            for r in rows:
                # Handle comma or space
                vals = [float(v) for v in r.replace(',', ' ').split()]
                x_data.append(vals)
            x = np.array(x_data, dtype=np.float64)
            
            # Y should also be checked, usually 1D but input might be multi-line
            y_data = [float(v) for v in y_str.replace(';', '\n').replace(',', ' ').split()]
            y = np.array(y_data, dtype=np.float64)
            
        else:
            # Legacy 1D
            x = np.array([float(v.strip()) for v in x_str.split(',')], dtype=np.float64)
            y = np.array([float(v.strip()) for v in y_str.split(',')], dtype=np.float64)
            
            # Ensure X is (N, 1) or (N,) depending on usage. 
            # Logic mostly expects (N,) for 1D, but model needs (N, 1).
            # Let's keep (N,) for 1D to not break existing plots, handling shape later.
        
        if len(x) != len(y):
            return None, None, f"Error: Cantidad de muestras X ({len(x)}) != Y ({len(y)})"
            
        return x, y, None
    except Exception as e:
        return None, None, f"Error parseando datos: {str(e)}"


def create_fit_plot(x, y, y_pred, formula):
    """Create a plot showing data vs prediction."""
    fig, ax = plt.subplots(figsize=(8, 5), facecolor='#1a1a2e')
    ax.set_facecolor('#1a1a2e')
    
    # Check dimensions
    if x.ndim > 1 and x.shape[1] > 1:
        # Multi-Variable: Parity Plot (Real vs Predicted)
        ax.scatter(y, y_pred, color='#4ade80', s=100, edgecolors='white', alpha=0.7)
        
        # Perfect fit line
        min_val = min(y.min(), y_pred.min())
        max_val = max(y.max(), y_pred.max())
        ax.plot([min_val, max_val], [min_val, max_val], '--', color='white', alpha=0.5, label='Ideal')
        
        ax.set_xlabel('Valor Real (Target)', color='white')
        ax.set_ylabel('Prediccion', color='white')
        ax.set_title(f'Multi-Variable: {x.shape[1]} Features', color='white', fontweight='bold')
        
    else:
        # 1D: Standard X vs Y
        # Flatten if needed
        if x.ndim > 1: x = x.flatten()
        
        ax.scatter(x, y, color='#00d4ff', s=100, label='Datos Reales', zorder=3, edgecolors='white', linewidth=1)
        
        sort_idx = np.argsort(x)
        ax.plot(x[sort_idx], y_pred[sort_idx], color='#ff6b6b', linewidth=3, label='Prediccion', zorder=2)
        
        ax.set_xlabel('X', color='white', fontsize=12)
        ax.set_ylabel('Y', color='white', fontsize=12)
        ax.set_title('Ajuste de la Formula', color='white', fontsize=14, fontweight='bold')
        ax.legend(facecolor='#16213e', edgecolor='#00d4ff', labelcolor='white')

    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.2, color='white')
    
    for spine in ax.spines.values():
        spine.set_color('#00d4ff')
    
    plt.tight_layout()
    return fig


def solve_formula(x_str, y_str, beam_width, search_method, max_workers=4, progress=gr.Progress()):
    """Main solving function with search method selection."""
    x, y, error = parse_data(x_str, y_str)
    if error:
        return error, None, "", "", ""
    
    MODEL, DEVICE = get_model()
    
    progress(0.1, desc=f"Analizando patron... [{DEVICE.type.upper()}]")
    pattern = detect_pattern(x, y)
    
    progress(0.3, desc=f"Buscando formulas ({search_method})... [{DEVICE.type.upper()}]")
    start_time = time.time()
    
    results = []
    
    num_vars = 1 if x.ndim == 1 else x.shape[1]
    
    if search_method == "Alpha-GP Hybrid":
        # Using hybrid search
        progress(0.4, desc="Fase 1: Neural Beam Search...")
        # Note: Hybrid search handles its own phases printing, but we want UI updates.
        # We pass beam_width. gp_timeout is increased to 30s to allow convergence on complex problems.
        hybrid_res = hybrid_solve(x, y, MODEL, DEVICE, beam_width=int(beam_width), gp_timeout=30, num_variables=num_vars, max_workers=max_workers)
        
        if hybrid_res and hybrid_res.get('formula'):
            progress(0.9, desc="Procesando resultados GP...")
            # Convert infix string back to tokens for consistency
            tree = ExpressionTree.from_infix(hybrid_res['formula'])
            if tree.is_valid:
                 # Evaluate RMSE roughly (GP result should be good, but let's confirm)
                 # Optimization is already done by GP, but we might want to fine-tune 
                 # or at least extract constants if they are numbers in the string.
                 # The string from GP has numbers like 2.345 embedded.
                 # optimize_constants expects a tree with 'C' placeholders if we want to re-optimize.
                 # But GP output is fully instantiated.
                 # So we just evaluate.
                 
                 y_pred_check = tree.evaluate(x)
                 rmse_check = np.sqrt(np.mean((y_pred_check - y)**2))
                 
                 results = [{
                     'tokens': tree.tokens,
                     'formula': tree.get_infix(),
                     'rmse': rmse_check,
                     'constants': {} # Constants are baked into the formula string
                 }]
    
    elif search_method == "Beam Search":
        searcher = BeamSearch(MODEL, DEVICE, beam_width=int(beam_width), max_length=25, num_variables=num_vars)
        results = searcher.search(x, y)
    else:  # MCTS
        mcts = MCTS(MODEL, DEVICE, max_simulations=int(beam_width) * 10, num_variables=num_vars)
        result = mcts.search(x, y)
        if result and result.get('tokens'):
            tokens = result['tokens']
            tree = ExpressionTree(tokens)
            if tree.is_valid:
                constants, rmse = optimize_constants(tree, x, y)
                results = [{
                    'tokens': tokens,
                    'formula': tree.get_infix(),
                    'rmse': rmse,
                    'constants': constants
                }]
    
    search_time = time.time() - start_time
    
    if not results:
        return "No se encontraron formulas validas", None, "", "", ""
    
    progress(0.7, desc="Optimizando constantes...")
    pareto = ParetoFront()
    pareto.add_from_results(results)
    best = pareto.get_best_by_rmse()
    
    if not best:
        return "Error en optimizacion", None, "", "", ""
    
    progress(0.9, desc="Procesando...") 
    tree = ExpressionTree(best.tokens)
    
    # Use the stored formula string directly (this is what GP/search found)
    display_formula = best.formula
    
    # If we have constants to substitute (Beam Search / MCTS with C placeholders)
    if best.constants:
        try:
            positions = tree.root.get_constant_positions()
            raw_infix = tree.get_infix()
            display_formula = substitute_constants(raw_infix, best.constants, positions)
        except:
            pass
    
    # Try to simplify algebraically (x0 + x0 -> 2*x0, etc.)
    try:
        simplified = simplify_tree(tree)
        # Only use simplified if it:
        # 1. Is valid (not just a number, not "Invalid")
        # 2. Still contains a variable (x or x0-x9)  
        # 3. Is shorter or similar length
        if simplified and simplified != "Invalid":
            has_variable = any(v in simplified for v in ['x', 'x0', 'x1', 'x2', 'x3'])
            is_not_just_number = not simplified.replace('.', '').replace('-', '').isdigit()
            if has_variable and is_not_just_number:
                display_formula = simplified
    except:
        pass
    
    y_pred = tree.evaluate(x, constants=best.constants)
    
    fig = create_fit_plot(x, y, y_pred, display_formula)
    
    # Format results
    result_html = f"""
    <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #00d4ff;">
        <h2 style="color: #00d4ff; margin: 0; font-size: 24px;">Formula Encontrada</h2>
        <div style="background: #0f0f23; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #ff6b6b;">
            <code style="color: #ff6b6b; font-size: 28px; font-weight: bold;">{display_formula}</code>
        </div>
        <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 10px;">
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">RMSE</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{best.rmse:.6f}</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Nodos</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{best.complexity}</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Tiempo</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{search_time:.2f}s</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Metodo</span><br>
                <span style="color: #4ade80; font-size: 16px; font-weight: bold;">{search_method}</span>
            </div>
        </div>
        <div style="margin-top: 15px; padding: 10px; background: #0f0f23; border-radius: 8px;">
            <span style="color: #888;">Patron:</span> 
            <span style="color: #ffd93d;">{pattern['type']}</span> 
            <span style="color: #666;">({pattern['confidence']:.0%})</span>
            <span style="color: #888; margin-left: 20px;">Device:</span>
            <span style="color: #4ade80;">{DEVICE.type.upper()}</span>
        </div>
    """
    
    # Add constants if any
    # Add constants if any
    if best.constants:
        # Sort and format cleanly
        sorted_items = sorted(best.constants.items(), key=lambda x: str(x[0]))
        clean_consts = []
        for i, (k, v) in enumerate(sorted_items):
            clean_consts.append(f"C_{i+1}: {v:.4f}")
        const_str = "  |  ".join(clean_consts)
        
        result_html += f"""
        <div style="margin-top: 10px; padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid #ffd93d;">
            <span style="color: #888;">Constantes:</span>
            <span style="color: #fff; font-family: monospace; margin-left: 10px;">{const_str}</span>
        </div>
        """
        
    result_html += "</div>"
    
    # Predictions table
    pred_html = '<table style="width: 100%; border-collapse: collapse; background: #1a1a2e; border-radius: 10px; overflow: hidden;">'
    pred_html += '<tr style="background: #16213e;"><th style="padding: 10px; color: #00d4ff;">X</th><th style="color: #00d4ff;">Pred</th><th style="color: #00d4ff;">Real</th><th style="color: #00d4ff;">Delta</th></tr>'
    for i in range(min(50, len(y))):
        delta = abs(y_pred[i] - y[i])
        color = "#4ade80" if delta < 0.1 else "#fbbf24" if delta < 1 else "#ef4444"
        
        # Display X nicely
        x_val_str = ""
        if x.ndim > 1 and x.shape[1] > 1:
             x_val_str = f"[{', '.join([f'{v:.1f}' for v in x[i]])}]"
        else:
             xv = x[i] if x.ndim == 1 else x[i,0]
             x_val_str = f"{xv:.2f}"
             
        pred_html += f'<tr style="border-bottom: 1px solid #333;"><td style="padding: 8px; color: white; text-align: center;">{x_val_str}</td><td style="color: white; text-align: center;">{y_pred[i]:.4f}</td><td style="color: white; text-align: center;">{y[i]:.4f}</td><td style="color: {color}; text-align: center; font-weight: bold;">{delta:.4f}</td></tr>'
    pred_html += '</table>'
    
    # Alternatives
    alt_html = '<div style="background: #1a1a2e; padding: 15px; border-radius: 10px;">'
    alt_html += '<h4 style="color: #00d4ff; margin-top: 0;">Alternativas</h4>'
    for i, sol in enumerate(pareto.solutions[:4]):
        alt_html += f'<div style="padding: 5px 10px; margin: 5px 0; background: #0f0f23; border-radius: 5px; border-left: 3px solid {"#00d4ff" if i == 0 else "#666"};"><code style="color: {"#ff6b6b" if i == 0 else "#888"};">{sol.formula}</code> <span style="color: #666; font-size: 12px;">RMSE: {sol.rmse:.4f}</span></div>'
    alt_html += '</div>'
    
    return result_html, fig, pred_html, alt_html, display_formula


def generate_example(tipo):
    """Generate example data."""
    if tipo == "lineal":
        x = np.linspace(1, 10, 10)
        y = 2 * x + 3
    elif tipo == "cuadratico":
        x = np.linspace(-5, 5, 11)
        y = x**2 + 1
    elif tipo == "trig":
        x = np.linspace(0, 6.28, 20)
        y = np.sin(x)
    elif tipo == "exp":
        x = np.linspace(0, 5, 15)
        y = 2 * np.exp(0.5 * x)
    else:
        x = np.linspace(1, 10, 10)
        y = 2 * x + 3
    
    return ", ".join([f"{v:.2f}" for v in x]), ", ".join([f"{v:.4f}" for v in y])


In [None]:
%%writefile AlphaSymbolic/ui/app_training.py
"""
Training functions for AlphaSymbolic Gradio App.
With proper data normalization.
"""
import os
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr
from collections import deque
import random
import time
import csv
import datetime

from core.grammar import VOCABULARY, TOKEN_TO_ID, OPERATORS, OPERATOR_STAGES, VARIABLES
from data.synthetic_data import DataGenerator
from ui.app_core import get_model, save_model, TRAINING_STATUS, add_training_error, should_stop_training, reset_stop_flag
from core.loss import QuantileLoss
from search.hybrid_search import hybrid_solve
from core.grammar import ExpressionTree, simplify_formula
from utils.data_utils import normalize_batch


def get_allowed_token_mask(stage, vocab_size, device):
    """
    Creates a mask tensor for token logits.
    Allowed tokens = 1.0, Disallowed = 0.0 (for multiplication mask)
    Or returns indices of allowed tokens for -inf masking.
    """
    allowed_ops = OPERATOR_STAGES.get(stage, list(OPERATORS.keys()))
    
    # All terminals are always allowed + VARIABLES
    allowed_tokens = set(['C', '0', '1', '2', '3', '5', '10', 'pi', 'e'])
    allowed_tokens.update(VARIABLES) # IMPORTANT! Don't forget variables
    allowed_tokens.update(allowed_ops)
    
    # Build mask
    mask = torch.zeros(vocab_size + 1, device=device)  # +1 for SOS token
    for token in allowed_tokens:
        if token in TOKEN_TO_ID:
            mask[TOKEN_TO_ID[token]] = 1.0
    mask[vocab_size] = 1.0  # SOS always allowed
    
    return mask


# Normalization moved to utils.data_utils



def train_basic(epochs, batch_size, point_count=10, num_variables=1, progress=gr.Progress()):
    """Basic training with synthetic data."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(epochs), eta_min=1e-6)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        data_gen = DataGenerator(max_depth=4, num_variables=int(num_variables))
        losses = []
        
        for epoch in range(int(epochs)):
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{int(epochs)} [{DEVICE.type.upper()}]")
            
            # Mix of inverse (known formulas) + random data (AlphaTensor-style)
            half_batch = int(batch_size) // 2
            batch_inverse = data_gen.generate_inverse_batch(half_batch, point_count=int(point_count))
            batch_random = data_gen.generate_batch(int(batch_size) - half_batch, point_count=int(point_count))
            batch = batch_inverse + batch_random
            if len(batch) < 2:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            
            # Normalize data
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for i, seq in enumerate(token_lists):
                decoder_input[i, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
            
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            # Forward
            optimizer.zero_grad()
            logits, _ = MODEL(x_tensor, y_tensor, decoder_input)
            loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            # Skip if loss is NaN
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            losses.append(loss.item())
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        if not losses:
            return "Error: No se pudo calcular loss (revisar datos)", None
        
        fig = create_loss_plot(losses, "Entrenamiento Basico")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #4ade80;">
            <h2 style="color: #4ade80; margin: 0;">Entrenamiento Completado</h2>
            <p style="color: white;">Epocas: {int(epochs)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #00d4ff;">Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_curriculum(epochs, batch_size, point_count=10, num_variables=1, progress=gr.Progress()):
    """Curriculum Learning - starts simple, increases difficulty gradually."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)  # Lower LR
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        quantile_loss_fn = QuantileLoss().to(DEVICE)
        losses = []
        
        for epoch in range(int(epochs)):
            # Curriculum: slow progression
            # Stage 1 (0-50%): depth 2-3, 80% inverse data
            # Stage 2 (50-80%): depth 3-4, 50% inverse data  
            # Stage 3 (80-100%): depth 4-5, 20% inverse data
            progress_pct = epoch / epochs
            
            if progress_pct < 0.5:
                current_depth = 2 + int(progress_pct * 2)  # 2-3
                inverse_ratio = 0.8
            elif progress_pct < 0.8:
                current_depth = 3 + int((progress_pct - 0.5) * 3.3)  # 3-4
                inverse_ratio = 0.5
            else:
                current_depth = 4 + int((progress_pct - 0.8) * 5)  # 4-5
                inverse_ratio = 0.2
            
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{int(epochs)} (prof: {current_depth}, inv: {inverse_ratio:.0%}) [{DEVICE.type.upper()}]")
            
            data_gen = DataGenerator(max_depth=current_depth, num_variables=int(num_variables))
            
            # Mix inverse + random based on curriculum stage
            n_inverse = int(batch_size * inverse_ratio)
            n_random = int(batch_size) - n_inverse
            
            batch_inverse = data_gen.generate_inverse_batch(max(1, n_inverse), point_count=int(point_count)) if n_inverse > 0 else []
            batch_random = data_gen.generate_batch(max(1, n_random), point_count=int(point_count)) if n_random > 0 else []
            batch = batch_inverse + batch_random
            if len(batch) < 2:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for i, seq in enumerate(token_lists):
                decoder_input[i, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
            
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            optimizer.zero_grad()
            logits, value_pred = MODEL(x_tensor, y_tensor, decoder_input)
            
            # Policy Loss
            loss_policy = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            # Value Loss
            # For supervised learning, these are "perfect" solutions, so Value Target = 1.0 (as a scalar per batch item)
            value_targets = torch.ones((len(batch), 1), device=DEVICE)
            loss_value = quantile_loss_fn(value_pred, value_targets)
            
            # Combined Loss
            loss = loss_policy + 0.5 * loss_value
            
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            losses.append(loss.item())
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        if not losses:
            return "Error: No se pudo calcular loss", None
        
        fig = create_loss_plot(losses, "Curriculum Learning")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #00d4ff;">
            <h2 style="color: #00d4ff; margin: 0;">Curriculum Learning Completado</h2>
            <p style="color: white;">Epocas: {int(epochs)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #888;">Profundidad maxima: 6 | Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_self_play(iterations, problems_per_iter, point_count=10, num_variables=1, progress=gr.Progress()):
    """AlphaZero Self-Play loop."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()  # Reset stop flag at start
    
    try:
        MODEL, DEVICE = get_model()
        
        from search.mcts import MCTS
        
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)
        # Scheduler: Reduce LR when plateauing to help convergence
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15, min_lr=1e-6)
        
        # Losses for AlphaZero
        # Policy: KLDiv (comparing distributions)
        # Value: Quantile Loss (3 Quantiles)
        kl_loss = torch.nn.KLDivLoss(reduction='batchmean')
        quantile_loss_fn = QuantileLoss()
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        replay_buffer = deque(maxlen=20000)
        
        # Adaptive Curriculum State
        current_depth = 2
        data_gen = DataGenerator(max_depth=current_depth, num_variables=int(num_variables))
        
        # MCTS for A100: Increase batch size and simulations significantly
        # Adjusted for RTX 3050/i5: Batch 64 is smoother (less CPU wait)
        # Initialize with Stage 0 (Arithmetic only)
        curriculum_stage = 0
        searcher = MCTS(MODEL, DEVICE, max_simulations=500, complexity_lambda=0.1, batch_size=64, curriculum_stage=curriculum_stage, num_variables=int(num_variables))
        
        rmses = []
        losses = []
        total_gp_corrections = 0  # Track GP expert corrections
        best_avg_rmse = float('inf')
        
        start_time = time.time()
        
        for iteration in range(int(iterations)):
            # Check for stop request
            if should_stop_training():
                print("⏹️ Training stopped by user")
                break
            # ETA Calculation
            elapsed = time.time() - start_time
            if iteration > 0:
                avg_time_per_iter = elapsed / iteration
                remaining_iters = int(iterations) - iteration
                eta_seconds = remaining_iters * avg_time_per_iter
                
                # Format ETA
                if eta_seconds > 3600:
                    eta_str = f"{eta_seconds/3600:.1f}h"
                elif eta_seconds > 60:
                    eta_str = f"{eta_seconds/60:.0f}m"
                else:
                    eta_str = f"{eta_seconds:.0f}s"
            else:
                eta_str = "Calculando..."

            # Adaptive Curriculum Check
            # Stages: 0=Arithmetic, 1=Poly, 2=Trig, 3=Adv, 4=Complex
            CURRICULUM_LEVELS = [
                {'depth': 1, 'ops': ['+', '-', '*', '/']},
                {'depth': 2, 'ops': ['+', '-', '*', '/']},
                {'depth': 3, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt']},
                {'depth': 4, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos']},
                {'depth': 5, 'ops': None} # All
            ]
            

            recent_rmse = np.mean(rmses[-20:]) if len(rmses) >= 20 else 1.0
            
            # Graduation condition: RMSE < 0.1 stable
            if len(rmses) > 20 and recent_rmse < 0.1 and curriculum_stage < len(CURRICULUM_LEVELS) - 1:
                curriculum_stage += 1
                stage_info = CURRICULUM_LEVELS[curriculum_stage]
                data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'], num_variables=int(num_variables))
                # Recreate MCTS with new curriculum stage for operator filtering
                searcher = MCTS(MODEL, DEVICE, max_simulations=500, complexity_lambda=0.1, batch_size=64, curriculum_stage=curriculum_stage, num_variables=int(num_variables))
                print(f"*** Curriculum Level Up! Stage {curriculum_stage} ({stage_info['depth']}, {stage_info['ops']}) ***")
                # Clear buffer to avoid training on old easy data? Maybe keep some for replay.
            
            # Ensure data_gen is initialized at start
            if iteration == 0:
                stage_info = CURRICULUM_LEVELS[0]
                data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'], num_variables=int(num_variables))

            stage_name = ["Arithmetic", "Polynomials", "Trigonometry", "Advanced", "Complex"][curriculum_stage]
            
            # Safe access to current_lr
            curr_lr_disp = optimizer.param_groups[0]['lr']
            msg = f"Iter {iteration+1}/{int(iterations)} [{stage_name}] RMSE:{recent_rmse:.3f} LR:{curr_lr_disp:.1e} | ETA: {eta_str}"
            progress((iteration + 1) / iterations, desc=msg)
            
            # Active Learning / Hard Mining Phase
            MODEL.eval()
            
            # Generate a large pool of candidates candidates to find the "hard" ones
            pool_size = problems_per_iter * 3  # Generate 3x more than we need
            candidates = data_gen.generate_inverse_batch(pool_size, point_count=int(point_count))
            
            if not candidates:
                continue
                
            # Quick forward pass to estimate difficulty (Loss)
            # We want to train on problems where the model currently FAILS (High Loss)
            hard_problems = []
            
            with torch.no_grad():
                # Process in chunks to avoid OOM
                chunk_size = 32
                for i in range(0, len(candidates), chunk_size):
                    chunk = candidates[i:i+chunk_size]
                    
                    x_list = [d['x'] for d in chunk]
                    y_list = [d['y'] for d in chunk]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in chunk]
                    max_len = max(len(s) for s in token_lists)
                    
                    # Prepare tensors
                    dec_in = torch.full((len(chunk), max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                    targets = torch.full((len(chunk), max_len + 1), -1, dtype=torch.long).to(DEVICE)
                    
                    for j, seq in enumerate(token_lists):
                        dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                        targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                        
                    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    
                    logits, _ = MODEL(x_tensor, y_tensor, dec_in)
                    
                    # Calculate loss per item
                    # CrossEntropy usually aggregates, so we use reduction='none'
                    loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
                    raw_losses = loss_f(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
                    
                    # Reshape back to [Batch, Seq] to sum/mean per sample
                    raw_losses = raw_losses.view(len(chunk), -1)
                    # Average loss per non-padded token
                    mask = (targets != -1)
                    sample_losses = (raw_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)
                    
                    for j, loss_val in enumerate(sample_losses):
                        # Store (Loss, Problem)
                        hard_problems.append((loss_val.item(), chunk[j]))
            
            # Sort by difficulty (Loss descending)
            hard_problems.sort(key=lambda x: x[0], reverse=True)
            
            # Stabilization: Mix Hardest (70%) + Random Examples (30%)
            # This prevents "Catastrophic Forgetting" of simpler patterns
            n_hard = int(problems_per_iter * 0.7)
            n_random = int(problems_per_iter) - n_hard
            
            # Top K hardest
            selected_hard = [p[1] for p in hard_problems[:n_hard]]
            
            # Random selection from the rest of the pool (to keep variety)
            remaining_pool = [p[1] for p in hard_problems[n_hard:]]
            selected_random = random.sample(remaining_pool, min(n_random, len(remaining_pool))) if remaining_pool else []
            
            selected_problems = selected_hard + selected_random
            
            avg_pool_loss = np.mean([p[0] for p in hard_problems])
            top_loss = np.mean([p[0] for p in hard_problems[:n_hard]]) if n_hard > 0 else 0
            
            print(f"Active Learning: Pool Loss {avg_pool_loss:.3f} -> Selected Mix (Hard:{top_loss:.3f})")

            # --- HALL OF SHAME CAPTURE ---
            # Capture what the model predicts for the top 3 hardest failures
            try:
                top_failures = hard_problems[:3]
                x_fail = [p[1]['x'].astype(np.float64) for p in top_failures]
                y_fail = [p[1]['y'].astype(np.float64) for p in top_failures]
                target_formulas = [p[1]['infix'] for p in top_failures]
                fail_losses = [p[0] for p in top_failures]
                
                # Simple Greedy Decode to see what it predicts
                from search.beam_search import BeamSearch
                # Use beam search with width 1 (Greedy) for speed, with curriculum mask
                bs = BeamSearch(MODEL, DEVICE, beam_width=1, max_length=20, curriculum_stage=curriculum_stage, num_variables=int(num_variables))
                
                for i in range(len(top_failures)):
                    try:
                        # Decode
                        # Enable return_partial to see what the model is thinking if it fails
                        res = bs.search(x_fail[i], y_fail[i], return_partial=True)
                        if not res:
                            pred_formula = "Search Empty (No Tokens)"
                        else:
                            pred_formula = res[0]['formula']
                            
                        # Detect Looping (e.g. "10 / / / / / /")
                        # Basic heuristic: check if last 10 chars contain > 80% same char or repeating pattern
                        if len(pred_formula) > 20:
                            # Check for repeating slashes or other single chars
                            if pred_formula.count('/') > 10 and pred_formula.endswith('/ .'): 
                                 pred_formula = pred_formula[:20] + " ... [Loop Detected]"
                            elif " / / / " in pred_formula:
                                 pred_formula = pred_formula.split(" / / / ")[0] + " ... [Loop Detected]"
                        
                        add_training_error(
                            target=target_formulas[i],
                            predicted=pred_formula,
                            loss=fail_losses[i],
                            stage=stage_name
                        )
                    except Exception as e:
                        print(f"HoS Inner Error: {e}")
                        add_training_error(
                            target=target_formulas[i],
                            predicted=f"CRASH: {str(e)[:20]}",
                            loss=fail_losses[i],
                            stage=stage_name
                        )
            except Exception as e:
                import traceback
                print(f"HoS Outer Error: {e}")
                traceback.print_exc()

            # --- MCTS SOLVE + GP EXPERT CORRECTION ---
            gp_corrections = 0
            nn_successes = 0
            
            for prob in selected_problems:
                x_data = prob['x'].astype(np.float64)
                y_data = prob['y'].astype(np.float64)
                target_tokens = prob.get('tokens', [])  # Known answer for inverse problems
                
                try:
                    # 1. Neural Network attempts to solve
                    result = searcher.search(x_data, y_data)
                    nn_rmse = result.get('rmse', float('inf'))
                    
                    # 2. Check if NN succeeded or failed
                    NN_SUCCESS_THRESHOLD = 0.1  # RMSE threshold for "good enough"
                    
                    if nn_rmse < NN_SUCCESS_THRESHOLD:
                        # NN succeeded - store its examples
                        nn_successes += 1
                        if 'root' in result:
                            examples = searcher.get_training_examples(result['root'])
                            for (tokens, policy, value) in examples:
                                replay_buffer.append({
                                    'x': x_data, 'y': y_data,
                                    'tokens': tokens,
                                    'policy': policy,
                                    'value': value,
                                    'source': 'NN'
                                })
                        rmses.append(nn_rmse)
                    else:
                        # 3. NN FAILED - GP Engine to the rescue!
                        # Use hybrid_solve with INCREASED timeout for better formulas
                        try:
                            gp_result = hybrid_solve(
                                x_data, y_data, MODEL, DEVICE,
                                beam_width=20,
                                gp_timeout=15,    # Increased from 5s to 15s
                                num_variables=int(num_variables)
                            )
                            
                            if gp_result and gp_result.get('formula'):
                                # Convert GP formula to tokens
                                tree = ExpressionTree.from_infix(gp_result['formula'])
                                if tree.is_valid and tree.tokens:
                                    # Calculate actual RMSE of GP solution
                                    y_pred = tree.evaluate(x_data)
                                    gp_rmse = np.sqrt(np.mean((y_pred - y_data)**2))
                                    
                                    # DYNAMIC ACCEPTANCE CRITERIA
                                    # Accept if RMSE <= 0.01 (Precision Mode)
                                    # Since we fixed the GP, we expect exact matches.
                                    is_decent = gp_rmse <= 0.01
                                    
                                    if is_decent and len(tree.tokens) <= 50:
                                        # Sanitize tokens: replace numeric constants NOT in vocab with 'C'
                                        sanitized_tokens = []
                                        for t in tree.tokens:
                                            if t in TOKEN_TO_ID:
                                                sanitized_tokens.append(t)
                                            else:
                                                try:
                                                    float(t)
                                                    sanitized_tokens.append('C')
                                                except ValueError:
                                                    sanitized_tokens = None
                                                    break
                                        
                                        if sanitized_tokens and len(sanitized_tokens) > 0:
                                            # Create "expert" policy - uniform over tokens
                                            policy = np.ones(len(VOCABULARY)) / len(VOCABULARY)
                                            
                                            # SCALED REWARD
                                            # Give higher value for better solutions
                                            # RMSE 0.0 -> Value 1.0
                                            # RMSE 0.1 -> Value 0.8
                                            # RMSE 0.5 -> Value 0.2
                                            # Formula: max(0.1, 1.0 - (rmse * 1.6))
                                            reward_value = max(0.1, 1.0 - (gp_rmse * 1.6))
                                            
                                            replay_buffer.append({
                                                'x': x_data, 'y': y_data,
                                                'tokens': sanitized_tokens,
                                                'policy': policy,
                                                'value': reward_value,
                                                'source': 'GP_EXPERT'
                                            })
                                            gp_corrections += 1
                                            rmses.append(gp_rmse)
                                            
                                            print(f"📚 GP Expert ACCEPTED: {gp_result['formula'][:50]}... (RMSE: {gp_rmse:.4f}, Val: {reward_value:.2f})")
                                        else:
                                             print(f"🔸 GP Rejected (Sanitization): {gp_result['formula'][:30]}")
                                    else:
                                         print(f"🔸 GP Rejected (Quality): RMSE {gp_rmse:.4f} vs NN {nn_rmse:.4f}")
                        except Exception as gp_err:
                            # GP failed too - skip this problem
                            pass
                        
                        # Also store NN failure for learning (lower value)
                        if 'root' in result and result.get('tokens'):
                            examples = searcher.get_training_examples(result['root'])
                            for (tokens, policy, value) in examples:
                                replay_buffer.append({
                                    'x': x_data, 'y': y_data,
                                    'tokens': tokens,
                                    'policy': policy,
                                    'value': max(0.0, 0.3 - nn_rmse * 0.1),  # Low value for bad solutions
                                    'source': 'NN_FAIL'
                                })
                            rmses.append(nn_rmse)
                        
                except Exception as e:
                    print(f"Self-play error: {e}")
                    continue
            
            # Log progress
            if gp_corrections > 0:
                total_gp_corrections += gp_corrections
                print(f"🎯 Iteration {iteration+1}: NN Success: {nn_successes}, GP Corrections: {gp_corrections}")
            
            # Training phase
            # To saturate GPU: Increase batch size and number of updates
            if len(replay_buffer) >= 64:
                MODEL.train()
                
                # Dynamic training steps: Train more if we have more data
                # AlphaZero ratio usually high (e.g. 10 epochs on new data)
                # Here we sample from buffer.
                train_batch_size = 128
                if len(replay_buffer) < train_batch_size:
                    train_batch_size = 64
                
                # Steps: roughly cover 20% of buffer or at least 10 steps
                steps = max(10, min(50, len(replay_buffer) // train_batch_size))
                
                for _ in range(steps):
                    batch = random.sample(list(replay_buffer), min(train_batch_size, len(replay_buffer)))
                    
                    x_list = [exp['x'] for exp in batch]
                    y_list = [exp['y'] for exp in batch]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID[t] for t in exp['tokens']] for exp in batch]
                    policy_targets = [exp['policy'] for exp in batch]
                    value_targets_list = [exp['value'] for exp in batch]
                    
                    max_len = max(len(s) for s in token_lists)
                    decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
                    
                    # Policy targets (for KLDiv) and Value targets
                    policy_target_tensor = torch.tensor(np.array(policy_targets), dtype=torch.float32).to(DEVICE)
                    value_target_tensor = torch.tensor(np.array(value_targets_list), dtype=torch.float32).unsqueeze(1).to(DEVICE)
                    
                    for i, seq in enumerate(token_lists):
                        l = len(seq)
                        decoder_input[i, 1:l+1] = torch.tensor(seq, dtype=torch.long)
                    
                    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    decoder_input = decoder_input.to(DEVICE)
                    
                    optimizer.zero_grad()
                    logits, value_pred = MODEL(x_tensor, y_tensor, decoder_input)
                    
                    # Policy Loss (KL Divergence)
                    # Get logits for the last token position of each sequence
                    last_logits = []
                    for i, seq in enumerate(token_lists):
                        idx = len(seq) # Post-padding index? No, index in padded tensor.
                        # decoder_input: [SOS, T1, T2]
                        # logits: [PredSOS, PredT1, PredT2]
                        # We want prediction AFTER T2? No.
                        # MCTS Example: State=[T1, T2]. Policy=Dist for T3.
                        # Model Input: [SOS, T1, T2]. Output Last: Dist for T3.
                        # Index is len(seq).
                        last_logits.append(logits[i, idx, :VOCAB_SIZE])
                    
                    last_logits = torch.stack(last_logits)
                    log_probs = torch.nn.functional.log_softmax(last_logits, dim=1)
                    
                    loss_policy = kl_loss(log_probs, policy_target_tensor)
                    
                    # Value Loss (Quantile)
                    loss_value = quantile_loss_fn(value_pred, value_target_tensor)
                    
                    # Total Loss
                    loss = loss_policy + loss_value 
                    
                    if not (torch.isnan(loss) or torch.isinf(loss)):
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                        optimizer.step()
                        losses.append(loss.item())
            
            # Step Scheduler based on recent Loss
            if losses:
                current_loss = np.mean(losses[-10:])
                scheduler.step(current_loss)
            
            current_lr = optimizer.param_groups[0]['lr']
            
            # Periodic save
            if (iteration + 1) % 10 == 0:
                save_model()
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_selfplay_plot(losses, rmses)
        
        avg_rmse = np.mean(rmses[-50:]) if rmses else 0
        gp_pct = (total_gp_corrections / max(1, len(rmses))) * 100
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #ff6b6b;">
            <h2 style="color: #ff6b6b; margin: 0;">Self-Play + GP Expert Completado</h2>
            <p style="color: white;">Iteraciones: {int(iterations)} | Problemas: {len(rmses)}</p>
            <p style="color: #888;">RMSE Promedio: {avg_rmse:.4f} | Dispositivo: {DEVICE.type.upper()}</p>
            <p style="color: #4ade80;">📚 Correcciones GP Expert: {total_gp_corrections} ({gp_pct:.1f}% de problemas)</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        import traceback
        print(f"Self-play error traceback:")
        traceback.print_exc()
        return f"Error: {str(e)}", None


def create_loss_plot(losses, title):
    """Create a loss plot with dark theme."""
    plt.close('all')
    fig, ax = plt.subplots(figsize=(8, 4), facecolor='#1a1a2e')
    ax.set_facecolor('#1a1a2e')
    
    if losses and len(losses) > 0:
        ax.plot(losses, color='#00d4ff', linewidth=2)
        ax.set_xlabel('Paso', color='white')
        ax.set_ylabel('Loss', color='white')
    else:
        # Placeholder when no data
        ax.text(0.5, 0.5, 'Esperando datos...', 
                transform=ax.transAxes, fontsize=16, color='#888',
                ha='center', va='center')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
    
    ax.set_title(title, color='white', fontweight='bold')
    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.2)
    for spine in ax.spines.values():
        spine.set_color('#00d4ff')
    plt.tight_layout()
    return fig


def create_selfplay_plot(losses, rmses):
    """Create dual plot for self-play results."""
    plt.close('all')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4), facecolor='#1a1a2e')
    
    ax1.set_facecolor('#1a1a2e')
    if losses:
        ax1.plot(losses, color='#00d4ff', linewidth=2)
    ax1.set_xlabel('Step', color='white')
    ax1.set_ylabel('Loss', color='white')
    ax1.set_title('Policy Loss', color='white', fontweight='bold')
    ax1.tick_params(colors='white')
    ax1.grid(True, alpha=0.2)
    
    ax2.set_facecolor('#1a1a2e')
    if rmses:
        ax2.plot(rmses, color='#ff6b6b', linewidth=1, alpha=0.5)
        if len(rmses) > 10:
            ma = np.convolve(rmses, np.ones(10)/10, mode='valid')
            ax2.plot(range(9, len(rmses)), ma, color='#ff6b6b', linewidth=2)
    ax2.set_xlabel('Problema', color='white')
    ax2.set_ylabel('RMSE', color='white')
    ax2.set_title('RMSE', color='white', fontweight='bold')
    ax2.tick_params(colors='white')
    ax2.grid(True, alpha=0.2)
    
    for ax in [ax1, ax2]:
        for spine in ax.spines.values():
            spine.set_color('#00d4ff')
    
    plt.tight_layout()
    return fig

def train_supervised(iterations, batch_size=128, point_count=10, progress=gr.Progress()):
    """
    Massive Supervised Pre-training (Warmup).
    Focus: Syntax, Basic Arithmetic, Overcoming "Collapse to Constant".
    Speed: High (No MCTS, just random generation + CrossEntropy).
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()  # Reset stop flag at start
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4, weight_decay=0.01)
        # Slower decay: T_max = iterations * 2 keeps LR higher for longer
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(iterations*2), eta_min=1e-6)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        # Progressive Curriculum Stages for Pre-training
        PRE_CURRICULUM = [
            {'depth': 2, 'ops': ['+', '-', '*', '/'], 'stage': 0},           # 0-20%: Basic arithmetic (depth 2 for variety)
            {'depth': 2, 'ops': ['+', '-', '*', '/'], 'stage': 0},           # 20-40%: Deeper arithmetic
            {'depth': 2, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt'], 'stage': 1},  # 40-60%: Powers
            {'depth': 3, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos'], 'stage': 2},  # 60-80%: Trig
            {'depth': 3, 'ops': None, 'stage': None},  # 80-100%: All ops
        ]
        
        losses = []
        current_stage_idx = 0
        # Curriculum for variables: start simple, add complexity
        VARS_BY_STAGE = [1, 1, 2, 3, 5]  # Max vars per stage
        stage_info = PRE_CURRICULUM[0]
        data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'], num_variables=1)
        allowed_mask = get_allowed_token_mask(stage_info['stage'] if stage_info['stage'] is not None else 4, VOCAB_SIZE, DEVICE)
        
        start_time = time.time()
        
        for i in range(int(iterations)):
            # Check for stop request
            if should_stop_training():
                print("⏹️ Pre-training stopped by user")
                break
            
            # Progressive curriculum: change stage based on progress
            progress_pct = i / int(iterations)
            new_stage_idx = min(int(progress_pct * 5), 4)  # 0-4 based on 20% increments
            
            if new_stage_idx != current_stage_idx:
                current_stage_idx = new_stage_idx
                stage_info = PRE_CURRICULUM[current_stage_idx]
                # Progressive variable curriculum: stages 0-1 use 1 var, 2 uses 1-2, 3 uses 1-3, 4 uses 1-5
                max_vars_this_stage = VARS_BY_STAGE[current_stage_idx]
                iter_num_vars = random.randint(1, max_vars_this_stage)
                data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'], num_variables=iter_num_vars)
                stage_id = stage_info['stage'] if stage_info['stage'] is not None else 4
                allowed_mask = get_allowed_token_mask(stage_id, VOCAB_SIZE, DEVICE)
                stage_name = ['Arithmetic', 'Polynomials', 'Trigonometry', 'Advanced', 'Complex'][new_stage_idx]
                print(f"📚 Pre-training: {stage_name} (depth={stage_info['depth']}, max_vars={max_vars_this_stage})")
            
            # ETA
            elapsed = time.time() - start_time
            if i > 0:
                iter_per_sec = i / elapsed
                remaining = int(iterations) - i
                eta = remaining / iter_per_sec
                eta_str = f"{eta:.0f}s"
            else:
                eta_str = "..."
                
            current_lr = optimizer.param_groups[0]['lr']
            stage_name = ['Arithmetic', 'Polynomials', 'Trigonometry', 'Advanced', 'Complex'][current_stage_idx]
            msg = f"[{stage_name}] Iter {i+1}/{int(iterations)} Loss:{np.mean(losses[-50:]) if losses else 0:.3f} LR:{current_lr:.1e} ETA:{eta_str}"
            progress((i + 1) / iterations, desc=msg)
            
            # Variable curriculum: use appropriate range for current stage
            # IMPORTANT: Set num_variables BEFORE generating batch to ensure uniform dimensions
            max_vars_this_stage = VARS_BY_STAGE[current_stage_idx]
            batch_num_vars = random.randint(1, max_vars_this_stage)
            data_gen.num_variables = batch_num_vars
            data_gen.active_variables = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9'][:batch_num_vars]
            data_gen.terminals = data_gen.active_variables + ['C', '0', '1', '2', '3', '5', '10', 'pi', 'e']
            
            # Generate Random Batch (High Speed)
            batch = data_gen.generate_batch(int(batch_size), point_count=int(point_count))
            
            if not batch:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for j, seq in enumerate(token_lists):
                decoder_input[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                
            # Prepare tensors: x is (batch, points, vars), y is (batch, points, 1)
            x_tensor = torch.tensor(np.stack(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.stack(y_list), dtype=torch.float32).to(DEVICE)
            if y_tensor.dim() == 2:
                y_tensor = y_tensor.unsqueeze(-1)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            

            
            optimizer.zero_grad()
            logits, _ = MODEL(x_tensor, y_tensor, decoder_input)
            
            # Apply curriculum mask to prevent learning tokens not yet introduced
            logits = logits + (1 - allowed_mask.view(1, 1, -1)) * -1e4
            
            loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            if not (torch.isnan(loss) or torch.isinf(loss)):
                loss.backward()
                torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                losses.append(loss.item())
                
            if (i+1) % 100 == 0:
                save_model()
                
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Pre-Entrenamiento Supervisado")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #ffd93d;">
            <h2 style="color: #ffd93d; margin: 0;">Escuela Primaria (Warmup) Completada</h2>
            <p style="color: white;">Iteraciones: {int(iterations)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #888;">El modelo ha aprendido sintaxis basica.</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_hybrid_feedback_loop(iterations, problems_per_iter=10, gp_timeout=10, max_workers=4, progress=gr.Progress()):
    """
    Teacher-Student Distillation Loop.
    1. Find problems where model has high loss.
    2. Use Hybrid Search (GP) to solve them.
    3. Train model on GP solutions.
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()
    
    try:
        MODEL, DEVICE = get_model()
        
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        # Replay buffer for "Gold Standard" examples found by GP
        replay_buffer = deque(maxlen=5000)
        quantile_loss_fn = QuantileLoss().to(DEVICE)
        
        # Randomize num_variables each iteration (see loop)
        # data_gen will be created per-iteration
        
        losses = []
        gp_successes = 0
        gp_attempts = 0
        
        start_time = time.time()
        
        for iteration in range(int(iterations)):
            if should_stop_training():
                print("⏹️ Feedback Loop stopped")
                break
                
            elapsed = time.time() - start_time
            # eta_str = f"{(int(iterations)-iteration) * (elapsed/(iteration+1) if iteration>0 else 0):.0f}s"
            iter_dur = elapsed/(iteration+1) if iteration > 0 else 0
            eta_seconds = (int(iterations)-iteration) * iter_dur
            eta_str = f"{eta_seconds:.0f}s"

            progress((iteration + 1) / iterations, 
                     desc=f"Iter {iteration+1}/{int(iterations)} | GP Success: {gp_successes}/{gp_attempts} | Loss: {np.mean(losses[-10:]) if losses else 0:.3f}")
            
            start_time_loop = time.time()
            
            # --- PHASE 1: HARD MINING ---
            MODEL.eval()
            
            # Randomize number of variables for this iteration (1-10)
            iter_num_vars = random.randint(1, 10)
            data_gen = DataGenerator(max_depth=3, num_variables=iter_num_vars)
            
            # Generate candidates
            pool_size = 50 
            candidates = data_gen.generate_inverse_batch(pool_size, point_count=10)
            
            hard_problems = []
            
            # Skip if no valid candidates
            if not candidates:
                continue
                
            with torch.no_grad():
                # We want to find problems with HIGH LOSS (model failure)
                # Quick batch forward
                x_list = [d['x'] for d in candidates]
                y_list = [d['y'] for d in candidates]
                x_list, y_list = normalize_batch(x_list, y_list)
                
                token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in candidates]
                
                # Filter out empty token lists
                valid_indices = [i for i, tl in enumerate(token_lists) if len(tl) > 0]
                if not valid_indices:
                    continue
                    
                token_lists = [token_lists[i] for i in valid_indices]
                candidates = [candidates[i] for i in valid_indices]
                x_list = [x_list[i] for i in valid_indices]
                y_list = [y_list[i] for i in valid_indices]
                actual_pool_size = len(valid_indices)

                # Sync candidates with normalized/filtered values
                for k_sync in range(actual_pool_size):
                    candidates[k_sync]['x'] = x_list[k_sync]
                    candidates[k_sync]['y'] = y_list[k_sync]
                
                max_len = max(len(s) for s in token_lists)
                
                dec_in = torch.full((actual_pool_size, max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                targets = torch.full((actual_pool_size, max_len + 1), -1, dtype=torch.long).to(DEVICE)
                
                for j, seq in enumerate(token_lists):
                    dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                    targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                
                # Ensure uniform array shapes
                try:
                    x_tensor = torch.tensor(np.stack(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.stack(y_list), dtype=torch.float32).to(DEVICE)
                    if y_tensor.dim() == 2:
                        y_tensor = y_tensor.unsqueeze(-1)
                except Exception as e:
                    print(f"Skipping batch due to shape error: {e}")
                    continue
                
                try:
                    logits, value_pred = MODEL(x_tensor, y_tensor, dec_in)
                except Exception as e:
                    print(f"Skipping batch due to model error: {e}")
                    continue
                
                loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
                raw_losses = loss_f(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
                raw_losses = raw_losses.view(actual_pool_size, -1)
                
                mask = (targets != -1)
                sample_losses = (raw_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)
                
                # Filter: Keep if loss > 1.0 (arbitrary threshold for "confused")
                for j, loss_val in enumerate(sample_losses):
                    if loss_val.item() > 0.5: # Lower threshold to catch more
                        hard_problems.append(candidates[j])
            
            # Take top K hardest
            # Limit GP calls per iter to avoid slowness
            problems_to_solve = hard_problems[:int(problems_per_iter)]
            
            if not problems_to_solve:
                if (iteration + 1) % 5 == 0 or iteration == 0:
                    print(f"Iter {iteration}: Looking for hard problems (found 0 in pool of {actual_pool_size})...")
                continue

            # --- PHASE 2: TEACHER SOLVES (GP) ---
            print(f"Iter {iteration}: Asking Teacher to solve {len(problems_to_solve)} hard problems...")
            
            for prob in problems_to_solve:
                gp_attempts += 1
                try:
                    # Calculate current stats
                    current_prob = (gp_attempts % int(problems_per_iter)) + 1
                    success_rate = (gp_successes / gp_attempts * 100) if gp_attempts > 0 else 0
                    loss_display = f"{losses[-1]:.4f}" if losses else "---"
                    
                    # Construct Live HTML with glassmorphism design
                    status_html = f"""
                    <div style="background: linear-gradient(135deg, rgba(26,26,46,0.95) 0%, rgba(22,33,62,0.95) 100%); 
                                padding: 20px; border-radius: 16px; 
                                border: 1px solid rgba(74,222,128,0.3);
                                box-shadow: 0 8px 32px rgba(0,0,0,0.3);
                                backdrop-filter: blur(10px);">
                        
                        <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 16px;">
                            <span style="font-size: 28px;">🚀</span>
                            <div>
                                <h3 style="color: #4ade80; margin: 0; font-size: 18px; font-weight: 600;">Training Hybrid Loop</h3>
                                <span style="color: #888; font-size: 12px;">Teacher-Student Distillation</span>
                            </div>
                            <div style="margin-left: auto; background: rgba(74,222,128,0.2); padding: 4px 12px; border-radius: 20px;">
                                <span style="color: #4ade80; font-size: 14px; font-weight: 500;">LIVE</span>
                            </div>
                        </div>
                        
                        <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 16px;">
                            <div style="background: rgba(255,255,255,0.05); padding: 12px; border-radius: 10px; text-align: center;">
                                <div style="color: #888; font-size: 11px; text-transform: uppercase;">Iteración</div>
                                <div style="color: #fff; font-size: 20px; font-weight: 600;">{iteration+1}<span style="color:#666; font-size:14px;">/{iterations}</span></div>
                            </div>
                            <div style="background: rgba(255,255,255,0.05); padding: 12px; border-radius: 10px; text-align: center;">
                                <div style="color: #888; font-size: 11px; text-transform: uppercase;">Problema</div>
                                <div style="color: #fff; font-size: 20px; font-weight: 600;">{current_prob}<span style="color:#666; font-size:14px;">/{int(problems_per_iter)}</span></div>
                            </div>
                            <div style="background: rgba(255,255,255,0.05); padding: 12px; border-radius: 10px; text-align: center;">
                                <div style="color: #888; font-size: 11px; text-transform: uppercase;">GP Éxitos</div>
                                <div style="color: #4ade80; font-size: 20px; font-weight: 600;">{gp_successes}</div>
                            </div>
                            <div style="background: rgba(255,255,255,0.05); padding: 12px; border-radius: 10px; text-align: center;">
                                <div style="color: #888; font-size: 11px; text-transform: uppercase;">Loss</div>
                                <div style="color: #00d4ff; font-size: 20px; font-weight: 600;">{loss_display}</div>
                            </div>
                        </div>
                        
                        <div style="display: flex; align-items: center; gap: 8px; padding: 10px; background: rgba(255,217,61,0.1); border-radius: 8px;">
                            <span style="font-size: 16px;">⏳</span>
                            <span style="color: #ffd93d; font-size: 14px;">ETA: <strong>{locals().get('eta_str', 'Calculando...')}</strong></span>
                            <div style="flex: 1; height: 4px; background: rgba(255,255,255,0.1); border-radius: 2px; margin-left: 12px;">
                                <div style="width: {((iteration * int(problems_per_iter) + gp_attempts) / (iterations * problems_per_iter) * 100):.0f}%; height: 100%; background: linear-gradient(90deg, #ffd93d, #4ade80); border-radius: 2px;"></div>
                            </div>
                        </div>
                        
                        {locals().get('seeds_html', '')}
                    </div>
                    """
                    # Always create a graph (placeholder if empty)
                    fig = create_loss_plot(losses, "Training Loss")
                    yield status_html, fig
                    
                    # Run Hybrid Search (Quick Mode)
                    # We pass the model so beam search can seed the GP
                    res = None # Initialize to avoid UnboundLocalError
                    res = hybrid_solve(
                        prob['x'], 
                        prob['y'], 
                        MODEL, 
                        DEVICE, 
                        beam_width=10,     # Faster beam
                        gp_timeout=gp_timeout,
                        gp_binary_path=None,
                        max_workers=max_workers,      # Parallel Workers (Mission 1)
                        num_variables=iter_num_vars
                    )
                    
                    # --- UI UPDATE: LIVE STATS ---
                    elapsed_total = time.time() - start_time_loop
                    full_loop_problems = iterations * problems_per_iter
                    solved_problems_count = (iteration * int(problems_per_iter)) + gp_attempts
                    if solved_problems_count > 0:
                        avg_time = elapsed_total / solved_problems_count
                        remaining = full_loop_problems - solved_problems_count
                        eta_seconds = remaining * avg_time
                        eta_str = f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s"
                    else:
                        eta_str = "Calculando..."

                    seeds_html = ""
                    if res and 'seeds_tried' in res and res['seeds_tried']:
                        seeds_html = "<h4 style='color:#ccc; margin-bottom:5px;'>🔎 Top Seeds per Worker:</h4>"
                        seeds_html += "<div style='display:flex; flex-wrap:wrap; gap:5px; font-size:12px; color:#888;'>"
                        for i, s in enumerate(res['seeds_tried']):
                            seeds_html += f"<span style='background:#222; padding:3px 6px; border-radius:4px;'>Worker {i+1}: {s[:30]}...</span>"
                        seeds_html += "</div>"
                    
                    gp_rmse = res.get('rmse', 1e6)
                    if res and res.get('formula') and gp_rmse <= 0.01:
                        # SUCCESS!
                        gp_successes += 1
                        
                        # Parse formula to tokens
                        try:
                            # 1. Parse string to tree
                            tree = ExpressionTree.from_infix(res['formula'])
                            # 2. Get tokens
                            tokens = tree.tokens
                            
                            # SCALED REWARD + Efficiency
                            # 1. Quality Reward (0.2 to 1.0)
                            quality_reward = max(0.2, 1.0 - (gp_rmse * 1.6))
                            
                            # 2. Efficiency Bonus (0.5 to 1.0)
                            taken_time = res.get('time', 10.0)
                            efficiency_bonus = 1.0
                            if taken_time > 5.0:
                                decay = ((taken_time - 5.0) / 25.0) * 0.5
                                efficiency_bonus = max(0.5, 1.0 - decay)
                            
                            # Final Reward = Quality * Efficiency
                            final_reward = quality_reward * efficiency_bonus

                            replay_buffer.append({
                                'x': prob['x'],
                                'y': prob['y'],
                                'tokens': tokens,
                                'source': 'GP_Teacher',
                                'reward': final_reward
                            })

                            # --- MISSION 2: PERSISTENCE ---
                            try:
                                log_file = os.path.join("results", "learned_formulas.csv")
                                file_exists = os.path.isfile(log_file)
                                
                                with open(log_file, "a", newline="", encoding="utf-8") as csvfile:
                                    fieldnames = ["timestamp", "formula", "rmse", "complexity", "source", "time_taken"]
                                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                                    
                                    if not file_exists:
                                        writer.writeheader()
                                        
                                    writer.writerow({
                                        "timestamp": datetime.datetime.now().isoformat(),
                                        "formula": res['formula'],
                                        "rmse": res.get('rmse', 0.0),
                                        "complexity": len(tokens),
                                        "source": "GP_Teacher",
                                        "time_taken": res.get('time', 0.0)
                                    })
                            except Exception as e:
                                print(f"Failed to log formula to CSV: {e}")
                            # -------------------------------
                            
                        except Exception as e:
                            print(f"Failed to tokenize GP result: {e}")
                            
                except Exception as e:
                    print(f"GP Hybrid Error: {e}")
                
                # --- FALLBACK: If GP failed, use Original Ground Truth ---
                # This ensures the model always learns something and the graph updates
                found_gp_solution = (res and res.get('formula') and res.get('rmse', 1e6) <= 0.01)
                
                if not found_gp_solution:
                    # Clean tokens
                    original_tokens = [t for t in prob['tokens'] if t in TOKEN_TO_ID]
                    if len(original_tokens) > 0:
                        replay_buffer.append({
                            'x': prob['x'],
                            'y': prob['y'],
                            'tokens': original_tokens,
                            'source': 'Original',
                            'reward': 1.0  # It is the ground truth
                        })
                    
            # --- PHASE 3: STUDENT TRAINS (NN) ---
            if len(replay_buffer) > 10:
                MODEL.train()
                # Train on batch from buffer
                batch_size_train = min(len(replay_buffer), 64)
                
                # Multiple steps to enforce learning
                steps = 5
                
                for _ in range(steps):
                    batch = random.sample(list(replay_buffer), batch_size_train)
                    
                    x_list = [d['x'] for d in batch]
                    y_list = [d['y'] for d in batch]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in batch]
                    max_len = max(len(s) for s in token_lists)
                    
                    dec_in = torch.full((batch_size_train, max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                    targets = torch.full((batch_size_train, max_len + 1), -1, dtype=torch.long).to(DEVICE)
                    
                    for j, seq in enumerate(token_lists):
                        dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                        targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                        
                    # DYNAMIC PADDING FOR X (Mixed Dimensions)
                    # Find max variables in this batch
                    max_vars = max(x.shape[1] for x in x_list)
                    points = x_list[0].shape[0]
                    
                    # Create padded array (Batch, Points, MaxVars)
                    x_padded = np.zeros((batch_size_train, points, max_vars), dtype=np.float32)
                    
                    for j, x_item in enumerate(x_list):
                        current_vars = x_item.shape[1]
                        x_padded[j, :, :current_vars] = x_item
                        
                    x_t = torch.tensor(x_padded, dtype=torch.float32).to(DEVICE)
                    y_t = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    dec_in = dec_in.to(DEVICE)
                    targets = targets.to(DEVICE)
                    
                    optimizer.zero_grad()
                    logits, value_pred = MODEL(x_t, y_t, dec_in)
                    
                    # Policy Loss only (Standard Supervised)
                    # We trust the GP solution is "Correct" (Value=1.0)
                    loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1)(logits.view(-1, VOCAB_SIZE+1), targets.view(-1))
                    
                    # Value Loss (Time-Aware Reward)
                    # We extract the specific reward for each sample in the batch
                    # Default to 1.0 (legacy data) if 'reward' is missing
                    batch_rewards = [d.get('reward', 1.0) for d in batch]
                    value_targets = torch.tensor(batch_rewards, dtype=torch.float32).to(DEVICE).unsqueeze(1)
                    loss_val = quantile_loss_fn(value_pred, value_targets)
                    
                    loss = loss_ce + 0.1 * loss_val
                    
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                    optimizer.step()
                    
                    losses.append(loss.item())
                    
                scheduler.step(np.mean(losses[-10:]))
                
            if (iteration + 1) % 5 == 0:
                save_model()
                
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Feedback Loop Loss")
        
        result_html = f"""
        <div style="background: linear-gradient(135deg, #2c3e50 0%, #000000 100%); padding: 20px; border-radius: 15px; border: 2px solid #f1c40f;">
            <h2 style="color: #f1c40f; margin: 0;">Feedback Loop Completado</h2>
            <p style="color: white;">Iteraciones: {iterations} | GP Success: {gp_successes}/{gp_attempts}</p>
            <p style="color: #bbb;">Nuevos Ejemplos Generados: {len(replay_buffer)}</p>
        </div>
        """
        
        # Intermediate Yield for Live Updates
        yield result_html, fig
        return result_html, fig

    except Exception as e:
        TRAINING_STATUS["running"] = False
        import traceback
        traceback.print_exc()
        return f"Error CRITICO: {str(e)}", None


def train_from_memory(epochs=10, batch_size=32, num_variables=1, progress=gr.Progress()):
    """
    Train from 'learned_formulas.csv' (Offline RL / Imitation Learning).
    Re-trains the model on the "Gold Standard" discoveries.
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
        
    log_file = os.path.join("results", "learned_formulas.csv")
    if not os.path.exists(log_file):
        return "No se encontró el archivo 'learned_formulas.csv'. Ejecuta primero el Feedback Loop.", None
        
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        # Load Data
        import pandas as pd
        df = pd.read_csv(log_file)
        
        if len(df) < 5:
             TRAINING_STATUS["running"] = False
             return f"Muy pocos datos para entrenar ({len(df)} ejemplos). Necesitas al menos 5.", None
             
        progress(0.1, desc=f"Cargando {len(df)} fórmulas maestras...")
        
        # Parse formulas to tokens
        valid_data = []
        for _, row in df.iterrows():
            try:
                formula = row['formula']
                # Re-parse to get clean tokens
                tree = ExpressionTree.from_infix(formula)
                if tree.is_valid:
                    # Generate fresh data points for this formula to train robustly
                    # We generate dynamic X to prevent overfitting to specific points
                    # But we can also use fixed points?
                    # Better: Generate random X, evaluate Y.
                    
                    # Generate X (Multi-var support)
                    # We don't know if formula is 1D or ND from CSV easily without checking vars
                    # But we can just assume 10 features and let the formula pick what it needs?
                    # Yes, ExpressionTree handles x0..x9.
                    
                    x_val = np.random.uniform(-5, 5, (10, 10)) # 10 points, 10 feats
                    y_val = tree.evaluate(x_val)
                    
                    if np.any(np.isnan(y_val)) or np.any(np.isinf(y_val)) or np.std(y_val) < 1e-6:
                        continue
                        
                    valid_data.append({
                        'tokens': tree.tokens,
                        'tree': tree # Store tree to generate fresh data each epoch? Or pre-gen?
                    })
            except:
                continue
        
        if not valid_data:
             TRAINING_STATUS["running"] = False
             return "No se pudieron parsear fórmulas válidas del CSV.", None
             
        # Training Setup
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        losses = []
        MODEL.train()
        
        for epoch in range(int(epochs)):
            # Shuffle
            random.shuffle(valid_data)
            
            # Create Batches
            epoch_loss = 0
            count = 0
            
            for i in range(0, len(valid_data), int(batch_size)):
                batch = valid_data[i:i+int(batch_size)]
                
                # Generate fresh X/Y for this batch (Data Augmentation on the fly)
                x_list = []
                y_list = []
                token_lists = []
                
                for item in batch:
                    # Generate random points
                    x = np.random.uniform(-3, 3, (20, 10)) # 20 points
                    y = item['tree'].evaluate(x)
                    
                    # Sanity check
                    if np.any(np.isnan(y)) or np.max(np.abs(y)) > 1e4:
                        continue
                        
                    x_list.append(x)
                    y_list.append(y)
                    token_lists.append([TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in item['tokens']])
                
                if not x_list: continue
                
                x_list, y_list = normalize_batch(x_list, y_list)
                
                # Tensors
                max_len = max(len(s) for s in token_lists)
                dec_in = torch.full((len(x_list), max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                tgt = torch.full((len(x_list), max_len + 1), -1, dtype=torch.long).to(DEVICE)
                
                for j, seq in enumerate(token_lists):
                    dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                    tgt[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                    
                x_t = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                y_t = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                
                optimizer.zero_grad()
                logits, _ = MODEL(x_t, y_t, dec_in)
                loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), tgt.view(-1))
                
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                count += 1
            
            avg_loss = epoch_loss / max(1, count)
            losses.append(avg_loss)
            
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")
            
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Offline Memory Training")
        
        return f"""
        <div style="background: #1a1a2e; padding: 20px; border-radius: 10px; border: 2px solid #a855f7;">
            <h2 style="color: #a855f7;">Entrenamiento de Memoria Completado</h2>
            <p style="color:white;">Fórmulas aprendidas: {len(valid_data)}</p>
            <p style="color:white;">Loss Final: {losses[-1]:.4f}</p>
        </div>
        """, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        import traceback
        traceback.print_exc()
        return f"Error: {str(e)}", None


In [None]:
%%writefile AlphaSymbolic/ui/app_benchmark.py
import gradio as gr
from utils.benchmark_comparison import run_comparison_benchmark
from ui.app_core import get_model, DEVICE

def get_benchmark_tab():
    with gr.Tab("🥇 Benchmark (IQ Test)"):
        gr.Markdown("### Evaluar Inteligencia del Modelo (Comparativa)")
        gr.Markdown("Ejecuta una batería de **10 problemas estándar** comparando diferentes métodos de búsqueda.")
        
        with gr.Row():
            methods_chk = gr.CheckboxGroup(
                choices=["beam", "mcts", "hybrid"], 
                value=["hybrid"], 
                label="Métodos a Evaluar",
                info="Selecciona uno o más métodos para comparar."
            )
            timeout_slider = gr.Slider(
                minimum=5, 
                maximum=60, 
                value=30, 
                step=5, 
                label="Timeout GP (s)", 
                info="Tiempo máximo para Beta-GP por problema."
            )
        
        run_btn = gr.Button("🚀 Iniciar Benchmark Comparativo", variant="primary")
        
        progress_bar = gr.HTML("")
        
        # Area de resultados
        summary_html = gr.HTML("Resultados aparecerán aquí...")
        
        results_df = gr.Dataframe(
            headers=["Problema", "Nivel", "Método", "Formula", "RMSE", "Tiempo", "Estado"],
            label="Resultados Detallados",
            interactive=False
        )
        
        def run_bench(selected_methods, gp_timeout, progress=gr.Progress()):
            model_obj, device_obj = get_model()
            if not model_obj:
                return "<div>⚠️ Error: Modelo no cargado. Ve a la pestaña 'Config' y carga un modelo.</div>", None, []
            
            if not selected_methods:
                return "<div>⚠️ Error: Selecciona al menos un método.</div>", None, []
                
            progress(0, desc="Iniciando Benchmark...")
            
            # Run comparison
            try:
                result_data = run_comparison_benchmark(
                    model_obj, 
                    device_obj, 
                    methods=selected_methods,
                    gp_timeout=gp_timeout,
                    beam_width=50,
                    progress_callback=lambda p, desc: progress(p, desc=desc)
                )
            except Exception as e:
                import traceback
                traceback.print_exc()
                return f"<div>❌ Error en Benchmark: {e}</div>", None, []
            
            results = result_data['results']
            summary_dict = result_data['summary']
            
            # Format dataframe
            rows = []
            for r in results:
                status_icon = "✅" if r['success'] else "❌"
                rmse_val = f"{r['rmse']:.5f}" if r['rmse'] < 1e6 else "> 10^6"
                rows.append([
                    r['problem_name'],
                    r['level'],
                    r['method'].upper(),
                    r['formula'],
                    rmse_val,
                    f"{r['time']:.2f}s",
                    status_icon
                ])
            
            # Generate HTML Summary
            html_content = "<div style='display: flex; gap: 20px; flex-wrap: wrap; justify-content: center;'>"
            
            # Determine winner if multiple methods
            winner_method = None
            if len(selected_methods) > 1:
                winner_method = max(summary_dict.items(), key=lambda x: (x[1]['solved'], -x[1]['avg_rmse']))[0]
            
            for method, stats in summary_dict.items():
                is_winner = (method == winner_method)
                border_color = "#4CAF50" if is_winner else ("#FF9800" if stats['score'] > 50 else "#F44336")
                bg_color = "#1e1e2f"
                if is_winner:
                    bg_color = "#1b3a24" # Dark green tint for winner
                    
                trophy = "🏆 GANADOR" if is_winner else ""
                
                html_content += f"""
                <div style="background: {bg_color}; padding: 15px; border-radius: 10px; border: 2px solid {border_color}; min-width: 200px; text-align: center;">
                    <h2 style="color: {border_color}; margin: 0 0 10px 0;">{method.upper()} {trophy}</h2>
                    <div style="font-size: 24px; font-weight: bold; margin-bottom: 5px;">{stats['solved']} / {stats['total']}</div>
                    <div style="color: #ccc; font-size: 14px;">Resueltos</div>
                    <hr style="border-color: #444; margin: 10px 0;">
                    <div style="font-size: 14px;">Nota: <b>{stats['score']:.1f}%</b></div>
                    <div style="font-size: 14px;">Tiempo Avg: <b>{stats['avg_time']:.2f}s</b></div>
                </div>
                """
            html_content += "</div>"
            
            return html_content, rows
            
        run_btn.click(run_bench, inputs=[methods_chk, timeout_slider], outputs=[summary_html, results_df])


In [None]:
%%writefile AlphaSymbolic/ui/__init__.py


In [None]:
%%writefile AlphaSymbolic/utils/optimize_constants.py
"""
Constant Optimization Module for AlphaSymbolic.
Uses scipy.optimize to find optimal values for 'C' placeholders.
"""
import numpy as np
from scipy.optimize import minimize
from core.grammar import ExpressionTree

def optimize_constants(tree, x_data, y_data, method='L-BFGS-B', initial_guess=None):
    """
    Given an ExpressionTree with 'C' placeholders, find optimal constant values.
    
    Args:
        tree: ExpressionTree object
        x_data: numpy array of x values
        y_data: numpy array of target y values
        method: optimization method ('L-BFGS-B', 'SLSQP', 'Nelder-Mead')
        
    Returns:
        dict: mapping of path tuples to optimized constant values
        float: final RMSE
    """
    if not tree.is_valid:
        return {}, float('inf')
    
    # Get positions of all constants
    positions = tree.root.get_constant_positions()
    n_constants = len(positions)
    
    if n_constants == 0:
        # No constants to optimize, just evaluate
        y_pred = tree.evaluate(x_data)
        mse = np.mean((y_pred - y_data)**2)
        return {}, np.sqrt(mse)
    
    def objective(params):
        """Objective function: RMSE given constant values."""
        # Build constants dict
        constants = {tuple(pos): params[i] for i, pos in enumerate(positions)}
        
        # Evaluate
        y_pred = tree.evaluate(x_data, constants=constants)
        
        # Handle invalid predictions
        if np.any(np.isnan(y_pred)) or np.any(np.isinf(y_pred)):
            return 1e10
        
        if not np.all(np.isfinite(y_pred)):
            return 1e9
        
        # Clip huge values to prevent overflow in MSE
        y_pred = np.clip(y_pred, -1e9, 1e9)
        
        mse = np.mean((y_pred - y_data)**2)
        return mse
    
    
    # Check if initial_guess matches n_constants
    if initial_guess is not None:
        if len(initial_guess) != n_constants:
             # Fallback if mismatch
             x0 = np.ones(n_constants)
        else:
             x0 = np.array(initial_guess)
    else:
        # Initial guess: all 1s
        x0 = np.ones(n_constants)
    
    # Bounds: reasonable range for constants
    bounds = [(-1e9, 1e9)] * n_constants
    
    try:
        result = minimize(
            objective,
            x0,
            method=method,
            bounds=bounds if method in ['L-BFGS-B', 'SLSQP'] else None,
            options={'maxiter': 1000, 'disp': False}
        )
        
        # Build final constants dict
        optimized_constants = {tuple(pos): result.x[i] for i, pos in enumerate(positions)}
        final_rmse = np.sqrt(result.fun) if result.fun > 0 else 0.0
        
        return optimized_constants, final_rmse
        
    except Exception as e:
        return {}, float('inf')

def convert_and_extract_constants(node, values=None):
    """
    Recursively converts numeric nodes to 'C' and extracts their values.
    Returns: list of initial values.
    """
    if values is None:
        values = []
        
    # Check if node is a number (and not a special constant like pi/e)
    try:
        val = float(node.value)
        # It is a number. Convert to C.
        node.value = 'C'
        values.append(val)
    except:
        pass
        
    for child in node.children:
        convert_and_extract_constants(child, values)
        
    return values


def substitute_constants(infix_str, constants_dict, positions):
    """
    Replace 'C' in the infix string with optimized values.
    Simple approach: replace each C with optimized value.
    """
    # For proper substitution, we'd need to track positions properly
    # This is a simplified version that replaces all C with the first constant
    result = infix_str
    for i, pos in enumerate(positions):
        if tuple(pos) in constants_dict:
            val = constants_dict[tuple(pos)]
            # Format nicely
            if abs(val - round(val)) < 1e-6:
                val_str = str(int(round(val)))
            else:
                val_str = f"{val:.4f}"
            # Replace first occurrence of C
            result = result.replace('C', val_str, 1)
    return result


# Quick test
if __name__ == "__main__":
    # Test: C * x + C should be optimized to fit y = 2*x + 3
    x_test = np.array([1, 2, 3, 4, 5], dtype=np.float64)
    y_test = 2 * x_test + 3  # y = 2x + 3
    
    tokens = ['+', '*', 'C', 'x', 'C']  # C*x + C
    tree = ExpressionTree(tokens)
    
    print(f"Formula structure: {tree.get_infix()}")
    print(f"Target: y = 2x + 3")
    
    constants, rmse = optimize_constants(tree, x_test, y_test)
    print(f"Optimized constants: {constants}")
    print(f"Final RMSE: {rmse:.6f}")
    
    # Verify
    y_pred = tree.evaluate(x_test, constants=constants)
    print(f"Predictions: {y_pred}")
    print(f"Targets: {y_test}")


In [None]:
%%writefile AlphaSymbolic/utils/detect_pattern.py
"""
Target Pattern Detection for AlphaSymbolic.
Analyzes target Y values to detect patterns (polynomial, exponential, periodic, etc.)
and suggests initial search biases.
"""
import numpy as np
from scipy import stats
from scipy.fft import fft
from core.grammar import ExpressionTree

def detect_pattern(x_values, y_values):
    """
    Analyze (x, y) data to detect patterns.
    Returns a dict with pattern type probabilities and suggested operators.
    """
    x = np.array(x_values, dtype=np.float64)
    y = np.array(y_values, dtype=np.float64)
    
    results = {
        'type': 'unknown',
        'confidence': 0.0,
        'suggested_ops': [],
        'details': {}
    }
    
    if len(x) < 3:
        return results

    # Handle Multivariable Input (Skip 1D pattern checks)
    if x.ndim > 1 and x.shape[1] > 1:
        results['type'] = 'multivariable'
        results['confidence'] = 1.0
        results['suggested_ops'] = ['+', '-', '*', 'x', 'C']
        results['details']['multivariable'] = {'num_vars': x.shape[1]}
        return results
    
    scores = {}
    
    # 1. Check for linear pattern (y = ax + b)
    if len(x) >= 2:
        slope, intercept, r_value, _, _ = stats.linregress(x, y)
        scores['linear'] = r_value ** 2
        results['details']['linear'] = {
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value ** 2
        }
    
    # 2. Check for quadratic pattern (y = ax^2 + bx + c)
    if len(x) >= 3:
        try:
            coeffs = np.polyfit(x, y, 2)
            y_pred = np.polyval(coeffs, x)
            ss_res = np.sum((y - y_pred) ** 2)
            ss_tot = np.sum((y - np.mean(y)) ** 2)
            r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
            scores['quadratic'] = r2
            results['details']['quadratic'] = {
                'coefficients': coeffs.tolist(),
                'r_squared': r2
            }
        except:
            pass
    
    # 3. Check for exponential pattern (y = a * e^(bx))
    if np.all(y > 0):  # Exponential only for positive y
        try:
            log_y = np.log(y)
            slope, intercept, r_value, _, _ = stats.linregress(x, log_y)
            scores['exponential'] = r_value ** 2
            results['details']['exponential'] = {
                'a': np.exp(intercept),
                'b': slope,
                'r_squared': r_value ** 2
            }
        except:
            pass
    
    # 4. Check for periodic/sinusoidal pattern
    if len(y) >= 4:
        try:
            # Simple FFT analysis
            y_centered = y - np.mean(y)
            fft_vals = np.abs(fft(y_centered))
            
            # Check if there's a dominant frequency
            if len(fft_vals) > 1:
                max_idx = np.argmax(fft_vals[1:len(fft_vals)//2]) + 1
                max_power = fft_vals[max_idx]
                total_power = np.sum(fft_vals[1:len(fft_vals)//2])
                
                if total_power > 0:
                    periodicity = max_power / total_power
                    scores['periodic'] = periodicity
                    results['details']['periodic'] = {
                        'dominant_freq_idx': int(max_idx),
                        'periodicity_score': periodicity
                    }
        except:
            pass
    
    # 5. Check for power law (y = a * x^b)
    if np.all(x > 0) and np.all(y > 0):
        try:
            log_x = np.log(x)
            log_y = np.log(y)
            slope, intercept, r_value, _, _ = stats.linregress(log_x, log_y)
            scores['power'] = r_value ** 2
            results['details']['power'] = {
                'a': np.exp(intercept),
                'b': slope,
                'r_squared': r_value ** 2
            }
        except:
            pass
    
    # 6. Check for factorial/gamma pattern (for integer-like x)
    if np.all(x > 0) and np.all(x == np.floor(x)):
        try:
            from scipy.special import gamma
            x_int = x.astype(int)
            y_gamma = gamma(x_int + 1)  # gamma(n+1) = n!
            
            # Simple linear fit between y and gamma
            if not np.any(np.isinf(y_gamma)):
                slope, intercept, r_value, _, _ = stats.linregress(y_gamma, y)
                scores['factorial'] = r_value ** 2
                results['details']['factorial'] = {
                    'r_squared': r_value ** 2
                }
        except:
            pass
    
    # Determine best pattern
    if scores:
        best_pattern = max(scores.items(), key=lambda x: x[1])
        results['type'] = best_pattern[0]
        results['confidence'] = best_pattern[1]
        
        # Suggest operators based on pattern
        op_suggestions = {
            'linear': ['+', '-', '*', 'x', 'C'],
            'quadratic': ['pow', '+', '*', 'x', 'C', '2'],
            'exponential': ['exp', '*', '+', 'x', 'C'],
            'periodic': ['sin', 'cos', '*', '+', 'x', 'C'],
            'power': ['pow', '*', 'x', 'C'],
            'factorial': ['gamma', '*', '+', 'x', 'C']
        }
        results['suggested_ops'] = op_suggestions.get(best_pattern[0], [])
    
    return results


def summarize_pattern(result):
    """Pretty-print pattern detection result."""
    print(f"\n=== Pattern Detection ===")
    print(f"Detected Type: {result['type']} (confidence: {result['confidence']:.2%})")
    print(f"Suggested Operators: {', '.join(result['suggested_ops'])}")
    
    if result['type'] in result['details']:
        print(f"Details: {result['details'][result['type']]}")


if __name__ == "__main__":
    # Test with different patterns
    
    # Linear: y = 2x + 3
    print("\n--- Test: Linear ---")
    x1 = np.linspace(0, 10, 20)
    y1 = 2 * x1 + 3 + np.random.normal(0, 0.1, 20)
    result1 = detect_pattern(x1, y1)
    summarize_pattern(result1)
    
    # Quadratic: y = x^2 + 1
    print("\n--- Test: Quadratic ---")
    x2 = np.linspace(-5, 5, 20)
    y2 = x2**2 + 1
    result2 = detect_pattern(x2, y2)
    summarize_pattern(result2)
    
    # Exponential: y = 2 * e^(0.5x)
    print("\n--- Test: Exponential ---")
    x3 = np.linspace(0, 5, 20)
    y3 = 2 * np.exp(0.5 * x3)
    result3 = detect_pattern(x3, y3)
    summarize_pattern(result3)
    
    # Periodic: y = sin(x)
    print("\n--- Test: Periodic ---")
    x4 = np.linspace(0, 4*np.pi, 50)
    y4 = np.sin(x4)
    result4 = detect_pattern(x4, y4)
    summarize_pattern(result4)


In [None]:
%%writefile AlphaSymbolic/utils/benchmark_runner.py
import torch
import numpy as np
import time
import traceback
from search.mcts import MCTS
from data.benchmark_data import BENCHMARK_SUITE, get_benchmark_data
from utils.optimize_constants import optimize_constants

def run_benchmark_suite(model, device, progress_callback=None):
    """
    Runs the full benchmark suite.
    Args:
        model: Loaded AlphaSymbolic model
        device: Torch device
        progress_callback: Function(float, string) to update UI
        
    Returns:
        results: List of result dicts
        summary: Dict with aggregated stats
    """
    results = []
    
    # Configure MCTS for benchmark (balanced speed/accuracy)
    # 500 simulations is decent for benchmarking
    mcts = MCTS(model, device, max_simulations=500, batch_size=32)
    
    total = len(BENCHMARK_SUITE)
    solved_count = 0
    
    for i, problem in enumerate(BENCHMARK_SUITE):
        if progress_callback:
            progress_callback(i / total, f"Testing: {problem['name']}...")
            
        x, y, _ = get_benchmark_data(problem['id'])
        
        start_time = time.time()
        
        # Run Search
        try:
            search_result = mcts.search(x, y)
             # Determine success
            # Success threshold: RMSE < 0.01 (or 1% relative error)
            rmse = search_result['rmse']
            is_solved = rmse < 0.05 # Looser threshold for general regression
            
            # Special check for exact integer symbolic match? No, RMSE is ground truth.
            
            elapsed = time.time() - start_time
            
            if is_solved:
                solved_count += 1
                status = "✅ SOLVED"
            else:
                status = "❌ FAILED"
                
            results.append({
                'id': problem['id'],
                'name': problem['name'],
                'level': problem['level'],
                'rmse': rmse,
                'time': elapsed,
                'status': status,
                'found_formula': search_result.get('formula', '???'),
                'is_solved': is_solved
            })
            
        except Exception as e:
            print(f"Error in benchmark {problem['name']}:")
            traceback.print_exc()
            results.append({
                'id': problem['id'],
                'name': problem['name'],
                'level': problem['level'],
                'rmse': 1e9,
                'time': 0,
                'status': "⚠️ ERROR",
                'found_formula': "Error",
                'is_solved': False
            })

    # Summary
    if progress_callback:
        progress_callback(1.0, "Done!")
        
    score = (solved_count / total) * 100
    summary = {
        'total': total,
        'solved': solved_count,
        'score': score,
        'avg_time': np.mean([r['time'] for r in results]) if results else 0
    }
    
    return results, summary


In [None]:
%%writefile AlphaSymbolic/utils/benchmark_comparison.py
"""
Comparative Benchmark: Beam Search vs MCTS vs Alpha-GP Hybrid
Runs all three search methods on the standard benchmark suite and compares performance.
"""
import torch
import numpy as np
import time
import traceback
from typing import List, Dict, Callable, Optional

from search.mcts import MCTS
from search.beam_search import BeamSearch
from search.hybrid_search import hybrid_solve
from data.benchmark_data import BENCHMARK_SUITE, get_benchmark_data
from core.grammar import ExpressionTree
from utils.optimize_constants import optimize_constants


def run_single_problem(
    x: np.ndarray, 
    y: np.ndarray, 
    method: str, 
    model, 
    device,
    timeout_sec: int = 30,
    beam_width: int = 50
) -> Dict:
    """
    Runs a single search method on a single problem.
    
    Returns:
        dict with keys: formula, rmse, time, success
    """
    start_time = time.time()
    
    try:
        if method == "beam":
            searcher = BeamSearch(model, device, beam_width=beam_width)
            # BeamSearch expects list-like input and returns a list of results sorted by RMSE
            results_list = searcher.search(x.tolist(), y.tolist())
            elapsed = time.time() - start_time
            if results_list and len(results_list) > 0:
                result = results_list[0]  # Best result (sorted by RMSE)
                return {
                    'formula': result.get('formula', 'N/A'),
                    'rmse': result.get('rmse', 1e9),
                    'time': elapsed,
                    'success': result.get('rmse', 1e9) < 0.05
                }
            else:
                return {'formula': 'No Result', 'rmse': 1e9, 'time': elapsed, 'success': False}
            
        elif method == "mcts":
            mcts = MCTS(model, device, max_simulations=500, batch_size=32)
            # MCTS expects list-like input 
            result = mcts.search(x.tolist(), y.tolist())
            elapsed = time.time() - start_time
            return {
                'formula': result.get('formula', 'N/A'),
                'rmse': result.get('rmse', 1e9),
                'time': elapsed,
                'success': result.get('rmse', 1e9) < 0.05
            }
            
        elif method == "hybrid":
            result = hybrid_solve(
                model=model,
                device=device,
                x_values=x.tolist(),
                y_values=y.tolist(),
                beam_width=beam_width,
                gp_timeout=timeout_sec
            )
            elapsed = time.time() - start_time
            
            if result['formula']:
                # Evaluate RMSE for hybrid result
                try:
                    tree = ExpressionTree.from_infix(result['formula'])
                    if tree.is_valid:
                        preds = tree.evaluate(x)
                        rmse = np.sqrt(np.mean((preds - y) ** 2))
                    else:
                        rmse = 1e9
                except:
                    rmse = 1e9
            else:
                rmse = 1e9
                
            return {
                'formula': result.get('formula', 'N/A') or 'Failed',
                'rmse': rmse,
                'time': elapsed,
                'success': rmse < 0.05
            }
        else:
            return {'formula': 'Unknown Method', 'rmse': 1e9, 'time': 0, 'success': False}
            
    except Exception as e:
        print(f"[ERROR] Method {method} failed: {e}")
        traceback.print_exc()
        return {'formula': 'Error', 'rmse': 1e9, 'time': time.time() - start_time, 'success': False}


def run_comparison_benchmark(
    model, 
    device, 
    methods: List[str] = ["beam", "mcts", "hybrid"],
    gp_timeout: int = 30,
    beam_width: int = 50,
    progress_callback: Optional[Callable] = None
) -> Dict:
    """
    Runs all methods on all benchmark problems.
    
    Returns:
        Dict with 'results' (per-problem-per-method) and 'summary' (aggregated stats)
    """
    results = []
    method_stats = {m: {'solved': 0, 'total_time': 0, 'total_rmse': 0} for m in methods}
    
    total_steps = len(BENCHMARK_SUITE) * len(methods)
    current_step = 0
    
    for problem in BENCHMARK_SUITE:
        x, y, _ = get_benchmark_data(problem['id'])
        
        for method in methods:
            current_step += 1
            
            if progress_callback:
                progress_callback(
                    current_step / total_steps, 
                    f"[{method.upper()}] {problem['name']}..."
                )
            
            result = run_single_problem(x, y, method, model, device, gp_timeout, beam_width)
            
            results.append({
                'problem_id': problem['id'],
                'problem_name': problem['name'],
                'level': problem['level'],
                'method': method,
                'formula': result['formula'],
                'rmse': result['rmse'],
                'time': result['time'],
                'success': result['success']
            })
            
            # Update stats
            method_stats[method]['total_time'] += result['time']
            method_stats[method]['total_rmse'] += result['rmse'] if result['rmse'] < 1e6 else 0
            if result['success']:
                method_stats[method]['solved'] += 1
    
    # Compute summary
    num_problems = len(BENCHMARK_SUITE)
    summary = {}
    for method in methods:
        stats = method_stats[method]
        summary[method] = {
            'solved': stats['solved'],
            'total': num_problems,
            'score': (stats['solved'] / num_problems) * 100,
            'avg_time': stats['total_time'] / num_problems,
            'avg_rmse': stats['total_rmse'] / num_problems
        }
    
    if progress_callback:
        progress_callback(1.0, "Benchmark Complete!")
    
    return {'results': results, 'summary': summary}


def format_comparison_table(results: List[Dict]) -> str:
    """
    Formats the results as a human-readable table.
    """
    # Group by problem
    problems = {}
    for r in results:
        pid = r['problem_id']
        if pid not in problems:
            problems[pid] = {'name': r['problem_name'], 'level': r['level'], 'methods': {}}
        problems[pid]['methods'][r['method']] = {
            'rmse': r['rmse'],
            'time': r['time'],
            'success': r['success'],
            'formula': r['formula']
        }
    
    output = []
    output.append("=" * 100)
    output.append(f"{'Problem':<25} | {'Method':<8} | {'RMSE':<12} | {'Time':<8} | {'Status':<10} | Formula")
    output.append("=" * 100)
    
    for pid, pdata in problems.items():
        name = pdata['name'][:24]
        for method, mdata in pdata['methods'].items():
            rmse_str = f"{mdata['rmse']:.6f}" if mdata['rmse'] < 1e6 else "FAILED"
            time_str = f"{mdata['time']:.2f}s"
            status = "[OK]" if mdata['success'] else "[FAIL]"
            formula = mdata['formula'][:40] if mdata['formula'] else "N/A"
            output.append(f"{name:<25} | {method:<8} | {rmse_str:<12} | {time_str:<8} | {status:<10} | {formula}")
        output.append("-" * 100)
    
    return "\n".join(output)


def print_summary(summary: Dict):
    """
    Prints a formatted summary comparison.
    """
    print("\n" + "=" * 60)
    print("BENCHMARK SUMMARY - Method Comparison")
    print("=" * 60)
    print(f"{'Method':<12} | {'Solved':<10} | {'Score':<10} | {'Avg Time':<10} | {'Avg RMSE':<12}")
    print("-" * 60)
    
    for method, stats in summary.items():
        solved_str = f"{stats['solved']}/{stats['total']}"
        score_str = f"{stats['score']:.1f}%"
        time_str = f"{stats['avg_time']:.2f}s"
        rmse_str = f"{stats['avg_rmse']:.6f}"
        print(f"{method.upper():<12} | {solved_str:<10} | {score_str:<10} | {time_str:<10} | {rmse_str:<12}")
    
    print("=" * 60)
    
    # Determine winner
    best_method = max(summary.items(), key=lambda x: (x[1]['solved'], -x[1]['avg_rmse']))
    print(f"\n*** WINNER: {best_method[0].upper()} with {best_method[1]['solved']}/{best_method[1]['total']} problems solved! ***")


if __name__ == "__main__":
    # Standalone test
    import sys
    sys.path.insert(0, '.')
    
    from ui.app_core import load_model, get_model
    
    print("Loading model...")
    load_model()
    model, device = get_model()
    
    if model is None:
        print("Error: No model loaded!")
        exit(1)
    
    print("Running comparison benchmark...")
    result = run_comparison_benchmark(
        model, 
        device, 
        methods=["beam", "mcts", "hybrid"],
        gp_timeout=30,
        beam_width=50
    )
    
    print(format_comparison_table(result['results']))
    print_summary(result['summary'])


In [None]:
%%writefile AlphaSymbolic/utils/simplify.py
"""
Algebraic Simplification Module for AlphaSymbolic.
Uses SymPy for symbolic math simplification.
"""
import sympy as sp
from core.grammar import Node, ExpressionTree, OPERATORS

# SymPy symbols
x_sym = sp.Symbol('x')
# Multi-variable support (x0 through x9)
x_syms = {f'x{i}': sp.Symbol(f'x{i}') for i in range(10)}

def tree_to_sympy(node):
    """Convert an ExpressionTree Node to a SymPy expression."""
    if node is None:
        return sp.Integer(0)
    
    val = node.value
    
    # Terminals
    if val == 'x':
        return x_sym
    if val in x_syms:
        return x_syms[val]
    if val == 'pi':
        return sp.pi
    if val == 'e':
        return sp.E
    if val == 'C':
        # Keep C as symbol for now
        return sp.Symbol('C')
    
    # Try numeric
    try:
        return sp.Float(float(val))
    except:
        pass
    
    # Operators
    args = [tree_to_sympy(c) for c in node.children]
    
    if val == '+': return args[0] + args[1]
    if val == '-': return args[0] - args[1]
    if val == '*': return args[0] * args[1]
    if val == '/': return args[0] / args[1]
    if val == 'pow': return sp.Pow(args[0], args[1])
    if val == 'mod': return sp.Mod(args[0], args[1])
    if val == 'sin': return sp.sin(args[0])
    if val == 'cos': return sp.cos(args[0])
    if val == 'tan': return sp.tan(args[0])
    if val == 'exp': return sp.exp(args[0])
    if val == 'log': return sp.log(args[0])
    if val == 'sqrt': return sp.sqrt(args[0])
    if val == 'abs': return sp.Abs(args[0])
    if val == 'floor': return sp.floor(args[0])
    if val == 'ceil': return sp.ceiling(args[0])
    if val == 'gamma': return sp.gamma(args[0])
    if val == 'lgamma': return sp.loggamma(args[0])  # SymPy's log-gamma
    if val == 'neg': return -args[0]
    
    return sp.Integer(0)

def sympy_to_infix(expr):
    """Convert SymPy expression back to a readable string."""
    return str(expr)

def simplify_tree(tree):
    """
    Takes an ExpressionTree and returns a simplified infix string.
    """
    if not tree.is_valid:
        return "Invalid"
    
    original_infix = tree.get_infix()
    
    try:
        sympy_expr = tree_to_sympy(tree.root)
        simplified = sp.simplify(sympy_expr)
        result_str = str(simplified)
        
        # Validate: reject results containing invalid SymPy artifacts
        # zoo = complex infinity, nan, oo = infinity
        invalid_terms = ['zoo', 'nan', 'I*']  # I* indicates complex numbers
        for term in invalid_terms:
            if term in result_str:
                return original_infix  # Fall back to original
        
        return result_str
    except Exception as e:
        # If simplification fails, return original
        return original_infix

def simplify_infix(infix_str):
    """
    Takes an infix string and returns a simplified version.
    """
    try:
        expr = sp.sympify(infix_str)
        simplified = sp.simplify(expr)
        return str(simplified)
    except:
        return infix_str

# Quick test
if __name__ == "__main__":
    from core.grammar import ExpressionTree
    
    # Test: x + 0 should simplify to x
    tokens = ['+', 'x', '0']
    tree = ExpressionTree(tokens)
    print(f"Original: {tree.get_infix()}")
    print(f"Simplified: {simplify_tree(tree)}")
    
    # Test: x * 1 should simplify to x
    tokens2 = ['*', 'x', '1']
    tree2 = ExpressionTree(tokens2)
    print(f"Original: {tree2.get_infix()}")
    print(f"Simplified: {simplify_tree(tree2)}")
    
    # Test: x - x should simplify to 0
    tokens3 = ['-', 'x', 'x']
    tree3 = ExpressionTree(tokens3)
    print(f"Original: {tree3.get_infix()}")
    print(f"Simplified: {simplify_tree(tree3)}")


In [None]:
%%writefile AlphaSymbolic/utils/__init__.py


In [None]:
%%writefile AlphaSymbolic/tools/run_benchmark_feynman.py

import torch
import numpy as np
import pandas as pd
import time
from tabulate import tabulate
from ui.app_core import get_model, load_model, MODEL_PRESETS
from search.hybrid_search import hybrid_solve
from data.expanded_benchmarks import load_expanded_feynman_subset, evaluate_projected_formula
from core.grammar import ExpressionTree

def evaluate_dynamic(problem, x_val):
    """Wrapper to evaluate dynamic problems with fixed contexts."""
    return evaluate_projected_formula(
        problem['original_formula'], 
        problem['target_var'], 
        x_val, 
        problem['fixed_context']
    )

def run_benchmark():
    print("\n" + "="*80)
    print("🔬 ALPHA SYMBOLIC: EXPANDED FEYNMAN BENCHMARK (LITE vs PRO)")
    print("="*80 + "\n")
    
    # LOAD DATASETS
    print("Loading Feynman Dataset (FULL)...")
    problems = load_expanded_feynman_subset(limit=None) 
    if not problems:
        print("No problems loaded. Check data/benchmarks/FeynmanEquations.csv")
        return

    presets_to_test = ['lite', 'pro']
    all_results = []
    
    summary_comparison = []

    for preset in presets_to_test:
        print(f"\n>>> LOADING MODEL: {preset.upper()} <<<")
        try:
            status, info = load_model(preset_name=preset)
            print(f"Status: {status} | Device: {info}")
            model, device = get_model()
        except Exception as e:
            print(f"Failed to load {preset}: {e}")
            continue

        preset_results = []
        
        # Iterate Problems
        for i, problem in enumerate(problems):
            print(f"\n[{preset.upper()}] Problem {i+1}/{len(problems)}: {problem['name']}")
            print(f"Target: {problem['original_formula']}")
            print(f"Desc: {problem['description']}")
            
            # Generate Data
            x_test = np.linspace(0.1, 5.0, 20)
            y_test = evaluate_dynamic(problem, x_test)
            
            if np.any(np.isnan(y_test)) or np.any(np.isinf(y_test)):
                print("Skipping due to numerical issues.")
                continue
                
            # Solve
            start_time = time.time()
            try:
                solution = hybrid_solve(
                    x_test, 
                    y_test, 
                    model, 
                    device, 
                    beam_width=50,
                    gp_timeout=10,
                    max_workers=6
                )
                
                elapsed = time.time() - start_time
                
                if solution:
                    pred_formula = solution.get('formula', "N/A")
                    
                    # Verify RMSE
                    try:
                        # For RMSE check we treat prediction as function of x
                        # The ground truth y_test is already correct
                        pred_tree = ExpressionTree.from_infix(pred_formula)
                        y_pred = pred_tree.evaluate(x_test)
                        real_rmse = np.sqrt(np.mean((y_test - y_pred)**2))
                        is_solved = real_rmse < 0.05 # Relaxed slightly for complex physics
                        status_text = "✅ SOLVED" if is_solved else "❌ FAILED"
                    except:
                        real_rmse = 999.0
                        status_text = "⚠️ ERROR"
                    
                    print(f"Result: {status_text} | RMSE: {real_rmse:.4f} | Time: {elapsed:.2f}s")
                    
                    all_results.append({
                        "Model": preset,
                        "ID": problem['id'],
                        "Name": problem['name'],
                        "Target": problem['original_formula'],
                        "Prediction": pred_formula,
                        "RMSE": real_rmse,
                        "Time": elapsed,
                        "Status": status_text
                    })
                    
                    preset_results.append({
                        "ID": problem['id'],
                        "Status": status_text,
                        "Time": elapsed
                    })

                else:
                    print("Result: No solution found.")
                    all_results.append({"Model": preset, "ID": problem['id'], "Name": problem['name'], "Status": "NO_SOLUTION", "RMSE": 999.0, "Time": elapsed})
                    preset_results.append({"ID": problem['id'], "Status": "NO_SOLUTION", "Time": elapsed})
                    
            except Exception as e:
                print(f"Error executing solve: {e}")
                all_results.append({"Model": preset, "ID": problem['id'], "Name": problem['name'], "Status": "CRASH", "RMSE": 999.0, "Time": 0.0})
                preset_results.append({"ID": problem['id'], "Status": "CRASH", "Time": 0.0})

        summary_comparison.append({"Model": preset, "Results": preset_results})

    # Final Comparative Report
    print("\n" + "="*80)
    print("🏆 FINAL COMPARISON REPORT (EXPANDED)")
    print("="*80)
    
    # Pivot results for side-by-side view
    comparison_rows = []
    
    lite_map = {r['ID']: r for r in summary_comparison[0]['Results']} if len(summary_comparison) > 0 else {}
    pro_map = {r['ID']: r for r in summary_comparison[1]['Results']} if len(summary_comparison) > 1 else {}
    
    for problem in problems:
        pid = problem['id']
        name = problem['name']
        
        l_res = lite_map.get(pid, {"Status": "N/A", "Time": 0.0})
        p_res = pro_map.get(pid, {"Status": "N/A", "Time": 0.0})
        
        comparison_rows.append({
            "ID": pid,
            "LITE Status": l_res['Status'],
            "LITE Time": f"{l_res['Time']:.2f}s",
            "PRO Status": p_res['Status'],
            "PRO Time": f"{p_res['Time']:.2f}s"
        })
        
    df_compare = pd.DataFrame(comparison_rows)
    print(tabulate(df_compare, headers="keys", tablefmt="grid", showindex=False))
    
    # Save CSV
    pd.DataFrame(all_results).to_csv("feynman_expanded_results.csv", index=False)
    print("\nDetailed results saved to 'feynman_expanded_results.csv'")

if __name__ == "__main__":
    run_benchmark()


In [None]:
%%writefile AlphaSymbolic/run_gpu_console.py

import sys
import os
import torch
import numpy as np
import time
from core.gpu import TensorGeneticEngine

# Configuration matching C++ Globals
# --- CONFIGURATION ---
TARGETS = np.array([
    1, 0, 0, 2, 10, 4, 40, 92, 352, 724, 2680, 14200, 
    73712, 365596, 2279184, 14772512, 95815104, 666090624, 
    4968057848, 39029188884, 314666222712, 2691008701644, 
    2423393768440, 227514171973736, 2207893435808352
], dtype=np.float64)

# Generate X_VALUES procedurally to match the pattern:
# x0 = 1..25
# x1 = x0 % 6
# x2 = x0 % 2
indices = np.arange(1, 26, dtype=np.float64)
x1_vals = indices % 6
x2_vals = indices % 2
X_VALUES = np.column_stack((indices, x1_vals, x2_vals))

def console_mimic_callback(gen, best_rmse, best_rpn_tensor, best_consts_tensor, is_new_best, island_idx=-1):
    """
    Mimics EXACTLY the C++ console output.
    """
    
    # 1. Decode Formula
    # We need access to engine instance to decode? 
    # The callback doesn't have 'self'. We can assume external engine variable or pass it.
    # But RPN decoding is simple if we have the method.
    # We will use the global 'engine' instance defined below.
    
    formula_str = engine.rpn_to_infix(best_rpn_tensor, best_consts_tensor)
    formula_size = engine.get_tree_size(best_rpn_tensor) 

    if is_new_best:
        print(f"\n========================================")
        print(f"New Global Best Found (Gen {gen}, Island {island_idx})")
        print(f"Fitness: {best_rmse:.8f}")
        print(f"Size: {formula_size}")
        print(f"Formula: {formula_str}")
        print("Predictions vs Targets:")
        
        # Show Predictions (Top 5 rows only to avoid spam? C++ showed all X_values)
        # C++ showed all. Let's show all if small, or top 10.
        # Recalculate predictions
        try:
            # We need to run evaluate on the best formula for single points
            # Or just use the batch evaluator on CPU for display?
            # Actually engine has 'rpn_to_infix', we can use ExpressionTree to eval?
            # Or engine.evaluate_batch?
            # engine.evaluate_batch expects a population.
            # Let's use ExpressionTree for clean single-point eval if possible.
            from core.grammar import ExpressionTree
            tree = ExpressionTree.from_infix(formula_str)
            
            # Determine display targets
            display_targets = TARGETS
            if GpuGlobals.USE_LOG_TRANSFORMATION:
                 # Parity with engine filtering if needed, but here we just trans for display
                 # However, to be safe and match engine, we filter too
                 mask = TARGETS > 1e-9
                 display_targets = np.log(np.where(mask, TARGETS, 1.0)) # Safe log for display
            
            for i in range(len(X_VALUES)):
                val = tree.evaluate(X_VALUES[i])
                
                # Ensure val is scalar
                if isinstance(val, np.ndarray):
                    val = val.item() if val.size == 1 else val[0]

                target = display_targets[i] if i < len(display_targets) else float('nan')
                diff = abs(val - target)
                
                # Format: x=(...): Pred=..., Target=..., Diff=... 
                # (Same as C++)
                x_str = ",".join([f"{x:.1f}" for x in X_VALUES[i]])
                
                print(f"  x=({x_str}): Pred={val:12.4f}, Target={target:12.4f}, Diff={diff:12.4f}")
        except Exception as e:
            print(f"  (Error calculating detailed predictions for display: {e})")
        
        print("========================================")
        sys.stdout.flush()
        
    else:
        # Progress Report
        if not hasattr(console_mimic_callback, "last_time"):
             console_mimic_callback.last_time = start_time_global
             console_mimic_callback.last_gen = 0
        
        current_time = time.time()
        delta_t = current_time - console_mimic_callback.last_time
        delta_g = gen - console_mimic_callback.last_gen
        
        instant_speed = (delta_g * engine.pop_size) / delta_t if delta_t > 0 else 0.0
        
        console_mimic_callback.last_time = current_time
        console_mimic_callback.last_gen = gen

        elapsed = time.time() - start_time_global
        print(f"\n--- Gen {gen} (Elapsed: {elapsed:.2f}s) | Instant Speed: {instant_speed:,.0f} Evals/sec ---")
        print(f"Overall Best Fitness: {best_rmse:.4e}")
        print(f"Best Formula Size: {formula_size}")
        sys.stdout.flush()

if __name__ == "__main__":
    print("Starting Genetic Algorithm (GPU Mode)...")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("WARNING: GPU NOT DETECTED. Running in CPU emulation mode (Slow).")
        

    from core.gpu.config import GpuGlobals
    
    # User can override Globals here
    GpuGlobals.POP_SIZE = 25000
    GpuGlobals.NUM_ISLANDS = 20
    GpuGlobals.PROGRESS_REPORT_INTERVAL = 100
    GpuGlobals.USE_PARETO_SELECTION = False  # Disable NSGA-II for speed test
    
    # Engine will use Globals defaults for pop_size and n_islands
    engine = TensorGeneticEngine(num_variables=3) # 3 variables as per new X_VALUES
    
    start_time_global = time.time()
    
    try:
        # Run Infinite Loop (until Ctrl+C or solved)
        # Timeout set to very high (1 hour)
        print("Evaluating initial population...")
        
        # SLICE INPUTS TO MATCH TARGETS (17)
        # And ensure correct shape for num_variables=1
        if engine.num_variables == 1:
            x_input = X_VALUES[:len(TARGETS), 0]
        else:
            x_input = X_VALUES[:len(TARGETS)]
        
        seeds = []
        if GpuGlobals.USE_INITIAL_FORMULA and GpuGlobals.INITIAL_FORMULA_STRING:
            seeds.append(GpuGlobals.INITIAL_FORMULA_STRING)
            print(f"Info: Injecting initial formula: {GpuGlobals.INITIAL_FORMULA_STRING}")

        final_formula = engine.run(
            x_input, 
            TARGETS, 
            seeds=seeds, 
            timeout_sec=3600, 
            callback=console_mimic_callback
        )

        
        print("\nSearch Finished.")
        if final_formula:
            print(f"Final Result: {final_formula}")
            
    except KeyboardInterrupt:
        print("\nSearch interrupted by user.")


In [None]:
%%writefile AlphaSymbolic/infinite_search.py
import time
import numpy as np
import torch
import pandas as pd
import os
import random
import sys

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from search.hybrid_search import hybrid_solve
from ui.app_core import get_model
from core.grammar import ExpressionTree
from utils.optimize_constants import optimize_constants, substitute_constants, convert_and_extract_constants

# --- CONFIGURATION ---
X_FULL = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25], dtype=np.float64)
Y_FULL = np.array([1,0,0,2,10,4,40,92,352,724,2680,14200,73712,365596,2279184,14772512,95815104,666090624,4968057848,39029188884,314666222712,2691008701644,2423393768440,227514171973736,2207893435808352], dtype=np.float64)

# Targets for Extrapolation
X_TARGETS = np.array([26, 27], dtype=np.float64)
Y_TARGETS = np.array([22317699616364044, 234907967154122528], dtype=np.float64)

CSV_FILE = "top_formulas.csv"
PATTERN_FILE = "pattern_memory.json"
TOP_K = 5
MIN_SAMPLE_SIZE = 6 # > 5

import json

# --- PATTERN MEMORY ("La Biblioteca") ---
def extract_structural_skeleton(formula_str):
    """
    Parses formula and replaces all numeric constants with 'C'.
    Returns the structural skeleton (infix).
    """
    try:
        from core.grammar import ExpressionTree, Node
        tree = ExpressionTree.from_infix(formula_str)
        if not tree.is_valid: return None
        
        def transform(node):
            if not node: return
            # If leaf is number, make it C
            # How to detect number? 
            # In ExpressionTree, numbers are just values.
            # Check if value is numeric string
            try:
                float(node.value)
                node.value = 'C'
            except:
                pass # Operator or Variable
            
            for child in node.children:
                transform(child)
                
        transform(tree.root)
        return tree.root.to_infix()
    except:
        return None

def load_pattern_memory():
    if os.path.exists(PATTERN_FILE):
        try:
            with open(PATTERN_FILE, 'r') as f:
                return json.load(f)
        except: return {}
    return {}

def save_pattern_memory(memory):
    try:
        with open(PATTERN_FILE, 'w') as f:
            json.dump(memory, f, indent=2)
    except: pass

def update_pattern_memory(memory, formula_str):
    skeleton = extract_structural_skeleton(formula_str)
    if skeleton:
        count = memory.get(skeleton, 0)
        memory[skeleton] = count + 1
        return True
    return False


def load_or_create_top_list():
    if os.path.exists(CSV_FILE):
        try:
            df = pd.read_csv(CSV_FILE)
            return df.to_dict('records')
        except Exception as e:
            print(f"Error loading CSV: {e}")
            return []
    return []

def save_top_list(top_list):
    df = pd.DataFrame(top_list)
    # Sort by RMSLE (Global Fit) - Match main loop priority
    # Use a safe sort that handles missing values if any
    if 'rmsle_global' in df.columns:
        df = df.sort_values(by='rmsle_global', ascending=True)
    else:
        df = df.sort_values(by='extrapolation_error', ascending=True)
    df.to_csv(CSV_FILE, index=False)
    print(f"Saved Top {len(df)} to {CSV_FILE}")
    
    # Auto-backup if in Colab
    backup_to_drive()

def backup_to_drive():
    """
    If running in Google Colab, copies the top list and pattern memory to Google Drive.
    """
    try:
        import shutil
        if os.path.exists('/content/drive/MyDrive'):
            drive_path = '/content/drive/MyDrive/AlphaSymbolic_Models'
            os.makedirs(drive_path, exist_ok=True)
            
            # Files to backup
            files = [CSV_FILE, PATTERN_FILE]
            for f in files:
                if os.path.exists(f):
                    shutil.copy(f, os.path.join(drive_path, f))
            # print("  [Backup] Synced to Google Drive.")
    except Exception as e:
        # Silently fail if drive not mounted or other issues
        pass

def main():
    print("--- Infinite Formula Search Script ---")
    
    # Check dependencies
    # Pandas is required


    # Load Model
    print("Loading Model...")
    try:
        MODEL, DEVICE = get_model()
    except Exception as e:
        print(f"Failed to load model: {e}")
        return
        
    top_formulas = load_or_create_top_list()
    pattern_memory = load_pattern_memory()
    print(f"Loaded Pattern Memory with {len(pattern_memory)} patterns.")

    
    print(f"Starting infinite search loop... (Press Ctrl+C to stop)")
    iteration = 0
    
    while True:
        iteration += 1
        
        # 1. Random Sampling
        # Ensure we pick at least MIN_SAMPLE_SIZE points
        # USER REQUEST: Start from n > 4 (Indices >= 4, since X starts at 1)
        valid_indices = np.arange(4, len(X_FULL)) # 4, 5, ... end
        
        # Adjust k if valid pool is small
        pool_size = len(valid_indices)
        k = random.randint(min(MIN_SAMPLE_SIZE, pool_size), pool_size)
        
        indices = np.sort(np.random.choice(valid_indices, k, replace=False))
        
        x_sample = X_FULL[indices]
        y_sample = Y_FULL[indices]
        
        print(f"[Iter {iteration}] Sampling {k} pts...", end=" ")
        
        # Prepare Seeds (Evolutionary Feedback)
        # Prepare Seeds (Evolutionary Feedback)
        extra_seeds = []
        if top_formulas:
            # User request: "pick 3 samples from top 5" (Updated from top 3)
            candidates = top_formulas[:5]
            candidate_formulas = [c['formula'] for c in candidates]
            
            if candidate_formulas:
                # Sample 3 seeds with replacement 
                # (allows giving more compute to the very best if picked twice)
                chosen = random.choices(candidate_formulas, k=3)
                extra_seeds.extend(chosen)
                print(f"+ {len(chosen)} Seeds")
        
        # Inject Patterns from Memory (The Library)
        if pattern_memory:
            # Pick top 3 most frequent patterns
            # Sort by count desc
            sorted_patterns = sorted(pattern_memory.items(), key=lambda x: x[1], reverse=True)
            top_patterns = [p[0] for p in sorted_patterns[:3]]
            if top_patterns:
                extra_seeds.extend(top_patterns)
                # print(f"+ {len(top_patterns)} Architectures")


        # 2. Search
        # 1.5 Flattening Transformation (The "Feynman" Trick)
        # y_flat = log(y) - lgamma(x + 1)
        # We use log1p for safety near 0, although y is usually large integerrs.
        # But wait, user said "log(target)". 
        # Since we reconstruct with exp, we must be consistent.
        # Shift y slightly to avoid log(0) if any y=0 exists (indices 1,2 are 0).
        # We'll use a small epsilon.
        epsilon = 1e-9
        # y_sample indices corresponds to x_sample values.
        # x_sample are values like 1, 2, ...
        
        # Calculate lgamma(n+1) which is log(n!)
        from scipy.special import gammaln
        factorial_term = gammaln(x_sample + 1)
        
        # Transform target
        # Use abs(y) just in case, though they are positive counts usually.
        y_sample_flat = np.log(np.abs(y_sample) + epsilon) - factorial_term
        
        # print first few to debug (in stdout)
        if iteration == 1:
            pass # print(f"Flat Y: {y_sample_flat[:3]}")

        # 2. Search (on FLATTENED target)
        try:
            # We use a relatively small beam width for speed, relying on many iterations
            result = hybrid_solve(
                x_sample, y_sample_flat,  # PASS FLAT Y
                MODEL, DEVICE, 
                beam_width=10, 
                gp_timeout=120, 
                max_workers=4, # Use 4 parallel workers (C++ Engine)
                num_variables=1,
                extra_seeds=extra_seeds
            )
        except Exception as e:
            print(f"Search failed: {e}")
            continue
            
        if not result or not result.get('formula'):
            print("No formula found in this iteration.")
            continue
            
        residual_formula_str = result['formula']
        # print(f"Found residual candidate: {residual_formula_str}")
        
        # 2.5 INTELLIGENT REFINEMENT (BFGS) on Residual
        final_formula_str = residual_formula_str # Default if refinement fails
        
        try:
            # Parse residual tree
            tree = ExpressionTree.from_infix(residual_formula_str)
            if tree.is_valid:
                # 1. Convert hardcoded numbers to 'C' and get initial values
                initial_values = convert_and_extract_constants(tree.root)
                
                # Refine on ALL data (1-27) but FLATTENED
                x_all = np.concatenate((X_FULL, X_TARGETS))
                y_all = np.concatenate((Y_FULL, Y_TARGETS))
                
                factorial_term_all = gammaln(x_all + 1)
                y_all_flat = np.log(np.abs(y_all) + epsilon) - factorial_term_all
                
                if initial_values:
                    # print(f"Refining {len(initial_values)} constants on FLAT surface...")
                    
                    # optimization expects C-tree
                    constants_dict, rmse = optimize_constants(tree, x_all, y_all_flat, initial_guess=initial_values)
                    
                    if constants_dict:
                         # 3. Substitute back into residual
                         positions = tree.root.get_constant_positions()
                         infix_with_Cs = tree.get_infix() 
                         refined_residual = substitute_constants(infix_with_Cs, constants_dict, positions)
                         
                         # print(f"Refined Residual: {refined_residual}")
                         residual_formula_str = refined_residual
                         
        except Exception as e:
            print(f"Refinement failed: {e}") 
        
        # 3. RECONSTRUCTION & Transformation
        # Formula = exp( Residual + lgamma(x+1) )
        # We construct this string.
        # Note: lgamma(x+1) is 'lgamma(x+1)' in our language (or similar).
        # Our language has 'lgamma'. Input is 'x'.
        # So we string concat: "exp(" + residual + " + lgamma(x + 1))"
        
        # We need to be careful about parens.
        # FIX: lgamma in ExpressionTree adds +1 internally (lgamma(|x|+1)).
        # So we use lgamma(x) to represent lgamma(x+1) mathematically.
        full_formula_str = f"exp({residual_formula_str} + lgamma(x0))"
        # print(f"Reconstructed Full Formula: {full_formula_str}")

        # 4. Evaluate on FULL RANGE (History + Targets) w/ Reconstructed Formula
        try:
            tree = ExpressionTree.from_infix(full_formula_str)
            if not tree.is_valid:
                print("Invalid reconstructed tree.")
                # Fallback to evaluate residual directly? No, that's wrong scale.
                continue
            
            # Combine all points for validation
            x_all = np.concatenate((X_FULL, X_TARGETS))
            y_all = np.concatenate((Y_FULL, Y_TARGETS))
            
            y_pred_all = tree.evaluate(x_all)
            
            # Calculate RMSLE on ORIGINAL SPACE
            y_pred_safe = np.maximum(y_pred_all, 0) # Clip negative predictions
            
            # Validating log error
            # Handle potential overflow in y_pred if it's huge?
            # if y_pred is inf, log is inf.
            
            log_error = np.sqrt(np.mean((np.log1p(y_pred_safe) - np.log1p(y_all))**2))
            
            # Also calculate Extrapolation Absolute Error
            pred_26 = y_pred_all[-2]
            pred_27 = y_pred_all[-1]
            extrap_error_sum = abs(pred_26 - Y_TARGETS[0]) + abs(pred_27 - Y_TARGETS[1])
            
            time_taken = result.get('time', 0)
            print(f"\n[SUCCESS] Formula: {full_formula_str}\n          RMSLE (1-27): {log_error:.6f} | Extrap Error: {extrap_error_sum:.2e} | Time: {time_taken:.2f}s")
            
            # 5. Update Top List
            entry = {
                'formula': full_formula_str, # Store the FULL formula
                'residual': residual_formula_str, # Store residual for curiosity
                'rmsle_global': log_error, 
                'extrapolation_error': extrap_error_sum, 
                'pred_26': pred_26,
                'pred_27': pred_27,
                'sample_size': k,
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
            }
            
            # Add to list
            top_formulas.append(entry)
            
            # Deduplicate by formula
            unique_formulas = {d['formula']: d for d in top_formulas}
            top_formulas = list(unique_formulas.values())
            
            # Sort by RMSLE
            top_formulas.sort(key=lambda x: x.get('rmsle_global', float('inf')))
            
            # Keep Top 5
            if len(top_formulas) > TOP_K:
                top_formulas = top_formulas[:TOP_K]
            
            # Save
            save_top_list(top_formulas)
            
            # Update Pattern Memory with the new winner
            if update_pattern_memory(pattern_memory, full_formula_str):
                 save_pattern_memory(pattern_memory)
                 backup_to_drive() # Also backup when pattern memory changes

            
            current_best = top_formulas[0].get('rmsle_global', 999)
            print(f"Current Best RMSLE: {current_best:.6f}")
            
        except Exception as e:
            print(f"Evaluation failed: {e}")
            continue

if __name__ == "__main__":
    main()


In [None]:
%%writefile AlphaSymbolic/generate_report.py
import pandas as pd
import numpy as np
import sys
import os

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from core.grammar import ExpressionTree

# --- DATA ---
X_FULL = np.arange(1, 28) # 1 to 27
Y_REAL = np.array([
    1, 0, 0, 2, 10, 4, 40, 92, 352, 724, 2680, 14200, 73712, 365596, 2279184, 
    14772512, 95815104, 666090624, 4968057848, 39029188884, 314666222712, 
    2691008701644, 2423393768440, 227514171973736, 2207893435808352,
    22317699616364044, 234907967154122528
], dtype=np.float64)

INPUT_CSV = "top_formulas.csv"
OUTPUT_CSV = "top_5_detailed_report.csv"

def evaluate_formula(formula_str, x_vals):
    try:
        tree = ExpressionTree.from_infix(formula_str)
        if not tree.is_valid:
            print(f"Invalid Formula: {formula_str}")
            return np.zeros_like(x_vals)
        return tree.evaluate(x_vals) # Should handle array input
    except Exception as e:
        print(f"Error evaluating {formula_str}: {e}")
        return np.zeros_like(x_vals)

def main():
    print(f"Generating report from {INPUT_CSV}...")
    
    if not os.path.exists(INPUT_CSV):
        print("Input file not found.")
        return

    # Load Top 5
    df_in = pd.read_csv(INPUT_CSV)
    top_5 = df_in.head(5).copy()
    
    # Prepare List of Rows for new DataFrame
    report_rows = []
    
    for idx, row in top_5.iterrows():
        formula = row['formula']
        print(f"Processing #{idx+1}: {formula[:30]}...")
        
        y_pred = evaluate_formula(formula, X_FULL)
        
        # Build Row Dictionary
        new_row = {'Rank': idx + 1, 'Formula': formula}
        
        total_mape = 0
        count = 0
        
        for i, x in enumerate(X_FULL):
            real = Y_REAL[i]
            pred = y_pred[i]
            
            # Handle division by zero if real is 0 (indexes 1 and 2 are 0)
            if real == 0:
                # If real is 0, error is absolute diff? Or undefined %?
                # Usually we define error strictly. If pred is 0, error is 0%.
                # Let's show abs diff for 0, or skip % calc.
                # User asked for "error porcentual".
                # If Real=0, Pred=0 -> 0%
                # If Real=0, Pred=0.1 -> Infinite %.
                # Let's put NaN or a placeholder for 0 values, or just show raw delta.
                # Indices 1, 2 (i=1, i=2 since 0-indexed) are 0.
                if abs(pred) < 1e-9:
                    err_pct = 0.0
                else:
                    err_pct = np.nan # Undefined
            else:
                err_pct = abs((pred - real) / real) * 100
                total_mape += err_pct
                count += 1
            
            # Add columns
            # new_row[f'X_{x}_Real'] = real
            new_row[f'X_{x}_Pred'] = pred
            new_row[f'X_{x}_Err%'] = err_pct
        
        new_row['Mean_Err%'] = total_mape / count if count > 0 else 0
        report_rows.append(new_row)
        
    # Create DF
    # To keep it organized, let's order columns: Rank, Formula, Mean_Err%, then X_1_Pred, X_1_Err%, etc.
    df_out = pd.DataFrame(report_rows)
    
    # Calculate Mean Error Percentage across all points (excluding 0s)
    
    cols = ['Rank', 'Formula', 'Mean_Err%']
    for x in X_FULL:
        cols.append(f'X_{x}_Pred')
        cols.append(f'X_{x}_Err%')
        
    df_out = df_out[cols]
    
    df_out.to_csv(OUTPUT_CSV, index=False)
    print(f"Report saved to {OUTPUT_CSV}")

if __name__ == "__main__":
    main()


In [None]:
%%writefile AlphaSymbolic/app.py
"""
AlphaSymbolic - Gradio Web Interface
With GPU/CPU toggle and search method selection.
"""
import gradio as gr
import torch

from ui.app_core import load_model, get_device, get_device_info, set_device, get_training_errors, request_stop_training
from ui.app_training import train_basic, train_curriculum, train_self_play, train_supervised, train_hybrid_feedback_loop, train_from_memory
from ui.app_search import solve_formula, generate_example
from ui.app_benchmark import get_benchmark_tab
from ui.theme import get_theme, CUSTOM_CSS
import pandas as pd
import io


def toggle_device(use_gpu):
    """Toggle between GPU and CPU."""
    device_info = set_device(use_gpu)
    color = "#4ade80" if "CUDA" in device_info else "#fbbf24" if "MPS" in device_info else "#888"
    return f'<div style="padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid {color};"><span style="color: {color}; font-weight: bold;">{device_info}</span></div>'


def create_app():
    """Create the Gradio app."""
    
    custom_theme = get_theme()
    
    with gr.Blocks(title="AlphaSymbolic") as demo:
        
        # Header
        device_info = get_device_info()
        device_color = "#4ade80" if "CUDA" in device_info else "#fbbf24" if "MPS" in device_info else "#888"
        

        def load_csv_data(file_obj):
            """Load CSV file to X/Y inputs."""
            if file_obj is None:
                return None, None
            
            try:
                # auto-detect separator
                try:
                    df = pd.read_csv(file_obj.name, sep=None, engine='python')
                except:
                    df = pd.read_csv(file_obj.name)
                
                if df.shape[1] < 2:
                    return None, "Error: El archivo debe tener al menos 2 columnas (X..., Y)"
                
                # Assume last column is Y, rest are X
                X = df.iloc[:, :-1].values
                y = df.iloc[:, -1].values
                
                # Format X string
                # If 1D: "1, 2, 3"
                # If 2D: "1 2; 3 4"
                if X.shape[1] == 1:
                    x_str = ", ".join(map(str, X.flatten()))
                else:
                    # Multi-line format
                    lines = [" ".join(map(str, row)) for row in X]
                    x_str = "\n".join(lines)
                
                y_str = ", ".join(map(str, y.flatten()))
                
                return x_str, y_str
            except Exception as e:
                return None, f"Error leyendo CSV: {str(e)}"

        # Header
        device_info = get_device_info()
        device_color = "#22d3ee" if "CUDA" in device_info else "#fbbf24"
        gpu_short = device_info.replace('NVIDIA GeForce ', '').replace(' Laptop GPU', '').replace('CUDA (', '').replace(')', '')
        
        gr.HTML(f"""
        <div style="display: flex; justify-content: space-between; align-items: center; padding: 20px 30px; background: linear-gradient(135deg, rgba(15, 23, 42, 0.9), rgba(30, 41, 59, 0.9)); border-radius: 16px; margin-bottom: 15px; border: 1px solid rgba(6, 182, 212, 0.2); box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);">
            <div>
                <h1 style="margin: 0; font-size: 2.5rem; font-weight: 800; background: linear-gradient(135deg, #06b6d4, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-family: 'Orbitron', sans-serif; letter-spacing: 2px;">
                    αSymbolic
                </h1>
                <p style="margin: 5px 0 0 0; color: #64748b; font-size: 0.9rem;">
                    Deep Reinforcement Learning & Symbolic Regression
                </p>
            </div>
            <div style="display: flex; align-items: center; gap: 15px;">
                <div style="text-align: right;">
                    <div style="background: {'rgba(34, 197, 94, 0.15)' if 'CUDA' in device_info else 'rgba(251, 191, 36, 0.15)'}; color: {'#22c55e' if 'CUDA' in device_info else '#fbbf24'}; padding: 8px 16px; border-radius: 25px; font-weight: 600; font-size: 0.85rem; border: 1px solid {'rgba(34, 197, 94, 0.3)' if 'CUDA' in device_info else 'rgba(251, 191, 36, 0.3)'};">
                        {'⚡ GPU' if 'CUDA' in device_info else '💻 CPU'} | {gpu_short}
                    </div>
                </div>
            </div>
        </div>
        """)
        
        # Model Selector - Compact inline
        with gr.Row():
            with gr.Column(scale=1):
                model_selector = gr.Radio(choices=["lite", "pro"], value="lite", label="Modelo", container=False)
            with gr.Column(scale=4):
                model_status = gr.HTML(value='<div style="padding: 8px 15px; background: rgba(34, 197, 94, 0.1); border-radius: 8px; color: #22c55e; font-size: 0.85rem; border: 1px solid rgba(34, 197, 94, 0.2);">✓ Lite Model (Optimized) - Vocabulary 2.0</div>')
        
        def on_model_change(preset):
            status, _ = load_model(preset_name=preset)
            return f'<div style="padding: 8px 15px; background: rgba(34, 197, 94, 0.1); border-radius: 8px; color: #22c55e; font-size: 0.85rem; border: 1px solid rgba(34, 197, 94, 0.2);">✓ {status}</div>'

        model_selector.change(on_model_change, model_selector, model_status)
        
        with gr.Tabs():
            # TAB 1: Search
            with gr.Tab("🔍 Buscar Formula"):
                with gr.Row():
                    # Column 1: Inputs + Config
                    with gr.Column(scale=1, min_width=400):
                        gr.Markdown("## Entrada")
                        x_input = gr.Textbox(label="Features (X)", placeholder="1, 2, 3...", lines=3)
                        y_input = gr.Textbox(label="Target (Y)", placeholder="2, 4, 6...", lines=3)
                        
                        with gr.Accordion("📁 Cargar desde CSV", open=False):
                            file_upload = gr.File(label="Seleccionar archivo", file_types=[".csv", ".txt"], file_count="single")
                            file_upload.change(load_csv_data, inputs=[file_upload], outputs=[x_input, y_input])
                        
                        with gr.Row():
                            gr.Button("Lineal", size="sm").click(lambda: generate_example("lineal"), outputs=[x_input, y_input])
                            gr.Button("Cuad", size="sm").click(lambda: generate_example("cuadratico"), outputs=[x_input, y_input])
                            gr.Button("Trig", size="sm").click(lambda: generate_example("trig"), outputs=[x_input, y_input])
                            gr.Button("Exp", size="sm").click(lambda: generate_example("exp"), outputs=[x_input, y_input])
                        
                        gr.Markdown("---")
                        search_method = gr.Radio(
                            choices=["Beam Search", "MCTS", "Alpha-GP Hybrid"],
                            value="Alpha-GP Hybrid",
                            label="Algoritmo de Búsqueda"
                        )
                        beam_slider = gr.Slider(5, 500, value=50, step=5, label="Intensidad (Beam Width)")
                        workers_slider = gr.Slider(1, 16, value=6, step=1, label="Workers (Paralelismo)", info="Procesos para el motor GP")
                        solve_btn = gr.Button("🚀 BUSCAR FÓRMULA", variant="primary", size="lg", elem_classes="primary-btn")
                        
                        with gr.Accordion("Tabla de Predicciones", open=False):
                            pred_html = gr.HTML(label="Predicciones")
                    
                    # Column 2: Results + Visualization
                    with gr.Column(scale=1, min_width=400):
                        gr.Markdown("## Resultados")
                        result_html = gr.HTML(label="Fórmula Encontrada")
                        plot_output = gr.Plot(label="Visualización del Ajuste")
                        alt_html = gr.HTML(label="Alternativas")
                
                raw_formula = gr.Textbox(visible=False)
                solve_btn.click(solve_formula, [x_input, y_input, beam_slider, search_method, workers_slider], 
                               [result_html, plot_output, pred_html, alt_html, raw_formula])
            
            # TAB 2: Training
            with gr.Tab("Entrenar Modelo"):
                # Training Control Panel - Compact Header
                gr.HTML(f"""
                <div style="display: flex; justify-content: space-between; align-items: center; padding: 15px 20px; background: linear-gradient(135deg, rgba(6, 182, 212, 0.1), rgba(139, 92, 246, 0.1)); border-radius: 12px; margin-bottom: 10px; border: 1px solid rgba(6, 182, 212, 0.3);">
                    <div style="display: flex; align-items: center; gap: 15px;">
                        <span style="font-size: 1.5rem;"> </span>
                        <div>
                            <h3 style="margin: 0; color: #e2e8f0; font-size: 1.1rem;">Centro de Entrenamiento</h3>
                            <span style="color: #64748b; font-size: 0.8rem;">Gestiona el aprendizaje del modelo</span>
                        </div>
                    </div>
                    <div style="display: flex; align-items: center; gap: 20px;">
                        <div style="text-align: center;">
                            <span style="background: rgba(34, 197, 94, 0.2); color: #22c55e; padding: 4px 12px; border-radius: 20px; font-size: 0.75rem; font-weight: 600;">
                                {'🟢 GPU' if torch.cuda.is_available() else '🟡 CPU'}
                            </span>
                            <div style="color: {device_color}; font-size: 0.7rem; margin-top: 4px;">{device_info.replace('NVIDIA GeForce ', '').replace(' Laptop GPU', '')}</div>
                        </div>
                    </div>
                </div>
                """)
                
                with gr.Row():
                    use_gpu = gr.Checkbox(label="Usar GPU", value=torch.cuda.is_available(), visible=False)
                    device_display = gr.HTML(visible=False)
                    use_gpu.change(toggle_device, [use_gpu], [device_display])
                    stop_train_btn = gr.Button("Detener Todo", variant="stop", size="sm", scale=1)
                    delete_model_btn = gr.Button("Reset Pesos", variant="secondary", size="sm", scale=1)
                
                stop_status = gr.HTML(visible=False)
                delete_status = gr.HTML(visible=False)
                
                def delete_model_action():
                    import os
                    from ui.app_core import CURRENT_PRESET
                    filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
                    if os.path.exists(filename):
                        os.remove(filename)
                        return "Modelo Eliminado"
                    return "No existe modelo"
                
                # Global Training Config
                with gr.Row():
                    reset_state_btn = gr.Button("⚠️ Reset Estado", variant="secondary", size="sm")

                def reset_training_state():
                    from ui.app_training import TRAINING_STATUS
                    TRAINING_STATUS["running"] = False
                    return "Estado reseteado. Intenta entrenar de nuevo."

                reset_state_btn.click(reset_training_state, outputs=[stop_status])
                delete_model_btn.click(delete_model_action, outputs=[delete_status])
                stop_train_btn.click(request_stop_training, outputs=[stop_status])
                
                with gr.Tabs():
                    # Basic
                    with gr.Tab("Basico"):
                        gr.HTML('<p style="color: #888;">Entrenamiento rapido con datos sinteticos</p>')
                        with gr.Row():
                            with gr.Column():
                                epochs_basic = gr.Slider(10, 500, value=100, step=10, label="Epocas")
                                batch_basic = gr.Slider(16, 128, value=32, step=16, label="Batch Size")
                                points_basic = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_basic_btn = gr.Button("Entrenar Basico", variant="primary")
                            with gr.Column():
                                result_basic = gr.HTML()
                                plot_basic = gr.Plot()
                        train_basic_btn.click(train_basic, [epochs_basic, batch_basic, points_basic], [result_basic, plot_basic])
                    
                    # Curriculum
                    with gr.Tab("Curriculum"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
                            <p style="color: #00d4ff; margin: 0;"><strong>Curriculum Learning</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">Empieza con formulas simples y aumenta la dificultad.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                epochs_curriculum = gr.Slider(50, 2000, value=200, step=50, label="Epocas")
                                batch_curriculum = gr.Slider(16, 128, value=64, step=16, label="Batch Size")
                                points_curriculum = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_curriculum_btn = gr.Button("Entrenar Curriculum", variant="primary")
                            with gr.Column():
                                result_curriculum = gr.HTML()
                                plot_curriculum = gr.Plot()
                        train_curriculum_btn.click(train_curriculum, [epochs_curriculum, batch_curriculum, points_curriculum], [result_curriculum, plot_curriculum])
                    
                    # Self-Play
                    with gr.Tab("Self-Play"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px; border-left: 3px solid #ff6b6b;">
                            <p style="color: #ff6b6b; margin: 0;"><strong>AlphaZero Self-Play</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">El modelo resuelve problemas y aprende de sus exitos.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                iterations_sp = gr.Slider(10, 1000, value=100, step=10, label="Iteraciones")
                                problems_sp = gr.Slider(5, 200, value=10, step=5, label="Problemas/Iter")
                                points_sp = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_sp_btn = gr.Button("Iniciar Self-Play", variant="primary", elem_classes="primary-btn")
                            with gr.Column():
                                result_sp = gr.HTML()
                                plot_sp = gr.Plot()
                        train_sp_btn.click(train_self_play, [iterations_sp, problems_sp, points_sp], [result_sp, plot_sp])
                
                    # Feedback Loop (Teacher-Student)
                    with gr.Tab("Feedback Loop (Hybrid)"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px; border-left: 3px solid #f1c40f;">
                            <p style="color: #f1c40f; margin: 0;"><strong>Teacher-Student Feedback Loop</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">El modelo (Estudiante) intenta resolver problemas. Si falla, el Alpha-GP (Maestro) interviene y añade la solución al dataset.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                iterations_fb = gr.Slider(5, 500, value=20, step=5, label="Ciclos")
                                problems_fb = gr.Slider(5, 50, value=10, step=5, label="Problemas Difíciles / Ciclo")
                                timeout_fb = gr.Slider(5, 30, value=10, step=5, label="Timeout Maestro (s)")
                                train_fb_btn = gr.Button("Iniciar Feedback Loop", variant="primary")
                            with gr.Column():
                                result_fb = gr.HTML()
                                plot_fb = gr.Plot()
                        train_fb_btn.click(train_hybrid_feedback_loop, [iterations_fb, problems_fb, timeout_fb, workers_slider], [result_fb, plot_fb])
                
                # --- PRE-TRAINING (Warmup) ---
                with gr.Accordion("🎓 Escuela Primaria (Pre-Entrenamiento)", open=False):
                    gr.Markdown("Entrenamiento masivo supervisado de alta velocidad para aprender sintaxis basica. **Recomendado al inicio.**")
                    with gr.Row():
                        with gr.Column():
                            epochs_pre = gr.Slider(100, 10000, value=2000, step=100, label="Iteraciones Rápidas")
                            train_pre_btn = gr.Button("Iniciar Pre-Entrenamiento", variant="primary", elem_classes="primary-btn")
                        with gr.Column():
                            result_pre = gr.HTML()
                            plot_pre = gr.Plot()
                    train_pre_btn.click(train_supervised, [epochs_pre], [result_pre, plot_pre])
                
                # --- MEMORY TRAINING (Offline RL) ---
                with gr.Accordion("🧠 Entrenamiento de Memoria (Offline)", open=False):
                    gr.Markdown("Re-entrena el modelo usando las fórmulas descubiertas y guardadas en `learned_formulas.csv`. Ideal para consolidar conocimientos.")
                    with gr.Row():
                        with gr.Column():
                            epochs_mem = gr.Slider(10, 500, value=50, step=10, label="Epocas de Repaso")
                            train_mem_btn = gr.Button("Iniciar Entrenamiento de Memoria", variant="primary")
                        with gr.Column():
                            result_mem = gr.HTML()
                            plot_mem = gr.Plot()
                    train_mem_btn.click(train_from_memory, [epochs_mem], [result_mem, plot_mem])

                # --- HALL OF SHAME (Error Analysis) ---
                with gr.Accordion("Hall of Shame (Analisis de Errores)", open=False):
                    gr.Markdown("Aquí se muestran los problemas donde el modelo falló drásticamente hoy.")
                    error_table = gr.DataFrame(
                        headers=["Time", "Target Formula", "Predicted", "Loss", "Stage"],
                        datatype=["str", "str", "str", "number", "str"],
                        interactive=False
                    )
                    refresh_errors_btn = gr.Button("Actualizar Errores", size="sm")
                    
                    def update_errors():
                        errors = get_training_errors()
                        # Reverse to show newest first
                        data = [[
                            e['time'], e['target'], e['predicted'], round(e['loss'], 2), e['stage']
                        ] for e in reversed(errors)]
                        return data
                    
                    refresh_errors_btn.click(update_errors, outputs=[error_table])
            
            # TAB 4: Benchmark
            get_benchmark_tab()

            # TAB 5: Info
            with gr.Tab("Informacion"):
                device_info_current = get_device_info()
                device_color_current = "#4ade80" if "CUDA" in device_info_current else "#fbbf24" if "MPS" in device_info_current else "#888"
                
                gr.HTML(f"""
                <div style="background: #1a1a2e; padding: 30px; border-radius: 15px;">
                    <h2 style="color: #00d4ff;">Que es AlphaSymbolic?</h2>
                    <p style="color: #ccc; line-height: 1.8;">
                        Sistema de <strong style="color: #ff6b6b;">regresion simbolica</strong> 
                        basado en <strong style="color: #00d4ff;">Deep Learning</strong> y 
                        <strong style="color: #ffd93d;">Monte Carlo Tree Search</strong>.
                    </p>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Dispositivo Actual</h3>
                    <p style="color: {device_color_current}; font-size: 20px;">{device_info_current}</p>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Metodos de Busqueda</h3>
                    <ul style="color: #ccc;">
                        <li><strong>Beam Search:</strong> Explora multiples candidatos en paralelo (rapido)</li>
                        <li><strong>MCTS:</strong> Monte Carlo Tree Search (mas preciso, lento)</li>
                        <li><strong>Alpha-GP Hybrid:</strong> Fusiona Neural Search con Algoritmo Genetico GPU (Extremo)</li>
                    </ul>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Operadores</h3>
                    <div style="display: flex; flex-wrap: wrap; gap: 10px; margin: 15px 0;">
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">+</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">-</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">*</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">/</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ff6b6b;">sin</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ff6b6b;">cos</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ffd93d;">exp</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ffd93d;">log</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #4ade80;">pow</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #4ade80;">sqrt</span>
                    </div>
                </div>
                """)
        
        gr.HTML("""
        <div style="text-align: center; padding: 20px; color: #666; margin-top: 30px;">
            <p>Powered by PyTorch - SymPy - Scipy - Gradio</p>
        </div>
        """)
    
    return demo




# --- Global Initialization for Hot Reloading ---
# IMPORTANT: For Windows Multiprocessing, we must protect entry point.
# However, Gradio needs 'demo' to be available for 'gradio app.py'.
# The issue is 'gradio app.py' imports this file, and multiprocessing spawns new processes that import it again.

if __name__ == "__main__":
    # If run directly (python app.py)
    print("Iniciando AlphaSymbolic (Global Init - Direct Execution)...")
    status_init, device_info_init = load_model() 
    print(f"   {status_init} | {device_info_init}")
    demo = create_app()
    print("Abriendo navegador...")
    from ui.theme import CUSTOM_CSS, get_theme
    demo.launch(share=True, inbrowser=True, theme=get_theme(), css=CUSTOM_CSS)
else:
    # If imported by 'gradio app.py' or multiprocessing workers
    # We only want to load the model if it's the Main Process (Gradio Server)
    # But multiprocessing workers import this too.
    # We can try to detect if we are a worker or the server.
    
    # Simple fix for Gradio Reload:
    # define demo globally but lazy load model?
    # No, let's keep it simple.
    
    print("AlphaSymbolic Module Imported.")
    # Attempt to load model only if not in a worker process?
    # Actually, for 'gradio app.py', this 'else' block runs.
    # We need 'demo' to be defined here.
    
    try:
        status_init, device_info_init = load_model() 
        print(f"   {status_init} | {device_info_init}")
    except Exception:
        pass # Might fail in workers, that's fine

    demo = create_app()


In [None]:
%%writefile AlphaSymbolic/data/benchmarks/FeynmanEquations.csv
﻿Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,v2_high,v3_name,v3_low,v3_high,v4_name,v4_low,v4_high,v5_name,v5_low,v5_high,v6_name,v6_low,v6_high,v7_name,v7_low,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
I.6.2a,1,f,exp(-theta**2/2)/sqrt(2*pi),1,theta,1,3,,,,,,,,,,,,,,,,,,,,,,,,,,,
I.6.2,2,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2,sigma,1,3,theta,1,3,,,,,,,,,,,,,,,,,,,,,,,,
I.6.2b,3,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*sigma),3,sigma,1,3,theta,1,3,theta1,1,3,,,,,,,,,,,,,,,,,,,,,
I.8.14,4,d,sqrt((x2-x1)**2+(y2-y1)**2),4,x1,1,5,x2,1,5,y1,1,5,y2,1,5,,,,,,,,,,,,,,,,,,
I.9.18,5,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9,m1,1,2,m2,1,2,G,1,2,x1,3,4,x2,1,2,y1,3,4,y2,1,2,z1,3,4,z2,1,2,,,
I.10.7,6,m,m_0/sqrt(1-v**2/c**2),3,m_0,1,5,v,1,2,c,3,10,,,,,,,,,,,,,,,,,,,,,
I.11.19,7,A,x1*y1+x2*y2+x3*y3,6,x1,1,5,x2,1,5,x3,1,5,y1,1,5,y2,1,5,y3,1,5,,,,,,,,,,,,
I.12.1,8,F,mu*Nn,2,mu,1,5,Nn,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.12.2,10,F,q1*q2*r/(4*pi*epsilon*r**3),4,q1,1,5,q2,1,5,epsilon,1,5,r,1,5,,,,,,,,,,,,,,,,,,
I.12.4,11,Ef,q1*r/(4*pi*epsilon*r**3),3,q1,1,5,epsilon,1,5,r,1,5,,,,,,,,,,,,,,,,,,,,,
I.12.5,12,F,q2*Ef,2,q2,1,5,Ef,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.12.11,13,F,q*(Ef+B*v*sin(theta)),5,q,1,5,Ef,1,5,B,1,5,v,1,5,theta,1,5,,,,,,,,,,,,,,,
I.13.4,9,K,1/2*m*(v**2+u**2+w**2),4,m,1,5,v,1,5,u,1,5,w,1,5,,,,,,,,,,,,,,,,,,
I.13.12,14,U,G*m1*m2*(1/r2-1/r1),5,m1,1,5,m2,1,5,r1,1,5,r2,1,5,G,1,5,,,,,,,,,,,,,,,
I.14.3,15,U,m*g*z,3,m,1,5,g,1,5,z,1,5,,,,,,,,,,,,,,,,,,,,,
I.14.4,16,U,1/2*k_spring*x**2,2,k_spring,1,5,x,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.15.3x,17,x1,(x-u*t)/sqrt(1-u**2/c**2),4,x,5,10,u,1,2,c,3,20,t,1,2,,,,,,,,,,,,,,,,,,
I.15.3t,18,t1,(t-u*x/c**2)/sqrt(1-u**2/c**2),4,x,1,5,c,3,10,u,1,2,t,1,5,,,,,,,,,,,,,,,,,,
I.15.1,19,p,m_0*v/sqrt(1-v**2/c**2),3,m_0,1,5,v,1,2,c,3,10,,,,,,,,,,,,,,,,,,,,,
I.16.6,20,v1,(u+v)/(1+u*v/c**2),3,c,1,5,v,1,5,u,1,5,,,,,,,,,,,,,,,,,,,,,
I.18.4,21,r,(m1*r1+m2*r2)/(m1+m2),4,m1,1,5,m2,1,5,r1,1,5,r2,1,5,,,,,,,,,,,,,,,,,,
I.18.12,22,tau,r*F*sin(theta),2,r,1,5,F,1,5,theta,0,5,,,,,,,,,,,,,,,,,,,,,
I.18.14,23,L,m*r*v*sin(theta),3,m,1,5,r,1,5,v,1,5,theta,1,5,,,,,,,,,,,,,,,,,,
I.24.6,24,E_n,1/2*m*(omega**2+omega_0**2)*1/2*x**2,4,m,1,3,omega,1,3,omega_0,1,3,x,1,3,,,,,,,,,,,,,,,,,,
I.25.13,25,Volt,q/C,2,q,1,5,C,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.26.2,26,theta1,arcsin(n*sin(theta2)),2,n,0,1,theta2,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.27.6,27,foc,1/(1/d1+n/d2),3,d1,1,5,d2,1,5,n,1,5,,,,,,,,,,,,,,,,,,,,,
I.29.4,28,k,omega/c,2,omega,1,10,c,1,10,,,,,,,,,,,,,,,,,,,,,,,,
I.29.16,29,x,sqrt(x1**2+x2**2-2*x1*x2*cos(theta1-theta2)),4,x1,1,5,x2,1,5,theta1,1,5,theta2,1,5,,,,,,,,,,,,,,,,,,
I.30.3,30,Int,Int_0*sin(n*theta/2)**2/sin(theta/2)**2,3,Int_0,1,5,theta,1,5,n,1,5,,,,,,,,,,,,,,,,,,,,,
I.30.5,31,theta,arcsin(lambd/(n*d)),3,lambd,1,2,d,2,5,n,1,5,,,,,,,,,,,,,,,,,,,,,
I.32.5,32,Pwr,q**2*a**2/(6*pi*epsilon*c**3),4,q,1,5,a,1,5,epsilon,1,5,c,1,5,,,,,,,,,,,,,,,,,,
I.32.17,33,Pwr,(1/2*epsilon*c*Ef**2)*(8*pi*r**2/3)*(omega**4/(omega**2-omega_0**2)**2),6,epsilon,1,2,c,1,2,Ef,1,2,r,1,2,omega,1,2,omega_0,3,5,,,,,,,,,,,,
I.34.8,34,omega,q*v*B/p,4,q,1,5,v,1,5,B,1,5,p,1,5,,,,,,,,,,,,,,,,,,
I.34.1,35,omega,omega_0/(1-v/c),3,c,3,10,v,1,2,omega_0,1,5,,,,,,,,,,,,,,,,,,,,,
I.34.14,36,omega,(1+v/c)/sqrt(1-v**2/c**2)*omega_0,3,c,3,10,v,1,2,omega_0,1,5,,,,,,,,,,,,,,,,,,,,,
I.34.27,37,E_n,(h/(2*pi))*omega,2,omega,1,5,h,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.37.4,38,Int,I1+I2+2*sqrt(I1*I2)*cos(delta),3,I1,1,5,I2,1,5,delta,1,5,,,,,,,,,,,,,,,,,,,,,
I.38.12,39,r,4*pi*epsilon*(h/(2*pi))**2/(m*q**2),3,m,1,5,q,1,5,h,1,5,epsilon,1,5,,,,,,,,,,,,,,,,,,
I.39.1,40,E_n,3/2*pr*V,2,pr,1,5,V,1,5,,,,,,,,,,,,,,,,,,,,,,,,
I.39.11,41,E_n,1/(gamma-1)*pr*V,3,gamma,2,5,pr,1,5,V,1,5,,,,,,,,,,,,,,,,,,,,,
I.39.22,42,pr,n*kb*T/V,4,n,1,5,T,1,5,V,1,5,kb,1,5,,,,,,,,,,,,,,,,,,
I.40.1,43,n,n_0*exp(-m*g*x/(kb*T)),6,n_0,1,5,m,1,5,x,1,5,T,1,5,g,1,5,kb,1,5,,,,,,,,,,,,
I.41.16,44,L_rad,h/(2*pi)*omega**3/(pi**2*c**2*(exp((h/(2*pi))*omega/(kb*T))-1)),5,omega,1,5,T,1,5,h,1,5,kb,1,5,c,1,5,,,,,,,,,,,,,,,
I.43.16,45,v,mu_drift*q*Volt/d,4,mu_drift,1,5,q,1,5,Volt,1,5,d,1,5,,,,,,,,,,,,,,,,,,
I.43.31,46,D,mob*kb*T,3,mob,1,5,T,1,5,kb,1,5,,,,,,,,,,,,,,,,,,,,,
I.43.43,47,kappa,1/(gamma-1)*kb*v/A,4,gamma,2,5,kb,1,5,A,1,5,v,1,5,,,,,,,,,,,,,,,,,,
I.44.4,48,E_n,n*kb*T*ln(V2/V1),5,n,1,5,kb,1,5,T,1,5,V1,1,5,V2,1,5,,,,,,,,,,,,,,,
I.47.23,49,c,sqrt(gamma*pr/rho),3,gamma,1,5,pr,1,5,rho,1,5,,,,,,,,,,,,,,,,,,,,,
I.48.2,50,E_n,m*c**2/sqrt(1-v**2/c**2),3,m,1,5,v,1,2,c,3,10,,,,,,,,,,,,,,,,,,,,,
I.50.26,51,x,x1*(cos(omega*t)+alpha*cos(omega*t)**2),4,x1,1,3,omega,1,3,t,1,3,alpha,1,3,,,,,,,,,,,,,,,,,,
II.2.42,52,Pwr,kappa*(T2-T1)*A/d,5,kappa,1,5,T1,1,5,T2,1,5,A,1,5,d,1,5,,,,,,,,,,,,,,,
II.3.24,53,flux,Pwr/(4*pi*r**2),2,Pwr,1,5,r,1,5,,,,,,,,,,,,,,,,,,,,,,,,
II.4.23,54,Volt,q/(4*pi*epsilon*r),3,q,1,5,epsilon,1,5,r,1,5,,,,,,,,,,,,,,,,,,,,,
II.6.11,55,Volt,1/(4*pi*epsilon)*p_d*cos(theta)/r**2,4,epsilon,1,3,p_d,1,3,theta,1,3,r,1,3,,,,,,,,,,,,,,,,,,
II.6.15a,56,Ef,p_d/(4*pi*epsilon)*3*z/r**5*sqrt(x**2+y**2),6,epsilon,1,3,p_d,1,3,r,1,3,x,1,3,y,1,3,z,1,3,,,,,,,,,,,,
II.6.15b,57,Ef,p_d/(4*pi*epsilon)*3*cos(theta)*sin(theta)/r**3,4,epsilon,1,3,p_d,1,3,theta,1,3,r,1,3,,,,,,,,,,,,,,,,,,
II.8.7,58,E_n,3/5*q**2/(4*pi*epsilon*d),3,q,1,5,epsilon,1,5,d,1,5,,,,,,,,,,,,,,,,,,,,,
II.8.31,59,E_den,epsilon*Ef**2/2,2,epsilon,1,5,Ef,1,5,,,,,,,,,,,,,,,,,,,,,,,,
II.10.9,60,Ef,sigma_den/epsilon*1/(1+chi),3,sigma_den,1,5,epsilon,1,5,chi,1,5,,,,,,,,,,,,,,,,,,,,,
II.11.3,61,x,q*Ef/(m*(omega_0**2-omega**2)),5,q,1,3,Ef,1,3,m,1,3,omega_0,3,5,omega,1,2,,,,,,,,,,,,,,,
II.11.17,62,n,n_0*(1+p_d*Ef*cos(theta)/(kb*T)),6,n_0,1,3,kb,1,3,T,1,3,theta,1,3,p_d,1,3,Ef,1,3,,,,,,,,,,,,
II.11.20,63,Pol,n_rho*p_d**2*Ef/(3*kb*T),5,n_rho,1,5,p_d,1,5,Ef,1,5,kb,1,5,T,1,5,,,,,,,,,,,,,,,
II.11.27,64,Pol,n*alpha/(1-(n*alpha/3))*epsilon*Ef,4,n,0,1,alpha,0,1,epsilon,1,2,Ef,1,2,,,,,,,,,,,,,,,,,,
II.11.28,65,theta,1+n*alpha/(1-(n*alpha/3)),2,n,0,1,alpha,0,1,,,,,,,,,,,,,,,,,,,,,,,,
II.13.17,66,B,1/(4*pi*epsilon*c**2)*2*I/r,4,epsilon,1,5,c,1,5,I,1,5,r,1,5,,,,,,,,,,,,,,,,,,
II.13.23,67,rho_c,rho_c_0/sqrt(1-v**2/c**2),3,rho_c_0,1,5,v,1,2,c,3,10,,,,,,,,,,,,,,,,,,,,,
II.13.34,68,j,rho_c_0*v/sqrt(1-v**2/c**2),3,rho_c_0,1,5,v,1,2,c,3,10,,,,,,,,,,,,,,,,,,,,,
II.15.4,69,E_n,-mom*B*cos(theta),3,mom,1,5,B,1,5,theta,1,5,,,,,,,,,,,,,,,,,,,,,
II.15.5,70,E_n,-p_d*Ef*cos(theta),3,p_d,1,5,Ef,1,5,theta,1,5,,,,,,,,,,,,,,,,,,,,,
II.21.32,71,Volt,q/(4*pi*epsilon*r*(1-v/c)),5,q,1,5,epsilon,1,5,r,1,5,v,1,2,c,3,10,,,,,,,,,,,,,,,
II.24.17,72,k,sqrt(omega**2/c**2-pi**2/d**2),3,omega,4,6,c,1,2,d,2,4,,,,,,,,,,,,,,,,,,,,,
II.27.16,73,flux,epsilon*c*Ef**2,3,epsilon,1,5,c,1,5,Ef,1,5,,,,,,,,,,,,,,,,,,,,,
II.27.18,74,E_den,epsilon*Ef**2,2,epsilon,1,5,Ef,1,5,,,,,,,,,,,,,,,,,,,,,,,,
II.34.2a,75,I,q*v/(2*pi*r),3,q,1,5,v,1,5,r,1,5,,,,,,,,,,,,,,,,,,,,,
II.34.2,76,mom,q*v*r/2,3,q,1,5,v,1,5,r,1,5,,,,,,,,,,,,,,,,,,,,,
II.34.11,77,omega,g_*q*B/(2*m),4,g_,1,5,q,1,5,B,1,5,m,1,5,,,,,,,,,,,,,,,,,,
II.34.29a,78,mom,q*h/(4*pi*m),3,q,1,5,h,1,5,m,1,5,,,,,,,,,,,,,,,,,,,,,
II.34.29b,79,E_n,g_*mom*B*Jz/(h/(2*pi)),5,g_,1,5,h,1,5,Jz,1,5,mom,1,5,B,1,5,,,,,,,,,,,,,,,
II.35.18,80,n,n_0/(exp(mom*B/(kb*T))+exp(-mom*B/(kb*T))),5,n_0,1,3,kb,1,3,T,1,3,mom,1,3,B,1,3,,,,,,,,,,,,,,,
II.35.21,81,M,n_rho*mom*tanh(mom*B/(kb*T)),5,n_rho,1,5,mom,1,5,B,1,5,kb,1,5,T,1,5,,,,,,,,,,,,,,,
II.36.38,82,f,mom*H/(kb*T)+(mom*alpha)/(epsilon*c**2*kb*T)*M,8,mom,1,3,H,1,3,kb,1,3,T,1,3,alpha,1,3,epsilon,1,3,c,1,3,M,1,3,,,,,,
II.37.1,83,E_n,mom*(1+chi)*B,6,mom,1,5,B,1,5,chi,1,5,,,,,,,,,,,,,,,,,,,,,
II.38.3,84,F,Y*A*x/d,4,Y,1,5,A,1,5,d,1,5,x,1,5,,,,,,,,,,,,,,,,,,
II.38.14,85,mu_S,Y/(2*(1+sigma)),2,Y,1,5,sigma,1,5,,,,,,,,,,,,,,,,,,,,,,,,
III.4.32,86,n,1/(exp((h/(2*pi))*omega/(kb*T))-1),4,h,1,5,omega,1,5,kb,1,5,T,1,5,,,,,,,,,,,,,,,,,,
III.4.33,87,E_n,(h/(2*pi))*omega/(exp((h/(2*pi))*omega/(kb*T))-1),4,h,1,5,omega,1,5,kb,1,5,T,1,5,,,,,,,,,,,,,,,,,,
III.7.38,88,omega,2*mom*B/(h/(2*pi)),3,mom,1,5,B,1,5,h,1,5,,,,,,,,,,,,,,,,,,,,,
III.8.54,89,prob,sin(E_n*t/(h/(2*pi)))**2,3,E_n,1,2,t,1,2,h,1,4,,,,,,,,,,,,,,,,,,,,,
III.9.52,90,prob,(p_d*Ef*t/(h/(2*pi)))*sin((omega-omega_0)*t/2)**2/((omega-omega_0)*t/2)**2,6,p_d,1,3,Ef,1,3,t,1,3,h,1,3,omega,1,5,omega_0,1,5,,,,,,,,,,,,
III.10.19,91,E_n,mom*sqrt(Bx**2+By**2+Bz**2),3,mom,1,5,Bx,1,5,By,1,5,Bz,1,5,,,,,,,,,,,,,,,,,,
III.12.43,92,L,n*(h/(2*pi)),2,n,1,5,h,1,5,,,,,,,,,,,,,,,,,,,,,,,,
III.13.18,93,v,2*E_n*d**2*k/(h/(2*pi)),4,E_n,1,5,d,1,5,k,1,5,h,1,5,,,,,,,,,,,,,,,,,,
III.14.14,94,I,I_0*(exp(q*Volt/(kb*T))-1),5,I_0,1,5,q,1,2,Volt,1,2,kb,1,2,T,1,2,,,,,,,,,,,,,,,
III.15.12,95,E_n,2*U*(1-cos(k*d)),3,U,1,5,k,1,5,d,1,5,,,,,,,,,,,,,,,,,,,,,
III.15.14,96,m,(h/(2*pi))**2/(2*E_n*d**2),3,h,1,5,E_n,1,5,d,1,5,,,,,,,,,,,,,,,,,,,,,
III.15.27,97,k,2*pi*alpha/(n*d),3,alpha,1,5,n,1,5,d,1,5,,,,,,,,,,,,,,,,,,,,,
III.17.37,98,f,beta*(1+alpha*cos(theta)),3,beta,1,5,alpha,1,5,theta,1,5,,,,,,,,,,,,,,,,,,,,,
III.19.51,99,E_n,-m*q**4/(2*(4*pi*epsilon)**2*(h/(2*pi))**2)*(1/n**2),4,m,1,5,q,1,5,h,1,5,n,1,5,epsilon,1,5,,,,,,,,,,,,,,,
III.21.20,100,j,-rho_c_0*q*A_vec/m,4,rho_c_0,1,5,q,1,5,A_vec,1,5,m,1,5,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,

In [None]:
%%writefile AlphaSymbolic/data/benchmarks/BonusEquations.csv
﻿Filename,Number,Name,Eqn. No.,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,v2_high,v3_name,v3_low,v3_high,v4_name,v4_low,v4_high,v5_name,v5_low,v5_high,v6_name,v6_low,v6_high,v7_name,v7_low,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
test_1,1,Rutherford scattering,1,A,(Z_1*Z_2*alpha*hbar*c/(4*E_n*sin(theta/2)**2))**2,7,Z_1,1,2,Z_2,1,2,alpha,1,2,hbar,1,2,c,1,2,E_n,1,3,theta,1,3,,,,,,,,,
test_2,2,3.55 Goldstein,2,k,m*k_G/L**2*(1+sqrt(1+2*E_n*L**2/(m*k_G**2))*cos(theta1-theta2)),6,m,1,3,k_G,1,3,L,1,3,E_n,1,3,theta1,0,6,theta2,0,6,,,,,,,,,,,,
test_3,3,3.64 Goldstein,3,r,d*(1-alpha**2)/(1+alpha*cos(theta1-theta2)),4,d,1,3,alpha,2,4,theta1,4,5,theta2,4,5,,,,,,,,,,,,,,,,,,
test_4,4,3.16 Goldstein,4,v,sqrt(2/m*(E_n-U-L**2/(2*m*r**2))),5,m,1,3,E_n,8,12,U,1,3,L,1,3,r,1,3,,,,,,,,,,,,,,,
test_5,5,3.74 Goldstein,5,t,2*pi*d**(3/2)/sqrt(G*(m1+m2)),4,d,1,3,G,1,3,m1,1,3,m2,1,3,,,,,,,,,,,,,,,,,,
test_6,6,3.99 Goldstein,6,alpha,sqrt(1+2*epsilon**2*E_n*L**2/(m*(Z_1*Z_2*q**2)**2)),7,epsilon,1,3,L,1,3,m,1,3,Z_1,1,3,Z_2,1,3,q,1,3,E_n,1,3,,,,,,,,,
test_7,7,Friedman Equation,7,H_G,sqrt(8*pi*G*rho/3-alpha*c**2/d**2),5,G,1,3,rho,1,3,alpha,1,2,c,1,2,d,1,3,,,,,,,,,,,,,,,
test_8,8,Compton Scattering,8,K,E_n/(1+E_n/(m*c**2)*(1-cos(theta))),4,E_n,1,3,m,1,3,c,1,3,theta,1,3,,,,,,,,,,,,,,,,,,
test_9,9,Gravitational wave ratiated power,9,Pwr,-32/5*G**4/c**5*(m1*m2)**2*(m1+m2)/r**5,5,G,1,2,c,1,2,m1,1,5,m2,1,5,r,1,2,,,,,,,,,,,,,,,
test_10,10,Relativistic aberation,10,theta1,arccos((cos(theta2)-v/c)/(1-v/c*cos(theta2))),3,c,4,6,v,1,3,theta2,1,3,,,,,,,,,,,,,,,,,,,,,
test_11,11,N-slit diffraction,11,I,I_0*(sin(alpha/2)*sin(n*delta/2)/(alpha/2*sin(delta/2)))**2,4,I_0,1,3,alpha,1,3,delta,1,3,n,1,2,,,,,,,,,,,,,,,,,,
test_12,12,2.11 Jackson,12,F,q/(4*pi*epsilon*y**2)*(4*pi*epsilon*Volt*d-q*d*y**3/(y**2-d**2)**2),4,q,1,5,y,1,3,Volt,1,5,d,4,6,epsilon,1,5,,,,,,,,,,,,,,,
test_13,13,3.45 Jackson,13,Volt,1/(4*pi*epsilon)*q/sqrt(r**2+d**2-2*r*d*cos(alpha)),4,q,1,5,r,1,3,d,4,6,alpha,0,6,epsilon,1,5,,,,,,,,,,,,,,,
test_14,14,4.60' Jackson,14,Volt,Ef*cos(theta)*(-r+d**3/r**2*(alpha-1)/(alpha+2)),5,Ef,1,5,theta,0,6,r,1,5,d,1,5,alpha,1,5,,,,,,,,,,,,,,,
test_15,15,11.38 Jackson,15,omega_0,sqrt(1-v**2/c**2)*omega/(1+v/c*cos(theta)),4,c,5,20,v,1,3,omega,1,5,theta,0,6,,,,,,,,,,,,,,,,,,
test_16,16,8.56 Goldstein,16,E_n,sqrt((p-q*A_vec)**2*c**2+m**2*c**4)+q*Volt,6,m,1,5,c,1,5,p,1,5,q,1,5,A_vec,1,5,Volt,1,5,,,,,,,,,,,,
test_17,17,12.80' Goldstein,17,E_n,1/(2*m)*(p**2+m**2*omega**2*x**2*(1+alpha*x/y)),6,m,1,5,omega,1,5,p,1,5,y,1,5,x,1,5,alpha,1,5,,,,,,,,,,,,
test_18,18,15.2.1 Weinberg,18,rho_0,3/(8*pi*G)*(c**2*k_f/r**2+H_G**2),4,G,1,5,k_f,1,5,r,1,5,H_G,1,5,c,1,5,,,,,,,,,,,,,,,
test_19,19,15.2.2 Weinberg,19,pr,-1/(8*pi*G)*(c**4*k_f/r**2+H_G**2*c**2*(1-2*alpha)),5,G,1,5,k_f,1,5,r,1,5,H_G,1,5,alpha,1,5,c,1,5,,,,,,,,,,,,
test_20,20,Klein-Nishina (13.132 Schwarz),20,A,1/(4*pi)*alpha**2*h**2/(m**2*c**2)*(omega_0/omega)**2*(omega_0/omega+omega/omega_0-sin(beta)**2),7,omega,1,5,omega_0,1,5,alpha,1,5,h,1,5,m,1,5,c,1,5,beta,0,6,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,

In [None]:
%%writefile AlphaSymbolic/top_formulas.csv
formula,residual,rmsle_global,extrapolation_error,pred_26,pred_27,sample_size,timestamp
exp((log(log((38.4335 * (x0 * log(log((x0 - (x0 / (log(x0) - 2.1513))))))))) - x0) + lgamma(x)),(log(log((38.4335 * (x0 * log(log((x0 - (x0 / (log(x0) - 2.1513))))))))) - x0),0.446213921297748,1.384836563676553e+17,9118205752846800.0,1.0962380464998448e+17,18,2026-01-10 01:01:51
exp(((g(x0) / (((e(((6.5993 * x0) / (x0 * (24 - x0)))) - 2.2322) / ((x0 - 5.5466) / 8.2059)) + x0)) - x0) + lgamma(x)),((g(x0) / (((e(((6.5993 * x0) / (x0 * (24 - x0)))) - 2.2322) / ((x0 - 5.5466) / 8.2059)) + x0)) - x0),0.49214994109680077,7161996148432400.0,2.3612133893656412e+16,2.4077552902526256e+17,15,2026-01-10 02:08:04
exp((log((x0 / log(((0.2404 - (x0 + x0)) + 12.0234)))) - x0) + lgamma(x)),(log((x0 / log(((0.2404 - (x0 + x0)) + 12.0234)))) - x0),0.49586141956618673,9.458651551890949e+16,1.4548596099178496e+16,1.480905551523986e+17,21,2026-01-10 02:16:06
exp((0.1527 + ((((1.0016 ^ x0) / (((x0 - 1.1483) / (((x0 ^ -3.1587) ^ -0.3481) - x0)) - (x0 + (x0 + x0)))) - x0) + 1.6047)) + lgamma(x)),(0.1527 + ((((1.0016 ^ x0) / (((x0 - 1.1483) / (((x0 ^ -3.1587) ^ -0.3481) - x0)) - (x0 + (x0 + x0)))) - x0) + 1.6047)),0.5438126039885908,1.2831730371202491e+17,1.178547406114365e+16,1.1712288899731802e+17,17,2026-01-10 01:55:57
exp((log(((x0 - 3.0000) / log(((((26.6821 / x0) / (22.9188 - x0)) ^ x0) + 13.3840)))) - x0) + lgamma(x)),(log(((x0 - 3.0000) / log(((((26.6821 / x0) / (22.9188 - x0)) ^ x0) + 13.3840)))) - x0),0.6612535694422839,4.960786910565984e+16,1.8268803826830184e+16,1.8934899383799654e+17,9,2026-01-10 02:06:03


In [None]:
%%writefile AlphaSymbolic/pattern_memory.json
{
  "e(((C + ((C / (C * (((x0 + x0) - g((C - x0))) * C))) - x0)) + g(x0)))": 1,
  "e(((((((((C + (x0 / (x0 - C))) / C) + C) / (C - x0)) + x0) / x0) - (x0 / C)) + g(x0)))": 1,
  "e(((C - x0) + g(x0)))": 1,
  "e(((C - ((log(x0) / ((log(((x0 ^ log((x0 * x0))) * x0)) - ((C / (C - x0)) + x0)) * (C - x0))) + x0)) + g(x0)))": 1,
  "e(((((x0 + (C / x0)) * C) / ((x0 / (((C - (x0 + ((((x0 * C) / (x0 - C)) / (x0 + x0)) - C))) - C) * (((((C - (x0 + x0)) + C) / (x0 * x0)) - C) + x0))) - C)) + g(x0)))": 1,
  "e(((log((C - (x0 / (x0 * (C - x0))))) - x0) + g(x0)))": 1,
  "e(((log(log((C * (x0 * log(log((x0 - (x0 / (log(x0) - C))))))))) - x0) + g(x0)))": 1,
  "e((((C - x0) - (x0 / ((x0 + (x0 - C)) + (x0 - (x0 ^ (((C - x0) / (C + x0)) + log((log((x0 + C)) - x0)))))))) + g(x0)))": 1,
  "e(((C + ((((C ^ x0) / (((x0 - C) / (((x0 ^ C) ^ C) - x0)) - (x0 + (x0 + x0)))) - x0) + C)) + g(x0)))": 1,
  "e((((((x0 - (((C - x0) / x0) / (C + ((x0 - C) / C)))) + x0) / x0) - x0) + g(x0)))": 1,
  "e(((C - ((C ^ x0) + x0)) + g(x0)))": 1,
  "e((((C * (C / (C - x0))) - ((C / ((x0 - (x0 / C)) * ((C - (x0 / C)) + C))) + x0)) + g(x0)))": 1,
  "e((((((x0 + x0) + (((x0 - C) / (((((x0 * C) / (x0 / (C * x0))) + C) / x0) - (x0 + ((x0 - C) / ((C / x0) - C))))) / x0)) / (x0 * C)) - x0) + g(x0)))": 1,
  "e(((log(((x0 - C) / log(((((C / x0) / (C - x0)) ^ x0) + C)))) - x0) + g(x0)))": 1,
  "e((((g(x0) / (((e(((C * x0) / (x0 * (C - x0)))) - C) / ((x0 - C) / C)) + x0)) - x0) + g(x0)))": 1,
  "e(((((C + x0) / ((C / (g((C - (x0 - C))) - (x0 + x0))) + x0)) - (x0 - C)) + g(x0)))": 1,
  "e((((x0 / ((x0 + C) * ((C * (C / x0)) - (x0 + ((C * (C / x0)) * C))))) + (log((x0 * (C / (C + x0)))) - x0)) + g(x0)))": 1,
  "e(((((neg(g(x0)) / (C - (C + x0))) - x0) + (C ^ (C - ((x0 * C) ^ (log((C - (x0 ^ C))) - C))))) + g(x0)))": 1,
  "e(((log((x0 / log(((C - (x0 + x0)) + C)))) - x0) + g(x0)))": 1,
  "e(((log(((log(((C - ((((x0 / C) / C) + C) - C)) * C)) - (x0 + ((C / ((log(x0) / x0) - ((x0 / C) + C))) / ((((x0 ^ log(x0)) - x0) / (neg(x0) / (C - x0))) - x0)))) / C)) - x0) + g(x0)))": 1,
  "e((((C - ((x0 / C) / C)) - (x0 - ((C / e((C - (x0 / C)))) / (((C - ((x0 - (x0 / C)) - C)) / C) + ((((x0 / C) - ((x0 / C) - (C + x0))) / (C + ((C - (x0 - C)) / C))) + x0))))) + g(x0)))": 1
}

In [None]:
# Run AlphaSymbolic
# The binaries are in ../Code/build/
%cd AlphaSymbolic
!python app.py


In [None]:
# Backup Learned Formulas to Drive (Run manually or keep active)
import shutil
import os

DRIVE_PATH = '/content/drive/MyDrive/AlphaSymbolic_Models'
FILES_TO_BACKUP = [
    ('AlphaSymbolic/top_formulas.csv', 'top_formulas.csv'),
    ('AlphaSymbolic/pattern_memory.json', 'pattern_memory.json'),
    ('AlphaSymbolic/top_5_detailed_report.csv', 'top_5_detailed_report.csv'),
    ('AlphaSymbolic/results/learned_formulas.csv', 'learned_formulas.csv')
]

os.makedirs(DRIVE_PATH, exist_ok=True)

for src, name in FILES_TO_BACKUP:
    dst = os.path.join(DRIVE_PATH, name)
    try:
        if os.path.exists(src):
            shutil.copy(src, dst)
            print(f"✅ Backup successful: {name}")
    except Exception as e:
        print(f"❌ Backup failed for {name}: {e}")
