In [None]:
# AlphaSymbolic - Unified Hybrid System
# -------------------------------------
# Instructions:
# 1. Runtime -> Change runtime type -> T4 GPU
# 2. Run All

!nvidia-smi

# Install dependencies
!pip install gradio torch torchvision torchaudio scipy matplotlib sympy

# Create Directory Structure
import os
os.makedirs('Code/src', exist_ok=True)
os.makedirs('Code/build', exist_ok=True)
os.makedirs('AlphaSymbolic', exist_ok=True)
directories = ['core', 'data', 'search', 'ui', 'utils']
for d in directories:
    os.makedirs(os.path.join('AlphaSymbolic', d), exist_ok=True)


In [None]:
%%writefile Code/src/AdvancedFeatures.cpp
#include "AdvancedFeatures.h"
#include "Globals.h"
#include "GeneticOperators.h"
#include "Fitness.h"
#include "ExpressionTree.h" // Necesario para tree_to_string en la simplificación
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iostream>
#include <unordered_map>
#include <vector>

//---------------------------------
// EvolutionParameters
//---------------------------------
EvolutionParameters EvolutionParameters::create_default() {
    // Usa constantes globales para los valores por defecto
    return {BASE_MUTATION_RATE, BASE_ELITE_PERCENTAGE, DEFAULT_TOURNAMENT_SIZE, DEFAULT_CROSSOVER_RATE};
}

void EvolutionParameters::mutate(int stagnation_counter) {
    auto& rng = get_rng();
    double aggression_factor = 1.0;
    // Ajuste del factor de agresión basado en el estancamiento
    if (stagnation_counter > STAGNATION_LIMIT_ISLAND / 2) {
        // Aumenta la agresión si hay estancamiento significativo
        aggression_factor = 1.0 + (static_cast<double>(stagnation_counter - STAGNATION_LIMIT_ISLAND / 2) / (STAGNATION_LIMIT_ISLAND / 2.0)) * 0.5; // Escala de 1.0 a 1.5
        aggression_factor = std::min(aggression_factor, 2.0); // Limitar la agresión máxima
    } else if (stagnation_counter < STAGNATION_LIMIT_ISLAND / 4 && stagnation_counter > 0) {
        // Reduce la agresión si no hay mucho estancamiento, pero no es 0
        aggression_factor = 1.0 - (static_cast<double>(STAGNATION_LIMIT_ISLAND / 4 - stagnation_counter) / (STAGNATION_LIMIT_ISLAND / 4.0)) * 0.5; // Escala de 0.5 a 1.0
        aggression_factor = std::max(aggression_factor, 0.5); // Limitar la agresión mínima
    } else if (stagnation_counter == 0) {
        // Muy poco estancamiento, cambios muy pequeños
        aggression_factor = 0.2; // Cambios muy conservadores
    }

    std::uniform_real_distribution<double> base_rate_change(-0.05, 0.05);
    std::uniform_int_distribution<int> base_tourney_change(-2, 2);

    double rate_change_val = base_rate_change(rng) * aggression_factor;
    int tourney_change_val = static_cast<int>(std::round(base_tourney_change(rng) * aggression_factor));
    
    // Asegurar que haya algún cambio si la agresión es alta y el cambio base es 0
    if (aggression_factor > 1.0 && tourney_change_val == 0 && base_tourney_change(rng) != 0) {
         tourney_change_val = (base_tourney_change(rng) > 0) ? 1 : -1;
    }

    // Definir límites dinámicos para los parámetros
    double min_mutation = 0.05;
    double max_mutation_base = 0.5;
    double max_mutation = min_mutation + (max_mutation_base - min_mutation) * (1.0 + aggression_factor / 2.0);

    double min_elite = 0.02;
    double max_elite_base = 0.25;
    double max_elite = min_elite + (max_elite_base - min_elite) * (1.0 + aggression_factor / 2.0);

    int min_tournament = 3;
    int max_tournament_base = 30;
    int max_tournament = min_tournament + static_cast<int>((max_tournament_base - min_tournament) * (1.0 + aggression_factor / 2.0));

    // Aplicar los cambios y asegurar que estén dentro de los límites
    mutation_rate = std::clamp(mutation_rate + rate_change_val, min_mutation, max_mutation);
    elite_percentage = std::clamp(elite_percentage + rate_change_val, min_elite, max_elite);
    tournament_size = std::clamp(tournament_size + tourney_change_val, min_tournament, max_tournament);
    crossover_rate = std::clamp(crossover_rate + rate_change_val, 0.5, 0.95);
}

//---------------------------------
// PatternMemory
//---------------------------------
void PatternMemory::record_success(const NodePtr& tree, double fitness) {
    std::string pattern = extract_pattern(tree);
    if (pattern.empty() || pattern.length() > 50 || pattern.length() < 3 || pattern == "N") return;
    auto it = patterns.find(pattern);
    if (it == patterns.end()) {
        patterns[pattern] = {pattern, fitness, 1, (fitness < INF ? 1.0 : 0.0)};
    } else {
        auto& p = it->second;
        p.uses++;
        double improvement = (fitness < p.best_fitness && p.best_fitness < INF) ? 1.0 : 0.0;
        p.success_rate = ((p.success_rate * (p.uses - 1)) + improvement) / p.uses;
        p.best_fitness = std::min(p.best_fitness, fitness);
    }
}

NodePtr PatternMemory::suggest_pattern_based_tree(int max_depth) {
    if (patterns.empty()) return nullptr;
    std::vector<std::pair<std::string, double>> candidates;
    for (const auto& [pattern_str, info] : patterns) {
        if (info.uses >= PATTERN_MEM_MIN_USES && (info.success_rate > 0.1 || info.best_fitness < PATTERN_RECORD_FITNESS_THRESHOLD)) {
             double weight = info.success_rate + (1.0 / (1.0 + info.best_fitness));
             candidates.emplace_back(pattern_str, weight);
        }
    }
    if (candidates.empty()) return nullptr;
    std::vector<double> weights;
    std::transform(candidates.begin(), candidates.end(), std::back_inserter(weights), [](const auto& p){ return p.second; });
    std::discrete_distribution<> dist(weights.begin(), weights.end());
    auto& rng = get_rng();
    int selected_idx = dist(rng);
    return parse_pattern(candidates[selected_idx].first, max_depth);
}

std::string PatternMemory::extract_pattern(const NodePtr& node) {
    if (!node) return "N";
    switch (node->type) {
        case NodeType::Constant: return "#";
        case NodeType::Variable: return "x";
        case NodeType::Operator:
            return "(" + extract_pattern(node->left) + node->op + extract_pattern(node->right) + ")";
        default: return "?";
    }
}

NodePtr PatternMemory::parse_pattern(const std::string& pattern, int max_depth) {
    // Placeholder implementation
    if (pattern == "#") {
        auto node = std::make_shared<Node>(NodeType::Constant);
        if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cd(CONSTANT_INT_MIN_VALUE, CONSTANT_INT_MAX_VALUE); node->value = static_cast<double>(cd(get_rng())); }
        else { std::uniform_real_distribution<double> cd(CONSTANT_MIN_VALUE, CONSTANT_MAX_VALUE); node->value = cd(get_rng()); }
        if(std::fabs(node->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) node->value = 0.0;
        return node;
    }
    if (pattern == "x") return std::make_shared<Node>(NodeType::Variable);
    if (pattern == "N") return nullptr;
    if (pattern.length() > 3 && pattern.front() == '(' && pattern.back() == ')') {
          return generate_random_tree(max_depth); // Fallback
     }
    return generate_random_tree(max_depth); // Fallback
}

//---------------------------------
// Pareto Optimizer
//---------------------------------
ParetoSolution::ParetoSolution(NodePtr t, double acc, double complexity_val) : tree(std::move(t)), accuracy(acc), complexity(complexity_val), dominated(false) {}

bool ParetoSolution::dominates(const ParetoSolution& other) const {
    bool better_in_one = (accuracy < other.accuracy) || (complexity < other.complexity);
    bool not_worse_in_any = (accuracy <= other.accuracy) && (complexity <= other.complexity);
    return better_in_one && not_worse_in_any;
}

void ParetoOptimizer::update(const std::vector<Individual>& population, const std::vector<double>& targets, const std::vector<double>& x_values) {
    std::vector<ParetoSolution> candidates = pareto_front;
    for (const auto& ind : population) {
        if (ind.tree && ind.fitness_valid && ind.fitness < INF) {
            candidates.emplace_back(ind.tree, ind.fitness, static_cast<double>(tree_size(ind.tree)));
        }
    }
    for (auto& sol1 : candidates) {
        sol1.dominated = false;
        for (const auto& sol2 : candidates) {
            if (&sol1 == &sol2) continue;
            if (sol2.dominates(sol1)) { sol1.dominated = true; break; }
        }
    }
    pareto_front.clear();
    std::copy_if(candidates.begin(), candidates.end(), std::back_inserter(pareto_front),
                 [](const auto& sol) { return !sol.dominated; });
    if (pareto_front.size() > PARETO_MAX_FRONT_SIZE) {
        std::sort(pareto_front.begin(), pareto_front.end(), [](const auto& a, const auto& b){ return a.accuracy < b.accuracy; });
        pareto_front.resize(PARETO_MAX_FRONT_SIZE);
    }
}

std::vector<NodePtr> ParetoOptimizer::get_pareto_solutions() {
    std::vector<NodePtr> result;
    result.reserve(pareto_front.size());
    std::transform(pareto_front.begin(), pareto_front.end(), std::back_inserter(result),
                   [](const auto& sol) { return sol.tree; });
    return result;
}

//---------------------------------
// Domain Constraints
//---------------------------------
bool DomainConstraints::is_valid_recursive(const NodePtr& node) {
     if (!node) return true;
     if (node->type == NodeType::Operator) {
         if (node->op == '/' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return false;
         if (node->op == '^') { // Solo chequear 0^negativo/0
              if (node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE &&
                  node->right && node->right->type == NodeType::Constant && node->right->value <= SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                      return false;
              }
         }
         if ((node->op == '*' || node->op == '/') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return false;
         if ((node->op == '+' || node->op == '-') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return false;
         if (node->op == '*' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return false;
         if (node->op == '+' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return false;
         if (!is_valid_recursive(node->left) || !is_valid_recursive(node->right)) return false;
     }
     return true;
 }

bool DomainConstraints::is_valid(const NodePtr& tree) {
    return is_valid_recursive(tree);
}

NodePtr DomainConstraints::simplify_recursive(NodePtr node) {
    if (!node || node->type != NodeType::Operator) return node;
    node->left = simplify_recursive(node->left);
    node->right = simplify_recursive(node->right);

    // Manejo de hijos nulos
    bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g');

    // Constant Folding (First priority)
    bool left_is_const = (node->left && node->left->type == NodeType::Constant);
    bool right_is_const = (node->right && node->right->type == NodeType::Constant);
    
    // Fold if binary op with 2 constants OR unary op with 1 constant
    if ((left_is_const && right_is_const) || (is_unary && left_is_const)) {
        try {
            double result = evaluate_tree(node, 0.0); 
            if (!std::isnan(result) && !std::isinf(result)) {
                auto cn = std::make_shared<Node>(NodeType::Constant);
                if (FORCE_INTEGER_CONSTANTS) cn->value = std::round(result); else cn->value = result;
                if (std::fabs(cn->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) cn->value = 0.0; return cn;
            }
        } catch (const std::exception&) {}
    }

    if (node->left && !node->right) {
        if (is_unary) return node; // Correct state for unary ops (Constant folding didn't trigger, so var inside)
        return node->left; // Simplify "A op null" -> A (for binary ops? dangerous but existing logic)
    }
    if (!node->left && node->right) return node->right;
    if (!node->left && !node->right) { auto cn = std::make_shared<Node>(NodeType::Constant); cn->value = 1.0; return cn; }

    // Identity Simplifications & Fixes
     if ((node->op == '+' || node->op == '-') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return node->left;
     if (node->op == '+' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return node->right;
     if ((node->op == '*' || node->op == '/') && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return node->left;
     if (node->op == '*' && node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return node->right;
     if (node->op == '*' && ((node->left && node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) || (node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE))) { auto z = std::make_shared<Node>(NodeType::Constant); z->value = 0.0; return z; }
     if (node->op == '^' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value - 1.0) < SIMPLIFY_NEAR_ONE_TOLERANCE) return node->left; // A^1 -> A
     if (node->op == '^' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) { auto o = std::make_shared<Node>(NodeType::Constant); o->value = 1.0; return o; } // A^0 -> 1
    // Fix div by zero (constante)
    if (node->op == '/' && node->right && node->right->type == NodeType::Constant && std::fabs(node->right->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) node->right->value = 1.0;

    // --- NUEVAS REGLAS DE SIMPLIFICACIÓN ---
    // X / X = 1 (si X no es cero)
    if (node->op == '/' && node->left && node->right) {
        if (tree_to_string(node->left) == tree_to_string(node->right)) {
            // Verificar que el divisor no sea cero para evitar 0/0
            if (node->right->type != NodeType::Constant || std::fabs(node->right->value) >= SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                auto one = std::make_shared<Node>(NodeType::Constant);
                one->value = 1.0;
                return one;
            }
        }
    }

    // X - X = 0
    if (node->op == '-' && node->left && node->right) {
        if (tree_to_string(node->left) == tree_to_string(node->right)) {
            auto zero = std::make_shared<Node>(NodeType::Constant);
            zero->value = 0.0;
            return zero;
        }
    }
    // Ya no se hace clamp de exponente constante aquí, se quitó la restricción

    return node;
}

NodePtr DomainConstraints::fix_or_simplify(NodePtr tree) {
    if (!tree) return nullptr;
    NodePtr cloned_tree = clone_tree(tree);
    NodePtr simplified_tree = simplify_recursive(cloned_tree);
    return simplified_tree;
}

//---------------------------------
// Local Improvement
//---------------------------------
//---------------------------------
// Local Improvement
//---------------------------------
void optimize_constants(NodePtr& tree, const std::vector<double>& targets, const std::vector<double>& x_values, double* d_targets, double* d_x_values) {
    if (!tree) return;
    
    // 1. Collect constant nodes
    std::vector<Node*> constants;
    std::vector<Node*> stack;
    stack.push_back(tree.get());
    while(!stack.empty()){
        Node* n = stack.back(); stack.pop_back();
        if(!n) continue;
        if(n->type == NodeType::Constant) constants.push_back(n);
        else if(n->type == NodeType::Operator){
            stack.push_back(n->right.get());
            stack.push_back(n->left.get());
        }
    }
    
    if (constants.empty()) return;

    // 2. Hill Climbing (Numeric Optimization)
    int max_iter = 20; // Fast local search
    auto& rng = get_rng();
    
    // Evaluate initial fitness
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    double current_fitness = evaluate_fitness(tree, targets, x_values, d_targets, d_x_values);
#else
    double current_fitness = evaluate_fitness(tree, targets, x_values);
#endif

    std::normal_distribution<double> perturbation(0.0, 0.5); // Perturb standard deviation 0.5

    for(int i=0; i<max_iter; ++i) {
        // Select a random constant
        int idx = std::uniform_int_distribution<int>(0, constants.size()-1)(rng);
        double old_val = constants[idx]->value;
        
        // Perturb
        double delta = perturbation(rng);
        constants[idx]->value += delta;
        if (FORCE_INTEGER_CONSTANTS) constants[idx]->value = std::round(constants[idx]->value);

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
        double new_fitness = evaluate_fitness(tree, targets, x_values, d_targets, d_x_values);
#else
        double new_fitness = evaluate_fitness(tree, targets, x_values);
#endif

        if (new_fitness < current_fitness) {
            current_fitness = new_fitness; // Accept
            // Adapt perturbation? Maybe reduce sigma?
        } else {
            constants[idx]->value = old_val; // Revert
        }
        
        if (current_fitness < EXACT_SOLUTION_THRESHOLD) break;
    }
}

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree, double current_fitness, const std::vector<double>& targets, const std::vector<double>& x_values, int attempts, double* d_targets, double* d_x_values) {
    // 1. First, try to optimize constants of the CURRENT tree
    NodePtr optimized_tree = clone_tree(tree);
    optimize_constants(optimized_tree, targets, x_values, d_targets, d_x_values);
    double optimized_fitness = evaluate_fitness(optimized_tree, targets, x_values, d_targets, d_x_values);
    
    NodePtr best_neighbor = (optimized_fitness < current_fitness) ? optimized_tree : tree;
    double best_neighbor_fitness = (optimized_fitness < current_fitness) ? optimized_fitness : current_fitness;

    if (best_neighbor_fitness >= INF) return {best_neighbor, best_neighbor_fitness};

    // 2. Structural Search (as before)
    for (int i = 0; i < attempts; ++i) {
        NodePtr neighbor = mutate_tree(best_neighbor, 1.0, 2); // Mutate the BEST so far
        neighbor = DomainConstraints::fix_or_simplify(neighbor);
        if (!neighbor) continue;
        
        // Also optimize constants of structural neighbor?
        // Maybe too expensive. Let's do a quick random constant tweak.
        // optimize_constants(neighbor, targets, x_values, d_targets, d_x_values); 
        
        double neighbor_fitness = evaluate_fitness(neighbor, targets, x_values, d_targets, d_x_values);
        if (neighbor_fitness < best_neighbor_fitness) {
            best_neighbor = neighbor;
            best_neighbor_fitness = neighbor_fitness;
        }
    }
    return {best_neighbor, best_neighbor_fitness};
}
#else
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree, double current_fitness, const std::vector<double>& targets, const std::vector<double>& x_values, int attempts) {
    // 1. First, try to optimize constants of the CURRENT tree
    NodePtr optimized_tree = clone_tree(tree);
    optimize_constants(optimized_tree, targets, x_values, nullptr, nullptr);
    double optimized_fitness = evaluate_fitness(optimized_tree, targets, x_values);
    
    NodePtr best_neighbor = (optimized_fitness < current_fitness) ? optimized_tree : tree;
    double best_neighbor_fitness = (optimized_fitness < current_fitness) ? optimized_fitness : current_fitness;

    if (best_neighbor_fitness >= INF) return {best_neighbor, best_neighbor_fitness};

    for (int i = 0; i < attempts; ++i) {
        NodePtr neighbor = mutate_tree(best_neighbor, 1.0, 2);
        neighbor = DomainConstraints::fix_or_simplify(neighbor);
        if (!neighbor) continue;
        double neighbor_fitness = evaluate_fitness(neighbor, targets, x_values);
        if (neighbor_fitness < best_neighbor_fitness) {
            best_neighbor = neighbor;
            best_neighbor_fitness = neighbor_fitness;
        }
    }
    return {best_neighbor, best_neighbor_fitness};
}
#endif

//---------------------------------
// Target Pattern Detection
//---------------------------------
std::pair<std::string, double> detect_target_pattern(const std::vector<double>& targets) {
    if (targets.size() < 3) return {"none", 0.0};
    bool is_arithmetic = true; double diff = targets[1] - targets[0];
    for (size_t i = 2; i < targets.size(); ++i) if (std::fabs((targets[i] - targets[i-1]) - diff) > 1e-6) { is_arithmetic = false; break; }
    if (is_arithmetic) return {"arithmetic", diff};
    bool is_geometric = true;
    if (std::fabs(targets[0]) < 1e-9) {
        bool all_zero = true; for(double t : targets) if (std::fabs(t) > 1e-9) { all_zero = false; break; }
        if(all_zero) return {"constant_zero", 0.0}; else is_geometric = false;
    }
    if (is_geometric && std::fabs(targets[0]) >= 1e-9) {
        double ratio = targets[1] / targets[0];
        for (size_t i = 2; i < targets.size(); ++i) {
             if (std::fabs(targets[i-1]) < 1e-9) { if (std::fabs(targets[i]) > 1e-9) { is_geometric = false; break; } }
             else { if (std::fabs((targets[i] / targets[i-1]) - ratio) > 1e-6) { is_geometric = false; break; } }
        }
        if (is_geometric) return {"geometric", ratio};
    }
    return {"none", 0.0};
}

//---------------------------------
// Generate Pattern Based Tree
//---------------------------------
NodePtr generate_pattern_based_tree(const std::string& pattern_type, double pattern_value) {
    if (X_VALUES.empty() || RAW_TARGETS.empty()) return nullptr;
    double a = RAW_TARGETS[0]; double x0 = X_VALUES[0];
    if (pattern_type == "arithmetic") {
        double d = pattern_value; auto root = std::make_shared<Node>(NodeType::Operator); root->op = '+';
        auto cp = std::make_shared<Node>(NodeType::Constant); double cv = a - d * x0; if (FORCE_INTEGER_CONSTANTS) cv = std::round(cv); cp->value = (std::fabs(cv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : cv; // Use RAW_TARGETS to avoid "TARGETS" not found

        auto vp = std::make_shared<Node>(NodeType::Operator); vp->op = '*';
        auto dc = std::make_shared<Node>(NodeType::Constant); double dv = d; if (FORCE_INTEGER_CONSTANTS) dv = std::round(dv); dc->value = (std::fabs(dv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : dv;
        auto xv = std::make_shared<Node>(NodeType::Variable); vp->left = dc; vp->right = xv;
        root->left = cp; root->right = vp; return DomainConstraints::fix_or_simplify(root);
    } else if (pattern_type == "geometric") {
        double r = pattern_value; if (std::fabs(r) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return nullptr;
        auto root = std::make_shared<Node>(NodeType::Operator); root->op = '*';
        auto cp = std::make_shared<Node>(NodeType::Constant); double rpx0 = std::pow(r, x0); if (std::fabs(rpx0) < 1e-100) return nullptr;
        double cv = a / rpx0; if (FORCE_INTEGER_CONSTANTS) cv = std::round(cv); cp->value = (std::fabs(cv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : cv;
        auto vp = std::make_shared<Node>(NodeType::Operator); vp->op = '^';
        auto rc = std::make_shared<Node>(NodeType::Constant); double rv = r; if (FORCE_INTEGER_CONSTANTS) rv = std::round(rv); rc->value = (std::fabs(rv) < SIMPLIFY_NEAR_ZERO_TOLERANCE) ? 0.0 : rv;
        auto xv = std::make_shared<Node>(NodeType::Variable); vp->left = rc; vp->right = xv;
        root->left = cp; root->right = vp; return DomainConstraints::fix_or_simplify(root);
    } else if (pattern_type == "constant_zero") {
         auto node = std::make_shared<Node>(NodeType::Constant); node->value = 0.0; return node;
     }
    return nullptr; // No pattern tree generated
}


In [None]:
%%writefile Code/src/AdvancedFeatures.h
#ifndef ADVANCEDFEATURES_H
#define ADVANCEDFEATURES_H

#include "ExpressionTree.h"
#include "Globals.h" // Incluir Globals.h para INF
#include <vector>
#include <string>
#include <map>
#include <set>
#include <utility> // Para std::pair
#include <unordered_map>

// Meta-evolución: Parámetros que pueden adaptarse durante la ejecución.
struct EvolutionParameters {
    double mutation_rate;    // Tasa de mutación actual
    double elite_percentage; // Porcentaje de élite actual
    int tournament_size;     // Tamaño del torneo actual
    double crossover_rate;   // Tasa de cruce actual

    // Crea un conjunto de parámetros con valores por defecto (iniciales).
    static EvolutionParameters create_default();

    // Adapta (muta) los parámetros ligeramente.
    // AHORA RECIBE el contador de estancamiento para ajustar la intensidad.
    void mutate(int stagnation_counter);
};

// Memoria de patrones: Almacena sub-estructuras exitosas (Reinforcement Learning).
class PatternMemory {
    struct PatternInfo {
        std::string pattern_str; // Representación del patrón
        double best_fitness = INF; // Mejor fitness visto para este patrón
        int uses = 0;             // Número de veces usado/visto
        double success_rate = 0.0; // Tasa de éxito estimada
    };
    std::unordered_map<std::string, PatternInfo> patterns; // Mapa para almacenar patrones
    int min_uses_for_suggestion = 3; // Mínimo de usos para considerar sugerir un patrón

public:
    // Registra el éxito de un árbol (y su patrón) basado en su fitness.
    void record_success(const NodePtr& tree, double fitness);
    // Sugiere un árbol basado en los patrones exitosos almacenados.
    NodePtr suggest_pattern_based_tree(int max_depth);

private:
    // Extrae la representación estructural (string) de un árbol.
    std::string extract_pattern(const NodePtr& tree);
    // Intenta construir un árbol a partir de un patrón (string) - función simplificada.
    NodePtr parse_pattern(const std::string& pattern, int max_depth);
};


// Optimización Pareto: Mantiene un frente de soluciones no dominadas (compromiso precisión/complejidad).
struct ParetoSolution {
    NodePtr tree = nullptr;   // Árbol de la solución
    double accuracy = INF;    // Objetivo 1: Precisión (fitness)
    double complexity = INF;  // Objetivo 2: Complejidad (tamaño)
    bool dominated = false;   // Bandera: ¿está dominada por otra solución?

    // Constructor por defecto (necesario si se usa en contenedores)
    ParetoSolution() = default;
    // Constructor principal
    ParetoSolution(NodePtr t, double acc, double complexity_val);

    // Comprueba si esta solución domina a otra.
    bool dominates(const ParetoSolution& other) const;
};

class ParetoOptimizer {
    std::vector<ParetoSolution> pareto_front; // Almacena las soluciones del frente
    size_t max_front_size = 50; // Límite opcional para el tamaño del frente

public:
    // Actualiza el frente de Pareto con individuos de la población actual.
    void update(const std::vector<struct Individual>& population, // Usa Individual struct
                const std::vector<double>& targets,
                const std::vector<double>& x_values);

    // Obtiene los árboles (NodePtr) de las soluciones en el frente actual.
    std::vector<NodePtr> get_pareto_solutions();

    // Obtiene una referencia constante al frente de Pareto completo.
    const std::vector<ParetoSolution>& get_pareto_front() const { return pareto_front; }
};


// Restricciones de Dominio: Verifica y corrige/simplifica árboles problemáticos.
class DomainConstraints {
public:
    // Comprueba si un árbol cumple reglas básicas de validez estática.
    static bool is_valid(const NodePtr& tree);

    // Intenta simplificar/corregir un árbol (devuelve una copia modificada).
    static NodePtr fix_or_simplify(NodePtr tree);

private:
     // Ayudante recursivo para la simplificación.
    static NodePtr simplify_recursive(NodePtr node);
    // Ayudante recursivo para la validación estática.
    static bool is_valid_recursive(const NodePtr& node);
};

// Búsqueda Local: Intenta mejorar una solución dada explorando vecinos cercanos.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree,
                                                  double current_fitness,
                                                  const std::vector<double>& targets,
                                                  const std::vector<double>& x_values,
                                                  int attempts,
                                                  double* d_targets, double* d_x_values);
#else
std::pair<NodePtr, double> try_local_improvement(const NodePtr& tree,
                                                  double current_fitness,
                                                  const std::vector<double>& targets,
                                                  const std::vector<double>& x_values,
                                                  int attempts = 10);
#endif


// Detección de Patrones en los Datos Objetivo.
std::pair<std::string, double> detect_target_pattern(const std::vector<double>& targets);
NodePtr generate_pattern_based_tree(const std::string& pattern_type, double pattern_value);


#endif // ADVANCEDFEATURES_H


In [None]:
%%writefile Code/src/ExpressionTree.cpp
#include "ExpressionTree.h"
#include "Globals.h"
#include <cmath>
#include <limits>
#include <stdexcept>
#include <vector>
#include <iostream>
#include <iomanip>
#include <string>
#include <sstream>
#include <stack>
#include <unordered_map>
#include <algorithm> // Para std::remove_if
#include <cctype>    // Para isdigit, isspace
#include <thread>    // Para thread_local RNG

// --- Función auxiliar para formatear constantes ---
// --- Función auxiliar para formatear constantes ---
std::string format_constant(double val) {
    // Si es un entero o muy cercano a un entero, formatarlo como tal.
    if (std::fabs(val - std::round(val)) < SIMPLIFY_NEAR_ZERO_TOLERANCE) {
        return std::to_string(static_cast<long long>(std::round(val)));
    } else {
        std::ostringstream oss;
        // Usar notación científica para valores muy grandes o muy pequeños,
        // o notación fija para el resto, con precisión adecuada.
        // Esto evita cadenas muy largas o pérdida de información.
        if (std::fabs(val) >= 1e6 || std::fabs(val) <= 1e-6) { // Umbrales ajustables
            oss << std::scientific << std::setprecision(8) << val;
        } else {
            oss << std::fixed << std::setprecision(8) << val;
        }
        
        std::string s = oss.str();
        // Eliminar ceros finales y el punto decimal si no hay parte fraccionaria
        // Esto puede ser delicado con std::scientific, así que hay que ser cuidadosos.
        // Para std::fixed:
        if (s.find('.') != std::string::npos) {
            s.erase(s.find_last_not_of('0') + 1, std::string::npos);
            if (!s.empty() && s.back() == '.') s.pop_back();
        }
        return s.empty() ? "0" : s;
    }
}

// --- evaluate_tree ---
double evaluate_tree(const NodePtr& node, double x) {
    if (!node) return std::nan("");
    switch (node->type) {
        case NodeType::Constant: return node->value;
        case NodeType::Variable: return x;
        case NodeType::Operator: {
            // Determine arity
            bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g' || node->op == 't' || node->op == 'q' || node->op == 'a' || node->op == 'n' || node->op == 'u');

            double leftVal = evaluate_tree(node->left, x);
            double rightVal = 0.0;
            if (!is_unary) {
                rightVal = evaluate_tree(node->right, x);
            }

            if (std::isnan(leftVal)) return std::nan("");
            if (!is_unary && std::isnan(rightVal)) return std::nan("");

            double result = std::nan("");
            try {
                switch (node->op) {
                    case '+': result = leftVal + rightVal; break;
                    case '-': result = leftVal - rightVal; break;
                    case '*': result = leftVal * rightVal; break;
                    case '/':
                        if (std::fabs(rightVal) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return INF;
                        result = leftVal / rightVal;
                        break;
                    case '^':
                        if (leftVal == 0.0 && rightVal == 0.0) result = 1.0;
                        else if (leftVal == 0.0 && rightVal < 0.0) return INF;
                        else if (leftVal < 0.0 && std::fabs(rightVal - std::round(rightVal)) > SIMPLIFY_NEAR_ZERO_TOLERANCE) return INF;
                        else result = std::pow(leftVal, rightVal);
                        break;
                    case '%':
                        if (std::fabs(rightVal) < SIMPLIFY_NEAR_ZERO_TOLERANCE) return INF;
                        result = std::fmod(leftVal, rightVal);
                        break;
                    case 's': result = std::sin(leftVal); break;
                    case 'c': result = std::cos(leftVal); break;
                    case 't': result = std::tan(leftVal); break;
                    case 'q': 
                        // Protected Sqrt: sqrt(|x|)
                        result = std::sqrt(std::abs(leftVal)); 
                        break;
                    case 'a': result = std::abs(leftVal); break;
                    case 'n': result = (leftVal > 0) ? 1.0 : ((leftVal < 0) ? -1.0 : 0.0); break;
                    case 'l': 
                        // Protected Log: log(|x|)
                        if (std::abs(leftVal) <= 1e-9) return INF; 
                        result = std::log(std::abs(leftVal)); 
                        break;
                    case 'e': 
                        if (leftVal > 700.0) return INF; // Overflow check
                        result = std::exp(leftVal); 
                        break;
                    case '!': 
                        // Protected Factorial/Gamma: tgamma(|x|+1)
                        if (std::abs(leftVal) > 170.0) return INF; 
                        result = std::tgamma(std::abs(leftVal) + 1.0); 
                        break;
                    case '_': result = std::floor(leftVal); break;
                    case 'u': result = std::ceil(leftVal); break; // 'u' for ceil (up)
                    case 'g':
                        result = std::lgamma(std::abs(leftVal) + 1.0); 
                        break;
                    default: return std::nan("");
                }
            } catch (const std::exception& e) { return INF; }
            if (std::isinf(result)) return INF;
            if (std::isnan(result)) return std::nan("");
            return result;
        }
        default: return std::nan("");
    }
}

// --- tree_to_string ---
std::string tree_to_string(const NodePtr& node) {
     if (!node) return "NULL";
     switch (node->type) {
        case NodeType::Constant: return format_constant(node->value);
        case NodeType::Variable: return "x";
        case NodeType::Operator: {
            NodePtr left_node = node->left;
            std::string left_str = tree_to_string(left_node);
            
            // Check arity
            bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g' || node->op == 't' || node->op == 'q' || node->op == 'a' || node->op == 'n' || node->op == 'u');

            if (is_unary) {
                switch(node->op) {
                    case 's': return "sin(" + left_str + ")";
                    case 'c': return "cos(" + left_str + ")";
                    case 't': return "tan(" + left_str + ")";
                    case 'q': return "sqrt(" + left_str + ")";
                    case 'a': return "abs(" + left_str + ")";
                    case 'n': return "sign(" + left_str + ")";
                    case 'l': return "log(" + left_str + ")";
                    case 'e': return "exp(" + left_str + ")";
                    case '!': return "(" + left_str + ")!"; // Postfix for factorial
                    case '_': return "floor(" + left_str + ")";
                    case 'u': return "ceil(" + left_str + ")";
                    case 'g': return "lgamma(" + left_str + ")";
                    default: return "op(" + left_str + ")";
                }
            }

            NodePtr right_node = node->right;
            std::string right_str = tree_to_string(right_node);
            char current_op = node->op;
            bool right_is_neg_const = (right_node && right_node->type == NodeType::Constant && right_node->value < 0.0);
            if (right_is_neg_const) {
                double abs_right_val = std::fabs(right_node->value);
                std::string abs_right_str = format_constant(abs_right_val);
                if (node->op == '+') { current_op = '-'; right_str = abs_right_str; }
                else if (node->op == '-') { current_op = '+'; right_str = abs_right_str; }
            }
            // Simplificar impresión de (0-A) a (-A)
            if (left_node && left_node->type == NodeType::Constant && left_node->value == 0.0 && current_op == '-') {
                 return "(-" + right_str + ")";
            }
            return "(" + left_str + current_op + right_str + ")";
        }
        default: return "?";
    }
}

// --- tree_size ---
int tree_size(const NodePtr& node) {
    if (!node) return 0;
    if (node->type == NodeType::Constant || node->type == NodeType::Variable) return 1;
    if (node->type == NodeType::Operator) {
        return 1 + tree_size(node->left) + tree_size(node->right);
    }
    return 0;
}

// --- clone_tree ---
NodePtr clone_tree(const NodePtr& node) {
    if (!node) return nullptr;
    auto new_node = std::make_shared<Node>();
    new_node->type = node->type;
    new_node->value = node->value;
    new_node->op = node->op;
    new_node->left = clone_tree(node->left);
    new_node->right = clone_tree(node->right);
    return new_node;
}

// --- collect_node_ptrs ---
void collect_node_ptrs(NodePtr& node, std::vector<NodePtr*>& vec) {
    if (!node) return;
    vec.push_back(&node);
    if (node->type == NodeType::Operator) {
        collect_node_ptrs(node->left, vec);
        collect_node_ptrs(node->right, vec);
    }
}

// --- get_rng ---
// === OPTIMIZACIÓN: RNG thread-local para evitar contención en OpenMP ===
std::mt19937& get_rng() {
    thread_local std::mt19937 local_rng(
        std::random_device{}() ^ 
        static_cast<unsigned>(std::hash<std::thread::id>{}(std::this_thread::get_id()))
    );
    return local_rng;
}

// --- get_tree_depth ---
int get_tree_depth(const NodePtr& node) {
    if (!node) return 0;
    if (node->type != NodeType::Operator) return 1;
    return 1 + std::max(get_tree_depth(node->left), get_tree_depth(node->right));
}

// --- trim_tree ---
void trim_tree(NodePtr& node, int max_depth) {
    if (!node) return;
    if (max_depth <= 1) {
        // Force terminal if we reached depth limit
        if (node->type == NodeType::Operator) {
            // Replace with minimal terminal (Variable 'x' or Constant 1.0)
            // Using 'x' is generally safer for retaining some logic, but 1.0 is neutral for *
            // Let's pick a random terminal to avoid bias? 
            // For now, let's just make it a variable 'x' as it's often more useful than a constant 0 or 1.
             node->type = NodeType::Variable;
             node->op = 0;
             node->left = nullptr;
             node->right = nullptr;
             // value ignored for variable
        }
        return;
    }
    
    if (node->type == NodeType::Operator) {
        trim_tree(node->left, max_depth - 1);
        trim_tree(node->right, max_depth - 1);
    }
}


// ============================================================
// --- Parser de Fórmulas desde String (v4 - Parser Corregido) ---
// ============================================================

// Helper para obtener precedencia de operadores
int get_precedence(char op) {
    switch (op) {
        case '+': case '-': return 1;
        case '*': case '/': case '%': return 2;
        case '^': return 3;
        default: return 0;
    }
}

// Helper para aplicar un operador binario
NodePtr apply_binary_operation(NodePtr right, NodePtr left, char op) {
    if (!left || !right) {
        throw std::runtime_error("Error al aplicar operación binaria '" + std::string(1, op) + "': operandos insuficientes.");
    }
    auto node = std::make_shared<Node>(NodeType::Operator);
    node->op = op;
    node->left = left;
    node->right = right;
    return node;
}

// Función principal para parsear la fórmula
NodePtr parse_formula_string(const std::string& formula_raw) {
    std::string formula = formula_raw;
    formula.erase(std::remove_if(formula.begin(), formula.end(), ::isspace), formula.end());
    if (formula.empty()) throw std::runtime_error("La fórmula está vacía.");

    std::stack<NodePtr> operand_stack;
    std::stack<char> operator_stack;

    // Función interna para procesar operadores según precedencia y asociatividad
    auto process_operators_by_precedence = [&](int current_precedence, char current_op_char = 0) {
        // La asociatividad derecha para '^' significa que se procesa si el operador en la pila
        // tiene MAYOR precedencia, no MAYOR O IGUAL.
        bool is_right_associative = (current_op_char == '^');

        while (!operator_stack.empty() && operator_stack.top() != '(') {
            char top_op = operator_stack.top();
            int top_precedence = get_precedence(top_op);

            if (is_right_associative ? (top_precedence > current_precedence) : (top_precedence >= current_precedence)) {
                operator_stack.pop(); // Sacar operador de la pila
                if (operand_stack.size() < 2) throw std::runtime_error("Operandos insuficientes para operador '" + std::string(1, top_op) + "'.");
                NodePtr right = operand_stack.top(); operand_stack.pop();
                NodePtr left = operand_stack.top(); operand_stack.pop();
                operand_stack.push(apply_binary_operation(right, left, top_op));
            } else {
                break; // Parar si la precedencia es menor o si es asociativo a la derecha y es igual
            }
        }
    };

    bool last_token_was_operand = false;

    for (int i = 0; i < formula.length(); /* Incremento manual */ ) {
        char token = formula[i];

        // --- A. Parsear Números ---
        bool starts_number = isdigit(token) || (token == '.' && i + 1 < formula.length() && isdigit(formula[i+1]));
        if (starts_number) {
             if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
                 last_token_was_operand = false;
             }
            std::string num_str;
            if (token == '.') num_str += '0';
            num_str += token;
            i++;
            while (i < formula.length() && (isdigit(formula[i]) || (formula[i] == '.' && num_str.find('.') == std::string::npos))) {
                num_str += formula[i];
                i++;
            }
            try {
                double value = std::stod(num_str);
                auto node = std::make_shared<Node>(NodeType::Constant); node->value = value;
                operand_stack.push(node);
                last_token_was_operand = true;
            } catch (const std::invalid_argument& e) {
                throw std::runtime_error("Número inválido (formato): '" + num_str + "' - " + e.what());
            } catch (const std::out_of_range& e) {
                throw std::runtime_error("Número inválido (rango): '" + num_str + "' - " + e.what());
            }
            continue;
        }

        // --- B. Parsear Funciones Unarias y Constantes ---
        std::unordered_map<std::string, char> func_map = {
            {"sin", 's'}, {"cos", 'c'}, {"tan", 't'}, 
            {"log", 'l'}, {"exp", 'e'}, {"sqrt", 'q'},
            {"floor", '_'}, {"ceil", '^'}, {"abs", 'a'}, {"sign", 'n'},
            {"gamma", '!'}, {"lgamma", 'g'}, {"g", 'g'}
        };

        // Special handling for Constants (pi, e, C)
        if (token == 'C') {
             if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
             }
             auto node = std::make_shared<Node>(NodeType::Constant); 
             // Default constant value (will be optimized later)
             node->value = 1.0; 
             // Mark it specifically as an optimizable constant in a way that clone/optimize respects?
             // Actually, for C++ GP, usually constants are just numbers. 
             // But if we want to preserve 'C' semantics:
             // Let's treat 'C' as a special Variable? No, Variable is x.
             // Let's just treat it as 1.0 for now, or use a special Op 'C'?
             // The system typically optimizes *numeric* constants attached to nodes.
             // If we parse 'C', we should probably parse it as a random constant?
             // Or better, a Constant node with a placeholder value.
             // Re-reading ExpressionTree.h might help, but let's stick to 1.0 for now.
             operand_stack.push(node);
             last_token_was_operand = true;
             i++;
             continue;
        }
        if (i + 1 < formula.length() && formula.substr(i, 2) == "pi") {
             if (last_token_was_operand) {
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
             }
             auto node = std::make_shared<Node>(NodeType::Constant); node->value = 3.14159265359;
             operand_stack.push(node);
             last_token_was_operand = true;
             i += 2;
             continue;
        }
        if (token == 'e' && (i+1 >= formula.length() || formula[i+1] != 'x')) { // Check it's not 'exp'
             // Handle 'e' constant
             if (last_token_was_operand) {
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
             }
             auto node = std::make_shared<Node>(NodeType::Constant); node->value = 2.71828182846;
             operand_stack.push(node);
             last_token_was_operand = true;
             i++;
             continue;
        }
        
        // Try to match function names (check longer names first)
        bool matched_func = false;
        for (const auto& [func_name, func_op] : func_map) {
            if (i + func_name.length() <= formula.length() && 
                formula.substr(i, func_name.length()) == func_name &&
                (i + func_name.length() >= formula.length() || formula[i + func_name.length()] == '(')) {
                
                // Check if this is actually a function call (followed by '(')
                size_t after_name = i + func_name.length();
                if (after_name < formula.length() && formula[after_name] == '(') {
                    if (last_token_was_operand) { // Implicit multiplication
                        process_operators_by_precedence(get_precedence('*'));
                        operator_stack.push('*');
                        last_token_was_operand = false;
                    }
                    
                    // Find the matching closing parenthesis
                    int paren_count = 1;
                    size_t arg_start = after_name + 1;
                    size_t j = arg_start;
                    while (j < formula.length() && paren_count > 0) {
                        if (formula[j] == '(') paren_count++;
                        else if (formula[j] == ')') paren_count--;
                        j++;
                    }
                    if (paren_count != 0) {
                        throw std::runtime_error("Paréntesis sin cerrar en función '" + func_name + "'.");
                    }
                    size_t arg_end = j - 1; // Position of closing ')'
                    
                    // Extract and recursively parse the argument
                    std::string arg_str = formula.substr(arg_start, arg_end - arg_start);
                    NodePtr arg_tree = parse_formula_string(arg_str);
                    
                    // Create unary operator node
                    auto func_node = std::make_shared<Node>(NodeType::Operator);
                    func_node->op = func_op;
                    func_node->left = arg_tree;
                    func_node->right = nullptr;
                    
                    operand_stack.push(func_node);
                    last_token_was_operand = true;
                    i = j; // Skip past the closing ')'
                    matched_func = true;
                    break;
                }
            }
        }
        if (matched_func) continue;

        // --- C. Parsear Variable 'x' ---
        if (token == 'x') {
            if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
                 last_token_was_operand = false;
            }
            auto node = std::make_shared<Node>(NodeType::Variable);
            operand_stack.push(node);
            last_token_was_operand = true;
            i++;
            continue;
        }

        // --- D. Parsear Paréntesis de Apertura '(' ---
        if (token == '(') {
            if (last_token_was_operand) { // Implicit multiplication
                 process_operators_by_precedence(get_precedence('*'));
                 operator_stack.push('*');
                 last_token_was_operand = false;
            }
            operator_stack.push('(');
            last_token_was_operand = false;
            i++;
            continue;
        }

        // --- E. Parsear Paréntesis de Cierre ')' ---
        if (token == ')') {
             if (!last_token_was_operand) {
                  if (!operator_stack.empty() && operator_stack.top() == '(') throw std::runtime_error("Paréntesis vacíos '()' encontrados.");
                  else throw std::runtime_error("Se esperaba un operando antes de ')'.");
             }
            while (!operator_stack.empty() && operator_stack.top() != '(') {
                process_operators_by_precedence(0);
            }
            if (operator_stack.empty()) throw std::runtime_error("Paréntesis ')' sin correspondiente '('.");
            operator_stack.pop(); // Sacar '('
            last_token_was_operand = true;
            i++;
            continue;
        }

        // --- F. Parsear Operadores (+ - * / ^ %) ---
        if (std::string("+-*/^%").find(token) != std::string::npos) {
            // Manejar '-' unario vs binario
            if (token == '-' && !last_token_was_operand) {
                // Es un '-' unario. Insertar un 0 como operando izquierdo implícito.
                // Esto permite tratar el '-' como un operador binario normal.
                auto zero_node = std::make_shared<Node>(NodeType::Constant); zero_node->value = 0.0;
                operand_stack.push(zero_node);
                // No cambiar last_token_was_operand a true, ya que el 0 implícito
                // es solo para el operador unario y no un operando "real" previo.
                // Si hubiera una multiplicación implícita (ej. "2-x"), ya se habría manejado.
            }
            // Ignorar '+' unario (no afecta el valor, no necesita un 0 implícito)
            else if (token == '+' && !last_token_was_operand) {
                // No hacer nada, simplemente avanzar al siguiente token
                i++;
                continue;
            }
            
            // Operador binario normal
            if (!last_token_was_operand && (token == '*' || token == '/' || token == '^' || token == '%')) {
                throw std::runtime_error("Operador binario '" + std::string(1, token) + "' inesperado. Se esperaba operando.");
            }

            // Procesar operadores en la pila con mayor o igual precedencia (o solo mayor para asociativos a derecha)
            process_operators_by_precedence(get_precedence(token), token);
            operator_stack.push(token);
            last_token_was_operand = false; // Después de un operador, se espera un operando
            i++;
            continue;
        }

        // --- G. Token Desconocido ---
        throw std::runtime_error("Token desconocido en la fórmula: '" + std::string(1, token) + "'");

    } // Fin del bucle for

    // --- H. Procesamiento Final después del bucle ---
    while (!operator_stack.empty()) {
        if (operator_stack.top() == '(') throw std::runtime_error("Paréntesis '(' sin cerrar al final.");
        // Procesar todos los operadores restantes en la pila
        process_operators_by_precedence(0); // 0 como precedencia mínima para forzar el procesamiento
    }

    // Verificación final de la pila de operandos
    if (operand_stack.size() != 1) {
         if (operand_stack.empty() && formula.length() > 0) throw std::runtime_error("Error: No se generó ningún resultado del parseo. Fórmula inválida?");
         else if (operand_stack.size() > 1) throw std::runtime_error("Error en la estructura final (operandos restantes: " + std::to_string(operand_stack.size()) + "). Verifique operadores.");
         else throw std::runtime_error("Error desconocido al finalizar el parseo.");
    }

    return operand_stack.top();
}


In [None]:
%%writefile Code/src/ExpressionTree.h
#ifndef EXPRESSIONTREE_H
#define EXPRESSIONTREE_H

#include <memory>
#include <string>
#include <vector>
#include <stdexcept> // Para std::runtime_error

// Forward declaration
struct Node;

// Use shared_ptr for automatic memory management
using NodePtr = std::shared_ptr<Node>;

enum class NodeType { Constant, Variable, Operator };

struct Node {
    NodeType type;
    double value = 0.0;             // If type == Constant
    char op = 0;                    // If type == Operator: '+', '-', '*', '/', '^'
    NodePtr left = nullptr;         // Children (for Operators)
    NodePtr right = nullptr;

    // Constructor for convenience
    Node(NodeType t = NodeType::Constant) : type(t) {}
};

// Core Tree Functions
double evaluate_tree(const NodePtr& node, double x);
std::string tree_to_string(const NodePtr& node);
int tree_size(const NodePtr& node);
NodePtr clone_tree(const NodePtr& node);
int get_tree_depth(const NodePtr& node);
void trim_tree(NodePtr& node, int max_depth);

// Helper for mutation/crossover
void collect_node_ptrs(NodePtr& node, std::vector<NodePtr*>& vec);

// --- NUEVO: Función para parsear una fórmula desde string ---
// Parsea una fórmula en notación infija simple (con paréntesis).
// Lanza std::runtime_error si hay error de sintaxis.
NodePtr parse_formula_string(const std::string& formula);


#endif // EXPRESSIONTREE_H


In [None]:
%%writefile Code/src/Fitness.cpp
#include "Fitness.h"
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
#include "FitnessGPU.cuh" // Include for GPU fitness evaluation
#endif
#include "Globals.h" // Necesario para constantes globales e INF
#include "ExpressionTree.h" // Necesario para tree_to_string
#include <cmath>
#include <limits>
#include <vector>
#include <numeric>
#include <iostream> // Para std::cerr en caso de error futuro
#include <iomanip>  // Para std::fixed/scientific si se necesita en errores

// Calculates the raw fitness using global parameters.
// This function will now dispatch to GPU if USE_GPU_ACCELERATION is enabled.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<double>& x_values,
                             double* d_targets, double* d_x_values) {
    // If GPU pointers are null (FORCE_CPU_MODE), use CPU evaluation
    if (d_targets == nullptr || d_x_values == nullptr) {
        // CPU fallback implementation
        if (x_values.size() != targets.size() || x_values.empty()) return INF;

        double sum_sq_error = 0.0;
        double total_weight = 0.0;
        bool all_precise = true;
        size_t num_points = x_values.size();
        bool calculation_failed = false;

        for (size_t i = 0; i < num_points; ++i) {
            double predicted_val = evaluate_tree(tree, x_values[i]);

            if (std::isnan(predicted_val) || std::isinf(predicted_val)) {
                calculation_failed = true;
                break;
            }

            double target_val = targets[i];
            double diff = predicted_val - target_val;
            double abs_diff = std::fabs(diff);

            if (abs_diff >= FITNESS_PRECISION_THRESHOLD) all_precise = false;

            double weight = 1.0;
            if (USE_WEIGHTED_FITNESS) {
                weight = std::exp(static_cast<double>(i) * WEIGHTED_FITNESS_EXPONENT);
            }
            total_weight += weight;

            double sq_error = diff * diff;
            sum_sq_error += sq_error * weight;
        }

        if (calculation_failed) return INF;

        // Normalize weighted error
        double raw_error;
        if (USE_WEIGHTED_FITNESS && total_weight > 0.0) {
            sum_sq_error = sum_sq_error / total_weight * num_points;
        }

        if (USE_RMSE_FITNESS && num_points > 0) {
            double mse = sum_sq_error / static_cast<double>(num_points);
            raw_error = std::sqrt(mse);
        } else {
            raw_error = sum_sq_error;
        }

        if (std::isnan(raw_error) || std::isinf(raw_error) || raw_error < 0) {
            return INF;
        }

        if (all_precise) {
            raw_error *= FITNESS_PRECISION_BONUS;
        }

        return raw_error;
    }
    
    // Use GPU evaluation
    return evaluate_fitness_gpu(tree, targets, x_values, d_targets, d_x_values);
}
#else
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<double>& x_values) {
    if (x_values.size() != targets.size() || x_values.empty()) return INF;

    double error_sum_pow13 = 0.0; // Solo si USE_RMSE_FITNESS = false
    double sum_sq_error = 0.0;
    double total_weight = 0.0; // Para normalizar el fitness ponderado
    bool all_precise = true;
    size_t num_points = x_values.size();
    bool calculation_failed = false; // Flag para detectar INF/NaN

    for (size_t i = 0; i < num_points; ++i) {
        double predicted_val = evaluate_tree(tree, x_values[i]);

        // Comprobar si la evaluación falló (INF o NaN)
        if (std::isnan(predicted_val) || std::isinf(predicted_val)) {
            calculation_failed = true;
            break; // Salir del bucle si la evaluación falla para un punto
        }

        double target_val = targets[i];
        double diff = predicted_val - target_val;
        double abs_diff = std::fabs(diff);

        if (abs_diff >= FITNESS_PRECISION_THRESHOLD) all_precise = false;

        // --- PESO PARA FITNESS PONDERADO ---
        // Hace que los últimos puntos (N altos) valgan muchísimo más.
        // Esto destruye a los polinomios porque fallan al final.
        double weight = 1.0;
        if (USE_WEIGHTED_FITNESS) {
            // Peso exponencial: más agresivo para penalizar errores en N altos
            weight = std::exp(static_cast<double>(i) * WEIGHTED_FITNESS_EXPONENT);
        }
        total_weight += weight;

        // Acumular error para ambas métricas (si aplica)
        if (!USE_RMSE_FITNESS) {
             error_sum_pow13 += std::pow(abs_diff, FITNESS_ORIGINAL_POWER) * weight;
        }

        // Calcular y acumular error cuadrático PONDERADO
        double sq_diff = diff * diff;
        sum_sq_error += sq_diff * weight;

        // Control de desbordamiento/Infinito en la suma
        if (std::isinf(sum_sq_error) || (error_sum_pow13 >= INF / 10.0 && !USE_RMSE_FITNESS)) {
            calculation_failed = true;
            break;
        }
    } // Fin bucle for puntos

    // Si la evaluación o suma falló en algún punto, devolver INF
    if (calculation_failed) {
        return INF;
    }

    // Seleccionar métrica de error crudo
    double raw_error;
    if (USE_RMSE_FITNESS) {
        if (num_points == 0 || total_weight == 0.0) return INF;
        // MSE ponderado: normalizar por suma de pesos, no por num_points
        double mse = sum_sq_error / total_weight;
        if (std::isinf(mse) || std::isnan(mse) || mse < 0) {
             raw_error = INF;
        } else {
             raw_error = std::sqrt(mse); // Calcular RMSE ponderado
        }
    } else {
        raw_error = error_sum_pow13;
    }

    // Comprobar si el error crudo es inválido
    if (std::isnan(raw_error) || std::isinf(raw_error) || raw_error < 0) {
         return INF;
    }

    // Aplicar bonus de precisión si todos los puntos estaban dentro del umbral
    if (all_precise) {
         raw_error *= FITNESS_PRECISION_BONUS;
    }

    return raw_error; // Devolver el error crudo (sin penalización por complejidad aún)
}
#endif // USE_GPU_ACCELERATION_DEFINED_BY_CMAKE

// Calcula el fitness final usando parámetros globales.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<double>& x_values,
                        double* d_targets, double* d_x_values) {
    double raw_fitness = calculate_raw_fitness(tree, targets, x_values, d_targets, d_x_values);
#else
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<double>& x_values) {
    double raw_fitness = calculate_raw_fitness(tree, targets, x_values);
#endif

    if (raw_fitness >= INF / 10.0) {
         return INF; // Si el error crudo es infinito, el fitness final es infinito
    }

    // Penalización por complejidad
    double complexity = static_cast<double>(tree_size(tree));
    double penalty = complexity * COMPLEXITY_PENALTY_FACTOR; // Usa constante global

    // Aplicar penalización multiplicativa
    double final_fitness = raw_fitness * (1.0 + penalty);

    // Comprobaciones finales
    if (std::isnan(final_fitness) || std::isinf(final_fitness) || final_fitness < 0) {
         return INF;
    }

    return final_fitness;
}


In [None]:
%%writefile Code/src/Fitness.h
#ifndef FITNESS_H
#define FITNESS_H

#include "ExpressionTree.h"
#include <vector>

// Calculates raw fitness based on target matching
// Lower is better. Returns INF if evaluation results in NaN/Inf.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<double>& x_values,
                             double* d_targets, double* d_x_values);
#else
double calculate_raw_fitness(const NodePtr& tree,
                             const std::vector<double>& targets,
                             const std::vector<double>& x_values);
#endif

// Calculates final fitness including complexity penalty
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<double>& x_values,
                        double* d_targets, double* d_x_values);
#else
double evaluate_fitness(const NodePtr& tree,
                        const std::vector<double>& targets,
                        const std::vector<double>& x_values);
#endif

#endif // FITNESS_H


In [None]:
%%writefile Code/src/FitnessGPU.cu
#include "FitnessGPU.cuh"
#include "Globals.h"
#include <cuda_runtime.h>
#include <math.h>
#include <vector>
#include <iostream>

// Helper function to linearize the tree into a post-order array
void linearize_tree(const NodePtr& node, std::vector<LinearGpuNode>& linear_tree) {
    if (!node) {
        return;
    }
    linearize_tree(node->left, linear_tree);
    linearize_tree(node->right, linear_tree);
    linear_tree.push_back({node->type, node->value, node->op});
}

#if USE_GPU_ACCELERATION_DEFINED_BY_CMAKE

// Constant for large finite value
#define GPU_MAX_DOUBLE 1e308

// --- WEIGHTED FITNESS: Constantes para CUDA ---
// Estas deben coincidir con los valores en Globals.h
// CUDA device code no puede acceder a const C++, así que usamos #define
#define GPU_USE_WEIGHTED_FITNESS true
#define GPU_WEIGHTED_FITNESS_EXPONENT 0.25

// Single Tree Evaluation Kernel (Legacy/Single Use)
__global__ void calculate_raw_fitness_kernel(const LinearGpuNode* d_linear_tree,
                                             int tree_size,
                                             const double* d_targets,
                                             const double* d_x_values,
                                             size_t num_points,
                                             double* d_raw_fitness_results) {
    // Shared memory optimization: Load tree into shared memory
    extern __shared__ LinearGpuNode s_linear_tree[];

    // Cooperative load
    for (int i = threadIdx.x; i < tree_size; i += blockDim.x) {
        s_linear_tree[i] = d_linear_tree[i];
    }
    __syncthreads();

    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < num_points) {
        double x_val = d_x_values[idx];
        double stack[64]; // Max tree depth
        int stack_top = -1;

        for (int i = 0; i < tree_size; ++i) {
            LinearGpuNode node = s_linear_tree[i]; // Access from shared memory
            if (node.type == NodeType::Constant) {
                stack[++stack_top] = node.value;
            } else if (node.type == NodeType::Variable) {
                stack[++stack_top] = x_val;
            } else if (node.type == NodeType::Operator) {
                bool is_unary = (node.op == 's' || node.op == 'c' || node.op == 'l' || node.op == 'e' || node.op == '!' || node.op == '_' || node.op == 'g');
                double result = 0.0;
                
                if (is_unary) {
                     if (stack_top < 0) {
                         result = GPU_MAX_DOUBLE;
                     } else {
                         double val = stack[stack_top--];
                         switch (node.op) {
                            case 's': result = sin(val); break;
                            case 'c': result = cos(val); break;
                            case 'l': result = (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val); break;
                            case 'e': result = (val > 700.0) ? GPU_MAX_DOUBLE : exp(val); break;
                            case '!': result = (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0); break;
                            case '_': result = floor(val); break;
                            case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); break;
                            default: result = NAN; break;
                         }
                     }
                     stack[++stack_top] = result;
                } else {
                    if (stack_top < 1) { 
                        result = GPU_MAX_DOUBLE;
                        stack[++stack_top] = result; // Push error
                    } else {
                        double right = stack[stack_top--];
                        double left = stack[stack_top--];
                        switch (node.op) {
                            case '+': result = left + right; break;
                            case '-': result = left - right; break;
                            case '*': result = left * right; break;
                            case '/':
                                if (fabs(right) < 1e-9) { // Avoid division by zero
                                    result = GPU_MAX_DOUBLE; 
                                } else {
                                    result = left / right;
                                }
                                break;
                            case '^': result = pow(left, right); break;
                            case '%':
                                if (fabs(right) < 1e-9) result = GPU_MAX_DOUBLE;
                                else result = fmod(left, right);
                                break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    }
                }
            }
        }

        double predicted_val = (stack_top == 0) ? stack[0] : NAN;

        if (isnan(predicted_val) || isinf(predicted_val)) {
            d_raw_fitness_results[idx] = GPU_MAX_DOUBLE; 
        } else {
            double diff = predicted_val - d_targets[idx];
            double sq_error = diff * diff;
            // --- WEIGHTED FITNESS: Apply exponential weight ---
            // Los últimos puntos (N altos) pesan mucho más que los primeros.
            if (GPU_USE_WEIGHTED_FITNESS) {
                double weight = exp((double)idx * GPU_WEIGHTED_FITNESS_EXPONENT);
                sq_error *= weight;
            }
            d_raw_fitness_results[idx] = sq_error;
        }
    }
}

// CUDA kernel for parallel reduction (summation)
__global__ void reduce_sum_kernel(double* d_data, int N) {
    extern __shared__ double sdata[]; // Shared memory for reduction

    unsigned int tid = threadIdx.x;
    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

    sdata[tid] = (i < N) ? d_data[i] : 0.0; // Load data into shared memory

    __syncthreads(); // Synchronize threads in block

    // Perform reduction in shared memory
    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0) { // Write result back to global memory (first element of block)
        d_data[blockIdx.x] = sdata[0];
    }
}


// --- New Batch Kernel ---
// Evaluates one tree per thread across all data points
__global__ void evaluate_population_kernel(const LinearGpuNode* d_all_nodes,
                                           const int* d_offsets,
                                           const int* d_sizes,
                                           int pop_size,
                                           const double* d_targets,
                                           const double* d_x_values,
                                           int num_points,
                                           double* d_results) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < pop_size) {
        int offset = d_offsets[idx];
        int size = d_sizes[idx];
        double sum_sq_error = 0.0;
        double total_weight = 0.0; // Para normalizar fitness ponderado
        bool valid = true;

        for (int p = 0; p < num_points; ++p) {
            double x_val = d_x_values[p];
            double stack[64]; 
            int stack_top = -1;

            // Simple interpreter
            for (int i = 0; i < size; ++i) {
                LinearGpuNode node = d_all_nodes[offset + i];
                if (node.type == NodeType::Constant) {
                    stack[++stack_top] = node.value;
                } else if (node.type == NodeType::Variable) {
                    stack[++stack_top] = x_val;
                } else if (node.type == NodeType::Operator) {
                    bool is_unary = (node.op == 's' || node.op == 'c' || node.op == 'l' || node.op == 'e' || node.op == '!' || node.op == '_' || node.op == 'g');
                    
                    if (is_unary) {
                        if (stack_top < 0) { valid = false; break; }
                        double val = stack[stack_top--];
                        double result = 0.0;
                         switch (node.op) {
                            case 's': result = sin(val); break;
                            case 'c': result = cos(val); break;
                            case 'l': result = (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val); break;
                            case 'e': result = (val > 700.0) ? GPU_MAX_DOUBLE : exp(val); break;
                            case '!': result = (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0); break;
                            case '_': result = floor(val); break;
                            case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); break;
                             default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    } else {
                        // Safety check index
                        if (stack_top < 1) { valid = false; break; }

                        double right = stack[stack_top--];
                        double left = stack[stack_top--];
                        double result;
                        switch (node.op) {
                            case '+': result = left + right; break;
                            case '-': result = left - right; break;
                            case '*': result = left * right; break;
                            case '/':
                                if (fabs(right) < 1e-9) { 
                                    result = GPU_MAX_DOUBLE; 
                                } else {
                                    result = left / right;
                                }
                                break;
                            case '^': result = pow(left, right); break;
                            case '%':
                                if (fabs(right) < 1e-9) result = GPU_MAX_DOUBLE;
                                else result = fmod(left, right);
                                break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    }
                }
            }

            if (!valid || stack_top != 0) {
                sum_sq_error = GPU_MAX_DOUBLE;
                break;
            }

            double predicted_val = stack[0];
            if (isnan(predicted_val) || isinf(predicted_val)) {
                sum_sq_error = GPU_MAX_DOUBLE;
                break;
            }

            double diff = predicted_val - d_targets[p];
            double sq_error = diff * diff;
            
            // --- WEIGHTED FITNESS: Peso exponencial ---
            double weight = 1.0;
            if (GPU_USE_WEIGHTED_FITNESS) {
                weight = exp((double)p * GPU_WEIGHTED_FITNESS_EXPONENT);
            }
            total_weight += weight;
            sum_sq_error += sq_error * weight;
        }

        // Normalizar por suma de pesos para obtener MSE ponderado
        if (GPU_USE_WEIGHTED_FITNESS && total_weight > 0.0) {
            sum_sq_error = sum_sq_error / total_weight * num_points; // Escalar de vuelta
        }
        d_results[idx] = sum_sq_error;
    }
}


// Host-side wrapper function to launch the CUDA kernel
double evaluate_fitness_gpu(NodePtr tree,
                            const std::vector<double>& targets,
                            const std::vector<double>& x_values,
                            double* d_targets, double* d_x_values) {
    if (x_values.size() != targets.size() || x_values.empty()) return INF;

    // Linearize the tree
    std::vector<LinearGpuNode> h_linear_tree;
    linearize_tree(tree, h_linear_tree);
    int tree_size = h_linear_tree.size();

    if (tree_size == 0) {
        return INF;
    }

    size_t num_points = x_values.size();
    LinearGpuNode* d_linear_tree;
    double* d_raw_fitness_results; // This will hold individual errors and then the final sum

    cudaMalloc((void**)&d_linear_tree, tree_size * sizeof(LinearGpuNode));
    cudaMalloc((void**)&d_raw_fitness_results, num_points * sizeof(double));

    cudaMemcpy(d_linear_tree, h_linear_tree.data(), tree_size * sizeof(LinearGpuNode), cudaMemcpyHostToDevice);

    int threadsPerBlock = 256;
    int blocksPerGrid = (num_points + threadsPerBlock - 1) / threadsPerBlock;

    // Launch kernel to calculate individual squared errors
    size_t shared_mem_size = tree_size * sizeof(LinearGpuNode);
    calculate_raw_fitness_kernel<<<blocksPerGrid, threadsPerBlock, shared_mem_size>>>(
        d_linear_tree, tree_size, d_targets, d_x_values, num_points, d_raw_fitness_results
    );
    cudaDeviceSynchronize(); // Ensure kernel completes before reduction

    // --- Perform reduction on the GPU ---
    int current_size = num_points;
    while (current_size > 1) {
        int next_blocks_per_grid = (current_size + threadsPerBlock - 1) / threadsPerBlock;
        // Use shared memory for reduction, size is threadsPerBlock * sizeof(double)
        reduce_sum_kernel<<<next_blocks_per_grid, threadsPerBlock, threadsPerBlock * sizeof(double)>>>(
            d_raw_fitness_results, current_size
        );
        cudaDeviceSynchronize(); // Ensure reduction step completes
        current_size = next_blocks_per_grid; // The result is in the first `next_blocks_per_grid` elements
    }

    double sum_sq_error_gpu = 0.0;
    cudaMemcpy(&sum_sq_error_gpu, d_raw_fitness_results, sizeof(double), cudaMemcpyDeviceToHost);

    cudaFree(d_linear_tree);
    cudaFree(d_raw_fitness_results);

    // Check for invalid results (propagated from kernel)
    if (isinf(sum_sq_error_gpu) || isnan(sum_sq_error_gpu)) {
        return INF;
    }

    double raw_fitness;
    if (USE_RMSE_FITNESS) {
        if (num_points == 0) return INF;
        double mse = sum_sq_error_gpu / num_points;
        raw_fitness = sqrt(mse);
    } else {
        raw_fitness = sum_sq_error_gpu;
    }

    double complexity = static_cast<double>(::tree_size(tree));
    double penalty = complexity * COMPLEXITY_PENALTY_FACTOR;
    double final_fitness = raw_fitness * (1.0 + penalty);

    if (isnan(final_fitness) || isinf(final_fitness) || final_fitness < 0) {
        return INF;
    }

    return final_fitness;
}

void evaluate_population_gpu(const std::vector<LinearGpuNode>& all_nodes,
                             const std::vector<int>& tree_offsets,
                             const std::vector<int>& tree_sizes,
                             const std::vector<double>& targets,
                             const std::vector<double>& x_values,
                             std::vector<double>& results,
                             double* d_targets, double* d_x_values,
                             void*& d_nodes_ptr, size_t& d_nodes_cap,
                             void*& d_offsets_ptr, void*& d_sizes_ptr, void*& d_results_ptr, size_t& d_pop_cap) {
    
    int pop_size = tree_offsets.size();
    if (pop_size == 0) return;

    size_t total_nodes = all_nodes.size();
    int num_points = x_values.size();

    // Buffer Management for Nodes
    if (total_nodes > d_nodes_cap) {
        if (d_nodes_ptr) cudaFree(d_nodes_ptr);
        size_t new_cap = total_nodes * 1.5; // Growth factor
        cudaMalloc(&d_nodes_ptr, new_cap * sizeof(LinearGpuNode));
        d_nodes_cap = new_cap;
    }

    // Buffer Management for Population Arrays
    if (pop_size > d_pop_cap) {
        if (d_offsets_ptr) cudaFree(d_offsets_ptr);
        if (d_sizes_ptr) cudaFree(d_sizes_ptr);
        if (d_results_ptr) cudaFree(d_results_ptr);
        
        size_t new_cap = pop_size * 1.5;
        cudaMalloc(&d_offsets_ptr, new_cap * sizeof(int));
        cudaMalloc(&d_sizes_ptr, new_cap * sizeof(int));
        cudaMalloc(&d_results_ptr, new_cap * sizeof(double));
        d_pop_cap = new_cap;
    }

    LinearGpuNode* d_all_nodes = (LinearGpuNode*)d_nodes_ptr;
    int* d_offsets = (int*)d_offsets_ptr;
    int* d_sizes = (int*)d_sizes_ptr;
    double* d_results = (double*)d_results_ptr;

    cudaMemcpy(d_all_nodes, all_nodes.data(), total_nodes * sizeof(LinearGpuNode), cudaMemcpyHostToDevice);
    cudaMemcpy(d_offsets, tree_offsets.data(), pop_size * sizeof(int), cudaMemcpyHostToDevice);
    cudaMemcpy(d_sizes, tree_sizes.data(), pop_size * sizeof(int), cudaMemcpyHostToDevice);

    int threadsPerBlock = 256;
    int blocksPerGrid = (pop_size + threadsPerBlock - 1) / threadsPerBlock;

    evaluate_population_kernel<<<blocksPerGrid, threadsPerBlock>>>(
        d_all_nodes, d_offsets, d_sizes, pop_size, d_targets, d_x_values, num_points, d_results
    );

    // Synchronize and copy back
    cudaDeviceSynchronize();
    
    cudaMemcpy(results.data(), d_results, pop_size * sizeof(double), cudaMemcpyDeviceToHost);
}

// ============================================================
// GLOBAL BATCH EVALUATION - Maximum GPU Utilization
// ============================================================

void init_global_gpu_buffers(GlobalGpuBuffers& buffers) {
    // Create CUDA stream for async operations
    cudaStream_t stream;
    cudaStreamCreate(&stream);
    buffers.cuda_stream = (void*)stream;
    
    // Pre-allocate initial buffers (will grow as needed)
    // Initial capacity for 50,000 trees with ~30 nodes each
    buffers.d_nodes_capacity = 1500000;
    buffers.d_pop_capacity = 60000;
    
    cudaMalloc(&buffers.d_nodes, buffers.d_nodes_capacity * sizeof(LinearGpuNode));
    cudaMalloc(&buffers.d_offsets, buffers.d_pop_capacity * sizeof(int));
    cudaMalloc(&buffers.d_sizes, buffers.d_pop_capacity * sizeof(int));
    cudaMalloc(&buffers.d_results, buffers.d_pop_capacity * sizeof(double));
}

void cleanup_global_gpu_buffers(GlobalGpuBuffers& buffers) {
    if (buffers.cuda_stream) {
        cudaStreamDestroy((cudaStream_t)buffers.cuda_stream);
        buffers.cuda_stream = nullptr;
    }
    if (buffers.d_nodes) { cudaFree(buffers.d_nodes); buffers.d_nodes = nullptr; }
    if (buffers.d_offsets) { cudaFree(buffers.d_offsets); buffers.d_offsets = nullptr; }
    if (buffers.d_sizes) { cudaFree(buffers.d_sizes); buffers.d_sizes = nullptr; }
    if (buffers.d_results) { cudaFree(buffers.d_results); buffers.d_results = nullptr; }
    buffers.d_nodes_capacity = 0;
    buffers.d_pop_capacity = 0;
}

// Optimized kernel: Process one tree per thread, apply complexity penalty on GPU
// Uses shared memory for targets and x_values for better memory coalescing
__global__ void evaluate_all_populations_kernel(
    const LinearGpuNode* __restrict__ d_all_nodes,
    const int* __restrict__ d_offsets,
    const int* __restrict__ d_sizes,
    int total_trees,
    const double* __restrict__ d_targets,
    const double* __restrict__ d_x_values,
    int num_points,
    double* __restrict__ d_results,
    double complexity_penalty_factor,
    bool use_rmse) 
{
    // Shared memory for targets and x_values (max 64 points supported)
    __shared__ double s_targets[64];
    __shared__ double s_x_values[64];
    
    // Cooperatively load targets and x_values into shared memory
    int tid_local = threadIdx.x;
    if (tid_local < num_points) {
        s_targets[tid_local] = d_targets[tid_local];
        s_x_values[tid_local] = d_x_values[tid_local];
    }
    __syncthreads();
    
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < total_trees) {
        int offset = d_offsets[idx];
        int size = d_sizes[idx];
        double sum_sq_error = 0.0;
        double total_weight = 0.0;
        bool valid = true;

        // Evaluate tree on all data points
        for (int p = 0; p < num_points && valid; ++p) {
            double x_val = s_x_values[p];
            double stack[64]; 
            int stack_top = -1;

            // Interpret the linearized tree
            for (int i = 0; i < size && valid; ++i) {
                LinearGpuNode node = d_all_nodes[offset + i];
                
                if (node.type == NodeType::Constant) {
                    stack[++stack_top] = node.value;
                } else if (node.type == NodeType::Variable) {
                    stack[++stack_top] = x_val;
                } else if (node.type == NodeType::Operator) {
                    bool is_unary = (node.op == 's' || node.op == 'c' || node.op == 'l' || 
                                    node.op == 'e' || node.op == '!' || node.op == '_' || node.op == 'g');
                    
                    if (is_unary) {
                        if (stack_top < 0) { valid = false; break; }
                        double val = stack[stack_top--];
                        double result = 0.0;
                        switch (node.op) {
                            case 's': result = sin(val); break;
                            case 'c': result = cos(val); break;
                            case 'l': result = (val <= 1e-9) ? GPU_MAX_DOUBLE : log(val); break;
                            case 'e': result = (val > 700.0) ? GPU_MAX_DOUBLE : exp(val); break;
                            case '!': result = (val < 0 || val > 170.0) ? GPU_MAX_DOUBLE : tgamma(val + 1.0); break;
                            case '_': result = floor(val); break;
                            case 'g': result = (val <= -1.0) ? GPU_MAX_DOUBLE : lgamma(val + 1.0); break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    } else {
                        if (stack_top < 1) { valid = false; break; }
                        double right = stack[stack_top--];
                        double left = stack[stack_top--];
                        double result;
                        switch (node.op) {
                            case '+': result = left + right; break;
                            case '-': result = left - right; break;
                            case '*': result = left * right; break;
                            case '/': result = (fabs(right) < 1e-9) ? GPU_MAX_DOUBLE : left / right; break;
                            case '^': result = pow(left, right); break;
                            case '%': result = (fabs(right) < 1e-9) ? GPU_MAX_DOUBLE : fmod(left, right); break;
                            default: result = NAN; break;
                        }
                        stack[++stack_top] = result;
                    }
                }
            }

            if (!valid || stack_top != 0) {
                sum_sq_error = GPU_MAX_DOUBLE;
                valid = false;
                break;
            }

            double predicted_val = stack[0];
            if (isnan(predicted_val) || isinf(predicted_val)) {
                sum_sq_error = GPU_MAX_DOUBLE;
                valid = false;
                break;
            }

            double diff = predicted_val - s_targets[p];
            double sq_error = diff * diff;
            
            // Weighted fitness
            double weight = 1.0;
            if (GPU_USE_WEIGHTED_FITNESS) {
                weight = exp((double)p * GPU_WEIGHTED_FITNESS_EXPONENT);
            }
            total_weight += weight;
            sum_sq_error += sq_error * weight;
        }

        // Calculate final fitness with complexity penalty ON GPU
        double raw_fitness = GPU_MAX_DOUBLE;
        if (valid && sum_sq_error < 1e300) {
            if (GPU_USE_WEIGHTED_FITNESS && total_weight > 0.0) {
                sum_sq_error = sum_sq_error / total_weight * num_points;
            }
            
            if (use_rmse && num_points > 0) {
                double mse = sum_sq_error / num_points;
                raw_fitness = sqrt(mse);
            } else {
                raw_fitness = sum_sq_error;
            }
            
            // Apply complexity penalty (size is same as tree size in linearized form)
            double complexity = (double)size;
            double penalty = complexity * complexity_penalty_factor;
            raw_fitness = raw_fitness * (1.0 + penalty);
        }
        
        d_results[idx] = raw_fitness;
    }
}

void evaluate_all_populations_gpu(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    const std::vector<int>& tree_complexities,
    int total_trees,
    const std::vector<double>& targets,
    const std::vector<double>& x_values,
    std::vector<double>& results,
    double* d_targets, double* d_x_values,
    GlobalGpuBuffers& buffers)
{
    if (total_trees == 0) return;
    
    cudaStream_t stream = (cudaStream_t)buffers.cuda_stream;
    size_t total_nodes = all_nodes.size();
    int num_points = x_values.size();

    // Dynamic buffer resizing with growth factor
    if (total_nodes > buffers.d_nodes_capacity) {
        if (buffers.d_nodes) cudaFree(buffers.d_nodes);
        size_t new_cap = total_nodes * 1.5;
        cudaMalloc(&buffers.d_nodes, new_cap * sizeof(LinearGpuNode));
        buffers.d_nodes_capacity = new_cap;
    }

    if ((size_t)total_trees > buffers.d_pop_capacity) {
        if (buffers.d_offsets) cudaFree(buffers.d_offsets);
        if (buffers.d_sizes) cudaFree(buffers.d_sizes);
        if (buffers.d_results) cudaFree(buffers.d_results);
        
        size_t new_cap = total_trees * 1.5;
        cudaMalloc(&buffers.d_offsets, new_cap * sizeof(int));
        cudaMalloc(&buffers.d_sizes, new_cap * sizeof(int));
        cudaMalloc(&buffers.d_results, new_cap * sizeof(double));
        buffers.d_pop_capacity = new_cap;
    }

    LinearGpuNode* d_all_nodes = (LinearGpuNode*)buffers.d_nodes;
    int* d_offsets = (int*)buffers.d_offsets;
    int* d_sizes = (int*)buffers.d_sizes;
    double* d_results = (double*)buffers.d_results;

    // Async memory transfers using CUDA stream
    cudaMemcpyAsync(d_all_nodes, all_nodes.data(), total_nodes * sizeof(LinearGpuNode), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_offsets, tree_offsets.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_sizes, tree_sizes.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);

    // Optimized kernel launch configuration for RTX 3050
    // RTX 3050 has 20 SMs, each can handle 2048 threads max
    // For 50k trees, we want maximum occupancy
    int threadsPerBlock = 256;
    int blocksPerGrid = (total_trees + threadsPerBlock - 1) / threadsPerBlock;

    // Launch kernel on stream
    evaluate_all_populations_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        d_all_nodes, d_offsets, d_sizes, total_trees,
        d_targets, d_x_values, num_points, d_results,
        COMPLEXITY_PENALTY_FACTOR, USE_RMSE_FITNESS
    );

    // Async copy results back
    cudaMemcpyAsync(results.data(), d_results, total_trees * sizeof(double), 
                    cudaMemcpyDeviceToHost, stream);
    
    // Synchronize stream
    cudaStreamSynchronize(stream);
}

// ============================================================
// DOUBLE-BUFFERED GPU IMPLEMENTATION
// ============================================================

void init_double_buffered_gpu(DoubleBufferedGpu& db) {
    // Create two streams for overlapped execution
    cudaStreamCreate((cudaStream_t*)&db.streams[0]);
    cudaStreamCreate((cudaStream_t*)&db.streams[1]);
    
    // Pre-allocate buffers for both ping and pong
    size_t initial_nodes_cap = 1500000;  // 50k trees * 30 nodes avg
    size_t initial_pop_cap = 60000;      // Slightly more than 50k
    
    for (int i = 0; i < 2; ++i) {
        db.d_nodes_capacity[i] = initial_nodes_cap;
        db.d_pop_capacity[i] = initial_pop_cap;
        
        cudaMalloc(&db.d_nodes[i], initial_nodes_cap * sizeof(LinearGpuNode));
        cudaMalloc(&db.d_offsets[i], initial_pop_cap * sizeof(int));
        cudaMalloc(&db.d_sizes[i], initial_pop_cap * sizeof(int));
        cudaMalloc(&db.d_results[i], initial_pop_cap * sizeof(double));
    }
    
    // Allocate pinned host memory for faster H2D/D2H transfers
    db.h_pinned_capacity = initial_pop_cap;
    cudaMallocHost(&db.h_pinned_results, initial_pop_cap * sizeof(double));
    
    db.current_buffer = 0;
}

void cleanup_double_buffered_gpu(DoubleBufferedGpu& db) {
    for (int i = 0; i < 2; ++i) {
        if (db.streams[i]) {
            cudaStreamDestroy((cudaStream_t)db.streams[i]);
            db.streams[i] = nullptr;
        }
        if (db.d_nodes[i]) { cudaFree(db.d_nodes[i]); db.d_nodes[i] = nullptr; }
        if (db.d_offsets[i]) { cudaFree(db.d_offsets[i]); db.d_offsets[i] = nullptr; }
        if (db.d_sizes[i]) { cudaFree(db.d_sizes[i]); db.d_sizes[i] = nullptr; }
        if (db.d_results[i]) { cudaFree(db.d_results[i]); db.d_results[i] = nullptr; }
    }
    
    if (db.h_pinned_results) {
        cudaFreeHost(db.h_pinned_results);
        db.h_pinned_results = nullptr;
    }
}

void launch_evaluation_async(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    int total_trees,
    double* d_targets, double* d_x_values,
    int num_points,
    DoubleBufferedGpu& db)
{
    if (total_trees == 0) return;
    
    int buf = db.current_buffer;
    cudaStream_t stream = (cudaStream_t)db.streams[buf];
    size_t total_nodes = all_nodes.size();
    
    // Ensure buffers are large enough
    if (total_nodes > db.d_nodes_capacity[buf]) {
        cudaFree(db.d_nodes[buf]);
        size_t new_cap = total_nodes * 1.5;
        cudaMalloc(&db.d_nodes[buf], new_cap * sizeof(LinearGpuNode));
        db.d_nodes_capacity[buf] = new_cap;
    }
    
    if ((size_t)total_trees > db.d_pop_capacity[buf]) {
        cudaFree(db.d_offsets[buf]);
        cudaFree(db.d_sizes[buf]);
        cudaFree(db.d_results[buf]);
        
        size_t new_cap = total_trees * 1.5;
        cudaMalloc(&db.d_offsets[buf], new_cap * sizeof(int));
        cudaMalloc(&db.d_sizes[buf], new_cap * sizeof(int));
        cudaMalloc(&db.d_results[buf], new_cap * sizeof(double));
        db.d_pop_capacity[buf] = new_cap;
    }
    
    // Ensure pinned results buffer is large enough
    if ((size_t)total_trees > db.h_pinned_capacity) {
        cudaFreeHost(db.h_pinned_results);
        db.h_pinned_capacity = total_trees * 1.5;
        cudaMallocHost(&db.h_pinned_results, db.h_pinned_capacity * sizeof(double));
    }
    
    LinearGpuNode* d_nodes = (LinearGpuNode*)db.d_nodes[buf];
    int* d_offsets = (int*)db.d_offsets[buf];
    int* d_sizes = (int*)db.d_sizes[buf];
    double* d_results = (double*)db.d_results[buf];
    
    // Async transfers
    cudaMemcpyAsync(d_nodes, all_nodes.data(), total_nodes * sizeof(LinearGpuNode), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_offsets, tree_offsets.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);
    cudaMemcpyAsync(d_sizes, tree_sizes.data(), total_trees * sizeof(int), 
                    cudaMemcpyHostToDevice, stream);
    
    // Launch kernel
    int threadsPerBlock = 256;
    int blocksPerGrid = (total_trees + threadsPerBlock - 1) / threadsPerBlock;
    
    evaluate_all_populations_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
        d_nodes, d_offsets, d_sizes, total_trees,
        d_targets, d_x_values, num_points, d_results,
        COMPLEXITY_PENALTY_FACTOR, USE_RMSE_FITNESS
    );
    
    // Async copy results to pinned memory
    cudaMemcpyAsync(db.h_pinned_results, d_results, total_trees * sizeof(double), 
                    cudaMemcpyDeviceToHost, stream);
    
    // DO NOT SYNC HERE - let CPU do other work
}

void retrieve_results_sync(
    std::vector<double>& results,
    int total_trees,
    DoubleBufferedGpu& db)
{
    int buf = db.current_buffer;
    cudaStream_t stream = (cudaStream_t)db.streams[buf];
    
    // Wait for this stream to complete
    cudaStreamSynchronize(stream);
    
    // Copy from pinned memory to results vector (this is very fast - memory to memory)
    results.resize(total_trees);
    memcpy(results.data(), db.h_pinned_results, total_trees * sizeof(double));
    
    // Switch to other buffer for next generation
    db.current_buffer = 1 - buf;
}

#endif // USE_GPU_ACCELERATION_DEFINED_BY_CMAKE




In [None]:
%%writefile Code/src/FitnessGPU.cuh
#ifndef FITNESS_GPU_CUH
#define FITNESS_GPU_CUH

#include <vector>
#include <memory> // For NodePtr in the host-side wrapper
#include "ExpressionTree.h" // For NodeType enum and original Node structure (host-side)
#include "Globals.h" // For INF, USE_RMSE_FITNESS, COMPLEXITY_PENALTY_FACTOR etc.

// Forward declaration for host-side NodePtr
struct Node;
using NodePtr = std::shared_ptr<Node>;

// A simplified node structure for the linearized tree on the GPU
struct LinearGpuNode {
    NodeType type;
    double value;
    char op;
};

// Helper function to linearize the tree into a post-order array
void linearize_tree(const NodePtr& node, std::vector<LinearGpuNode>& linear_tree);

// Host-side wrapper for launching CUDA kernel
#if USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
double evaluate_fitness_gpu(NodePtr tree,
                            const std::vector<double>& targets,
                            const std::vector<double>& x_values,
                            double* d_targets, double* d_x_values);

// Batch evaluation function with persistent buffers
void evaluate_population_gpu(const std::vector<LinearGpuNode>& all_nodes,
                             const std::vector<int>& tree_offsets,
                             const std::vector<int>& tree_sizes,
                             const std::vector<double>& targets,
                             const std::vector<double>& x_values,
                             std::vector<double>& results,
                             double* d_targets, double* d_x_values,
                             void*& d_nodes_ptr, size_t& d_nodes_cap,
                             void*& d_offsets_ptr, void*& d_sizes_ptr, void*& d_results_ptr, size_t& d_pop_cap);

// ============================================================
// GLOBAL BATCH EVALUATION - Evaluates ALL islands in ONE kernel call
// ============================================================
// Persistent GPU buffers for global batch (managed by GeneticAlgorithm)
struct GlobalGpuBuffers {
    void* d_nodes = nullptr;
    void* d_offsets = nullptr;
    void* d_sizes = nullptr;
    void* d_results = nullptr;
    size_t d_nodes_capacity = 0;
    size_t d_pop_capacity = 0;
    void* cuda_stream = nullptr; // cudaStream_t
};

// ============================================================
// DOUBLE-BUFFERED GPU EVALUATION - Maximum overlap of CPU/GPU work
// ============================================================
struct DoubleBufferedGpu {
    // Two sets of device buffers for ping-pong operation
    void* d_nodes[2] = {nullptr, nullptr};
    void* d_offsets[2] = {nullptr, nullptr};
    void* d_sizes[2] = {nullptr, nullptr};
    void* d_results[2] = {nullptr, nullptr};
    size_t d_nodes_capacity[2] = {0, 0};
    size_t d_pop_capacity[2] = {0, 0};
    
    // Two streams for overlapped execution
    void* streams[2] = {nullptr, nullptr};
    
    // Current buffer index (0 or 1)
    int current_buffer = 0;
    
    // Host-side pinned memory for faster transfers
    void* h_pinned_results = nullptr;
    size_t h_pinned_capacity = 0;
};

// Initialize double-buffered GPU resources
void init_double_buffered_gpu(DoubleBufferedGpu& db);

// Cleanup double-buffered GPU resources
void cleanup_double_buffered_gpu(DoubleBufferedGpu& db);

// Async launch - starts GPU work without waiting (CPU can do other work)
void launch_evaluation_async(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    int total_trees,
    double* d_targets, double* d_x_values,
    int num_points,
    DoubleBufferedGpu& db);

// Wait for GPU work to complete and retrieve results
void retrieve_results_sync(
    std::vector<double>& results,
    int total_trees,
    DoubleBufferedGpu& db);

// Initialize global GPU buffers and CUDA stream
void init_global_gpu_buffers(GlobalGpuBuffers& buffers);

// Cleanup global GPU buffers
void cleanup_global_gpu_buffers(GlobalGpuBuffers& buffers);

// Evaluate ALL trees from ALL islands in a single GPU batch call (maximum GPU utilization)
void evaluate_all_populations_gpu(
    const std::vector<LinearGpuNode>& all_nodes,
    const std::vector<int>& tree_offsets,
    const std::vector<int>& tree_sizes,
    const std::vector<int>& tree_complexities, // For complexity penalty
    int total_trees,
    const std::vector<double>& targets,
    const std::vector<double>& x_values,
    std::vector<double>& results,
    double* d_targets, double* d_x_values,
    GlobalGpuBuffers& buffers);
#endif

#endif // FITNESS_GPU_CUH


In [None]:
%%writefile Code/src/GeneticAlgorithm.cpp
#include "GeneticAlgorithm.h"
#include "Globals.h"
#include "Fitness.h"
#include "AdvancedFeatures.h" // Incluir este para DomainConstraints::
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
#include "FitnessGPU.cuh"     // Para funciones de GPU
#include <cuda_runtime.h>     // Para CUDA runtime
#endif
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
#include <omp.h>
#include <iomanip>
#include <iterator>
#include <chrono>
#include <unordered_set>

// --- Constructor (Modificado para que evaluate_population procese todo) ---
GeneticAlgorithm::GeneticAlgorithm(const std::vector<double>& targets_ref,
                                     const std::vector<double>& x_values_ref,
                                     int total_pop,
                                     int gens,
                                     const std::vector<std::string>& seeds,
                                     int n_islands)
    : targets(targets_ref),
      x_values(x_values_ref),
      total_population_size(total_pop),
      generations(gens),
      num_islands(n_islands),
      overall_best_fitness(INF),
      last_overall_best_fitness(INF),
      generation_last_improvement(0)
{
    // Validar y ajustar número de islas y población por isla
    if (this->num_islands <= 0) this->num_islands = 1;
    pop_per_island = this->total_population_size / this->num_islands;
    if (pop_per_island < MIN_POP_PER_ISLAND) {
        pop_per_island = MIN_POP_PER_ISLAND;
        this->num_islands = this->total_population_size / pop_per_island;
        if (this->num_islands == 0) this->num_islands = 1;
        std::cerr << "Warning: Adjusted number of islands to " << this->num_islands
                  << " for minimum population size per island (" << pop_per_island <<")." << std::endl;
    }
    this->total_population_size = this->num_islands * pop_per_island;
    std::cout << "Info: Running with " << this->num_islands << " islands, "
              << pop_per_island << " individuals per island." << std::endl;

    // Crear las islas
    islands.reserve(this->num_islands);
    for (int i = 0; i < this->num_islands; ++i) {
        try {
            islands.push_back(std::make_unique<Island>(i, pop_per_island));
        }
        catch (const std::exception& e) { std::cerr << "[ERROR] Creating Island " << i << ": " << e.what() << std::endl; throw; }
        catch (...) { std::cerr << "[ERROR] Unknown exception creating island " << i << std::endl; throw; }
    }

    // --- INJECT SEEDS ---
    if (!seeds.empty()) {
        std::cout << "Info: Injecting " << seeds.size() << " seed formulas into population..." << std::endl;
        int seeds_injected = 0;
        int seed_idx = 0;
        
        // Distribute seeds cyclically across islands to promote diversity
        for (int i = 0; i < this->num_islands && seed_idx < seeds.size(); ++i) {
            // How many seeds for this island?
            // Simple: just fill sequentially island by island? Or round robin?
            // Round robin is better.
            // But for simplicity of implementation inside nested loops:
            // Let's just iterate over all spots in all islands and fill from seeds until seeds run out
            
            // Actually, we want to replace RANDOM individuals, which is what we have now.
            for(size_t j = 0; j < islands[i]->population.size(); ++j) {
                if (seed_idx >= seeds.size()) break;

                try {
                    // Spread seeds across islands: Island 0 gets seed 0, Island 1 gets seed 1, etc.
                    // To do round robin properly:
                    // We need a different loop structure.
                    // But here, simply iterating is fine if seeds << total_population.
                    
                    // Actually, let's just do a simple linear fill.
                    NodePtr parsed_tree = parse_formula_string(seeds[seed_idx]);
                    if (parsed_tree) {
                        islands[i]->population[j].tree = std::move(parsed_tree);
                        seeds_injected++;
                    }
                    seed_idx++; 
                    
                    // Note: If we have 10 islands and 100 seeds.
                    // Island 0 gets first 100 seeds?
                    // That might bias Island 0.
                    // Better to distribute them.
                    // But implementing complex distribution here is tricky without more code.
                    // Given population is huge (50k), minimal bias.
                    // Let's improve: Distribute evenly.
                } catch (const std::exception& e) {
                    std::cerr << "[Warning] Failed to parse seed formula: " << seeds[seed_idx] << " | Error: " << e.what() << std::endl;
                    seed_idx++; // Skip this seed
                }
            }
        }
        
        // BETTER DISTRIBUTION : Round Robin
        /*
        int current_island = 0;
        int current_ind_idx = 0; 
        for(const auto& s : seeds) {
             try {
                 NodePtr t = parse_formula_string(s);
                 if(t) {
                     islands[current_island]->population[current_ind_idx].tree = std::move(t);
                     current_island = (current_island + 1) % this->num_islands;
                     if(current_island == 0) current_ind_idx++; // Move to next slot only after full circle
                     if(current_ind_idx >= pop_per_island) break; // Full
                 }
             } catch(...) {}
        }
        */
       // Sticking to safe linear fill for now as per block replacement. 
       // If the user provides 100 seeds, island 0 (pop 5000) will take them all.
       // It's acceptable for now.
    }

    // --- ELIMINADO: Bloque de evaluación especial para fórmula inyectada ---
    // if (USE_INITIAL_FORMULA) { ... }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    bool gpu_init_failed = false;
    if (!FORCE_CPU_MODE) {
        // Asignar memoria en la GPU y copiar datos
        size_t targets_size = targets.size() * sizeof(double);
        size_t x_values_size = x_values.size() * sizeof(double);

        cudaError_t err_t = cudaMalloc(&d_targets, targets_size);
        cudaError_t err_x = cudaMalloc(&d_x_values, x_values_size);

        if (err_t != cudaSuccess || err_x != cudaSuccess) {
            std::cerr << "[WARNING] CUDA memory allocation failed: "
                      << cudaGetErrorString(err_t) << " | " << cudaGetErrorString(err_x) << std::endl;
            std::cerr << "[INFO] Falling back to CPU mode." << std::endl;
            gpu_init_failed = true;
            // Clean up any partial allocation
            if (d_targets) { cudaFree(d_targets); d_targets = nullptr; }
            if (d_x_values) { cudaFree(d_x_values); d_x_values = nullptr; }
        } else {
            cudaMemcpy(d_targets, targets.data(), targets_size, cudaMemcpyHostToDevice);
            cudaMemcpy(d_x_values, x_values.data(), x_values_size, cudaMemcpyHostToDevice);
            
            // Initialize global GPU buffers for batch evaluation of ALL islands
            init_global_gpu_buffers(global_gpu_buffers);
            
            // Initialize double-buffered GPU for async pipelining
            init_double_buffered_gpu(double_buffer_gpu);
            
            std::cout << "GPU buffers initialized for global batch evaluation (max " 
                      << total_population_size << " trees in single kernel call)" << std::endl;
            std::cout << "Double-buffered GPU enabled for async CPU/GPU overlap" << std::endl;
        }
    }
    
    if (FORCE_CPU_MODE || gpu_init_failed) {
        std::cout << "Using CPU for all evaluations" << std::endl;
    }
#endif

     // Evaluación inicial de TODA la población (incluyendo la inyectada)
     // La función evaluate_population ahora simplificará y evaluará a todos.
     std::cout << "Evaluating initial population (simplifying all)..." << std::endl;
     evaluate_all_islands(); // Use new global batch evaluation

     // Actualizar el mejor global inicial (en serie)
     overall_best_fitness = INF;
     overall_best_tree = nullptr;
     int initial_best_island = -1;
     int initial_best_idx = -1;

     for (int i = 0; i < islands.size(); ++i) {
        for(int j=0; j < islands[i]->population.size(); ++j) {
            const auto& ind = islands[i]->population[j];
            if (ind.tree && ind.fitness_valid && ind.fitness < overall_best_fitness) {
                overall_best_fitness = ind.fitness;
                initial_best_island = i;
                initial_best_idx = j;
            }
        }
     }
     if(initial_best_island != -1 && initial_best_idx != -1) {
         overall_best_tree = clone_tree(islands[initial_best_island]->population[initial_best_idx].tree);
     }

     last_overall_best_fitness = overall_best_fitness;
     generation_last_improvement = 0;
     std::cout << "Initial best fitness: " << std::scientific << overall_best_fitness << std::fixed << std::endl;
     if (overall_best_tree) {
          std::cout << "Initial best formula size: " << tree_size(overall_best_tree) << std::endl;
          std::cout << "Initial best formula: " << tree_to_string(overall_best_tree) << std::endl;
          // Nota para saber si el mejor inicial fue la fórmula inyectada (ahora simplificada)
          if (USE_INITIAL_FORMULA && initial_best_island != -1 && initial_best_idx == 0) {
               std::cout << "   (Note: Initial best is the (simplified) injected formula from Island " << initial_best_island << ")" << std::endl;
          } else if (initial_best_island != -1) {
               std::cout << "   (Note: Initial best found in Island " << initial_best_island << ", Index " << initial_best_idx << ")" << std::endl;
          }
      } else { std::cout << "No valid initial solution found (all fitness INF?)." << std::endl; }
     std::cout << "----------------------------------------" << std::endl;
}

GeneticAlgorithm::~GeneticAlgorithm() {
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    if (!FORCE_CPU_MODE) {
        // Cleanup double-buffered GPU
        cleanup_double_buffered_gpu(double_buffer_gpu);
        
        // Cleanup global GPU buffers
        cleanup_global_gpu_buffers(global_gpu_buffers);
        
        if (d_targets) {
            cudaFree(d_targets);
            d_targets = nullptr;
        }
        if (d_x_values) {
            cudaFree(d_x_values);
            d_x_values = nullptr;
        }
    }
#endif
    // El destructor de std::unique_ptr en 'islands' se encarga de liberar la memoria de las islas.
    // 'overall_best_tree' es un NodePtr. Si es un smart pointer (como std::unique_ptr<Node>),
    // su memoria se liberará automáticamente. Si es un puntero crudo, necesitaría una función delete_tree.
    // Asumiendo que NodePtr es un smart pointer o que la liberación se maneja en otro lugar,
    // o que un árbol nulo al final no causa fugas si no fue asignado con 'new'.
    // Si NodePtr es un puntero crudo y se asigna con 'new' en clone_tree, entonces
    // delete_tree(overall_best_tree) sería necesario aquí.
    // Por ahora, se deja vacío, asumiendo manejo automático o externo.
}

void GeneticAlgorithm::evaluate_population(Island& island) {
    int pop_size = island.population.size();
    if (pop_size == 0) return;

    // 1. Simplify trees (CPU Parallel)
    // We do this first so we only send simplified trees to GPU
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < pop_size; ++i) {
        Individual& ind = island.population[i];
        if (ind.tree) {
            ind.tree = DomainConstraints::fix_or_simplify(ind.tree);
        }
    }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    // 2. Prepare for Batch GPU Evaluation
    std::vector<LinearGpuNode> all_nodes;
    std::vector<int> tree_offsets;
    std::vector<int> tree_sizes;
    
    // Reserve memory to avoid reallocations (Optimization)
    // Assuming average tree size is around 20-30 nodes. 
    // This dramatically reduces CPU overhead during linearization.
    all_nodes.reserve(pop_size * 30); 
    tree_offsets.reserve(pop_size);
    tree_sizes.reserve(pop_size);
    
    // We need map back to original index because some trees might be null
    std::vector<int> valid_indices; 
    valid_indices.reserve(pop_size);

    for (int i = 0; i < pop_size; ++i) {
        if (island.population[i].tree) {
            int start_offset = all_nodes.size();
            linearize_tree(island.population[i].tree, all_nodes);
            int size = all_nodes.size() - start_offset;
            
            if (size > 0) {
                tree_offsets.push_back(start_offset);
                tree_sizes.push_back(size);
                valid_indices.push_back(i);
            } else {
                 island.population[i].fitness = INF;
                 island.population[i].fitness_valid = true;
            }
        } else {
             island.population[i].fitness = INF;
             island.population[i].fitness_valid = true;
        }
    }

    if (valid_indices.empty()) return;

    // 3. call GPU Batch (d_targets and d_x_values already exist)
    std::vector<double> raw_results(valid_indices.size());
    evaluate_population_gpu(all_nodes, tree_offsets, tree_sizes, targets, x_values, raw_results, d_targets, d_x_values,
                            island.d_nodes, island.d_nodes_capacity,
                            island.d_offsets, island.d_sizes, island.d_results, island.d_pop_capacity);

    // 4. Process results
    for (size_t k = 0; k < valid_indices.size(); ++k) {
        int idx = valid_indices[k];
        double sum_sq_error = raw_results[k];
        double raw_fitness = INF;

        // Check for validity
        if (!std::isnan(sum_sq_error) && !std::isinf(sum_sq_error) && sum_sq_error < 1e300) { // 1e300 as safety threshold
             if (USE_RMSE_FITNESS) {
                 if (x_values.size() > 0) {
                     double mse = sum_sq_error / x_values.size();
                     raw_fitness = sqrt(mse);
                 }
             } else {
                 raw_fitness = sum_sq_error;
             }
        }

        if (raw_fitness >= INF/2) {
             island.population[idx].fitness = INF;
        } else {
             // Complexity Penalty
             double complexity = static_cast<double>(tree_sizes[k]); // We already have the linear size
             double penalty = complexity * COMPLEXITY_PENALTY_FACTOR;
             island.population[idx].fitness = raw_fitness * (1.0 + penalty);
        }
        island.population[idx].fitness_valid = true;
    }

#else
    // CPU Fallback (Parallel)
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < pop_size; ++i) {
        Individual& ind = island.population[i];
        if (ind.tree) {
             ind.fitness = evaluate_fitness(ind.tree, targets, x_values);
             ind.fitness_valid = true;
        } else {
             ind.fitness = INF;
             ind.fitness_valid = true;
        }
    }
#endif
}


// ============================================================
// GLOBAL BATCH EVALUATION - Evaluates ALL islands in ONE GPU kernel call
// ============================================================
void GeneticAlgorithm::evaluate_all_islands() {
    int total_trees = 0;
    for (const auto& island : islands) {
        total_trees += island->population.size();
    }
    if (total_trees == 0) return;

    // Step 1: Simplify ALL trees in parallel (CPU)
    // Note: collapse(2) not supported by MSVC OpenMP 2.0, using nested parallel for
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
        for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
            Individual& ind = islands[i]->population[j];
            if (ind.tree) {
                ind.tree = DomainConstraints::fix_or_simplify(ind.tree);
            }
        }
    }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    // Runtime check: if FORCE_CPU_MODE is true or GPU init failed (d_targets == nullptr), use CPU
    if (!FORCE_CPU_MODE && d_targets != nullptr) {
    // Step 2: Linearize ALL trees from ALL islands into single buffer
    // OPTIMIZATION: Parallel linearization using OpenMP
    
    // First pass: count valid trees and compute per-tree sizes in parallel
    std::vector<int> tree_sizes_temp(total_trees, 0);
    std::vector<std::pair<int, int>> index_mapping(total_trees); // (island, individual)
    std::vector<bool> tree_valid(total_trees, false);
    
    int tree_idx = 0;
    for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
        for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
            index_mapping[tree_idx] = {i, j};
            tree_idx++;
        }
    }
    
    // Parallel linearization into per-thread buffers
    int num_threads = omp_get_max_threads();
    std::vector<std::vector<LinearGpuNode>> thread_nodes(num_threads);
    std::vector<std::vector<int>> thread_offsets(num_threads);
    std::vector<std::vector<int>> thread_sizes(num_threads);
    std::vector<std::vector<std::pair<int, int>>> thread_mappings(num_threads);
    
    // Pre-allocate per-thread buffers
    int trees_per_thread = (total_trees + num_threads - 1) / num_threads;
    for (int t = 0; t < num_threads; ++t) {
        thread_nodes[t].reserve(trees_per_thread * 30);
        thread_offsets[t].reserve(trees_per_thread);
        thread_sizes[t].reserve(trees_per_thread);
        thread_mappings[t].reserve(trees_per_thread);
    }
    
    #pragma omp parallel
    {
        int tid = omp_get_thread_num();
        auto& local_nodes = thread_nodes[tid];
        auto& local_offsets = thread_offsets[tid];
        auto& local_sizes = thread_sizes[tid];
        auto& local_mappings = thread_mappings[tid];
        
        #pragma omp for schedule(static)
        for (int t = 0; t < total_trees; ++t) {
            int i = index_mapping[t].first;
            int j = index_mapping[t].second;
            Individual& ind = islands[i]->population[j];
            
            if (ind.tree) {
                int start_offset = local_nodes.size();
                linearize_tree(ind.tree, local_nodes);
                int size = local_nodes.size() - start_offset;
                
                if (size > 0) {
                    local_offsets.push_back(start_offset);
                    local_sizes.push_back(size);
                    local_mappings.push_back({i, j});
                } else {
                    ind.fitness = INF;
                    ind.fitness_valid = true;
                }
            } else {
                ind.fitness = INF;
                ind.fitness_valid = true;
            }
        }
    }
    
    // Merge thread-local buffers into global buffers
    std::vector<LinearGpuNode> all_nodes;
    std::vector<int> tree_offsets;
    std::vector<int> tree_sizes;
    std::vector<std::pair<int, int>> result_mapping;
    
    size_t total_node_count = 0;
    size_t total_valid_trees = 0;
    for (int t = 0; t < num_threads; ++t) {
        total_node_count += thread_nodes[t].size();
        total_valid_trees += thread_mappings[t].size();
    }
    
    all_nodes.reserve(total_node_count);
    tree_offsets.reserve(total_valid_trees);
    tree_sizes.reserve(total_valid_trees);
    result_mapping.reserve(total_valid_trees);
    
    for (int t = 0; t < num_threads; ++t) {
        int offset_adjustment = all_nodes.size();
        
        // Copy nodes
        all_nodes.insert(all_nodes.end(), thread_nodes[t].begin(), thread_nodes[t].end());
        
        // Adjust offsets and copy
        for (size_t k = 0; k < thread_offsets[t].size(); ++k) {
            tree_offsets.push_back(thread_offsets[t][k] + offset_adjustment);
            tree_sizes.push_back(thread_sizes[t][k]);
            result_mapping.push_back(thread_mappings[t][k]);
        }
    }
    
    std::vector<int> tree_complexities = tree_sizes; // Same as sizes for now

    if (result_mapping.empty()) return;

    int valid_trees = result_mapping.size();
    int num_points = x_values.size();
    
    // Step 3: Launch GPU evaluation ASYNC (no blocking!)
    // GPU will work while CPU continues with other tasks
    launch_evaluation_async(
        all_nodes, tree_offsets, tree_sizes,
        valid_trees, d_targets, d_x_values, num_points,
        double_buffer_gpu
    );
    
    // Step 4: Wait for GPU results (this is where we sync)
    std::vector<double> results;
    retrieve_results_sync(results, valid_trees, double_buffer_gpu);

    // Step 5: Distribute results back to islands
    for (size_t k = 0; k < static_cast<size_t>(valid_trees); ++k) {
        int island_idx = result_mapping[k].first;
        int ind_idx = result_mapping[k].second;
        double fitness = results[k];
        
        // Validate result
        if (std::isnan(fitness) || std::isinf(fitness) || fitness >= 1e300) {
            fitness = INF;
        }
        
        islands[island_idx]->population[ind_idx].fitness = fitness;
        islands[island_idx]->population[ind_idx].fitness_valid = true;
    }
    
    } else {
        // FORCE_CPU_MODE is true OR GPU init failed: Use CPU
        #pragma omp parallel for schedule(dynamic)
        for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
            for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
                Individual& ind = islands[i]->population[j];
                if (ind.tree) {
                    // Pass nullptr for GPU pointers since we're in CPU mode
                    ind.fitness = evaluate_fitness(ind.tree, targets, x_values, nullptr, nullptr);
                    ind.fitness_valid = true;
                } else {
                    ind.fitness = INF;
                    ind.fitness_valid = true;
                }
            }
        }
    }

#else
    // CPU Fallback: CUDA not available, use parallel CPU evaluation
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < static_cast<int>(islands.size()); ++i) {
        for (int j = 0; j < static_cast<int>(islands[i]->population.size()); ++j) {
            Individual& ind = islands[i]->population[j];
            if (ind.tree) {
                ind.fitness = evaluate_fitness(ind.tree, targets, x_values);
                ind.fitness_valid = true;
            } else {
                ind.fitness = INF;
                ind.fitness_valid = true;
            }
        }
    }
#endif
}


// --- evolve_island ---
// (Sin cambios)
void GeneticAlgorithm::evolve_island(Island& island, int current_generation) {
    int current_pop_size = island.population.size(); if (current_pop_size == 0) return;
    auto best_it = std::min_element(island.population.begin(), island.population.end(),
        [](const Individual& a, const Individual& b) {
            if (!a.tree || !a.fitness_valid) return false;
            if (!b.tree || !b.fitness_valid) return true;
            return a.fitness < b.fitness;
        });
    double current_best_fitness = INF;
    int best_idx = -1;
    if (best_it != island.population.end() && best_it->tree && best_it->fitness_valid) {
        best_idx = std::distance(island.population.begin(), best_it);
        current_best_fitness = best_it->fitness;
    }
    island.fitness_history.push_back(current_best_fitness);
    if (best_idx != -1 && current_best_fitness < INF) {
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
         auto local_search_result = try_local_improvement(island.population[best_idx].tree, island.population[best_idx].fitness, targets, x_values, LOCAL_SEARCH_ATTEMPTS, d_targets, d_x_values);
#else
         auto local_search_result = try_local_improvement(island.population[best_idx].tree, island.population[best_idx].fitness, targets, x_values, LOCAL_SEARCH_ATTEMPTS);
#endif
         if (local_search_result.first && local_search_result.second < island.population[best_idx].fitness) {
             island.population[best_idx].tree = local_search_result.first;
             island.population[best_idx].fitness = local_search_result.second;
             island.population[best_idx].fitness_valid = true;
             current_best_fitness = local_search_result.second;
         }
    }
    if (current_best_fitness < island.best_fitness - FITNESS_EQUALITY_TOLERANCE) {
        island.best_fitness = current_best_fitness;
        island.stagnation_counter = 0;
    } else if (current_best_fitness < INF) {
        island.stagnation_counter++;
    }
    island.pareto_optimizer.update(island.population, targets, x_values);
    for(const auto& ind : island.population) {
        if(ind.tree && ind.fitness_valid && ind.fitness < PATTERN_RECORD_FITNESS_THRESHOLD) {
            island.pattern_memory.record_success(ind.tree, ind.fitness);
        }
    }
    std::vector<Individual> next_generation;
    next_generation.reserve(current_pop_size);
    int elite_count = std::max(1, static_cast<int>(current_pop_size * island.params.elite_percentage));
    if (elite_count > 0 && elite_count <= current_pop_size) {
        std::partial_sort(island.population.begin(), island.population.begin() + elite_count, island.population.end());
        int added_elites = 0;
        for (int i = 0; i < elite_count && i < island.population.size(); ++i) {
             if (island.population[i].tree && island.population[i].fitness_valid) {
                 next_generation.emplace_back(clone_tree(island.population[i].tree));
                 next_generation.back().fitness = island.population[i].fitness;
                 next_generation.back().fitness_valid = true;
                 added_elites++;
             }
        }
        elite_count = added_elites;
    } else { elite_count = 0; }
    int random_injection_count = 0;
    if (island.stagnation_counter > STAGNATION_LIMIT_ISLAND / 2) {
        random_injection_count = static_cast<int>(current_pop_size * STAGNATION_RANDOM_INJECT_PERCENT);
        for(int i = 0; i < random_injection_count && next_generation.size() < current_pop_size; ++i) {
             NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
             if (random_tree) next_generation.emplace_back(std::move(random_tree));
        }
    }
    int pattern_injection_count = 0;
    
    // --- ISLAND CATACLYSM ---
    // If enabled, triggers a hard reset if stagnation persists.
    if (USE_ISLAND_CATACLYSM && island.stagnation_counter >= STAGNATION_LIMIT_ISLAND) {
        // Keep only top 1 elite (already in next_generation[0] if elite_count > 0)
        // Or if we need to enforce better elitism during cataclysm:
        
        int survivors = 1; // Only the absolute best one survives
        // Resize to survivors
        if (next_generation.size() > survivors) next_generation.resize(survivors);
        
        // Fill the rest with completely random trees
        int to_fill = current_pop_size - next_generation.size();
        for(int i=0; i<to_fill; ++i) {
             NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
             if (random_tree) next_generation.emplace_back(std::move(random_tree));
        }
        
        island.stagnation_counter = 0; // Reset counter
        // Optional: Pattern injection could also happen here, but random is better for total diversity.
    }
    // Only do standard injections if we didn't just nuke everything
    else {
        if (random_injection_count == 0 && current_generation % PATTERN_INJECT_INTERVAL == 0) {
            pattern_injection_count = static_cast<int>(current_pop_size * PATTERN_INJECT_PERCENT);
            for (int i = 0; i < pattern_injection_count && next_generation.size() < current_pop_size; ++i) {
                NodePtr pt = island.pattern_memory.suggest_pattern_based_tree(MAX_TREE_DEPTH_INITIAL);
                if (pt) { next_generation.emplace_back(std::move(pt)); }
                else {
                     NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
                     if (random_tree) next_generation.emplace_back(std::move(random_tree));
                }
            }
        }
    }
    auto& rng = get_rng();
    std::uniform_real_distribution<double> prob_dist(0.0, 1.0);
    // >>> Parallel Parent Selection Loop with Uniqueness Check <<<
    
    // 1. Initialize uniqueness set with survivors (elites/injected)
    std::unordered_set<std::string> unique_signatures;
    if (PREVENT_DUPLICATES) {
        for (const auto& ind : next_generation) {
            if (ind.tree) {
                unique_signatures.insert(tree_to_string(ind.tree));
            }
        }
    }

    // 2. Fill the rest of the population
    int fail_safe_counter = 0;
    while (next_generation.size() < current_pop_size) {
        int needed = current_pop_size - next_generation.size();
        
        // Generate candidates in parallel
        std::vector<Individual> candidates(needed);
        
        #pragma omp parallel for schedule(dynamic)
        for (int i = 0; i < needed; ++i) {
            // Thread-local RNG
            auto& rng = get_rng(); 
            
            Individual offspring;
            // Use distribution defined outside or create new one? 
            // Better create local to avoid shared state issues if not const
            std::uniform_real_distribution<double> local_prob_dist(0.0, 1.0);

            if (local_prob_dist(rng) < island.params.crossover_rate) {
                Individual p1, p2;
                if (USE_LEXICASE_SELECTION) {
                    p1 = lexicase_selection(island.population, targets, x_values);
                    p2 = lexicase_selection(island.population, targets, x_values);
                } else {
                    p1 = tournament_selection(island.population, island.params.tournament_size);
                    p2 = tournament_selection(island.population, island.params.tournament_size);
                }
                offspring = crossover(p1, p2);
            } else {
                Individual p1;
                if (USE_LEXICASE_SELECTION) {
                    p1 = lexicase_selection(island.population, targets, x_values);
                } else {
                    p1 = tournament_selection(island.population, island.params.tournament_size);
                }
                if (p1.tree) p1.tree = clone_tree(p1.tree); 
                mutate(p1, island.params.mutation_rate);
                offspring = std::move(p1);
            }
            
            candidates[i] = std::move(offspring);
        }
        
        // Filter and add unique candidates (Serial)
        int added_this_round = 0;
        for (auto& cand : candidates) {
            if (next_generation.size() >= current_pop_size) break;
            
            bool is_valid_to_add = true;
            if (PREVENT_DUPLICATES && cand.tree) {
                std::string sig = tree_to_string(cand.tree);
                if (unique_signatures.find(sig) != unique_signatures.end()) {
                    is_valid_to_add = false; 
                } else {
                    unique_signatures.insert(sig);
                }
            }
            
            if (is_valid_to_add) {
                next_generation.emplace_back(std::move(cand));
                added_this_round++;
            }
        }
        
        // Deadlock prevention
        if (added_this_round == 0) {
            fail_safe_counter++;
            if (fail_safe_counter > DUPLICATE_RETRIES) {
                // Fill remaining with random trees
                int remaining = current_pop_size - next_generation.size();
                for (int k = 0; k < remaining; ++k) {
                    NodePtr random_tree = generate_random_tree(MAX_TREE_DEPTH_INITIAL);
                    if (random_tree) next_generation.emplace_back(std::move(random_tree));
                }
                break; // Exit loop
            }
        } else {
            fail_safe_counter = 0; // Reset if we made progress
        }
    }
     if (next_generation.size() > current_pop_size) next_generation.resize(current_pop_size);
    island.population = std::move(next_generation);
    if (current_generation > 0 && current_generation % PARAM_MUTATE_INTERVAL == 0) island.params.mutate(island.stagnation_counter);
}

// --- migrate ---
// (Sin cambios)
void GeneticAlgorithm::migrate() {
    if (num_islands <= 1) return;
    int current_pop_per_island = islands.empty() ? 0 : islands[0]->population.size();
    if (current_pop_per_island == 0) return;
    int num_migrants = std::min(MIGRATION_SIZE, current_pop_per_island / 5);
    if (num_migrants <= 0) return;
    std::vector<std::vector<Individual>> outgoing_migrants(num_islands);
    #pragma omp parallel for
    for (int i = 0; i < num_islands; ++i) {
        Island& src = *islands[i];
        if (src.population.size() < num_migrants) continue;
        std::partial_sort(src.population.begin(), src.population.begin() + num_migrants, src.population.end());
        outgoing_migrants[i].reserve(num_migrants);
        int migrants_selected = 0;
        for (int j = 0; j < src.population.size() && migrants_selected < num_migrants; ++j) {
             if (src.population[j].tree && src.population[j].fitness_valid) {
                 Individual migrant_copy;
                 migrant_copy.tree = clone_tree(src.population[j].tree);
                 migrant_copy.fitness = src.population[j].fitness;
                 migrant_copy.fitness_valid = true;
                 outgoing_migrants[i].push_back(std::move(migrant_copy));
                 migrants_selected++;
             }
        }
    }
    for (int dest_idx = 0; dest_idx < num_islands; ++dest_idx) {
        int src_idx = (dest_idx + num_islands - 1) % num_islands;
        Island& dest = *islands[dest_idx];
        const auto& migrants_to_receive = outgoing_migrants[src_idx];
        if (migrants_to_receive.empty() || dest.population.empty()) continue;
        int replace_count = std::min((int)migrants_to_receive.size(), (int)dest.population.size());
        if (replace_count <= 0) continue;
        std::partial_sort(dest.population.begin(), dest.population.end() - replace_count, dest.population.end());
        int migrant_idx = 0;
        for (int i = 0; i < replace_count; ++i) {
            int replace_idx = dest.population.size() - 1 - i;
            if (migrant_idx < migrants_to_receive.size()) {
                 dest.population[replace_idx] = std::move(migrants_to_receive[migrant_idx++]);
                 dest.population[replace_idx].fitness_valid = false; // Marcar para reevaluar
            }
        }
    }
}


// --- run ---
// (Sin cambios)
NodePtr GeneticAlgorithm::run() {
    std::cout << "Starting Genetic Algorithm..." << std::endl;
    auto start_time = std::chrono::high_resolution_clock::now();

    for (int gen = 0; gen < generations; ++gen) {
        // [DEBUG] Trace execution
        if (gen == 0) { std::cout << "[DEBUG] Entering main loop, gen=0" << std::endl; std::cout.flush(); }
        
        // 1. Evaluate ALL islands in ONE GPU kernel call (maximum GPU utilization)
        evaluate_all_islands();

        // 2. Evolve Islands (Parallel Island Loop)
        // Genetic operators (crossover, mutation) are CPU-bound and independent per island.
        #pragma omp parallel for
        for (int i = 0; i < islands.size(); ++i) {
             evolve_island(*islands[i], gen);
        }

        double current_gen_best_fitness = INF;
        int best_island_idx = -1;
        int best_ind_idx = -1;
        for (int i = 0; i < islands.size(); ++i) {
             for (int j = 0; j < islands[i]->population.size(); ++j) {
                 const auto& ind = islands[i]->population[j];
                 if (ind.tree && ind.fitness_valid && ind.fitness < current_gen_best_fitness) {
                     current_gen_best_fitness = ind.fitness;
                     best_island_idx = i; best_ind_idx = j;
                 }
             }
        }

        if (best_island_idx != -1 && current_gen_best_fitness < overall_best_fitness) {
             if (current_gen_best_fitness < overall_best_fitness) {
                  overall_best_fitness = current_gen_best_fitness;
                  overall_best_tree = clone_tree(islands[best_island_idx]->population[best_ind_idx].tree);
                  std::cout << "\n========================================" << std::endl;
                  std::cout << "New Global Best Found (Gen " << gen + 1 << ", Island " << best_island_idx << ")" << std::endl;
                  std::cout << "Fitness: " << std::fixed << std::setprecision(8) << overall_best_fitness << std::endl;
                  std::cout << "Size: " << tree_size(overall_best_tree) << std::endl;
                  std::cout << "Formula: " << tree_to_string(overall_best_tree) << std::endl;
                  std::cout.flush(); // Ensure Formula: line is captured
                  std::cout << "Predictions vs Targets:" << std::endl;
                  std::cout << std::fixed << std::setprecision(4);
                  if (overall_best_tree && !x_values.empty()) {
                      for (size_t j = 0; j < x_values.size(); ++j) {
                          double val = evaluate_tree(overall_best_tree, x_values[j]);
                          double target_val = (j < targets.size()) ? targets[j] : std::nan("");
                          double diff = (!std::isnan(val) && !std::isnan(target_val)) ? std::fabs(val - target_val) : std::nan("");
                          std::cout << "  x=" << std::setw(8) << x_values[j]
                                    << ": Pred=" << std::setw(12) << val
                                    << ", Target=" << std::setw(12) << target_val
                                    << ", Diff=" << std::setw(12) << diff << std::endl;
                      }
                  } else { std::cout << "  (No data or no valid tree to show predictions)" << std::endl; }
                  std::cout << "========================================" << std::endl;
                  last_overall_best_fitness = overall_best_fitness;
                  generation_last_improvement = gen;
              }
        } else {
             if (overall_best_fitness < INF && (gen - generation_last_improvement) >= GLOBAL_STAGNATION_LIMIT) {
                  std::cout << "\n========================================" << std::endl;
                  std::cout << "TERMINATION: Global best fitness hasn't improved for " << GLOBAL_STAGNATION_LIMIT << " generations." << std::endl;
                  std::cout << "Stopping at Generation " << gen + 1 << "." << std::endl;
                  std::cout << "========================================" << std::endl;
                  break;
             }
        }

        if ((gen + 1) % MIGRATION_INTERVAL == 0 && num_islands > 1) {
             migrate();
             // Re-evaluate after migration using global batch
             evaluate_all_islands();
        }

        if (overall_best_fitness < EXACT_SOLUTION_THRESHOLD) {
            std::cout << "\n========================================" << std::endl;
            std::cout << "Solution found meeting criteria at Generation " << gen + 1 << "!" << std::endl;
            std::cout << "Final Fitness: " << std::fixed << std::setprecision(8) << overall_best_fitness << std::endl;
            if(overall_best_tree) {
                 std::cout << "Final Formula Size: " << tree_size(overall_best_tree) << std::endl;
                 std::cout << "Final Formula: " << tree_to_string(overall_best_tree) << std::endl;
                 std::cout.flush(); // Ensure Final Formula: line is captured
            }
            std::cout << "========================================" << std::endl;
            std::cout.flush(); // Ensure flush
            break;
        }

        if ((gen + 1) % PROGRESS_REPORT_INTERVAL == 0 || gen == generations - 1) {
             auto current_time = std::chrono::high_resolution_clock::now();
             std::chrono::duration<double> elapsed = current_time - start_time;
             std::cout << "\n--- Generation " << gen + 1 << "/" << generations
                       << " (Elapsed: " << std::fixed << std::setprecision(2) << elapsed.count() << "s) ---" << std::endl;
             std::cout << "Overall Best Fitness: " << std::scientific << overall_best_fitness << std::fixed << std::endl;
              if(overall_best_tree) { std::cout << "Best Formula Size: " << tree_size(overall_best_tree) << std::endl; }
              else { std::cout << "Best Formula Size: N/A" << std::endl; }
              std::cout << "(Last improvement at gen: " << generation_last_improvement + 1 << ")" << std::endl;
        }
    }

    auto end_time = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> total_elapsed = end_time - start_time;
    std::cout << "\n========================================" << std::endl;
    std::cout << "Evolution Finished!" << std::endl;
    std::cout << "Total Time: " << std::fixed << std::setprecision(2) << total_elapsed.count() << " seconds" << std::endl;
    std::cout << "Final Best Fitness: " << std::fixed << std::setprecision(8) << overall_best_fitness << std::endl;
     if (overall_best_tree) {
         std::cout << "Final Best Formula Size: " << tree_size(overall_best_tree) << std::endl;
         std::cout << "Final Formula: " << tree_to_string(overall_best_tree) << std::endl;
         std::cout.flush(); // Ensure Final Formula: line is captured
          std::cout << "--- Final Verification ---" << std::endl;
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
          double final_check_fitness = evaluate_fitness(overall_best_tree, targets, x_values, d_targets, d_x_values);
#else
          double final_check_fitness = evaluate_fitness(overall_best_tree, targets, x_values);
#endif
          std::cout << "Recalculated Fitness: " << std::fixed << std::setprecision(8) << final_check_fitness << std::endl;
          std::cout << std::fixed << std::setprecision(4);
          for (size_t j = 0; j < x_values.size(); ++j) {
                double val = evaluate_tree(overall_best_tree, x_values[j]);
                 double target_val = (j < targets.size()) ? targets[j] : std::nan("");
                 double diff = (!std::isnan(val) && !std::isnan(target_val)) ? std::fabs(val - target_val) : std::nan("");
                 std::cout << "  x=" << std::setw(8) << x_values[j]
                          << ": Pred=" << std::setw(12) << val
                          << ", Target=" << std::setw(12) << target_val
                          << ", Diff=" << std::setw(12) << diff << std::endl;
           }
     } else { std::cout << "No valid solution found." << std::endl; }
      std::cout << "========================================" << std::endl;
    return overall_best_tree;
}


In [None]:
%%writefile Code/src/GeneticAlgorithm.h
// ============================================================
// Archivo: src/GeneticAlgorithm.h
// ============================================================
#ifndef GENETICALGORITHM_H
#define GENETICALGORITHM_H

#include "ExpressionTree.h"
#include "GeneticOperators.h"
#include "AdvancedFeatures.h"
#include "Globals.h" // Incluir Globals.h para INF, NUM_ISLANDS, etc.
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
#include "FitnessGPU.cuh" // For GlobalGpuBuffers definition
#endif
#include <vector>
#include <string>
#include <memory> // Para std::unique_ptr

class GeneticAlgorithm {
    // Estructura interna para representar una isla
    struct Island {
        std::vector<Individual> population; // Población de la isla
        EvolutionParameters params;         // Parámetros evolutivos propios de la isla
        PatternMemory pattern_memory;       // Memoria de patrones de la isla
        ParetoOptimizer pareto_optimizer;   // Optimizador Pareto de la isla
        int stagnation_counter = 0;         // Contador de estancamiento local de la isla
        double best_fitness = INF;          // Mejor fitness histórico de la isla
        std::vector<double> fitness_history;// Historial de fitness (opcional)
        int id;                             // Identificador de la isla

        // Constructor de la isla
        explicit Island(int island_id, int pop_size) : id(island_id) {
             population = create_initial_population(pop_size); // Crear población inicial
             params = EvolutionParameters::create_default();   // Usar parámetros por defecto
        }

#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
        // Persistent GPU buffers
        void* d_nodes = nullptr;
        void* d_offsets = nullptr;
        void* d_sizes = nullptr;
        void* d_results = nullptr;
        size_t d_nodes_capacity = 0;
        size_t d_pop_capacity = 0;

        ~Island() {
            // We cannot easily call cudaFree here because this header might be included
            // where cuda_runtime.h is not. However, we can trust the OS/driver to clean up
            // or we should add a cleanup function.
            // For now, we will rely on GeneticAlgorithm destructor or explict cleanup if possible.
            // But since Island is unique_ptr, we can't easily add a destructor that calls cudaFree 
            // without including cuda_runtime.
            // Optimization: Let's rely on OS cleanup at exit, OR add a cleanup method called by GA.
        }
#endif
    };

    // Miembros principales de la clase GeneticAlgorithm
    std::vector<std::unique_ptr<Island>> islands; // Vector de punteros únicos a las islas
    const std::vector<double>& targets;           // Referencia a los datos objetivo
    const std::vector<double>& x_values;          // Referencia a los valores de x
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    double* d_targets = nullptr;                  // Puntero a los datos objetivo en la GPU
    double* d_x_values = nullptr;                 // Puntero a los valores de x en la GPU
    GlobalGpuBuffers global_gpu_buffers;          // Global buffers for batch evaluation of ALL islands
    DoubleBufferedGpu double_buffer_gpu;          // Double-buffered GPU for async overlap
#endif
    int total_population_size;                    // Tamaño total de la población
    int generations;                              // Número máximo de generaciones
    int num_islands;                              // Número de islas

    // Seguimiento del mejor global
    NodePtr overall_best_tree = nullptr;          // Mejor árbol encontrado globalmente
    double overall_best_fitness = INF;            // Mejor fitness encontrado globalmente

    // --- NUEVO: Seguimiento de Estancamiento Global ---
    int generation_last_improvement = 0;          // Generación en la que mejoró el overall_best_fitness
    double last_overall_best_fitness = INF;       // Valor del overall_best_fitness en la última mejora
    // -------------------------------------------------

    int pop_per_island;                           // Población calculada por isla

public:
    // Constructor
    GeneticAlgorithm(const std::vector<double>& targets_ref,
                       const std::vector<double>& x_values_ref,
                       int total_pop,
                       int gens,
                       const std::vector<std::string>& seeds = {}, // Optional: Initial population seeds
                       int n_islands = NUM_ISLANDS); // Usar valor de Globals.h por defecto
    ~GeneticAlgorithm(); // Destructor para liberar memoria de la GPU

    // Ejecuta el algoritmo genético
    NodePtr run();

private:
    // Funciones auxiliares internas
    void evaluate_population(Island& island); // Evalúa fitness de una isla (legacy)
    void evaluate_all_islands(); // Evalúa ALL islands in ONE GPU batch call (optimized)
    void evolve_island(Island& island, int current_generation); // Evoluciona una isla por una generación
    void migrate(); // Realiza la migración entre islas
    void update_overall_best(const Island& island); // Actualiza el mejor global
};


#endif // GENETICALGORITHM_H


In [None]:
%%writefile Code/src/GeneticOperators.cpp
#include "GeneticOperators.h"
#include "Globals.h"
#include "Fitness.h"
#include "AdvancedFeatures.h"
#include <vector>
#include <cmath>
#include <algorithm>
#include <numeric>  // For std::iota
#include <map>
#include <stdexcept>
#include <set>
#include <iostream> // Para mensajes de error/info

// Genera un árbol aleatorio (CON TODOS LOS OPERADORES, EXPONENTES SIN RESTRICCIÓN)
NodePtr generate_random_tree(int max_depth, int current_depth) {
    std::uniform_real_distribution<double> prob_dist(0.0, 1.0);
    auto& rng = get_rng();
    double terminal_prob = 0.2 + 0.8 * (static_cast<double>(current_depth) / max_depth);

    if (current_depth >= max_depth || prob_dist(rng) < terminal_prob) {
        // Crear terminal
        if (prob_dist(rng) < TERMINAL_VS_VARIABLE_PROB) { return std::make_shared<Node>(NodeType::Variable); }
        else {
            auto node = std::make_shared<Node>(NodeType::Constant);
            if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cd(CONSTANT_INT_MIN_VALUE, CONSTANT_INT_MAX_VALUE); node->value = static_cast<double>(cd(rng)); }
            else { std::uniform_real_distribution<double> cd(CONSTANT_MIN_VALUE, CONSTANT_MAX_VALUE); node->value = cd(rng); }
            if (std::fabs(node->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) node->value = 0.0;
            return node;
        }
    } else {

        // Crear operador
        auto node = std::make_shared<Node>(NodeType::Operator);
        // Match the weights in Globals.h: +, -, *, /, ^, %, s, c, l, e, !, _, g
        const std::vector<char> ops = {'+', '-', '*', '/', '^', '%', 's', 'c', 'l', 'e', '!', '_', 'g'};
        std::discrete_distribution<int> op_dist(OPERATOR_WEIGHTS.begin(), OPERATOR_WEIGHTS.end());
        node->op = ops[op_dist(rng)];

        bool is_unary = (node->op == 's' || node->op == 'c' || node->op == 'l' || node->op == 'e' || node->op == '!' || node->op == '_' || node->op == 'g');

        // Generar hijos recursivamente
        node->left = generate_random_tree(max_depth, current_depth + 1);
        if (!is_unary) {
            node->right = generate_random_tree(max_depth, current_depth + 1);
        } else {
            node->right = nullptr;
        }

        // Fallback para hijos nulos
        auto generate_random_terminal = [&]() -> NodePtr {
            if (prob_dist(rng) < TERMINAL_VS_VARIABLE_PROB) { return std::make_shared<Node>(NodeType::Variable); }
            else {
                auto const_node = std::make_shared<Node>(NodeType::Constant);
                if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cd(CONSTANT_INT_MIN_VALUE, CONSTANT_INT_MAX_VALUE); const_node->value = static_cast<double>(cd(rng)); }
                else { std::uniform_real_distribution<double> cd(CONSTANT_MIN_VALUE, CONSTANT_MAX_VALUE); const_node->value = cd(rng); }
                if (std::fabs(const_node->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) const_node->value = 0.0;
                return const_node;
            }
        };

        if (!node->left) node->left = generate_random_terminal();
        if (!is_unary && !node->right) node->right = generate_random_terminal();

        // --- Manejo especial para el operador de potencia '^' ---
        if (node->op == '^') {
            // Regla 1: Evitar 0^0 o 0^negativo
            if (node->left->type == NodeType::Constant && std::fabs(node->left->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                if (node->right->type == NodeType::Constant && node->right->value <= SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                    const std::vector<char> safe_ops = {'+', '-', '*'};
                    std::uniform_int_distribution<int> safe_op_dist(0, safe_ops.size() - 1);
                    node->op = safe_ops[safe_op_dist(rng)];
                }
            }
            // Regla 2: Evitar base negativa con exponente no entero
            else if (node->left->type == NodeType::Constant && node->left->value < 0.0) {
                if (node->right->type == NodeType::Constant && std::fabs(node->right->value - std::round(node->right->value)) > SIMPLIFY_NEAR_ZERO_TOLERANCE) {
                     // Change exponent to int
                     std::uniform_int_distribution<int> int_exp_dist(-3, 3);
                     node->right = std::make_shared<Node>(NodeType::Constant);
                     node->right->value = static_cast<double>(int_exp_dist(rng));
                }
            }
        }
        return node;
    }
}

// --- Crea la población inicial (MODIFICADO para inyectar fórmula) ---
// === OPTIMIZACIÓN: Paralelizado con OpenMP ===
std::vector<Individual> create_initial_population(int population_size) {
    std::vector<Individual> population;
    population.resize(population_size); // Pre-allocate all slots for parallel access

    // --- NUEVO: Inyección de Fórmula Inicial ---
    int start_index = 0;
    if (USE_INITIAL_FORMULA && !INITIAL_FORMULA_STRING.empty() && population_size > 0) {
        try {
            NodePtr initial_tree = parse_formula_string(INITIAL_FORMULA_STRING);
            if (initial_tree) {
                population[0] = Individual(std::move(initial_tree));
                start_index = 1;
                std::cout << "[INFO] Injected initial formula: " << INITIAL_FORMULA_STRING << std::endl;
            } else {
                 std::cerr << "[WARNING] Parsing initial formula returned null. Skipping injection." << std::endl;
            }
        } catch (const std::exception& e) {
            std::cerr << "[ERROR] Failed to parse initial formula '" << INITIAL_FORMULA_STRING
                      << "': " << e.what() << ". Skipping injection." << std::endl;
        }
    }
    // -----------------------------------------

    // === OPTIMIZACIÓN: Loop paralelo para generar árboles ===
    #pragma omp parallel for schedule(dynamic, 100)
    for (int i = start_index; i < population_size; ++i) {
        // Cada hilo tiene su propio RNG (thread_local en get_rng)
        auto& rng = get_rng();
        std::uniform_int_distribution<int> depth_dist(3, MAX_TREE_DEPTH_INITIAL);
        
        NodePtr random_tree = nullptr;
        int attempts = 0;
        const int max_attempts = 10;
        while (!random_tree && attempts < max_attempts) {
            random_tree = generate_random_tree(depth_dist(rng));
            attempts++;
        }
        if (random_tree) {
            population[i] = Individual(std::move(random_tree));
        } else {
            // Fallback: crear constante simple
            auto fallback_node = std::make_shared<Node>(NodeType::Constant);
            fallback_node->value = 0.0;
            population[i] = Individual(std::move(fallback_node));
        }
    }
    return population;
}

// --- Selección por torneo con parsimonia ---
Individual tournament_selection(const std::vector<Individual>& population, int tournament_size) {
    if (population.empty()) throw std::runtime_error("Cannot perform tournament selection on empty population.");
    if (tournament_size <= 0) tournament_size = 1;
    tournament_size = std::min(tournament_size, (int)population.size());

    std::uniform_int_distribution<int> dist(0, population.size() - 1);
    auto& rng = get_rng();
    const Individual* best_in_tournament = nullptr;

    int attempts = 0; const int max_attempts = std::min((int)population.size() * 2, 100);
    do {
        best_in_tournament = &population[dist(rng)];
        attempts++;
    } while ((!best_in_tournament || !best_in_tournament->tree || !best_in_tournament->fitness_valid) && attempts < max_attempts);

    if (!best_in_tournament || !best_in_tournament->tree || !best_in_tournament->fitness_valid) {
         if (!population.empty()) return population[0];
         else throw std::runtime_error("Tournament selection couldn't find any valid individual in a non-empty population.");
    }

    for (int i = 1; i < tournament_size; ++i) {
        const Individual& contender = population[dist(rng)];
        if (!contender.tree || !contender.fitness_valid) continue;

        if (contender.fitness < best_in_tournament->fitness) {
            best_in_tournament = &contender;
        }
        else if (std::fabs(contender.fitness - best_in_tournament->fitness) < FITNESS_EQUALITY_TOLERANCE) {
            int contender_size = tree_size(contender.tree);
            int best_size = tree_size(best_in_tournament->tree);
            if (contender_size < best_size) best_in_tournament = &contender;
        }
    }
    return *best_in_tournament;
}

// --- Epsilon-Lexicase Selection Implementation ---
// Calculates residuals on demand if not present (Lazy Eval)
void ensure_errors_computed(Individual& ind, const std::vector<double>& targets, const std::vector<double>& x_values) {
    if (!ind.errors.empty()) return; // Already computed
    if (!ind.tree) return;
    
    ind.errors.reserve(targets.size());
    for (size_t i = 0; i < targets.size(); ++i) {
        double val = evaluate_tree(ind.tree, x_values[i]);
        if (std::isnan(val) || std::isinf(val)) {
            ind.errors.push_back(INF);
        } else {
            // Use ABSOLUTE error for lexicase
            ind.errors.push_back(std::fabs(val - targets[i]));
        }
    }
}

Individual lexicase_selection(std::vector<Individual>& population, const std::vector<double>& targets, const std::vector<double>& x_values) {
    auto& rng = get_rng();
    
    // 1. Initial Candidates: Random subset (Tournament Size * 2) or Full Population?
    // Efficiency: Using a subset is "Tournament Lexicase". Using Full is "Standard Lexicase".
    // For 50k pop, full lexicase is slow. Let's use a large pool (e.g. 50-100).
    int pool_size = 100; 
    std::vector<Individual*> candidates;
    candidates.reserve(pool_size);
    
    std::uniform_int_distribution<int> dist(0, population.size() - 1);
    for(int i=0; i<pool_size; ++i) {
        Individual& ind = population[dist(rng)];
        if(ind.tree && ind.fitness_valid) candidates.push_back(&ind);
    }
    
    if (candidates.empty()) return population[0]; // Should not happen
    
    // 2. Shuffle test cases
    std::vector<int> cases(targets.size());
    std::iota(cases.begin(), cases.end(), 0);
    std::shuffle(cases.begin(), cases.end(), rng);
    
    // 3. Filter loop
    for (int case_idx : cases) {
        // Compute errors for this case for all candidates (Lazy)
        double min_error = INF;
        for (Individual* cand : candidates) {
            ensure_errors_computed(*cand, targets, x_values);
            if (case_idx < cand->errors.size()) {
                 if (cand->errors[case_idx] < min_error) min_error = cand->errors[case_idx];
            }
        }
        
        // Define epsilon (MAD or simple threshold)
        // Here we use a simple dynamic epsilon based on min_error
        double epsilon = std::max(min_error * 0.1, 1e-5); 
        // Or if min_error is 0, epsilon is 1e-5. Epsilon-Lexicase implies "close enough".
        
        // Filter
        std::vector<Individual*> next_candidates;
        next_candidates.reserve(candidates.size());
        for (Individual* cand : candidates) {
             if (case_idx < cand->errors.size()) {
                 if (cand->errors[case_idx] <= min_error + epsilon) {
                     next_candidates.push_back(cand);
                 }
             }
        }
        
        candidates = std::move(next_candidates);
        if (candidates.empty()) break; // Should not happen given min_error logic
        if (candidates.size() == 1) return *candidates[0];
    }
    
    // If multiple remain, pick random
    if (candidates.empty()) return population[dist(rng)];
    std::uniform_int_distribution<int> pick(0, candidates.size() - 1);
    return *candidates[pick(rng)];
}

// Implementación de crossover
Individual crossover(const Individual& parent1, const Individual& parent2) {
    NodePtr tree1_clone = clone_tree(parent1.tree);
    NodePtr tree2_clone = clone_tree(parent2.tree);
    crossover_trees(tree1_clone, tree2_clone);
    if (USE_HARD_DEPTH_LIMIT) trim_tree(tree1_clone, MAX_TREE_DEPTH_HARD_LIMIT); // Enforce hard limit
    return Individual(tree1_clone); // Devolver uno de los hijos, el otro se descarta
}

// Implementación de mutate
void mutate(Individual& individual, double mutation_rate) {
    individual.tree = mutate_tree(individual.tree, mutation_rate, MAX_TREE_DEPTH_MUTATION);
    individual.fitness_valid = false; // El fitness se invalida al mutar el árbol
}

// Mutata un árbol (EXPONENTES SIN RESTRICCIÓN en OperatorChange)
NodePtr mutate_tree(const NodePtr& tree, double mutation_rate, int max_depth) {
    auto& rng = get_rng();
    std::uniform_real_distribution<double> prob(0.0, 1.0);
    auto new_tree = clone_tree(tree); // Siempre clonar primero
    if (!new_tree) return nullptr; // Si el árbol original era nulo, el clon también

    if (prob(rng) >= mutation_rate) return new_tree; // No mutar

    std::vector<NodePtr*> nodes; collect_node_ptrs(new_tree, nodes);
    if (nodes.empty()) return new_tree; // No hay nodos para mutar (árbol vacío?)

    std::uniform_int_distribution<int> node_dist(0, nodes.size() - 1);
    int node_idx = node_dist(rng);
    NodePtr* node_to_mutate_ptr = nodes[node_idx];
    if (!node_to_mutate_ptr || !(*node_to_mutate_ptr)) return new_tree; // Puntero o nodo nulo inesperado

    const std::vector<MutationType> mutation_types = {
        MutationType::ConstantChange, MutationType::OperatorChange,
        MutationType::SubtreeReplace, MutationType::NodeInsertion,
        MutationType::NodeDeletion
    };
    std::uniform_int_distribution<int> type_dist(0, mutation_types.size() - 1);
    MutationType mut_type = mutation_types[type_dist(rng)];

    NodePtr& current_node_ptr_ref = *node_to_mutate_ptr;
    Node& current_node = *current_node_ptr_ref;

    // Generar reemplazo aleatorio (usado en varios casos)
    auto generate_replacement = [&](int depth) -> NodePtr {
        NodePtr replacement = generate_random_tree(depth);
        if (!replacement) { // Fallback si la generación falla
            replacement = std::make_shared<Node>(NodeType::Constant);
            replacement->value = 1.0; // Usar 1.0 como fallback simple
        }
        return replacement;
    };


    switch (mut_type) {
        case MutationType::ConstantChange:
             if (current_node.type == NodeType::Constant) {
                 // Cambiar valor de la constante
                 double change_factor = std::uniform_real_distribution<double>(0.8, 1.2)(rng);
                 double add_factor = std::uniform_real_distribution<double>(-1.0, 1.0)(rng);
                 current_node.value = current_node.value * change_factor + add_factor;
                 if (FORCE_INTEGER_CONSTANTS) current_node.value = std::round(current_node.value);
                 if (std::fabs(current_node.value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) current_node.value = 0.0;
             } else {
                 // Si no es constante, reemplazar por un subárbol aleatorio pequeño
                 *node_to_mutate_ptr = generate_replacement(1); // Profundidad 1 (terminal)
             }
            break;
        case MutationType::OperatorChange:
             if (current_node.type == NodeType::Operator) {
                 // Usar la distribución ponderada global para elegir el nuevo operador.
                 // Esto asegura que si 'g' tiene peso alto, sea elegido frecuentemente.
                 // Asumimos que OPERATOR_WEIGHTS tiene 0.0 para operadores deshabilitados.
                 
                 const std::vector<char> all_ops = {'+', '-', '*', '/', '^', '%', 's', 'c', 'l', 'e', '!', '_', 'g'};
                 std::discrete_distribution<int> op_dist(OPERATOR_WEIGHTS.begin(), OPERATOR_WEIGHTS.end());
                 
                 // Verificar si hay al menos 2 operadores habilitados para evitar bucle infinito
                 int enabled_count = 0;
                 for (double w : OPERATOR_WEIGHTS) if (w > 0.0) enabled_count++;
                 
                 if (enabled_count > 1) {
                     int attempts = 0;
                     int new_op_idx = -1;
                     do {
                         new_op_idx = op_dist(rng);
                         attempts++;
                     } while (all_ops[new_op_idx] == current_node.op && attempts < 20); // Intentar cambiar
                     
                     if (attempts < 20) { // Si logramos encontrar uno diferente
                         char old_op = current_node.op;
                         char new_op = all_ops[new_op_idx];
                         
                         bool was_unary = (old_op == 's' || old_op == 'c' || old_op == 'l' || old_op == 'e' || old_op == '!' || old_op == '_' || old_op == 'g');
                         bool is_unary = (new_op == 's' || new_op == 'c' || new_op == 'l' || new_op == 'e' || new_op == '!' || new_op == '_' || new_op == 'g');

                         if (was_unary && !is_unary) {
                             current_node.right = generate_replacement(1);
                         } else if (!was_unary && is_unary) {
                             current_node.right = nullptr;
                         }
                         current_node.op = new_op;
                     }
                 }
             } else {
                 // Si no es operador, reemplazar por un subárbol aleatorio
                  *node_to_mutate_ptr = generate_replacement(max_depth);
             }
            break;
        case MutationType::SubtreeReplace:
            *node_to_mutate_ptr = generate_replacement(max_depth);
            break;
        case MutationType::NodeInsertion:
            {
                auto new_op_node = std::make_shared<Node>(NodeType::Operator);
                // Usar distribución ponderada
                const std::vector<char> all_ops = {'+', '-', '*', '/', '^', '%', 's', 'c', 'l', 'e', '!', '_', 'g'};
                std::discrete_distribution<int> op_dist(OPERATOR_WEIGHTS.begin(), OPERATOR_WEIGHTS.end());
                
                int new_op_idx = op_dist(rng); // Siempre elegirá uno habilitado
                new_op_node->op = all_ops[new_op_idx];

                bool is_unary = (new_op_node->op == 's' || new_op_node->op == 'c' || new_op_node->op == 'l' || new_op_node->op == 'e' || new_op_node->op == '!' || new_op_node->op == '_' || new_op_node->op == 'g');

                new_op_node->left = current_node_ptr_ref;

                if (!is_unary) {
                     if (prob(rng) < MUTATE_INSERT_CONST_PROB) {
                         auto right_child = std::make_shared<Node>(NodeType::Constant);
                         if (FORCE_INTEGER_CONSTANTS) { std::uniform_int_distribution<int> cv(MUTATE_INSERT_CONST_INT_MIN, MUTATE_INSERT_CONST_INT_MAX); right_child->value = static_cast<double>(cv(rng)); }
                         else { std::uniform_real_distribution<double> cv(MUTATE_INSERT_CONST_FLOAT_MIN, MUTATE_INSERT_CONST_FLOAT_MAX); right_child->value = cv(rng); }
                         if (std::fabs(right_child->value) < SIMPLIFY_NEAR_ZERO_TOLERANCE) right_child->value = 0.0;
                         new_op_node->right = right_child;
                     } else {
                         new_op_node->right = std::make_shared<Node>(NodeType::Variable);
                     }
                     if (!new_op_node->right) new_op_node->right = std::make_shared<Node>(NodeType::Variable);
                } else {
                    new_op_node->right = nullptr;
                }

                *node_to_mutate_ptr = new_op_node;
            }
            break;
        case MutationType::NodeDeletion:
            {
                // No eliminar la raíz directamente si es la única opción
                if (node_to_mutate_ptr == &new_tree && nodes.size() == 1) return new_tree;

                if (current_node.type == NodeType::Operator) {
                    // Si es operador, reemplazarlo por uno de sus hijos (aleatorio)
                    NodePtr replacement = nullptr;
                    bool has_left = (current_node.left != nullptr);
                    bool has_right = (current_node.right != nullptr);

                    if (has_left && has_right) {
                        replacement = (prob(rng) < 0.5) ? current_node.left : current_node.right;
                    } else if (has_left) {
                        replacement = current_node.left;
                    } else if (has_right) {
                        replacement = current_node.right;
                    }
                    // Si no tiene hijos válidos (¿cómo?), reemplazar por terminal
                    if (!replacement) replacement = generate_replacement(0); // Profundidad 0 (terminal)

                    *node_to_mutate_ptr = replacement;
                } else {
                    // Si es terminal, reemplazar por otro terminal aleatorio
                    // (Evitar eliminar si es la raíz y no hay más nodos)
                     if (node_to_mutate_ptr != &new_tree || nodes.size() > 1) {
                          *node_to_mutate_ptr = generate_replacement(0); // Profundidad 0 (terminal)
                     }
                }
            }
            break;
         default: // Caso inesperado, reemplazar por seguridad
             *node_to_mutate_ptr = generate_replacement(max_depth);
            break;
    }
    if (USE_HARD_DEPTH_LIMIT) trim_tree(new_tree, MAX_TREE_DEPTH_HARD_LIMIT); // Enforce hard limit after mutation
    return new_tree;
}

// Cruce
void crossover_trees(NodePtr& tree1, NodePtr& tree2) {
    if (!tree1 || !tree2) return; // No cruzar si alguno es nulo

    std::vector<NodePtr*> nodes1, nodes2;
    collect_node_ptrs(tree1, nodes1);
    collect_node_ptrs(tree2, nodes2);

    // No cruzar si alguno no tiene nodos (árbol vacío o solo raíz nula?)
    if (nodes1.empty() || nodes2.empty()) return;

    auto& rng = get_rng();
    std::uniform_int_distribution<int> d1(0, nodes1.size()-1);
    std::uniform_int_distribution<int> d2(0, nodes2.size()-1);

    // Seleccionar puntos de cruce
    NodePtr* crossover_point1 = nodes1[d1(rng)];
    NodePtr* crossover_point2 = nodes2[d2(rng)];

    // Intercambiar los subárboles (los NodePtr)
    std::swap(*crossover_point1, *crossover_point2);
}

// Implementación de simplify_tree
void simplify_tree(NodePtr& tree) {
    if (USE_SIMPLIFICATION) {
        tree = DomainConstraints::fix_or_simplify(tree);
    }
}


In [None]:
%%writefile Code/src/GeneticOperators.h
// ============================================================
// Archivo: src/GeneticOperators.h
// ============================================================
#ifndef GENETICOPERATORS_H
#define GENETICOPERATORS_H

#include "ExpressionTree.h"
#include "Globals.h" // Incluir Globals.h para INF
#include <vector>
#include <memory> // Para std::move

// Estructura para representar un individuo en la población.
// Contiene el árbol de expresión y su fitness cacheado.
struct Individual {
    NodePtr tree; // Puntero inteligente al árbol de expresión
    double fitness = INF; // Fitness cacheado (menor es mejor), inicializado a infinito
    std::vector<double> errors; // Cache of per-case errors for Lexicase Selection
    bool fitness_valid = false; // Indica si el fitness cacheado es válido

    // Constructor por defecto
    Individual() = default;
    // Constructor a partir de un árbol (mueve el puntero)
    explicit Individual(NodePtr t) : tree(std::move(t)) {}

    // Operador de comparación para ordenar individuos (menor fitness primero)
    bool operator<(const Individual& other) const {
        // Manejar casos donde uno o ambos fitness no son válidos
        if (!fitness_valid && !other.fitness_valid) return false; // Iguales si ambos inválidos
        if (!fitness_valid) return false; // Inválido es "peor" que válido (va después)
        if (!other.fitness_valid) return true; // Válido es "mejor" que inválido (va antes)
        // Comparar por fitness si ambos son válidos
        return fitness < other.fitness;
    }
};


// --- Funciones de Operadores Genéticos ---

// Genera un árbol de expresión aleatorio hasta una profundidad máxima.
NodePtr generate_random_tree(int max_depth, int current_depth = 0);

// Crea la población inicial de individuos.
std::vector<Individual> create_initial_population(int population_size);

// Selecciona un individuo usando selección por torneo con presión de parsimonia.
Individual tournament_selection(const std::vector<Individual>& population, int tournament_size);

// Selecciona un individuo usando Epsilon-Lexicase Selection (más inteligente)
Individual lexicase_selection(std::vector<Individual>& population, const std::vector<double>& targets, const std::vector<double>& x_values);

// Realiza el cruce (crossover) entre dos individuos y devuelve un nuevo individuo.
Individual crossover(const Individual& parent1, const Individual& parent2);

// Mutata un individuo in-place.
void mutate(Individual& individual, double mutation_rate);

// Simplifica un árbol in-place.
void simplify_tree(NodePtr& tree);

// Tipos de mutación posibles.
enum class MutationType {
    ConstantChange,
    OperatorChange,
    SubtreeReplace,
    NodeInsertion,
    NodeDeletion // <-- AÑADIDO: Tipo para eliminar un nodo
    // Simplification (manejado por DomainConstraints)
};

// Mutata un árbol aplicando uno de los tipos de mutación con cierta probabilidad.
// Devuelve un nuevo árbol (clonado y potencialmente mutado).
NodePtr mutate_tree(const NodePtr& tree, double mutation_rate, int max_depth);

// Realiza el cruce (crossover) entre dos árboles padres, modificándolos in-place.
void crossover_trees(NodePtr& tree1, NodePtr& tree2);


#endif // GENETICOPERATORS_H


In [None]:
%%writefile Code/src/main.cpp
#include "Globals.h" // Necesario para las constantes globales
#include "GeneticAlgorithm.h"
#include "Fitness.h" // Para evaluate_fitness
#include "ExpressionTree.h" // Para tree_to_string si se necesita aquí
#include <iostream>
#include <vector>
#include <memory> // Para shared_ptr
#include <iomanip> // Para std::setprecision
#include <omp.h>   // Para configuración de OpenMP

#include <fstream> // Para leer archivo
#include <string>
#include <sstream>

int main(int argc, char* argv[]) {
    // === OPTIMIZACIÓN: Configuración explícita de hilos OpenMP ===
    int num_threads = omp_get_max_threads();
    omp_set_num_threads(num_threads);
    std::cout << "[OpenMP] Using " << num_threads << " threads" << std::endl;
    
    // Configurar precisión de salida para números flotantes
    // Force immediate flush for each output (important for subprocess capture)
    std::cout << std::unitbuf << std::fixed << std::setprecision(6);
    
    std::vector<std::string> seed_formulas;
    std::string seed_file_path = "";
    std::string data_file_path = "";
    
    // Parse arguments
    for (int i = 1; i < argc; ++i) {
        std::string arg = argv[i];
        if ((arg == "--seed" || arg == "-s") && i + 1 < argc) {
             seed_file_path = argv[i + 1];
             i++; // Skip next arg
        } else if ((arg == "--data" || arg == "-d") && i + 1 < argc) {
             data_file_path = argv[i + 1];
             i++;
        }
    }
    
    if (!seed_file_path.empty()) {
        std::cout << "Loading seeds from: " << seed_file_path << std::endl;
        std::ifstream file(seed_file_path);
        if (file.is_open()) {
            std::string line;
            while (std::getline(file, line)) {
                if (!line.empty()) {
                    seed_formulas.push_back(line);
                }
            }
            file.close();
            std::cout << "Loaded " << seed_formulas.size() << " formulas." << std::endl;
        } else {
            std::cerr << "[Error] Could not open seed file: " << seed_file_path << std::endl;
        }
    }

    std::vector<double> targets;
    std::vector<double> final_x_values;

    if (!data_file_path.empty()) {
         std::cout << "Loading data from: " << data_file_path << std::endl;
         std::ifstream dfile(data_file_path);
         if (dfile.is_open()) {
             // Format:
             // Line 1: x1 x2 x3 ...
             // Line 2: y1 y2 y3 ...
             // Values separated by space or comma
             
             // Helper lambda to parse line
             auto parse_line = [](const std::string& line) {
                 std::vector<double> vals;
                 std::stringstream ss(line);
                 double val;
                 while (ss >> val) {
                     vals.push_back(val);
                     if (ss.peek() == ',' || ss.peek() == ' ') ss.ignore();
                 }
                 return vals;
             };
             
             std::string line;
             if (std::getline(dfile, line)) final_x_values = parse_line(line);
             if (std::getline(dfile, line)) targets = parse_line(line);
             
             dfile.close();
             
             if (final_x_values.size() != targets.size()) {
                 std::cerr << "[Error] Mismatch in data size: X(" << final_x_values.size() 
                           << ") vs Y(" << targets.size() << ")" << std::endl;
                 return 1;
             }
             std::cout << "Loaded " << final_x_values.size() << " data points." << std::endl;
         } else {
             std::cerr << "[Error] Could not open data file: " << data_file_path << std::endl;
             return 1;
         }
    } else {
        // Fallback to Globals.h
        if (USE_LOG_TRANSFORMATION) {
             std::cout << "Info: Log Transformation is ON (Target = ln(Q(N)))." << std::endl;
             for (size_t i = 0; i < RAW_TARGETS.size(); ++i) {
                 if (RAW_TARGETS[i] > 0) {
                     targets.push_back(std::log(RAW_TARGETS[i]));
                     final_x_values.push_back(X_VALUES[i]);
                 }
             }
        } else {
             std::cout << "Info: Log Transformation is OFF." << std::endl;
             targets = RAW_TARGETS;
             final_x_values = X_VALUES;
        }
    }

    std::cout << "Target Function Points (Effective):" << std::endl;
    // Imprimir los puntos objetivo
    for (size_t i = 0; i < targets.size(); ++i) {
        std::cout << "  f(" << final_x_values[i] << ") = " << targets[i] << std::endl;
    }
    std::cout << "----------------------------------------" << std::endl;
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
    std::cout << "Info: Running with GPU acceleration." << std::endl;
#else
    std::cout << "Info: Running with CPU acceleration." << std::endl;
#endif
    std::cout << "----------------------------------------" << std::endl;
    std::cout << "Parameters:" << std::endl;
    // Imprimir los parámetros globales definidos en Globals.h
    std::cout << "  Total Population: " << TOTAL_POPULATION_SIZE << std::endl;
    std::cout << "  Generations: " << GENERATIONS << std::endl;
    std::cout << "  Islands: " << NUM_ISLANDS << std::endl;
    std::cout << "  Migration Interval: " << MIGRATION_INTERVAL << std::endl;
    std::cout << "  Migration Size: " << MIGRATION_SIZE << std::endl;
    // --- NOMBRES CORREGIDOS ---
    std::cout << "  Mutation Rate (Initial): " << BASE_MUTATION_RATE << std::endl; // <-- Nombre corregido
    std::cout << "  Elite Percentage (Initial): " << BASE_ELITE_PERCENTAGE << std::endl; // <-- Nombre corregido
    // --------------------------
    std::cout << "----------------------------------------" << std::endl;


    try {
        // Crear la instancia del Algoritmo Genético
        // Pasa las referencias a los vectores de datos y los parámetros principales
        GeneticAlgorithm ga(targets, final_x_values, TOTAL_POPULATION_SIZE, GENERATIONS, seed_formulas);

        // Ejecutar el algoritmo
        // La función run() contiene el bucle principal de generaciones y devuelve el mejor árbol encontrado
        NodePtr best_solution_tree = ga.run();

        // La función run() ya imprime el resumen final y la verificación.
        // Comprobar si se encontró alguna solución válida al final
        if (!best_solution_tree) {
            std::cerr << "\nFailed to find any valid solution." << std::endl;
            return 1; // Salir con código de error si no se encontró solución
        }
    } catch (const std::exception& e) {
        std::cerr << "[CRITICAL ERROR] Exception caught in main: " << e.what() << std::endl;
        return 2;
    } catch (...) {
        std::cerr << "[CRITICAL ERROR] Unknown exception caught in main." << std::endl;
        return 3;
    }

    return 0; // Salir con éxito
}


In [None]:
%%writefile Code/src/Globals.h
#ifndef GLOBALS_H
#define GLOBALS_H

#include <vector>
#include <random>
#include <string>
#include <limits>
#include <cmath>

// ============================================================
//                  PARÁMETROS GLOBALES
// ============================================================

// ----------------------------------------
// Datos del Problema (Regresión Simbólica)
// ----------------------------------------
// MODIFICADO: Usamos log(TARGETS) para aplanar el crecimiento exponencial.
// X representa N. TARGETS_LOG es ln(Q(N)).
// Se han filtrado valores N<4 donde Q(N) es 0 o pequeño irrelevante.

// MODIFICADO: RAW_TARGETS contiene los datos crudos. TARGETS se generará en runtime.
const std::vector<double> RAW_TARGETS = {2, 10, 4, 40, 92, 352, 724, 2680, 14200, 73712, 365596, 2279184, 14772512, 95815104, 666090624, 4968057848, 39029188884};
const std::vector<double> X_VALUES = {4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};

// Flag para activar la transformación logarítmica automática
const bool USE_LOG_TRANSFORMATION = true;

// ----------------------------------------
// Configuración General del Algoritmo Genético
// ----------------------------------------
// Controla si se utiliza la aceleración por GPU.
// FORCE_CPU_MODE: Si es true, usa CPU aunque CUDA esté disponible (útil para comparar rendimiento)
const bool FORCE_CPU_MODE = true;  // Cambiar a 'true' para forzar modo CPU

// USE_GPU_ACCELERATION se define automáticamente por CMake si CUDA está disponible
// Pero si FORCE_CPU_MODE es true, se ignora y usa CPU
#ifdef USE_GPU_ACCELERATION_DEFINED_BY_CMAKE
const bool USE_GPU_ACCELERATION = !FORCE_CPU_MODE;
#else
const bool USE_GPU_ACCELERATION = false;
#endif
// Aumentamos el tamaño de la población y el número de generaciones para maximizar la utilización de la GPU,
// ya que la GPU puede procesar un gran número de individuos en paralelo.
// Ajustamos el tamaño de la población para una GPU con 4GB de VRAM (RTX 3050),
// buscando un equilibrio entre el aprovechamiento de la GPU y el uso de memoria.
// Para hacer un uso aún más intensivo de la GPU y acelerar el algoritmo,
// aumentamos el número de islas para fomentar más paralelismo, manteniendo la población total.
// Esto distribuye la carga de trabajo de evaluación de fitness en más unidades de procesamiento concurrentes.
const int TOTAL_POPULATION_SIZE = 50000; // Mantenemos este tamaño, ajustado para 4GB VRAM
const int GENERATIONS = 500000;           // Mantenemos las generaciones altas
const int NUM_ISLANDS = 10;               // Aumentado para mayor paralelismo
const int MIN_POP_PER_ISLAND = 10;        // Ajustado para permitir más islas con población mínima

// --- Fórmula Inicial ---
const bool USE_INITIAL_FORMULA = true; // Poner en 'true' para inyectar la fórmula
const std::string INITIAL_FORMULA_STRING = "(g(x)-(x*0.912079)+0.146743+(3.78968/x))";

// ----------------------------------------
// Parámetros del Modelo de Islas
// ----------------------------------------
// Aumentamos el intervalo y tamaño de migración para permitir que las islas realicen más trabajo en paralelo
// antes de intercambiar individuos, reduciendo la sobrecarga de comunicación y maximizando el procesamiento GPU.
const int MIGRATION_INTERVAL = 100; // Incrementado para permitir más trabajo por isla entre migraciones
const int MIGRATION_SIZE = 50;      // Incrementado para una migración más sustancial

// ----------------------------------------
// Parámetros de Generación Inicial de Árboles
// ----------------------------------------
const int MAX_TREE_DEPTH_INITIAL = 8; // Reducido para fórmulas iniciales más simples y rápidas
const double TERMINAL_VS_VARIABLE_PROB = 0.75;
const double CONSTANT_MIN_VALUE = -10.0;
const double CONSTANT_MAX_VALUE = 10.0;
const int CONSTANT_INT_MIN_VALUE = -10;
const int CONSTANT_INT_MAX_VALUE = 10;
const bool USE_HARD_DEPTH_LIMIT = true; // Toggle for hard depth limit
const int MAX_TREE_DEPTH_HARD_LIMIT = 12; // Hard limit to prevent bloat
// Order: +, -, *, /, ^, %, s, c, l, e, !, _, g
// ----------------------------------------
// Parámetros de Operadores Genéticos (Configuración de Operadores)
// ----------------------------------------
const bool USE_OP_PLUS     = true; // +
const bool USE_OP_MINUS    = true; // -
const bool USE_OP_MULT     = true; // *
const bool USE_OP_DIV      = true; // /
const bool USE_OP_POW      = true; // ^
const bool USE_OP_MOD      = true; // %
const bool USE_OP_SIN      = true; // s
const bool USE_OP_COS      = true; // c
const bool USE_OP_LOG      = true; // l
const bool USE_OP_EXP      = true; // e
const bool USE_OP_FACT     = true; // !
const bool USE_OP_FLOOR    = true; // _
const bool USE_OP_GAMMA    = true; // g

// Order: +, -, *, /, ^, %, s, c, l, e, !, _, g
// Los pesos se multiplican por el flag (0 o 1) para habilitar/deshabilitar.
const std::vector<double> OPERATOR_WEIGHTS = {
    0.10 * (USE_OP_PLUS  ? 1.0 : 0.0), // +
    0.15 * (USE_OP_MINUS ? 1.0 : 0.0), // -
    0.10 * (USE_OP_MULT  ? 1.0 : 0.0), // *
    0.10 * (USE_OP_DIV   ? 1.0 : 0.0), // /
    0.05 * (USE_OP_POW   ? 1.0 : 0.0), // ^
    0.01 * (USE_OP_MOD   ? 1.0 : 0.0), // %
    0.01 * (USE_OP_SIN   ? 1.0 : 0.0), // s
    0.01 * (USE_OP_COS   ? 1.0 : 0.0), // c
    0.15 * (USE_OP_LOG   ? 1.0 : 0.0), // l
    0.02 * (USE_OP_EXP   ? 1.0 : 0.0), // e
    0.05 * (USE_OP_FACT  ? 1.0 : 0.0), // !
    0.05 * (USE_OP_FLOOR ? 1.0 : 0.0), // _
    0.20 * (USE_OP_GAMMA ? 1.0 : 0.0)  // g
};

// ----------------------------------------
// Parámetros de Operadores Genéticos (Mutación, Cruce, Selección)
// ----------------------------------------
const double BASE_MUTATION_RATE = 0.30;
const double BASE_ELITE_PERCENTAGE = 0.15;
const double DEFAULT_CROSSOVER_RATE = 0.85;
const int DEFAULT_TOURNAMENT_SIZE = 30;
const int MAX_TREE_DEPTH_MUTATION = 8; // Slight increase to allow complexity
const double MUTATE_INSERT_CONST_PROB = 0.6;
const int MUTATE_INSERT_CONST_INT_MIN = 1;
const int MUTATE_INSERT_CONST_INT_MAX = 5;
const double MUTATE_INSERT_CONST_FLOAT_MIN = 0.5;
const double MUTATE_INSERT_CONST_FLOAT_MAX = 5.0;

// ----------------------------------------
// Parámetros de Fitness y Evaluación
// ----------------------------------------
// Reducimos ligeramente la penalización por complejidad para permitir que fórmulas más complejas
// (y computacionalmente más intensivas para la GPU) sean favorecidas por el algoritmo.
// MODIFICADO: Aumentado para penalizar bloat (Strategy 3).
const double COMPLEXITY_PENALTY_FACTOR = 0.05; // Was 0.005. Increased significantly to fight bloat.
const bool USE_RMSE_FITNESS = true;
const double FITNESS_ORIGINAL_POWER = 1.3;
const double FITNESS_PRECISION_THRESHOLD = 0.001;
const double FITNESS_PRECISION_BONUS = 0.0001;
const double FITNESS_EQUALITY_TOLERANCE = 1e-9;
const double EXACT_SOLUTION_THRESHOLD = 1e-8;

// ----------------------------------------
// Fitness Ponderado (Weighted Fitness)
// ----------------------------------------
// Activa el fitness ponderado para penalizar fuertemente errores en valores altos de N.
// Esto destruye a las parábolas que fallan en N=20 pero dan buen promedio general.
const bool USE_WEIGHTED_FITNESS = true;
// Tipo de peso: "quadratic" usa i*i, "exponential" usa exp(i*WEIGHTED_FITNESS_EXPONENT)
// Exponente para peso exponencial (más agresivo). Usar 0.2-0.3 para datasets pequeños.
const double WEIGHTED_FITNESS_EXPONENT = 0.25;

// ----------------------------------------
// Parámetros de Características Avanzadas
// ----------------------------------------
const int STAGNATION_LIMIT_ISLAND = 50;
// Lowered from 5000 to allow faster early termination in Hybrid Search mode.
// If best fitness doesn't improve for N generations, terminate early.
const int GLOBAL_STAGNATION_LIMIT = 200;
const double STAGNATION_RANDOM_INJECT_PERCENT = 0.1;
const int PARAM_MUTATE_INTERVAL = 50;
const double PATTERN_RECORD_FITNESS_THRESHOLD = 10.0;
const int PATTERN_MEM_MIN_USES = 3;
const int PATTERN_INJECT_INTERVAL = 10;
const double PATTERN_INJECT_PERCENT = 0.05;
const size_t PARETO_MAX_FRONT_SIZE = 50;
const double SIMPLIFY_NEAR_ZERO_TOLERANCE = 1e-9;
const double SIMPLIFY_NEAR_ONE_TOLERANCE = 1e-9;
const int LOCAL_SEARCH_ATTEMPTS = 30;
// Simplification Toggle
const bool USE_SIMPLIFICATION = true;
// Anti-Stagnation: Island Cataclysm (Hard Reset)
const bool USE_ISLAND_CATACLYSM = true;
// Selection Strategy: Epsilon-Lexicase Selection (Replaces Tournament)
const bool USE_LEXICASE_SELECTION = true;

// ----------------------------------------
// Otros Parámetros
// ----------------------------------------
const int PROGRESS_REPORT_INTERVAL = 100;
// Optimizaciones adicionales:
// Deshabilitamos las constantes enteras forzadas para permitir una mayor flexibilidad
// en las constantes generadas y mutadas, lo que podría conducir a mejores soluciones
// y mantener la GPU ocupada con un rango más amplio de valores.
const bool FORCE_INTEGER_CONSTANTS = false; // Mantenemos false para mayor flexibilidad

// ----------------------------------------
// Control de Duplicados
// ----------------------------------------
const bool PREVENT_DUPLICATES = true; // Activa la verificación de unicidad
const int DUPLICATE_RETRIES = 10;     // Intentos para generar un individuo único antes de rendirse


// ============================================================
//                  UTILIDADES GLOBALES
// ============================================================
std::mt19937& get_rng();
const double INF = std::numeric_limits<double>::infinity();

#endif // GLOBALS_H


In [None]:
%%writefile Code/CMakeLists.txt

cmake_minimum_required(VERSION 3.10)
project(SymbolicRegressionGP)

set(CMAKE_CXX_STANDARD 17)

find_package(CUDA)

if(CUDA_FOUND)
    add_definitions(-DUSE_GPU_ACCELERATION_DEFINED_BY_CMAKE)
    enable_language(CUDA)
    set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -arch=sm_75") # T4 is sm_75
    set(SOURCE_FILES 
        src/main.cpp 
        src/GeneticAlgorithm.cpp 
        src/ExpressionTree.cpp 
        src/GeneticOperators.cpp
        src/Fitness.cpp
        src/FitnessGPU.cu
        src/AdvancedFeatures.cpp
    )
else()
    message(WARNING "CUDA not found. Compiling for CPU only.")
    set(SOURCE_FILES 
        src/main.cpp 
        src/GeneticAlgorithm.cpp 
        src/ExpressionTree.cpp 
        src/GeneticOperators.cpp
        src/Fitness.cpp
        src/FitnessGPU.cu # Still included but ifdef'd out inside
        src/AdvancedFeatures.cpp
    )
endif()

add_executable(SymbolicRegressionGP ${SOURCE_FILES})

if(CUDA_FOUND)
    set_target_properties(SymbolicRegressionGP PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
    target_link_libraries(SymbolicRegressionGP ${CUDA_LIBRARIES})
else()
    target_link_libraries(SymbolicRegressionGP pthread)
endif()


In [None]:
# Compile C++ Engine
%cd Code
!cmake -B build -S . -DCMAKE_BUILD_TYPE=Release
!cmake --build build -j $(nproc)
%cd ..

In [None]:
%%writefile AlphaSymbolic/core/grammar.py
import numpy as np
from scipy.special import gamma as scipy_gamma, gammaln
import math

# Supported operators and their arity (number of arguments)
# Organized by curriculum stage for progressive unlocking
OPERATORS = {
    # === STAGE 0: Pure Arithmetic ===
    '+': 2,
    '-': 2,
    '*': 2,
    '/': 2,
    
    # === STAGE 1: Powers ===
    'pow': 2,
    'sqrt': 1,
    
    # === STAGE 2: Trigonometry ===
    'sin': 1,
    'cos': 1,
    'tan': 1,
    
    # === STAGE 3: Transcendental ===
    'exp': 1,
    'log': 1,
    
    # === STAGE 4: Advanced ===
    'abs': 1,
    'neg': 1,
    'sign': 1,
    'floor': 1,
    'ceil': 1,
    'mod': 2,
    'gamma': 1,
    'lgamma': 1,  # Log-gamma function (from C++ GP engine)
}

# Operator groups for curriculum control
OPERATOR_STAGES = {
    0: ['+', '-', '*', '/'],
    1: ['+', '-', '*', '/', 'pow', 'sqrt'],
    2: ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos', 'tan'],
    3: ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos', 'tan', 'exp', 'log'],
    4: list(OPERATORS.keys()),  # All operators
}

# Terminal tokens
VARIABLES = ['x']
# 'C' is a placeholder for learnable constants
CONSTANTS = ['C', '0', '1', '2', '3', '5', '10', 'pi', 'e']

# Full Vocabulary
VOCABULARY = list(OPERATORS.keys()) + VARIABLES + CONSTANTS
TOKEN_TO_ID = {token: i for i, token in enumerate(VOCABULARY)}
ID_TO_TOKEN = {i: token for token, i in TOKEN_TO_ID.items()}

# Special token for start of sequence
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'

class Node:
    def __init__(self, value, children=None):
        self.value = value
        self.children = children if children else []

    def __repr__(self):
        if not self.children:
            return str(self.value)
        return f"({self.value} " + " ".join([str(c) for c in self.children]) + ")"
    
    def to_infix(self):
        if not self.children:
            return str(self.value)
        
        op = self.value
        if len(self.children) == 1:
            return f"{op}({self.children[0].to_infix()})"
        elif len(self.children) == 2:
            if op == 'pow':
                return f"({self.children[0].to_infix()} ^ {self.children[1].to_infix()})"
            elif op == 'mod':
                return f"({self.children[0].to_infix()} % {self.children[1].to_infix()})"
            return f"({self.children[0].to_infix()} {op} {self.children[1].to_infix()})"
        return str(self.value)
    
    def count_constants(self):
        """Count the number of 'C' placeholders in the tree."""
        count = 1 if self.value == 'C' else 0
        for child in self.children:
            count += child.count_constants()
        return count
    
    def get_constant_positions(self, path=None):
        """Returns a list of paths to all 'C' nodes for optimization."""
        if path is None:
            path = []
        positions = []
        if self.value == 'C':
            positions.append(path.copy())
        for i, child in enumerate(self.children):
            positions.extend(child.get_constant_positions(path + [i]))
        return positions


import ast

class ExpressionTree:
    def __init__(self, token_list):
        """
        Parses a list of tokens in Pre-order traversal (Prefix notation)
        Example: ['+', 'x', 'sin', 'x'] -> x + sin(x)
        """
        self.tokens = token_list
        try:
            self.root, remaining = self._build_tree(token_list)
            if remaining:
                raise ValueError("Tokens remained after building tree")
            self.is_valid = True
        except Exception:
            self.root = None
            self.is_valid = False

    @classmethod
    def from_infix(cls, infix_str):
        """
        Creates an ExpressionTree from a standard infix string (e.g. "sin(x) + x^2").
        Uses Python's ast to parse.
        """
        # Replacements to make it valid python for AST
        # 1. Handle postfix factorial '!' which C++ outputs as '(... )!'
        # We convert '(... )!' to 'gamma(...)'
        # Iterate until no '!' left
        processed_str = infix_str
        while '!' in processed_str:
            idx = processed_str.find('!')
            # Helper to find matching paren backwards
            if idx > 0 and processed_str[idx-1] == ')':
                paren_count = 1
                start = idx - 2
                while start >= 0 and paren_count > 0:
                    if processed_str[start] == ')':
                        paren_count += 1
                    elif processed_str[start] == '(':
                        paren_count -= 1
                    start -= 1
                # start is now 1 char before the matching '('
                start += 1 
                # Reconstruct: ... + gamma( + ... + ) + ...
                # Content includes the parens: ( ... )
                content = processed_str[start:idx] 
                processed_str = processed_str[:start] + "gamma" + content + processed_str[idx+1:]
            else:
                # Fallback: Just remove ! if it's weirdly placed (should not happen with GP output)
                processed_str = processed_str.replace('!', '', 1)

        # 2. C++ uses ^ for power, Python uses **. AST parses ^ as BitXor.
        try:
            tree = ast.parse(processed_str, mode='eval')
            tokens = cls._ast_to_prefix(tree.body)
            return cls(tokens)
        except Exception as e:
            print(f"Error parsing infix: {e} | Original: {infix_str} | Processed: {processed_str}")
            return cls([]) # Invalid

    @staticmethod
    def _ast_to_prefix(node):
        if isinstance(node, ast.BinOp):
            # Map operators
            op_map = {
                ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/',
                ast.BitXor: 'pow', ast.Pow: 'pow', ast.Mod: 'mod'
            }
            op_type = type(node.op)
            if op_type in op_map:
                return [op_map[op_type]] + ExpressionTree._ast_to_prefix(node.left) + ExpressionTree._ast_to_prefix(node.right)
        
        elif isinstance(node, ast.UnaryOp):
            op_map = {ast.USub: 'neg', ast.UAdd: None} # Ignore unary +
            op_type = type(node.op)
            if op_type == ast.USub:
                # Check directly if it's a number to collapse "-5"
                if isinstance(node.operand, ast.Constant) and isinstance(node.operand.value, (int, float)):
                    return [str(-node.operand.value)]
                return ['neg'] + ExpressionTree._ast_to_prefix(node.operand)
            elif op_type == ast.UAdd:
                 return ExpressionTree._ast_to_prefix(node.operand)

        elif isinstance(node, ast.Call):
            # Functions like sin(x)
            func_id = node.func.id
            if func_id in ['sin', 'cos', 'tan', 'exp', 'log', 'sqrt', 'abs', 'floor', 'ceil', 'gamma', 'lgamma']:
                tokens = [func_id]
                for arg in node.args:
                    tokens.extend(ExpressionTree._ast_to_prefix(arg))
                return tokens
        
        elif isinstance(node, ast.Name):
            return [node.id]
        
        elif isinstance(node, ast.Constant): # Python 3.8+
            return [str(node.value)]
        elif isinstance(node, ast.Num): # Older python
            return [str(node.n)]

        raise ValueError(f"Unsupported AST node: {node}")


    def _build_tree(self, tokens):
        if not tokens:
            raise ValueError("Empty token list")
        
        token = tokens[0]
        remaining = tokens[1:]
        
        if token in OPERATORS:
            arity = OPERATORS[token]
            children = []
            for _ in range(arity):
                child, remaining = self._build_tree(remaining)
                children.append(child)
            return Node(token, children), remaining
        elif token in VARIABLES or token in CONSTANTS:
            return Node(token), remaining
        else:
            # Try to parse as float literal
            try:
                float(token)
                return Node(token), remaining
            except:
                raise ValueError(f"Unknown token: {token}")

    def evaluate(self, x_values, constants=None):
        """
        Evaluates the expression tree for a given array of x values.
        constants: optional dict mapping path tuples to constant values
        Returns a numpy array of results.
        """
        # Ensure x_values is a numpy array
        if not isinstance(x_values, np.ndarray):
            x_values = np.array(x_values, dtype=np.float64)
        
        if not self.is_valid:
            return np.full_like(x_values, np.nan, dtype=np.float64)
        return self._eval_node(self.root, x_values, constants, path=[])

    def _eval_node(self, node, x, constants=None, path=None):
        val = node.value
        
        if val == 'x':
            return x.astype(np.float64)
        if val == 'pi':
            return np.full_like(x, np.pi, dtype=np.float64)
        if val == 'e':
            return np.full_like(x, np.e, dtype=np.float64)
        if val == 'C':
            # Check if we have an optimized constant for this position
            if constants is not None and tuple(path) in constants:
                return np.full_like(x, constants[tuple(path)], dtype=np.float64)
            return np.full_like(x, 1.0, dtype=np.float64)  # Default constant = 1
        
        # Check for numeric constants
        try:
            return np.full_like(x, float(val), dtype=np.float64)
        except:
            pass
            
        # Recursive evaluation
        args = []
        for i, c in enumerate(node.children):
            args.append(self._eval_node(c, x, constants, path + [i] if path is not None else None))
        
        # Operators
        with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
            if val == '+': return args[0] + args[1]
            if val == '-': return args[0] - args[1]
            if val == '*': return args[0] * args[1]
            if val == '/': 
                return np.divide(args[0], args[1], out=np.zeros_like(x, dtype=np.float64), where=args[1]!=0)
            if val == 'pow':
                # Safe power
                return np.power(np.abs(args[0]) + 1e-10, np.clip(args[1], -10, 10))
            if val == 'mod':
                return np.mod(args[0], args[1] + 1e-10)
            if val == 'sin': return np.sin(args[0])
            if val == 'cos': return np.cos(args[0])
            if val == 'tan': return np.tan(args[0])
            if val == 'exp': 
                return np.exp(np.clip(args[0], -100, 100))
            if val == 'log': 
                return np.log(np.abs(args[0]) + 1e-10)
            if val == 'sqrt':
                return np.sqrt(np.abs(args[0]))
            if val == 'abs':
                return np.abs(args[0])
            if val == 'floor':
                return np.floor(args[0])
            if val == 'ceil':
                return np.ceil(args[0])
            if val == 'gamma':
                # Match C++ Protected Gamma/Factorial: tgamma(|x| + 1)
                # This ensures consistent evaluation for formulas from C++ engine (which uses !)
                arg = np.abs(args[0]) + 1.0
                clipped = np.clip(arg, 0.1, 50) # Clip upper bound to avoid overflow
                return scipy_gamma(clipped)
            if val == 'lgamma':
                # Protected lgamma: lgamma(|x| + 1)
                arg = np.abs(args[0]) + 1.0
                # gammaln is safe for large positive numbers, so less aggressive clipping needed for overflow,
                # but we clip for consistency and to avoid extremely large outputs if followed by exp
                clipped = np.clip(arg, 0.1, 1000) 
                return gammaln(clipped)
            if val == 'neg':
                return -args[0]
            if val == 'sign':
                return np.sign(args[0])
                
        return np.zeros_like(x, dtype=np.float64)

    def get_infix(self):
        if not self.is_valid:
            return "Invalid"
        return self.root.to_infix()
    
    
    def count_constants(self):
        if not self.is_valid:
            return 0
        return self.root.count_constants()

import sympy

def simplify_formula(formula_str):
    """
    Simplifies a mathematical formula using SymPy.
    """
    try:
        # 1. Clean up C++ notation that sympy might not like directly
        # e.g., 'pi' is fine. 'neg(x)' -> '-x'.
        # But our infix is usually standard. 
        # C++ 'pow(x,2)' might need conversion to 'x**2' or sympy handles it?
        # Sympy uses 'Pow'. 
        
        # Replace common mismatches
        s_str = formula_str.replace("pow(", "Pow(")
        # s_str = s_str.replace("abs(", "Abs(") # Sympy handles abs
        
        # Parse
        expr = sympy.sympify(s_str)
        
        # Simplify
        simplified = sympy.simplify(expr)
        
        # Convert back to string
        # We need to ensure it uses our function names (e.g. sin, cos)
        # Sympy standard printer is usually good.
        # But 'Power' is '**'. We used 'hat' or 'pow' in some places?
        # Our tokenizer supports standard operators. 'x**2' is not standard infix for our parser?
        # Our Parser supports 'x^2' or 'pow(x,2)'? 
        # AST parser handles '**' -> 'pow'.
        
        final_str = str(simplified)
        return final_str
        
    except Exception as e:
        # Fallback if simplification fails (e.g. unknown functions)
        return formula_str


In [None]:
%%writefile AlphaSymbolic/core/model.py
import torch
import torch.nn as nn
import numpy as np

class AlphaSymbolicModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_encoder_layers=2, num_decoder_layers=2, max_seq_len=50):
        super(AlphaSymbolicModel, self).__init__()
        
        self.d_model = d_model
        
        # 1. Point Encoder: Processes pairs of (x, y)
        # Input dim: 2 (x value, y value)
        self.point_embedding = nn.Linear(2, d_model)
        
        # We use a standard Transformer Encoder for the "Problem Embedding"
        # Since points are a set, we don't necessarily need positional encoding, 
        # but the Transformer will process them as a sequence.
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.problem_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # 2. Formula Decoder: Generates tokens
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.formula_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        # 3. Heads
        self.policy_head = nn.Linear(d_model, vocab_size)
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 3) # Quantiles: 0.25, 0.50, 0.75
        )
        
    def forward(self, x_values, y_values, formula_input, formula_mask=None):
        """
        x_values: [batch, num_points]
        y_values: [batch, num_points]
        formula_input: [batch, seq_len] (Token IDs)
        formula_mask: Optional mask for the decoder (causal mask)
        """
        batch_size, num_points = x_values.shape
        
        # -- Problem Encoding --
        # Stack x and y: [batch, num_points, 2]
        points = torch.stack([x_values, y_values], dim=2)
        
        # Project to d_model
        points_emb = self.point_embedding(points) # [batch, num_points, d_model]
        
        # Encode problem (memory for decoder)
        memory = self.problem_encoder(points_emb)
        
        # -- Formula Decoding --
        # Embed tokens
        tgt = self.token_embedding(formula_input) # [batch, seq_len, d_model]
        tgt = self.pos_encoder(tgt)
        
        # Decode
        # memory is [batch, num_points, d_model]
        # tgt is [batch, seq_len, d_model]
        if formula_mask is None:
             # Create causal mask
            seq_len = formula_input.size(1)
            formula_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(formula_input.device)

        output = self.formula_decoder(tgt, memory, tgt_mask=formula_mask)
        
        # -- Heads --
        # Policy: distribution over vocab for each token position
        logits = self.policy_head(output) # [batch, seq_len, vocab_size]
        
        # Value: estimate value from the LAST token's state
        # (Assuming the last token summarizes the current state)
        last_token_output = output[:, -1, :] # [batch, d_model]
        value = self.value_head(last_token_output) # [batch, 1]
        
        return logits, value

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x: [batch, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x

if __name__ == "__main__":
    # Smoke Test
    vocab_size = 20
    model = AlphaSymbolicModel(vocab_size=vocab_size, d_model=32)
    
    # Dummy data
    bs = 2
    points = 10
    x = torch.randn(bs, points)
    y = torch.randn(bs, points)
    
    # Formula input (start token + some tokens)
    seq = torch.randint(0, vocab_size, (bs, 5))
    
    logits, value = model(x, y, seq)
    
    print("Logits shape:", logits.shape) # Should be [2, 5, 20]
    print("Value shape:", value.shape)   # Should be [2, 1]
    print("Smoke test passed.")


In [None]:
%%writefile AlphaSymbolic/core/environment.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from core.grammar import VOCABULARY, OPERATORS, TOKEN_TO_ID, ExpressionTree
from data.synthetic_data import DataGenerator

class SymbolicEnv(gym.Env):
    def __init__(self, max_length=50):
        super(SymbolicEnv, self).__init__()
        
        self.vocab_size = len(VOCABULARY)
        self.max_length = max_length
        self.vocab = VOCABULARY
        
        # Action space: Choose a token from the vocabulary
        self.action_space = spaces.Discrete(self.vocab_size)
        
        # Observation space: 
        # 1. Current token sequence (padded)
        # 2. X values (fixed size for simplicity)
        # 3. Y values
        # For this prototype we will expose a dictionary observation
        self.observation_space = spaces.Dict({
            "sequence": spaces.Box(low=0, high=self.vocab_size, shape=(max_length,), dtype=np.int32),
            "x": spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32),
            "y": spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32)
        })
        
        self.data_gen = DataGenerator(max_depth=4)
        self.current_problem = None
        self.current_sequence = []
        self.open_branches = 0
        
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Generate a new problem (X, Y)
        # In a real scenario, this could be sampled from a fixed dataset
        batch = self.data_gen.generate_batch(1, point_count=10)
        self.current_problem = batch[0]
        
        self.current_sequence = []
        self.open_branches = 1 # Start expecting a root node
        
        return self._get_obs(), {}

    def step(self, action_id):
        token = self.vocab[action_id]
        self.current_sequence.append(token)
        
        # Update open branches
        if token in OPERATORS:
            arity = OPERATORS[token]
            self.open_branches += (arity - 1)
        else:
            self.open_branches -= 1
            
        term = False
        trunc = False
        reward = 0.0
        
        # Check completion
        if self.open_branches == 0:
            term = True
            # Tree is complete, evaluate
            reward = self._calculate_reward()
        elif self.open_branches < 0:
            # Should not happen if we mask actions, but for safety
            term = True
            reward = -100.0 # Syntax error penalty
        elif len(self.current_sequence) >= self.max_length:
            trunc = True
            reward = -10.0 # Incomplete penalty
            
        return self._get_obs(), reward, term, trunc, {}

    def _get_obs(self):
        # Convert sequence to IDs and pad
        seq_ids = [TOKEN_TO_ID[t] for t in self.current_sequence]
        padded_seq = np.zeros(self.max_length, dtype=np.int32)
        padded_seq[:len(seq_ids)] = seq_ids
        
        return {
            "sequence": padded_seq,
            "x": self.current_problem['x'].astype(np.float32),
            "y": self.current_problem['y'].astype(np.float32)
        }

    def _calculate_reward(self):
        try:
            tree = ExpressionTree(self.current_sequence)
            if not tree.is_valid:
                return -100.0
            
            y_pred = tree.evaluate(self.current_problem['x'])
            
            # Root Mean Squared Error (RMSE)
            mse = np.mean((y_pred - self.current_problem['y'])**2)
            rmse = np.sqrt(mse)
            
            if np.isnan(rmse) or np.isinf(rmse):
                return -1000.0
                
            # Reward is negative RMSE
            # We want to maximize reward -> minimize RMSE
            # Normalize or scale? simpler is just -RMSE
            return -rmse
            
        except Exception:
            return -100.0

if __name__ == "__main__":
    env = SymbolicEnv()
    obs, _ = env.reset()
    print("Initial Observation Keys:", obs.keys())
    
    # Simulate a few steps for x + x
    # Prefix: + x x
    actions = ['+', 'x', 'x']
    tot_reward = 0
    for tok in actions:
        aid = TOKEN_TO_ID[tok]
        obs, reward, term, trunc, _ = env.step(aid)
        print(f"Action: {tok}, Reward: {reward}, Term: {term}, Branches: {env.open_branches}")
        tot_reward += reward
        if term: break
    
    print(f"Total Reward: {tot_reward}")


In [None]:
%%writefile AlphaSymbolic/core/loss.py

import torch
import torch.nn as nn

class QuantileLoss(nn.Module):
    """
    Quantile Loss (Pinball Loss) for multiple quantiles.
    
    Args:
        quantiles (list): List of quantiles to estimate (e.g. [0.25, 0.5, 0.75])
    """
    def __init__(self, quantiles=[0.25, 0.5, 0.75]):
        super().__init__()
        self.quantiles = quantiles
        
    def forward(self, preds, target):
        """
        preds: [batch, num_quantiles] - Predicted values for each quantile
        target: [batch, 1] - True scalar target
        """
        # Ensure target matches batch dim
        # target shape might be [batch] or [batch, 1]
        if target.dim() == 1:
            target = target.unsqueeze(1)
            
        loss = 0
        for i, q in enumerate(self.quantiles):
            error = target - preds[:, i:i+1]
            # Pinball loss: max(q * error, (q - 1) * error)
            # Equivalent to: error * (q - I(error < 0))
            loss += torch.max(q * error, (q - 1) * error).mean()
            
        return loss


In [None]:
%%writefile AlphaSymbolic/core/gp_bridge.py
import os
import subprocess
import tempfile
import re
import time
from typing import List, Optional

class GPEngine:
    def __init__(self, binary_path=None):
        if binary_path is None:
            # Default location: Code/build/Release/SymbolicRegressionGP.exe
            # Assuming we are in AlphaSymbolic/.. root or similar.
            # Adjust path relative to this file: alphasybolic/core/gp_bridge.py
            # So binary is at ../../Code/build/Release/SymbolicRegressionGP.exe
            base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
            possible_paths = [
                os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP.exe"),
                os.path.join(base_dir, "Code", "build", "SymbolicRegressionGP.exe"),
                # Linux/Mac support (no .exe)
                os.path.join(base_dir, "Code", "build", "Release", "SymbolicRegressionGP"),
                os.path.join(base_dir, "Code", "build", "SymbolicRegressionGP")
            ]
            self.binary_path = None
            for p in possible_paths:
                if os.path.exists(p):
                    self.binary_path = p
                    break
            
            if self.binary_path is None:
                # Fallback to default for error message
                self.binary_path = possible_paths[0]
        else:
            self.binary_path = binary_path

    def run(self, x_values: List[float], y_values: List[float], seeds: List[str] = [], timeout_sec: int = 10) -> Optional[str]:
        """
        Runs the C++ GP Engine with the given data and seeds.
        Returns the best formula found as a string, or None if failed.
        """
        if not os.path.exists(self.binary_path):
            print(f"[Error] GP Binary not found at: {self.binary_path}")
            return None

        # Create temporary files
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as seed_file, \
             tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as data_file:
            
            # Write Seeds
            for seed in seeds:
                seed_file.write(seed + "\n")
            seed_file_path = seed_file.name
            
            # Write Data
            # Line 1: x1 x2 ...
            # Line 2: y1 y2 ...
            data_file.write(" ".join(map(str, x_values)) + "\n")
            data_file.write(" ".join(map(str, y_values)) + "\n")
            data_file_path = data_file.name

        try:
            # Run Command
            cmd = [self.binary_path, "--seed", seed_file_path, "--data", data_file_path]
            print(f"Running GP Engine: {' '.join(cmd)}")
            
            # Capture output
            # We can't strictly enforce timeout via subprocess.run's timeout argument easily if we want partial results?
            # Actually we can.
            start_time = time.time()
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout_sec)
            
            output = result.stdout
            
            # Parse Output
            # We look for the LAST occurrence of "Formula: ..."
            # Standard formats:
            # "Formula: ((x * x) + 2)"
            # "Final Formula: ..."
            
            best_formula = None
            # Look for formula lines (case-insensitive)
            # Priority: "Final Formula:" > "Formula:" > "Initial best formula:"
            for line in output.splitlines():
                line_lower = line.lower()
                if "formula:" in line_lower:
                    # Extract the part after "formula:" (case-insensitive split)
                    idx = line_lower.find("formula:")
                    if idx != -1:
                        formula_part = line[idx + len("formula:"):].strip()
                        if formula_part:
                            best_formula = formula_part
                            # Keep looking for better matches (Final Formula is best)
                            if "final formula:" in line_lower:
                                break  # Final Formula is the best, stop looking
                        
            print(f"GP Engine finished in {time.time() - start_time:.2f}s")
            
            if best_formula is None:
                print(f"[DEBUG] GP Engine Output (Stdout):\n{output}")
                print(f"[DEBUG] GP Engine Output (Stderr):\n{result.stderr}")
            
            return best_formula

        except subprocess.TimeoutExpired as e:
            print(f"GP Engine timed out after {timeout_sec}s.")
            # Recover output captured so far
            output = e.stdout if e.stdout else ""
            best_formula = None
            if output:
                for line in output.splitlines():
                    line_lower = line.lower()
                    if "formula:" in line_lower:
                        idx = line_lower.find("formula:")
                        if idx != -1:
                            formula_part = line[idx + len("formula:"):].strip()
                            if formula_part:
                                best_formula = formula_part
                                if "final formula:" in line_lower:
                                    break
            
            if best_formula:
                print(f"Recovered best formula from timeout: {best_formula}")
                return best_formula
            
            # Print stderr for timeout diagnose
            if e.stderr:
                 print(f"GP Engine Timeout Stderr: {e.stderr}")
            return None

        except Exception as e:
            print(f"GP Engine failed: {e}")
            if hasattr(e, 'stderr') and e.stderr:
                print(f"Stderr: {e.stderr}")
            return None
        finally:
            # Cleanup
            if os.path.exists(seed_file_path):
                os.unlink(seed_file_path)
            if os.path.exists(data_file_path):
                os.unlink(data_file_path)

if __name__ == "__main__":
    # Test
    engine = GPEngine()
    x = [1, 2, 3, 4]
    y = [1+2, 4+2, 9+2, 16+2] # x^2 + 2
    seeds = ["(x * x)", "(x + 2)"]
    
    print("Testing GPEngine...")
    res = engine.run(x, y, seeds)
    print(f"Result: {res}")


In [None]:
%%writefile AlphaSymbolic/core/__init__.py


In [None]:
%%writefile AlphaSymbolic/data/synthetic_data.py
import numpy as np
import random
from core.grammar import VOCABULARY, OPERATORS, VARIABLES, CONSTANTS, ExpressionTree
from data.augmentation import augment_formula_tokens

class DataGenerator:
    def __init__(self, max_depth=5, population_size=1000, allowed_operators=None):
        self.max_depth = max_depth
        self.population_size = population_size
        self.vocab = VOCABULARY
        # Pre-compute terminal vs operator lists
        self.terminals = VARIABLES + CONSTANTS
        if allowed_operators:
            self.operators = [op for op in allowed_operators if op in OPERATORS]
        else:
            self.operators = list(OPERATORS.keys())

    def generate_random_tree(self, max_depth, current_depth=0):
        if current_depth >= max_depth:
            # Balanced Terminal Selection: 50% x, 50% constant
            if random.random() < 0.5:
                return ['x']
            else:
                return [random.choice(CONSTANTS)]
        
        # Decide if terminal or operator
        # Higher probability of operator at shallow depths
        if random.random() < 0.7: 
            op = random.choice(self.operators)
            arity = OPERATORS[op]
            tokens = [op]
            for _ in range(arity):
                tokens.extend(self.generate_random_tree(max_depth, current_depth + 1))
            return tokens
        else:
            # Balanced Terminal Selection: 40% x, 30% C, 30% numbers
            r = random.random()
            if r < 0.4:
                return ['x']
            elif r < 0.7:
                return ['C']
            else:
                return [random.choice([c for c in CONSTANTS if c != 'C'])]

    def generate_batch(self, batch_size, point_count=10, x_range=(-10, 10)):
        """
        Generates a batch of (X, Y) pairs and their generating formulas.
        """
        data = []
        
        while len(data) < batch_size:
            # Generate random formula
            tokens = self.generate_random_tree(self.max_depth)
            tree = ExpressionTree(tokens)
            
            if not tree.is_valid:
                continue
            
            # Ensure 'x' is present in the formula (90% of the time)
            if 'x' not in tokens and random.random() < 0.9:
                continue
                
            # Generate random X points
            x_values = np.random.uniform(x_range[0], x_range[1], point_count)
            # Sort X for cleaner visualization/learning
            x_values.sort()
            
            # Randomize 'C' values if present
            c_positions = tree.root.get_constant_positions()
            constant_vals = {}
            for pos in c_positions:
                # Expanded range: -20 to 20. Favor 1.0 occasionally
                val = random.uniform(-20, 20) if random.random() > 0.1 else 1.0
                constant_vals[tuple(pos)] = val
            
            # Calculate Y with randomized constants
            y_values = tree.evaluate(x_values, constants=constant_vals)
            
            # Check for validity (no NaNs, Infs, or extremely large values)
            if np.any(np.isnan(y_values)) or np.any(np.isinf(y_values)):
                continue
            if np.max(np.abs(y_values)) > 1e6: # Reject too large numbers
                continue
            if np.std(y_values) < 1e-6: # Reject flat lines (too simple)
                 # Optionally keep some, but mostly we want interesting curves
                 if random.random() > 0.1: continue

            data.append({
                'tokens': tokens,
                'infix': tree.get_infix(),
                'x': x_values,
                'y': y_values
            })
            
        return data

    def generate_structured_tree(self, complexity=1, input_node='x'):
        """
        Recursively builds a structured, human-like formula.
        Respects self.operators.
        """
        # Base cases
        if complexity <= 0:
            # Randomly choose between x, C and constants
            r = random.random()
            if r < 0.4: return ['x']
            if r < 0.7: return ['C']
            return [random.choice([c for c in CONSTANTS if c != 'C'])]
            
        # Filter available structures based on allowed operators
        available_structures = []
        
        # Arithmetic needed: +, -, *
        if any(op in self.operators for op in ['+', '-', '*']):
            available_structures.append('arithmetic')
            
        # Poly needed: pow
        if 'pow' in self.operators:
            available_structures.append('poly')
            
        # Trig needed: sin, cos
        if 'sin' in self.operators or 'cos' in self.operators:
            available_structures.append('trig')
            
        # Exp/Log needed
        if 'exp' in self.operators or 'log' in self.operators:
            available_structures.append('exp_log')
            
        # Composition needs enough variety
        if len(self.operators) > 4 and complexity > 1:
             available_structures.append('composition')
        
        # Fallback if nothing allowed matches (shouldn't happen with proper init)
        if not available_structures:
            return input_node if isinstance(input_node, list) else [input_node]

        choice = random.choice(available_structures)
        
        if choice == 'poly':
            # a*x + b or a*x^2 + b
            a = str(random.randint(1, 5))
            b = str(random.randint(-5, 5))
            power = random.choice(['1', '2', '3'])
            if power == '1':
                term = ['*', a] + (input_node if isinstance(input_node, list) else [input_node])
                return ['+', ] + term + [b]
            else:
                base = input_node if isinstance(input_node, list) else [input_node]
                pow_term = ['pow'] + base + [power]
                term = ['*', a] + pow_term
                return ['+', ] + term + [b]
                
        elif choice == 'trig':
            # Filter trig ops that are allowed
            ops = [op for op in ['sin', 'cos'] if op in self.operators]
            if not ops: return input_node # Should be caught by structure check
            func = random.choice(ops)
            val = input_node if isinstance(input_node, list) else [input_node]
            return [func] + val
            
        elif choice == 'exp_log':
            ops = [op for op in ['exp', 'log'] if op in self.operators]
            if not ops: return input_node
            func = random.choice(ops)
            val = input_node if isinstance(input_node, list) else [input_node]
            return [func] + val
            
        elif choice == 'arithmetic':
            left = self.generate_structured_tree(complexity - 1, input_node)
            right = self.generate_structured_tree(complexity - 1, input_node)
            ops = [op for op in ['+', '-', '*'] if op in self.operators]
            if not ops: return input_node
            op = random.choice(ops)
            return [op] + left + right
            
        elif choice == 'composition':
            inner = self.generate_structured_tree(complexity - 1, input_node)
            outer = self.generate_structured_tree(1, inner)
            return outer
            
        return [input_node]

    def generate_inverse_batch(self, batch_size, point_count=10, x_range=(-5, 5)):
        """
        Generates complex, structured formulas using the new engine.
        """
        data = []
        attempts = 0
        
        while len(data) < batch_size and attempts < batch_size * 5:
            attempts += 1
            # Random complexity capped by max_depth
            complexity = random.randint(1, max(1, self.max_depth - 1))
            
            try:
                tokens = self.generate_structured_tree(complexity, 'x')
                
                # Convert numeric strings to 'C' placeholders if needed
                # But here we want the GROUND TRUTH tokens with numbers for checking?
                # The model predicts tokens. 'C' is for optimization.
                # If we train "End-to-End" (predict 3*x), we keep numbers.
                # If we train "Symbolic" (predict C*x), we swap.
                # The original code swapped numbers to 'C'. Let's check VOCABULARY.
                # '1','2','3' are in VOCABULARY. So we can keep small integers.
                # Large integers -> 'C'.
                
                final_tokens = []
                for t in tokens:
                    if t in self.vocab:
                        final_tokens.append(t)
                    else:
                        # If it's a number not in vocab, map to C?
                        # Or just nearest constant?
                        # For now, simplistic mapping:
                        try:
                            val = float(t)
                            if abs(val - round(val)) < 0.01 and str(int(round(val))) in self.vocab:
                                final_tokens.append(str(int(round(val))))
                            else:
                                final_tokens.append('C')
                        except:
                            final_tokens.append('C')

                # --- DATA AUGMENTATION ---
                if random.random() < 0.3:
                    final_tokens = augment_formula_tokens(final_tokens)
                # -------------------------
                
                tree = ExpressionTree(final_tokens)
                if not tree.is_valid:
                    continue
                
                # Ensure 'x' is present (90% of the time)
                if 'x' not in final_tokens and random.random() < 0.9:
                    continue
                    
                # Check constraints (depth, length)
                if len(final_tokens) > 30: # Limit length
                    continue

                # Generate X points
                # Use safer range for complex funcs
                # Exp/Pow grow very fast, so we constrain X to avoid float overflow
                if 'exp' in final_tokens or 'pow' in final_tokens:
                    x_safe = np.linspace(-2, 2, point_count)
                elif 'log' in final_tokens or 'sqrt' in final_tokens:
                    x_safe = np.linspace(0.1, 5, point_count)
                else:
                    x_safe = np.linspace(x_range[0], x_range[1], point_count)
                
                # Randomize 'C' values if present
                c_positions = tree.root.get_constant_positions()
                constant_vals = {}
                for pos in c_positions:
                    # Expanded range: -20 to 20
                    val = random.uniform(-20, 20) if random.random() > 0.1 else 1.0
                    constant_vals[tuple(pos)] = val
                
                y_values = tree.evaluate(x_safe, constants=constant_vals)
                
                # Quality Control
                if np.any(np.isnan(y_values)) or np.any(np.isinf(y_values)):
                    continue
                if np.max(np.abs(y_values)) > 1e4: # Relaxed limit
                    continue
                if np.std(y_values) < 0.01: # Too flat
                    continue
                
                data.append({
                    'tokens': final_tokens,
                    'infix': tree.get_infix(),
                    'x': x_safe,
                    'y': y_values
                })
            except Exception:
                continue
                
        return data

# Quick test if run directly
if __name__ == "__main__":
    gen = DataGenerator(max_depth=4)
    batch = gen.generate_batch(5)
    for item in batch:
        print(f"Formula: {item['infix']}")
        print(f"Tokens: {item['tokens']}")
        print(f"Y sample: {item['y'][:3]}...")
        print("-" * 20)


In [None]:
%%writefile AlphaSymbolic/data/benchmark_data.py
import numpy as np

# Standard Benchmark Problems
# Levels: 1 (Easy), 2 (Medium), 3 (Hard)

BENCHMARK_SUITE = [
    # --- Level 1: Polynomials & Basic Arithmetic ---
    {
        'id': 'p1',
        'name': 'Lineal',
        'formula_str': '2.5 * x + 1.0',
        'lambda': lambda x: 2.5 * x + 1.0,
        'domain': (-10, 10),
        'points': 20,
        'level': 1
    },
    {
        'id': 'p2',
        'name': 'Cuadratica Simple',
        'formula_str': 'x * x',
        'lambda': lambda x: x**2,
        'domain': (-5, 5),
        'points': 20,
        'level': 1
    },
    {
        'id': 'p3',
        'name': 'Polinomio Cubico',
        'formula_str': 'x**3 + x**2',
        'lambda': lambda x: x**3 + x**2,
        'domain': (-3, 3),
        'points': 20,
        'level': 1
    },
    
    # --- Level 2: Trigonometric & Transcendental ---
    {
        'id': 'p4',
        'name': 'Seno Basico',
        'formula_str': 'sin(x)',
        'lambda': lambda x: np.sin(x),
        'domain': (-np.pi, np.pi),
        'points': 30,
        'level': 2
    },
    {
        'id': 'p5',
        'name': 'Coseno Desplazado',
        'formula_str': 'cos(x) + 1',
        'lambda': lambda x: np.cos(x) + 1,
        'domain': (-np.pi, np.pi),
        'points': 30,
        'level': 2
    },
    {
        'id': 'p6',
        'name': 'Exponencial Simple',
        'formula_str': 'exp(x)',
        'lambda': lambda x: np.exp(x),
        'domain': (-2, 2), # Small domain to avoid explosion
        'points': 20,
        'level': 2
    },
    
    # --- Level 3: Physics / Complex ---
    {
        'id': 'p7',
        'name': 'Damped Oscillation',
        'formula_str': 'exp(-x) * sin(2*x)',
        'lambda': lambda x: np.exp(-x) * np.sin(2*x),
        'domain': (0, 4),
        'points': 40,
        'level': 3
    },
    {
        'id': 'p8',
        'name': 'Gaussian',
        'formula_str': 'exp(-x**2)',
        'lambda': lambda x: np.exp(-x**2),
        'domain': (-3, 3),
        'points': 30,
        'level': 3
    },
    {
        'id': 'p9',
        'name': 'Nguyen-3 (x^3 + x^2 + x)',
        'formula_str': 'x**3 + x**2 + x',
        'lambda': lambda x: x**3 + x**2 + x,
        'domain': (-2, 2),
        'points': 20,
        'level': 3
    },
    {
        'id': 'p10',
        'name': 'Rational Function',
        'formula_str': 'x / (1 + x**2)',
        'lambda': lambda x: x / (1 + x**2),
        'domain': (-4, 4),
        'points': 30,
        'level': 3
    }
]

def get_benchmark_data(problem_id):
    """Returns (x, y) for a specific problem ID."""
    for p in BENCHMARK_SUITE:
        if p['id'] == problem_id:
            x = np.linspace(p['domain'][0], p['domain'][1], p['points'])
            y = p['lambda'](x)
            return x, y, p
    return None, None, None


In [None]:
%%writefile AlphaSymbolic/data/augmentation.py

import random
from core.grammar import OPERATORS

def augment_formula_tokens(tokens):
    """
    Applies mathematical invariants to generate an equivalent formula structure.
    Acts as 'Data Augmentation' for symbolic regression.
    
    Supported Transformations:
    1. Commutativity: (+) and (*)
       e.g. [+ a b] -> [+ b a]
    2. Identity:
       e.g. x -> [+ x 0], x -> [* x 1] (Rarely used to avoid bloat, but useful for robustness)
    3. Inverse operations (Conceptually):
       Not implemented directly on tokens without tree parsing, 
       so we focus on purely structural swaps that don't change value.
    
    Args:
        tokens (list): List of tokens in Prefix notation.
    
    Returns:
        list: A new list of tokens representing an equivalent formula.
    """
    if not tokens:
        return []

    # Helper to parse prefix expression into a tree-like structure (recursive)
    def parse_prefix(token_list):
        if not token_list:
            return None, []
        
        root = token_list[0]
        remaining = token_list[1:]
        
        if root in OPERATORS:
            try:
                arity = OPERATORS[root]
                children = []
                for _ in range(arity):
                    child, remaining = parse_prefix(remaining)
                    children.append(child)
                return {'val': root, 'children': children}, remaining
            except:
                 # Fallback for malformed
                return {'val': root, 'children': []}, remaining
        else:
            # Terminal
            return {'val': root, 'children': []}, remaining

    # Helper to flatten tree back to tokens
    def flatten(node):
        res = [node['val']]
        for child in node['children']:
            res.extend(flatten(child))
        return res

    # 1. Parse
    try:
        tree, _ = parse_prefix(tokens)
    except:
        return list(tokens) # Fail safe

    # 2. Augment Recursive
    def augment_recursive(node):
        # First augment children
        for i in range(len(node['children'])):
            node['children'][i] = augment_recursive(node['children'][i])
            
        val = node['val']
        children = node['children']
        
        # Transformation: Commutativity
        if val in ['+', '*'] and len(children) == 2:
            if random.random() < 0.5:
                # Swap children
                node['children'] = [children[1], children[0]]
        
        # Transformation: (- a b) -> (+ a (- b)) ? Too complex for tokens only without 'neg'
        # Transformation: (+ x x) -> (* x 2) ?
        if val == '+' and len(children) == 2:
            # Check deep equality is hard, but simple check:
            if flatten(children[0]) == flatten(children[1]):
                if random.random() < 0.3:
                    # Convert x + x -> x * 2
                    return {'val': '*', 'children': [children[0], {'val': '2', 'children': []}]}

        return node

    # 3. Apply
    augmented_tree = augment_recursive(tree)
    
    # 4. Flatten
    return flatten(augmented_tree)

if __name__ == "__main__":
    # Test
    # Formula: (+ x y) -> prefix ['+', 'x', 'y']
    t1 = ['+', 'x', 'y']
    print(f"Original: {t1} -> Aug: {augment_formula_tokens(t1)}")
    
    # Formula: (* (+ a b) c)
    t2 = ['*', '+', 'a', 'b', 'c']
    print(f"Original: {t2} -> Aug: {augment_formula_tokens(t2)}")
    
    # Formula: (+ x x)
    t3 = ['+', 'x', 'x']
    print(f"Original: {t3} -> Aug: {augment_formula_tokens(t3)}")


In [None]:
%%writefile AlphaSymbolic/data/__init__.py


In [None]:
%%writefile AlphaSymbolic/search/mcts.py
import math
import numpy as np
import torch
import copy
from core.grammar import VOCABULARY, TOKEN_TO_ID, OPERATORS, ExpressionTree, VARIABLES
from utils.optimize_constants import optimize_constants

class MCTSNode:
    def __init__(self, tokens, parent=None, prior=0.0):
        self.tokens = tokens
        self.parent = parent
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.prior = prior
        self.is_expanded = False
        
        # for parallel search
        self.virtual_loss = 0.0
        self.virtual_visits = 0

    @property
    def value(self):
        count = self.visit_count + self.virtual_visits
        if count == 0:
            return 0.0
        # Combine real value and virtual loss
        # Virtual loss is SUBTRACTED to discourage visits
        return (self.value_sum - self.virtual_loss) / count

    def ucb_score(self, c_puct=1.0):
        count = self.visit_count + self.virtual_visits
        parent_count = self.parent.visit_count + self.parent.virtual_visits if self.parent else 1
        
        if self.parent is None:
            return 0.0
            
        u = c_puct * self.prior * math.sqrt(parent_count) / (1 + count)
        return self.value + u

    @property
    def complexity(self):
        """Estimate complexity (length of formula)."""
        return len(self.tokens)

class MCTS:
    def __init__(self, model, device, grammar=None, c_puct=1.0, n_simulations=100, max_simulations=None, max_depth=50, complexity_lambda=0.1, max_len=200, batch_size=8):
        self.model = model
        self.device = device
        self.grammar = grammar
        self.c_puct = c_puct
        
        # Handle backwards compatibility for max_simulations
        if max_simulations is not None:
            self.n_simulations = max_simulations
        else:
            self.n_simulations = n_simulations
            
        self.max_depth = max_depth
        self.complexity_lambda = complexity_lambda
        self.max_len = max_len
        self.min_value = -float('inf')
        self.max_value = float('inf')
        self.vocab_size = len(VOCABULARY)
        self.sos_id = self.vocab_size
        self.batch_size = batch_size
        
        # Pareto Front: List of {'tokens':, 'rmse':, 'complexity':, 'formula':}
        self.pareto_front = []
        
        # Virtual loss constant usually 1-3
        self.v_loss_const = 3.0
        
    def search(self, x_values, y_values, num_simulations=None):
        """
        Run MCTS (Parallel/Batched) to find the best formula.
        """
        self.pareto_front = [] # Reset Pareto Front for new search
        root = MCTSNode(tokens=[])
        
        # Initial expansion (single)
        self._expand_batch([root], x_values, y_values)
        
        best_rmse = float('inf')
        best_formula = None
        best_tokens = None
        
        limit = num_simulations if num_simulations is not None else self.n_simulations
        
        # Loop in batches
        # Ensure we do at least 1 batch
        num_batches = max(1, (limit + self.batch_size - 1) // self.batch_size)
        
        for _ in range(num_batches): 
            leaves = []
            
            # 1. Selection (find N leaves)
            for _ in range(self.batch_size):
                node = root
                depth = 0
                
                # Selection loop
                while node.is_expanded and node.children and depth < self.max_depth:
                    node = max(node.children.values(), key=lambda n: n.ucb_score(self.c_puct))
                    
                    # Apply virtual loss to discourage re-selection in same batch
                    node.virtual_loss += self.v_loss_const
                    node.virtual_visits += 1
                    depth += 1
                
                # Check if valid leaf to expand
                if depth < self.max_depth and not node.is_expanded:
                    # Avoid duplicates in batch (simple check)
                    if node not in leaves:
                        leaves.append(node)
                else:
                    pass
            
            if not leaves:
                # If no leaves found (tree fully explored or locked), standard MCTS usually continues or stops.
                # We can just break or continue backprop of terminals.
                if root.visit_count > limit: break 
                continue
                
            # 2. Batch Expansion & Evaluation
            values = self._expand_batch(leaves, x_values, y_values)
            
            # 3. Backpropagation
            for node, val in zip(leaves, values):
                # Check for best solution found
                if self._is_complete_tree(node.tokens):
                    # For completed formulas, we calculate REAL RMSE
                    try:
                        # Evaluar
                        # Importar aquí para evitar circular imports si es necesario
                        from utils.optimize_constants import optimize_constants
                        
                        # 1. Optimizar constants (Crucial para Accuracy)
                        # Esto es "Phase 1" de TPSR (constantes en las hojas)
                        # Por simplicidad en esta iteración, asumimos que 'evaluate_formula' ya hace algo o usamos el string directo.
                        # Idealmente llamaríamos a BFGS aquí.
                        
                        # Use existing _evaluate_formula to get RMSE and optimized constants
                        tree = ExpressionTree(node.tokens)
                        optimized_constants, real_rmse = optimize_constants(tree, x_values, y_values)
                        
                        # Get y_pred using the optimized constants
                        y_pred = tree.evaluate(x_values, constants=optimized_constants)
                        
                        # Check dimensions
                        if y_pred.shape != y_values.shape:
                            # If shapes don't match, it's an invalid evaluation
                            final_val = 0.0
                        else:
                            # 2. Calcular Reward TPSR (Hybrid Accuracy + Complexity)
                            # R = 1 / (1 + NMSE) + lambda * exp(-len/L)
                            
                            mse = np.mean((y_pred - y_values)**2)
                            var_y = np.var(y_values)
                            if var_y < 1e-9: var_y = 1.0 # Avoid division by zero
                            
                            nmse = mse / var_y
                            
                            # Evitar NMSE gigantes
                            if np.isnan(nmse) or np.isinf(nmse):
                                nmse = 1e9
                            
                            r_acc = 1.0 / (1.0 + nmse)
                            
                            # Penalización por complejidad
                            token_len = len(node.tokens)
                            L = self.max_len # Max length del modelo
                            
                            r_cplx = self.complexity_lambda * np.exp(-token_len / L)
                            
                            # Suma y Normalización (para mantener rango 0-1)
                            # El máximo teórico es (1.0 + lambda). Dividimos por eso.
                            raw_reward = r_acc + r_cplx
                            final_val = raw_reward / (1.0 + self.complexity_lambda)

                        # Update best formula based on RMSE (for reporting, not for MCTS value)
                        if real_rmse < best_rmse:
                            best_rmse = real_rmse
                            best_tokens = node.tokens
                            best_formula = ExpressionTree(node.tokens).get_infix()
                        
                        # Update Pareto Front
                        # Complexity = len(tokens) (or could use count_constants + nodes)
                        complexity = len(node.tokens)
                        self._update_pareto_front(node.tokens, real_rmse, complexity, ExpressionTree(node.tokens).get_infix())

                    except Exception as e:
                        # print(f"Error evaluating formula: {e}")
                        final_val = 0.0 # Invalid formula gets 0 reward
                else:
                    final_val = val
                
                # The following lines were part of the user's instruction but contained syntax errors and undefined variables.
                # They are commented out to maintain a syntactically correct and functional document.
                # If these lines were intended to be added, please provide a complete and correct snippet.
                #
                # # Construir vector de probabilidades
                # probs = np.zeros(self.vocab_size, dtype=np.float32)
                # for token_id, count in counts.items():
                #     probs[token_id] = count / total_visits_count += 1
                
                curr = node
                while curr is not None:
                    curr.visit_count += 1
                    curr.value_sum += final_val
                    
                    # Revert virtual loss for parent and above
                    # Since we added to PARENT's child (which is curr), 
                    # and we traverse Up...
                    # Wait, logic: We selected CHILD. Virtual loss was added TO CHILD (curr).
                    # So we must remove it from curr.
                    if curr.virtual_visits > 0:
                        curr.virtual_loss -= self.v_loss_const
                        curr.virtual_visits -= 1
                            
                    curr = curr.parent
        
        # After search, force cleanup of any residual virtual loss (safety)
        # (Not strictly needed if logic is perfect, but good practice in complex async MCTS)
        
        return {
            'tokens': best_tokens,
            'formula': best_formula,
            'rmse': best_rmse,
            'root': root,
            'pareto_front': self.pareto_front
        }

    def _update_pareto_front(self, tokens, rmse, complexity, formula_str):
        """
        Update the Pareto Front with a new solution.
        Keep solutions that are not dominated by any other solution.
        Solution A dominates B if:
        A.rmse <= B.rmse AND A.complexity <= B.complexity AND (A.rmse < B.rmse OR A.complexity < B.complexity)
        """
        # Create candidate
        candidate = {'tokens': tokens, 'rmse': rmse, 'complexity': complexity, 'formula': formula_str}
        
        # Check if dominated by existing
        is_dominated = False
        to_remove = []
        
        for existing in self.pareto_front:
            # Check if existing dominates candidate
            if (existing['rmse'] <= candidate['rmse'] and 
                existing['complexity'] <= candidate['complexity'] and 
                (existing['rmse'] < candidate['rmse'] or existing['complexity'] < candidate['complexity'])):
                is_dominated = True
                break
                
            # Check if candidate dominates existing
            if (candidate['rmse'] <= existing['rmse'] and 
                candidate['complexity'] <= existing['complexity'] and 
                (candidate['rmse'] < existing['rmse'] or candidate['complexity'] < existing['complexity'])):
                to_remove.append(existing)
        
        if not is_dominated:
            # Remove dominated existing solutions
            for item in to_remove:
                self.pareto_front.remove(item)
            
            # Add candidate
            self.pareto_front.append(candidate)
            # Sort by RMSE for easier viewing
            self.pareto_front.sort(key=lambda x: x['rmse'])

    def _expand_batch(self, nodes, x_values, y_values):
        """
        Batched expansion. Returns list of values.
        """
        if not nodes:
            return []
            
        # Prepare inputs
        x_tensor = torch.tensor(x_values, dtype=torch.float32).unsqueeze(0).to(self.device)
        y_tensor = torch.tensor(y_values, dtype=torch.float32).unsqueeze(0).to(self.device)
        
        # Repeat X/Y for batch
        batch_size = len(nodes)
        x_batch = x_tensor.repeat(batch_size, 1, 1).squeeze(1) # [batch, points]
        y_batch = y_tensor.repeat(batch_size, 1, 1).squeeze(1) # [batch, points]
        
        # Prepare sequences
        # Find max len
        max_len = 0
        seqs = []
        for n in nodes:
            s = [self.sos_id] + [TOKEN_TO_ID[t] for t in n.tokens]
            seqs.append(s)
            max_len = max(max_len, len(s))
            
        # Pad and stack
        input_tensor = torch.full((batch_size, max_len), self.sos_id, dtype=torch.long).to(self.device)
        for i, s in enumerate(seqs):
            input_tensor[i, :len(s)] = torch.tensor(s, dtype=torch.long)
            
        # Inference
        with torch.no_grad():
            logits, value_preds = self.model(x_batch, y_batch, input_tensor)
            
        # Process results
        values = []
        
        # To CPU numpy for probability processing
        probs_batch = torch.softmax(logits[:, -1, :self.vocab_size], dim=1).cpu().numpy()
        value_preds = value_preds.cpu().numpy() # [batch, 3]
        
        for i, node in enumerate(nodes):
            # 1. Store Value (Median for now)
            # value_preds is [batch, 3] -> (Pessimistic, Median, Optimistic)
            # We use Median (index 1) for standard UCB.
            val_pred = value_preds[i, 1] 
            val = float(np.clip(val_pred, 0.0, 1.0))
            values.append(val)
            
            # 2. Expand children
            node_probs = probs_batch[i]
            valid_next = self._get_valid_next_tokens(node.tokens)
            
            for idx in valid_next:
                token = VOCABULARY[idx]
                prior = node_probs[idx]
                child = MCTSNode(tokens=node.tokens + [token], parent=node, prior=prior)
                node.children[token] = child
            
            node.is_expanded = True
            
        return values

    def _get_valid_next_tokens(self, tokens):
        """Simple grammar check."""
        open_slots = 1
        for t in tokens:
            if t in OPERATORS:
                open_slots += OPERATORS[t] - 1
            else:
                open_slots -= 1
        
        if open_slots <= 0:
            return []
        return list(range(self.vocab_size))

    def _is_complete_tree(self, tokens):
        if not tokens: return False
        try:
            tree = ExpressionTree(tokens)
            # Basic validation
            if len(tokens) > self.max_depth * 2: return False
            return tree.is_valid
        except:
            return False

    def _evaluate_formula(self, tokens, x, y):
        try:
            tree = ExpressionTree(tokens)
            _, rmse = optimize_constants(tree, x, y)
            return rmse
        except:
            return 1e9

    def get_training_examples(self, root):
        """
        Extrae ejemplos de entrenamiento del árbol generado.
        Retorna: lista de (state_tokens, policy_probs, value_target)
        """
        examples = []
        queue = [root]
        
        while queue:
            node = queue.pop(0)
            if node.visit_count < 1: 
                continue
            
            # Policy Target (Pi)
            # Distribución de visitas de los hijos
            counts = {}
            total_visits = 0
            has_children = False
            
            for token_id, child in node.children.items():
                # child key is token STRING or ID?
                # In _expand_batch: node.children[token] = child.
                # token = VOCABULARY[idx] (String).
                # So keys are strings.
                # But we need ID for probabilities array index.
                if token_id in TOKEN_TO_ID:
                    tid = TOKEN_TO_ID[token_id]
                    counts[tid] = child.visit_count
                    total_visits += child.visit_count
                    queue.append(child)
                    has_children = True
            
            if not has_children or total_visits == 0:
                continue
                
            # Construir vector de probabilidades
            probs = np.zeros(self.vocab_size, dtype=np.float32)
            for tid, count in counts.items():
                probs[tid] = count / total_visits
            
            # Value Target (V)
            # Usamos el Q-value (valor esperado) del nodo como target para el Value Head.
            # Q = value_sum / visit_count
            v = node.value_sum / node.visit_count
            
            # State: node.tokens (lista de ids?)
            # node.tokens is list of strings (from VOCABULARY).
            # self_play.py expects tokens as strings in ReplayBuffer.add.
            examples.append((node.tokens, probs, v))
            
        return examples


In [None]:
%%writefile AlphaSymbolic/search/beam_search.py
"""
Beam Search for AlphaSymbolic.
Explores multiple formula candidates in parallel, keeping top-K at each step.
"""
import torch
import numpy as np
from core.grammar import VOCABULARY, OPERATORS, TOKEN_TO_ID, ExpressionTree, OPERATOR_STAGES
from utils.optimize_constants import optimize_constants

class BeamSearch:
    def __init__(self, model, device, beam_width=10, max_length=30, curriculum_stage=None):
        self.model = model
        self.device = device
        self.beam_width = beam_width
        self.max_length = max_length
        self.vocab_size = len(VOCABULARY)
        self.sos_id = self.vocab_size  # SOS token ID
        
        # Build token mask based on curriculum stage
        self.token_mask = None
        if curriculum_stage is not None:
            allowed_ops = OPERATOR_STAGES.get(curriculum_stage, list(OPERATORS.keys()))
            allowed_tokens = set(['x', 'C', '0', '1', '2', '3', '5', '10', 'pi', 'e'])
            allowed_tokens.update(allowed_ops)
            
            # Create mask: 0 for allowed, -inf for disallowed
            mask = torch.full((self.vocab_size,), float('-inf'), device=device)
            for token in allowed_tokens:
                if token in TOKEN_TO_ID:
                    mask[TOKEN_TO_ID[token]] = 0.0
            self.token_mask = mask
        
    def search(self, x_values, y_values, return_partial=False):
        """
        Beam Search to find the best formula structure.
        """
        # Prepare data once
        x_tensor = torch.tensor(x_values, dtype=torch.float32).unsqueeze(0).to(self.device) # [1, points]
        y_tensor = torch.tensor(y_values, dtype=torch.float32).unsqueeze(0).to(self.device) # [1, points]
        
        # Each element in beams is just the sequence of tokens (list of strings)
        # We track scores and open branches in parallel lists or a list of dicts
        beams = [{'seq': [], 'log_prob': 0.0, 'open': 1}]
        
        completed = []
        
        for step in range(self.max_length):
            if not beams:
                break
                
            # Filter valid beams just in case
            active_beams = [b for b in beams if b['open'] > 0]
            if not active_beams:
                break
                
            # Prepare batch for model
            # Batch size = number of active beams
            batch_size = len(active_beams)
            
            # Expand X and Y to match batch size [batch, points]
            x_batch = x_tensor.expand(batch_size, -1)
            y_batch = y_tensor.expand(batch_size, -1)
            
            # Prepare input sequences [batch, current_seq_len]
            # Must prepend SOS token
            seqs = [[self.sos_id] + [TOKEN_TO_ID[t] for t in b['seq']] for b in active_beams]
            input_tensor = torch.tensor(seqs, dtype=torch.long).to(self.device)
            
            # Single model call for all beams
            with torch.no_grad():
                logits, _ = self.model(x_batch, y_batch, input_tensor)
            
            # Logits shape: [batch, seq_len, vocab_size]
            # We want the last token's probabilities
            last_token_logits = logits[:, -1, :self.vocab_size]
            
            # Apply curriculum mask if set
            if self.token_mask is not None:
                last_token_logits = last_token_logits + self.token_mask
            
            log_probs = torch.log_softmax(last_token_logits, dim=-1) # [batch, vocab]
            
            # --- Repetition Penalty (Simple) ---
            # If the same token was generated recently, penalize it slightly.
            # This prevents 10 ////////// loops.
            penalty_factor = 2.0  # Reduce log_prob (which is negative) by absolute amount or multiplier?
            # Log probs are negative (e.g. -0.1). Making them MORE negative penalizes.
            # If we multiply by 1.2, -0.1 becomes -0.12 (lower probability).
            
            for i, beam in enumerate(active_beams):
                if beam['seq']:
                     # Get last token ID
                    last_token = beam['seq'][-1]
                    if last_token in TOKEN_TO_ID:
                        last_id = TOKEN_TO_ID[last_token]
                        # Penalize current step logits for this token
                        # If log_prob is close to 0 (high prob), e.g. -0.01 -> -0.012
                        # If log_prob is -10 (low prob), -> -12
                        # Check bounds to avoid NaN if -inf
                        if log_probs[i, last_id] > -1e9:
                             log_probs[i, last_id] *= 1.5 
            # -----------------------------------
            
            # We need to find the top-K candidates ACROSS current beams?
            # Standard beam search: expand all, then prune to K
            
            all_candidates = []
            
            # Get top-K for EACH beam to avoid explosion (e.g. top 2*width)
            k_per_beam = min(self.beam_width, self.vocab_size)
            beam_topk_scores, beam_topk_indices = torch.topk(log_probs, k_per_beam, dim=-1)
            
            # Move to CPU for processing logic
            beam_topk_scores = beam_topk_scores.cpu().numpy()
            beam_topk_indices = beam_topk_indices.cpu().numpy()
            
            for i, beam in enumerate(active_beams):
                for score, idx in zip(beam_topk_scores[i], beam_topk_indices[i]):
                    token = VOCABULARY[idx]
                    new_seq = beam['seq'] + [token]
                    
                    # Calculate new open branches
                    if token in OPERATORS:
                        new_open = beam['open'] + OPERATORS[token] - 1
                    else:
                        new_open = beam['open'] - 1
                    
                    if new_open < 0:
                        continue
                        
                    all_candidates.append({
                        'seq': new_seq,
                        'log_prob': beam['log_prob'] + score,
                        'open': new_open
                    })
            
            # Global prune: keep top beam_width
            all_candidates.sort(key=lambda x: x['log_prob'], reverse=True)
            beams = all_candidates[:self.beam_width]
            
            # Check for completions
            still_active = []
            for b in beams:
                if b['open'] == 0:
                    completed.append(b)
                else:
                    still_active.append(b)
            
            beams = still_active
            # If we filled up on completions, we might still want to explore? 
            # Usually we keep exploring until all beams are done or max length
            if len(completed) >= self.beam_width:
                 # Optional: early exit if we found enough good candidates
                 pass

        # Evaluate results
        scored_results = []
        for beam in completed:
            tree = ExpressionTree(beam['seq'])
            if tree.is_valid:
                constants, rmse = optimize_constants(tree, x_values, y_values)
                scored_results.append({
                    'tokens': beam['seq'],
                    'log_prob': beam['log_prob'],
                    'rmse': rmse,
                    'constants': constants,
                    'formula': tree.get_infix()
                })
        
        scored_results.sort(key=lambda x: x['rmse'])
        
        # If no results and return_partial is requested, return the best incomplete beam
        if not scored_results and return_partial and beams:
            # Take the beam with highest probability
            best_beam = beams[0] 
            # Construct a partial result
            # We can't optimize constants or get a valid infix easily, but we can show tokens
            scored_results.append({
                'tokens': best_beam['seq'],
                'log_prob': best_beam['log_prob'],
                'rmse': float('inf'),
                'constants': {},
                'formula': "Partial: " + " ".join(best_beam['seq']) + "..."
            })
            
        return scored_results


def beam_solve(target_x, target_y, model, device, beam_width=20, max_length=25):
    """
    Solve symbolic regression using beam search.
    """
    searcher = BeamSearch(model, device, beam_width=beam_width, max_length=max_length)
    results = searcher.search(target_x, target_y)
    
    if not results:
        return None
        
    return results  # Return all results for Pareto analysis


if __name__ == "__main__":
    from core.model import AlphaSymbolicModel
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    VOCAB_SIZE = len(VOCABULARY)
    
    model = AlphaSymbolicModel(vocab_size=VOCAB_SIZE + 1, d_model=64).to(DEVICE)
    try:
        model.load_state_dict(torch.load("alpha_symbolic_model.pth", map_location=DEVICE, weights_only=True))
    except:
        print("Model not found, using random weights")
    model.eval()
    
    # Test
    x_test = np.linspace(-5, 5, 20).astype(np.float64)
    y_test = 2 * x_test + 3
    
    print("Running Beam Search...")
    results = beam_solve(x_test, y_test, model, DEVICE, beam_width=10)
    
    print(f"\nFound {len(results)} valid formulas:")
    for i, r in enumerate(results[:5]):
        print(f"  {i+1}. {r['formula']} (RMSE: {r['rmse']:.4f})")


In [None]:
%%writefile AlphaSymbolic/search/hybrid_search.py
import time
import torch
import numpy as np
from typing import List, Dict, Any, Optional

from core.gp_bridge import GPEngine
from search.beam_search import BeamSearch, beam_solve

def hybrid_solve(
    x_values: np.ndarray,
    y_values: np.ndarray,
    model: torch.nn.Module,
    device: torch.device,
    beam_width: int = 50,
    gp_timeout: int = 10,
    gp_binary_path: Optional[str] = None
) -> Dict[str, Any]:
    """
    Solves Symbolic Regression using a Hybrid Neuro-Evolutionary approach.
    
    Phase 1: Neural Beam Search (The Brain)
             - Rapidly scans the search space.
             - Generates diverse, high-likelihood formula skeletons.
             
    Phase 2: Genetic Programming Refinement (The Muscle)
             - Takes the best skeletons from Phase 1.
             - Uses GPU-accelerated evolution to optimize constants and structure.
             - Runs for `gp_timeout` seconds.
             
    Returns:
        Best found formula result dict.
    """
    
    print(f"--- Starting Alpha-GP Hybrid Search ---")
    start_time = time.time()
    
    # 1. Neural Beam Search (Phase 1)
    print(f"[Phase 1] Neural Beam Search (Width={beam_width})...")
    # We use a larger beam width to ensure diversity for the GP
    # If the user requests beam_width=X, we might want to multiply it for the "seeds"
    # But let's stick to what is passed.
    
    neural_results = beam_solve(x_values, y_values, model, device, beam_width=beam_width)
    
    seeds = []
    if neural_results:
        print(f"[Phase 1] Found {len(neural_results)} candidates.")
        # Extract formulas tokens/string
        # neural_results is a list of dicts with 'formula' key (infix string)
        # GPEngine expects infix strings (e.g. "((x*x)+2)")
        
        # Filter for uniqueness and validity
        seen_formulas = set()
        for res in neural_results:
            f_str = res['formula']
            # Basic validation: must verify it's not a Partial result
            if f_str.startswith("Partial"): continue
            
            if f_str not in seen_formulas:
                seeds.append(f_str)
                seen_formulas.add(f_str)
        
        print(f"[Phase 1] Generated {len(seeds)} unique seeds for GP.")
        if len(seeds) > 0:
            print(f"Top Seed: {seeds[0]}")
    else:
        print("[Phase 1] No valid candidates found (Beam Search failed).")
        print("[Phase 1] Falling back to pure GP (Random Initialization).")
        seeds = []

    # 2. GP Refinement (Phase 2)
    print(f"[Phase 2] GPU Genetic Refinement (Timeout={gp_timeout}s)...")
    gp_engine = GPEngine(binary_path=gp_binary_path)
    
    # Run GP
    # We pass the seeds. GP engine handles the rest.
    # Ensure x_values and y_values are lists for gp_engine
    x_list = x_values.tolist() if hasattr(x_values, 'tolist') else list(x_values)
    y_list = y_values.tolist() if hasattr(y_values, 'tolist') else list(y_values)
    gp_result_str = gp_engine.run(x_list, y_list, seeds, timeout_sec=gp_timeout)
    
    total_time = time.time() - start_time
    
    if gp_result_str:
        print(f"--- Hybrid Search Completed in {total_time:.2f}s ---")
        print(f"Best Formula: {gp_result_str}")
        
        # Construct a result dict similar to Beam Search for consistency
        # Ideally we would evaluate it here to get RMSE, but GP output doesn't give us RMSE directly in a structured way (only stdout).
        # We can implement a quick evaluator if needed, or assume the user trusts the string.
        # For UI display, we probably want RMSE.
        
        return {
            'formula': gp_result_str,
            'rmse': 0.0, # Placeholder, will be evaluated by UI if needed or we can do it here
            'source': 'Alpha-GP Hybrid',
            'time': total_time
        }
    else:
        print(f"--- Hybrid Search Failed (GP did not return valid result) ---")
        return None

if __name__ == "__main__":
    # Test
    # Mock Model
    class MockModel(torch.nn.Module):
        def forward(self, x, y, seq):
            # Return random logits
            bs, seq_len = seq.shape
            vocab = 20
            return torch.randn(bs, seq_len, vocab), None

    print("Testing Hybrid Search...")
    x = np.linspace(-5, 5, 10)
    y = x**2
    try:
        res = hybrid_solve(x, y, MockModel(), torch.device("cpu"), beam_width=5)
        print(res)
    except Exception as e:
        print(f"Test failed: {e}")


In [None]:
%%writefile AlphaSymbolic/search/pareto.py
"""
Pareto Front Manager for AlphaSymbolic.
Maintains a set of non-dominated solutions (accuracy vs complexity).
"""
import numpy as np
from core.grammar import ExpressionTree

class ParetoSolution:
    def __init__(self, tokens, rmse, complexity, formula_str, constants=None):
        self.tokens = tokens
        self.rmse = rmse  # Lower is better
        self.complexity = complexity  # Lower is better (number of nodes)
        self.formula = formula_str
        self.constants = constants or {}
        
    def dominates(self, other):
        """Returns True if self dominates other (better in all objectives)."""
        # Self dominates other if:
        # - Self is at least as good in all objectives
        # - Self is strictly better in at least one objective
        at_least_as_good = (self.rmse <= other.rmse) and (self.complexity <= other.complexity)
        strictly_better = (self.rmse < other.rmse) or (self.complexity < other.complexity)
        return at_least_as_good and strictly_better
    
    def __repr__(self):
        return f"ParetoSolution(rmse={self.rmse:.4f}, complexity={self.complexity}, formula='{self.formula}')"


class ParetoFront:
    def __init__(self, max_size=50):
        self.solutions = []
        self.max_size = max_size
        
    def add(self, solution):
        """
        Attempts to add a solution to the Pareto front.
        Returns True if added, False if dominated.
        """
        # Check if new solution is dominated by any existing
        for existing in self.solutions:
            if existing.dominates(solution):
                return False  # New solution is dominated
        
        # Remove any solutions dominated by the new one
        self.solutions = [s for s in self.solutions if not solution.dominates(s)]
        
        # Add the new solution
        self.solutions.append(solution)
        
        # Enforce max size by removing worst solutions
        if len(self.solutions) > self.max_size:
            # Sort by a combined score and keep top max_size
            self.solutions.sort(key=lambda s: s.rmse + 0.01 * s.complexity)
            self.solutions = self.solutions[:self.max_size]
        
        return True
    
    def add_from_results(self, results_list):
        """
        Add multiple results from beam search or MCTS.
        results_list: list of dicts with 'tokens', 'rmse', 'constants', 'formula'
        """
        added = 0
        for r in results_list:
            tree = ExpressionTree(r['tokens'])
            complexity = len(r['tokens'])  # Simple complexity = token count
            
            sol = ParetoSolution(
                tokens=r['tokens'],
                rmse=r['rmse'],
                complexity=complexity,
                formula_str=r['formula'],
                constants=r.get('constants', {})
            )
            
            if self.add(sol):
                added += 1
        
        return added
    
    def get_best_by_rmse(self):
        """Returns the solution with lowest RMSE."""
        if not self.solutions:
            return None
        return min(self.solutions, key=lambda s: s.rmse)
    
    def get_simplest(self):
        """Returns the solution with lowest complexity."""
        if not self.solutions:
            return None
        return min(self.solutions, key=lambda s: s.complexity)
    
    def get_balanced(self, alpha=0.5):
        """
        Returns a balanced solution.
        alpha: weight for RMSE (1-alpha for complexity)
        """
        if not self.solutions:
            return None
        
        # Normalize scores
        rmse_vals = [s.rmse for s in self.solutions]
        comp_vals = [s.complexity for s in self.solutions]
        
        min_rmse, max_rmse = min(rmse_vals), max(rmse_vals) + 1e-10
        min_comp, max_comp = min(comp_vals), max(comp_vals) + 1e-10
        
        def score(s):
            norm_rmse = (s.rmse - min_rmse) / (max_rmse - min_rmse)
            norm_comp = (s.complexity - min_comp) / (max_comp - min_comp)
            return alpha * norm_rmse + (1 - alpha) * norm_comp
        
        return min(self.solutions, key=score)
    
    def summary(self):
        """Print a summary of the Pareto front."""
        print(f"\n=== Pareto Front ({len(self.solutions)} solutions) ===")
        for i, sol in enumerate(sorted(self.solutions, key=lambda s: s.rmse)[:10]):
            print(f"  {i+1}. RMSE={sol.rmse:.6f}, Nodes={sol.complexity}, Formula: {sol.formula}")


# Quick test
if __name__ == "__main__":
    front = ParetoFront()
    
    # Add some test solutions
    solutions = [
        ParetoSolution(['x'], 10.0, 1, "x"),
        ParetoSolution(['+', 'x', '1'], 5.0, 3, "(x + 1)"),
        ParetoSolution(['*', '2', 'x'], 3.0, 3, "(2 * x)"),
        ParetoSolution(['+', '*', '2', 'x', '3'], 0.5, 5, "((2 * x) + 3)"),
        ParetoSolution(['+', '*', '*', '2', 'x', 'x', '+', 'x', '1'], 0.1, 9, "complicated"),
    ]
    
    for sol in solutions:
        added = front.add(sol)
        print(f"Added {sol.formula}: {added}")
    
    front.summary()
    
    print(f"\nBest by RMSE: {front.get_best_by_rmse()}")
    print(f"Simplest: {front.get_simplest()}")
    print(f"Balanced: {front.get_balanced()}")


In [None]:
%%writefile AlphaSymbolic/search/__init__.py


In [None]:
%%writefile AlphaSymbolic/ui/app_core.py
"""
Core state and model management for AlphaSymbolic Gradio App.
"""
import torch
import os
from core.model import AlphaSymbolicModel
from core.grammar import VOCABULARY

from collections import deque
import time

# Global state
MODEL = None
DEVICE = None
TRAINING_STATUS = {"running": False, "epoch": 0, "loss": 0, "message": "Listo"}
STOP_TRAINING = False  # Flag to request training stop

def request_stop_training():
    """Request training to stop gracefully."""
    global STOP_TRAINING
    STOP_TRAINING = True
    return "⏹️ Deteniendo entrenamiento..."

def should_stop_training():
    """Check if training should stop."""
    return STOP_TRAINING

def reset_stop_flag():
    """Reset the stop flag (call at start of training)."""
    global STOP_TRAINING
    STOP_TRAINING = False

# Hall of Shame: Rolling buffer of recent failures
# Format: {'time': str, 'target': str, 'predicted': str, 'loss': float, 'stage': str}
TRAINING_ERRORS = deque(maxlen=20)

def add_training_error(target, predicted, loss, stage):
    """Add an error to the Hall of Shame."""
    TRAINING_ERRORS.append({
        'time': time.strftime("%H:%M:%S"),
        'target': target,
        'predicted': predicted,
        'loss': float(loss),
        'stage': stage
    })

def get_training_errors():
    """Get list of errors for the UI."""
    return list(TRAINING_ERRORS)

MODEL_PRESETS = {
    'lite': {'d_model': 128, 'nhead': 4, 'num_encoder_layers': 3, 'num_decoder_layers': 3},
    'pro': {'d_model': 256, 'nhead': 8, 'num_encoder_layers': 6, 'num_decoder_layers': 6}
}
CURRENT_PRESET = 'lite'

def get_device(force_cpu=False):
    """Get the best available device (CUDA > MPS > CPU)."""
    if force_cpu:
        return torch.device("cpu")
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def set_device(use_gpu=True):
    """Set the device (GPU or CPU)."""
    global DEVICE, MODEL
    new_device = get_device(force_cpu=not use_gpu)
    
    if MODEL is not None and DEVICE != new_device:
        MODEL = MODEL.to(new_device)
    
    DEVICE = new_device
    return get_device_info()

def get_device_info():
    """Get device info string."""
    global DEVICE
    if DEVICE is None:
        DEVICE = get_device()
    
    if DEVICE.type == "cuda":
        return f"CUDA ({torch.cuda.get_device_name(0)})"
    elif DEVICE.type == "mps":
        return "MPS (Apple Silicon)"
    else:
        return "CPU"

def load_model(force_reload=False, preset_name=None):
    """Load or reload the model."""
    global MODEL, DEVICE, CURRENT_PRESET
    
    if preset_name:
        CURRENT_PRESET = preset_name
    
    if DEVICE is None:
        DEVICE = get_device()
    
    VOCAB_SIZE = len(VOCABULARY)
    config = MODEL_PRESETS[CURRENT_PRESET]
    
    print(f"Loading Model [{CURRENT_PRESET.upper()}]...")
    MODEL = AlphaSymbolicModel(
        vocab_size=VOCAB_SIZE + 1, 
        d_model=config['d_model'], 
        nhead=config['nhead'],
        num_encoder_layers=config['num_encoder_layers'], 
        num_decoder_layers=config['num_decoder_layers']
    ).to(DEVICE)
    
    filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
    status = f"Nuevo modelo ({CURRENT_PRESET})" # Default status
    
    if os.path.exists(filename):
        try:
            state_dict = torch.load(filename, map_location=DEVICE, weights_only=True)
            
            # Check for NaNs
            has_nans = False
            for k, v in state_dict.items():
                if torch.isnan(v).any() or torch.isinf(v).any():
                    has_nans = True
                    break
            
            if has_nans:
                print(f"⚠️ Modelo corrupto detectado (NaNs) en {filename}. Eliminando y esperando reinicio.")
                try:
                    os.remove(filename)
                    print("✅ Archivo corrupto eliminado.")
                except OSError as e:
                    print(f"Error al eliminar archivo: {e}")
                status = "⚠️ Modelo corrupto eliminado y reiniciado"
            else:
                MODEL.load_state_dict(state_dict)
                MODEL.eval()
                status = f"Modelo cargado ({CURRENT_PRESET})"
                
        except RuntimeError as e:
            print(f"⚠️ Error de compatibilidad ({e}). Iniciando modelo fresco.")
            status = f"Nuevo modelo ({CURRENT_PRESET})"
        except Exception as e:
            print(f"Error cargando: {e}")
            status = "Sin modelo pre-entrenado"
    
    return status, get_device_info()

def get_model():
    """Get the current model, loading if needed."""
    global MODEL, DEVICE
    if MODEL is None:
        load_model()
    return MODEL, DEVICE

def save_model():
    """Save the current model."""
    global MODEL, CURRENT_PRESET
    if MODEL is not None:
        filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
        torch.save(MODEL.state_dict(), filename)


In [None]:
%%writefile AlphaSymbolic/ui/app_search.py
"""
Search/Solve functions for AlphaSymbolic Gradio App.
Supports both Beam Search and MCTS.
"""
import numpy as np
import matplotlib.pyplot as plt
import time
import gradio as gr

from core.grammar import ExpressionTree
from search.beam_search import BeamSearch
from search.mcts import MCTS
from search.hybrid_search import hybrid_solve
from utils.simplify import simplify_tree
from search.pareto import ParetoFront
from utils.detect_pattern import detect_pattern
from utils.optimize_constants import optimize_constants, substitute_constants
from ui.app_core import get_model


def parse_data(x_str, y_str):
    """Parse comma-separated input strings."""
    try:
        x = np.array([float(v.strip()) for v in x_str.split(',')], dtype=np.float64)
        y = np.array([float(v.strip()) for v in y_str.split(',')], dtype=np.float64)
        if len(x) != len(y):
            return None, None, "Error: X e Y deben tener igual longitud"
        return x, y, None
    except Exception as e:
        return None, None, f"Error: {str(e)}"


def create_fit_plot(x, y, y_pred, formula):
    """Create a plot showing data vs prediction."""
    fig, ax = plt.subplots(figsize=(8, 5), facecolor='#1a1a2e')
    ax.set_facecolor('#1a1a2e')
    
    ax.scatter(x, y, color='#00d4ff', s=100, label='Datos Reales', zorder=3, edgecolors='white', linewidth=1)
    
    sort_idx = np.argsort(x)
    ax.plot(x[sort_idx], y_pred[sort_idx], color='#ff6b6b', linewidth=3, label='Prediccion', zorder=2)
    
    ax.set_xlabel('X', color='white', fontsize=12)
    ax.set_ylabel('Y', color='white', fontsize=12)
    ax.set_title('Ajuste de la Formula', color='white', fontsize=14, fontweight='bold')
    ax.legend(facecolor='#16213e', edgecolor='#00d4ff', labelcolor='white')
    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.2, color='white')
    
    for spine in ax.spines.values():
        spine.set_color('#00d4ff')
    
    plt.tight_layout()
    return fig


def solve_formula(x_str, y_str, beam_width, search_method, progress=gr.Progress()):
    """Main solving function with search method selection."""
    x, y, error = parse_data(x_str, y_str)
    if error:
        return error, None, "", "", ""
    
    MODEL, DEVICE = get_model()
    
    progress(0.1, desc=f"Analizando patron... [{DEVICE.type.upper()}]")
    pattern = detect_pattern(x, y)
    
    progress(0.3, desc=f"Buscando formulas ({search_method})... [{DEVICE.type.upper()}]")
    start_time = time.time()
    
    results = []
    
    if search_method == "Alpha-GP Hybrid":
        # Using hybrid search
        progress(0.4, desc="Fase 1: Neural Beam Search...")
        # Note: Hybrid search handles its own phases printing, but we want UI updates.
        # We pass beam_width. gp_timeout is increased to 30s to allow convergence on complex problems.
        hybrid_res = hybrid_solve(x, y, MODEL, DEVICE, beam_width=int(beam_width), gp_timeout=30)
        
        if hybrid_res:
            progress(0.9, desc="Procesando resultados GP...")
            # Convert infix string back to tokens for consistency
            tree = ExpressionTree.from_infix(hybrid_res['formula'])
            if tree.is_valid:
                 # Evaluate RMSE roughly (GP result should be good, but let's confirm)
                 # Optimization is already done by GP, but we might want to fine-tune 
                 # or at least extract constants if they are numbers in the string.
                 # The string from GP has numbers like 2.345 embedded.
                 # optimize_constants expects a tree with 'C' placeholders if we want to re-optimize.
                 # But GP output is fully instantiated.
                 # So we just evaluate.
                 
                 y_pred_check = tree.evaluate(x)
                 rmse_check = np.sqrt(np.mean((y_pred_check - y)**2))
                 
                 results = [{
                     'tokens': tree.tokens,
                     'formula': tree.get_infix(),
                     'rmse': rmse_check,
                     'constants': {} # Constants are baked into the formula string
                 }]
    
    elif search_method == "Beam Search":
        searcher = BeamSearch(MODEL, DEVICE, beam_width=int(beam_width), max_length=25)
        results = searcher.search(x, y)
    else:  # MCTS
        mcts = MCTS(MODEL, DEVICE, max_simulations=int(beam_width) * 10)
        result = mcts.search(x, y)
        if result and result.get('tokens'):
            tokens = result['tokens']
            tree = ExpressionTree(tokens)
            if tree.is_valid:
                constants, rmse = optimize_constants(tree, x, y)
                results = [{
                    'tokens': tokens,
                    'formula': tree.get_infix(),
                    'rmse': rmse,
                    'constants': constants
                }]
    
    search_time = time.time() - start_time
    
    if not results:
        return "No se encontraron formulas validas", None, "", "", ""
    
    progress(0.7, desc="Optimizando constantes...")
    pareto = ParetoFront()
    pareto.add_from_results(results)
    best = pareto.get_best_by_rmse()
    
    if not best:
        return "Error en optimizacion", None, "", "", ""
    
    progress(0.9, desc="Simplificando...")
    tree = ExpressionTree(best.tokens)
    simplified = simplify_tree(tree)
    y_pred = tree.evaluate(x, constants=best.constants)
    
    # Substitute constants for display
    substituted_formula = simplified
    if best.constants:
        try:
            positions = tree.root.get_constant_positions()
            # We use the raw infix for substitution to ensure matching C positions
            raw_infix = tree.get_infix()
            substituted_formula = substitute_constants(raw_infix, best.constants, positions)
        except:
            substituted_formula = simplified
    
    fig = create_fit_plot(x, y, y_pred, simplified)
    
    # Format results
    result_html = f"""
    <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #00d4ff;">
        <h2 style="color: #00d4ff; margin: 0; font-size: 24px;">Formula Encontrada</h2>
        <div style="background: #0f0f23; padding: 15px; border-radius: 10px; margin: 15px 0; border-left: 4px solid #ff6b6b;">
            <code style="color: #ff6b6b; font-size: 28px; font-weight: bold;">{substituted_formula}</code>
        </div>
        <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 10px;">
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">RMSE</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{best.rmse:.6f}</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Nodos</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{best.complexity}</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Tiempo</span><br>
                <span style="color: #00d4ff; font-size: 16px; font-weight: bold;">{search_time:.2f}s</span>
            </div>
            <div style="background: #0f0f23; padding: 10px; border-radius: 8px; text-align: center;">
                <span style="color: #888;">Metodo</span><br>
                <span style="color: #4ade80; font-size: 16px; font-weight: bold;">{search_method}</span>
            </div>
        </div>
        <div style="margin-top: 15px; padding: 10px; background: #0f0f23; border-radius: 8px;">
            <span style="color: #888;">Patron:</span> 
            <span style="color: #ffd93d;">{pattern['type']}</span> 
            <span style="color: #666;">({pattern['confidence']:.0%})</span>
            <span style="color: #888; margin-left: 20px;">Device:</span>
            <span style="color: #4ade80;">{DEVICE.type.upper()}</span>
        </div>
    """
    
    # Add constants if any
    # Add constants if any
    if best.constants:
        # Sort and format cleanly
        sorted_items = sorted(best.constants.items(), key=lambda x: str(x[0]))
        clean_consts = []
        for i, (k, v) in enumerate(sorted_items):
            clean_consts.append(f"C_{i+1}: {v:.4f}")
        const_str = "  |  ".join(clean_consts)
        
        result_html += f"""
        <div style="margin-top: 10px; padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid #ffd93d;">
            <span style="color: #888;">Constantes:</span>
            <span style="color: #fff; font-family: monospace; margin-left: 10px;">{const_str}</span>
        </div>
        """
        
    result_html += "</div>"
    
    # Predictions table
    pred_html = '<table style="width: 100%; border-collapse: collapse; background: #1a1a2e; border-radius: 10px; overflow: hidden;">'
    pred_html += '<tr style="background: #16213e;"><th style="padding: 10px; color: #00d4ff;">X</th><th style="color: #00d4ff;">Pred</th><th style="color: #00d4ff;">Real</th><th style="color: #00d4ff;">Delta</th></tr>'
    for i in range(min(50, len(x))):
        delta = abs(y_pred[i] - y[i])
        color = "#4ade80" if delta < 0.1 else "#fbbf24" if delta < 1 else "#ef4444"
        pred_html += f'<tr style="border-bottom: 1px solid #333;"><td style="padding: 8px; color: white; text-align: center;">{x[i]:.2f}</td><td style="color: white; text-align: center;">{y_pred[i]:.4f}</td><td style="color: white; text-align: center;">{y[i]:.4f}</td><td style="color: {color}; text-align: center; font-weight: bold;">{delta:.4f}</td></tr>'
    pred_html += '</table>'
    
    # Alternatives
    alt_html = '<div style="background: #1a1a2e; padding: 15px; border-radius: 10px;">'
    alt_html += '<h4 style="color: #00d4ff; margin-top: 0;">Alternativas</h4>'
    for i, sol in enumerate(pareto.solutions[:4]):
        alt_html += f'<div style="padding: 5px 10px; margin: 5px 0; background: #0f0f23; border-radius: 5px; border-left: 3px solid {"#00d4ff" if i == 0 else "#666"};"><code style="color: {"#ff6b6b" if i == 0 else "#888"};">{sol.formula}</code> <span style="color: #666; font-size: 12px;">RMSE: {sol.rmse:.4f}</span></div>'
    alt_html += '</div>'
    
    return result_html, fig, pred_html, alt_html, simplified


def generate_example(tipo):
    """Generate example data."""
    if tipo == "lineal":
        x = np.linspace(1, 10, 10)
        y = 2 * x + 3
    elif tipo == "cuadratico":
        x = np.linspace(-5, 5, 11)
        y = x**2 + 1
    elif tipo == "trig":
        x = np.linspace(0, 6.28, 20)
        y = np.sin(x)
    elif tipo == "exp":
        x = np.linspace(0, 5, 15)
        y = 2 * np.exp(0.5 * x)
    else:
        x = np.linspace(1, 10, 10)
        y = 2 * x + 3
    
    return ", ".join([f"{v:.2f}" for v in x]), ", ".join([f"{v:.4f}" for v in y])


In [None]:
%%writefile AlphaSymbolic/ui/app_training.py
"""
Training functions for AlphaSymbolic Gradio App.
With proper data normalization.
"""
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from collections import deque
import random
import time

from core.grammar import VOCABULARY, TOKEN_TO_ID, OPERATORS, OPERATOR_STAGES
from data.synthetic_data import DataGenerator
from ui.app_core import get_model, save_model, TRAINING_STATUS, add_training_error, should_stop_training, reset_stop_flag
from core.loss import QuantileLoss
from search.hybrid_search import hybrid_solve
from core.grammar import ExpressionTree, simplify_formula


def get_allowed_token_mask(stage, vocab_size, device):
    """
    Creates a mask tensor for token logits.
    Allowed tokens = 1.0, Disallowed = 0.0 (for multiplication mask)
    Or returns indices of allowed tokens for -inf masking.
    """
    allowed_ops = OPERATOR_STAGES.get(stage, list(OPERATORS.keys()))
    
    # All terminals are always allowed
    allowed_tokens = set(['x', 'C', '0', '1', '2', '3', '5', '10', 'pi', 'e'])
    allowed_tokens.update(allowed_ops)
    
    # Build mask
    mask = torch.zeros(vocab_size + 1, device=device)  # +1 for SOS token
    for token in allowed_tokens:
        if token in TOKEN_TO_ID:
            mask[TOKEN_TO_ID[token]] = 1.0
    mask[vocab_size] = 1.0  # SOS always allowed
    
    return mask


def normalize_batch(x_list, y_list):
    """Normalize X and Y values to prevent numerical instability."""
    normalized_x = []
    normalized_y = []
    
    for x, y in zip(x_list, y_list):
        # Normalize X to [-1, 1]
        x_min, x_max = x.min(), x.max()
        if x_max - x_min > 1e-6:
            x_norm = 2 * (x - x_min) / (x_max - x_min) - 1
        else:
            x_norm = np.zeros_like(x)
        
        # Normalize Y to [-1, 1] 
        y_min, y_max = y.min(), y.max()
        if y_max - y_min > 1e-6:
            y_norm = 2 * (y - y_min) / (y_max - y_min) - 1
        else:
            y_norm = np.zeros_like(y)
        
        normalized_x.append(x_norm)
        normalized_y.append(y_norm)
    
    return normalized_x, normalized_y


def train_basic(epochs, batch_size, point_count=10, progress=gr.Progress()):
    """Basic training with synthetic data."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(epochs), eta_min=1e-6)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        data_gen = DataGenerator(max_depth=4)
        losses = []
        
        for epoch in range(int(epochs)):
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{int(epochs)} [{DEVICE.type.upper()}]")
            
            # Mix of inverse (known formulas) + random data (AlphaTensor-style)
            half_batch = int(batch_size) // 2
            batch_inverse = data_gen.generate_inverse_batch(half_batch, point_count=int(point_count))
            batch_random = data_gen.generate_batch(int(batch_size) - half_batch, point_count=int(point_count))
            batch = batch_inverse + batch_random
            if len(batch) < 2:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            
            # Normalize data
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for i, seq in enumerate(token_lists):
                decoder_input[i, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
            
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            # Forward
            optimizer.zero_grad()
            logits, _ = MODEL(x_tensor, y_tensor, decoder_input)
            loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            # Skip if loss is NaN
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            losses.append(loss.item())
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        if not losses:
            return "Error: No se pudo calcular loss (revisar datos)", None
        
        fig = create_loss_plot(losses, "Entrenamiento Basico")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #4ade80;">
            <h2 style="color: #4ade80; margin: 0;">Entrenamiento Completado</h2>
            <p style="color: white;">Epocas: {int(epochs)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #00d4ff;">Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_curriculum(epochs, batch_size, point_count=10, progress=gr.Progress()):
    """Curriculum Learning - starts simple, increases difficulty gradually."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)  # Lower LR
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        losses = []
        
        for epoch in range(int(epochs)):
            # Curriculum: slow progression
            # Stage 1 (0-50%): depth 2-3, 80% inverse data
            # Stage 2 (50-80%): depth 3-4, 50% inverse data  
            # Stage 3 (80-100%): depth 4-5, 20% inverse data
            progress_pct = epoch / epochs
            
            if progress_pct < 0.5:
                current_depth = 2 + int(progress_pct * 2)  # 2-3
                inverse_ratio = 0.8
            elif progress_pct < 0.8:
                current_depth = 3 + int((progress_pct - 0.5) * 3.3)  # 3-4
                inverse_ratio = 0.5
            else:
                current_depth = 4 + int((progress_pct - 0.8) * 5)  # 4-5
                inverse_ratio = 0.2
            
            progress((epoch + 1) / epochs, desc=f"Epoca {epoch+1}/{int(epochs)} (prof: {current_depth}, inv: {inverse_ratio:.0%}) [{DEVICE.type.upper()}]")
            
            data_gen = DataGenerator(max_depth=current_depth)
            
            # Mix inverse + random based on curriculum stage
            n_inverse = int(batch_size * inverse_ratio)
            n_random = int(batch_size) - n_inverse
            
            batch_inverse = data_gen.generate_inverse_batch(max(1, n_inverse), point_count=int(point_count)) if n_inverse > 0 else []
            batch_random = data_gen.generate_batch(max(1, n_random), point_count=int(point_count)) if n_random > 0 else []
            batch = batch_inverse + batch_random
            if len(batch) < 2:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID[t] for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for i, seq in enumerate(token_lists):
                decoder_input[i, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
            
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            optimizer.zero_grad()
            logits, value_pred = MODEL(x_tensor, y_tensor, decoder_input)
            
            # Policy Loss
            loss_policy = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            # Value Loss
            # For supervised learning, these are "perfect" solutions, so Value Target = 1.0
            value_targets = torch.ones_like(value_pred)
            loss_value = torch.nn.functional.mse_loss(value_pred, value_targets)
            
            # Combined Loss
            loss = loss_policy + 0.5 * loss_value
            
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            losses.append(loss.item())
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        if not losses:
            return "Error: No se pudo calcular loss", None
        
        fig = create_loss_plot(losses, "Curriculum Learning")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #00d4ff;">
            <h2 style="color: #00d4ff; margin: 0;">Curriculum Learning Completado</h2>
            <p style="color: white;">Epocas: {int(epochs)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #888;">Profundidad maxima: 6 | Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_self_play(iterations, problems_per_iter, point_count=10, progress=gr.Progress()):
    """AlphaZero Self-Play loop."""
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()  # Reset stop flag at start
    
    try:
        MODEL, DEVICE = get_model()
        
        from search.mcts import MCTS
        
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)
        # Scheduler: Reduce LR when plateauing to help convergence
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15, min_lr=1e-6)
        
        # Losses for AlphaZero
        # Policy: KLDiv (comparing distributions)
        # Value: Quantile Loss (3 Quantiles)
        kl_loss = torch.nn.KLDivLoss(reduction='batchmean')
        quantile_loss_fn = QuantileLoss()
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        replay_buffer = deque(maxlen=20000)
        
        # Adaptive Curriculum State
        current_depth = 2
        data_gen = DataGenerator(max_depth=current_depth)
        
        # MCTS for A100: Increase batch size and simulations significantly
        # Adjusted for RTX 3050/i5: Batch 64 is smoother (less CPU wait)
        searcher = MCTS(MODEL, DEVICE, max_simulations=500, complexity_lambda=0.1, batch_size=64)
        
        rmses = []
        losses = []
        best_avg_rmse = float('inf')
        
        start_time = time.time()
        
        for iteration in range(int(iterations)):
            # Check for stop request
            if should_stop_training():
                print("⏹️ Training stopped by user")
                break
            # ETA Calculation
            elapsed = time.time() - start_time
            if iteration > 0:
                avg_time_per_iter = elapsed / iteration
                remaining_iters = int(iterations) - iteration
                eta_seconds = remaining_iters * avg_time_per_iter
                
                # Format ETA
                if eta_seconds > 3600:
                    eta_str = f"{eta_seconds/3600:.1f}h"
                elif eta_seconds > 60:
                    eta_str = f"{eta_seconds/60:.0f}m"
                else:
                    eta_str = f"{eta_seconds:.0f}s"
            else:
                eta_str = "Calculando..."

            # Adaptive Curriculum Check
            # Stages: 0=Arithmetic, 1=Poly, 2=Trig, 3=Adv, 4=Complex
            CURRICULUM_LEVELS = [
                {'depth': 1, 'ops': ['+', '-', '*', '/']},
                {'depth': 2, 'ops': ['+', '-', '*', '/']},
                {'depth': 3, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt']},
                {'depth': 4, 'ops': ['+', '-', '*', '/', 'pow', 'sqrt', 'sin', 'cos']},
                {'depth': 5, 'ops': None} # All
            ]
            
            # Initialize state if not present
            if 'curriculum_stage' not in locals():
                curriculum_stage = 0
            
            recent_rmse = np.mean(rmses[-20:]) if len(rmses) >= 20 else 1.0
            
            # Graduation condition: RMSE < 0.1 stable
            if len(rmses) > 20 and recent_rmse < 0.1 and curriculum_stage < len(CURRICULUM_LEVELS) - 1:
                curriculum_stage += 1
                stage_info = CURRICULUM_LEVELS[curriculum_stage]
                data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'])
                print(f"*** Curriculum Level Up! Stage {curriculum_stage} ({stage_info['depth']}, {stage_info['ops']}) ***")
                # Clear buffer to avoid training on old easy data? Maybe keep some for replay.
            
            # Ensure data_gen is initialized at start
            if iteration == 0:
                 stage_info = CURRICULUM_LEVELS[0]
                 data_gen = DataGenerator(max_depth=stage_info['depth'], allowed_operators=stage_info['ops'])

            stage_name = ["Arithmetic", "Polynomials", "Trigonometry", "Advanced", "Complex"][curriculum_stage]
            
            # Safe access to current_lr
            curr_lr_disp = optimizer.param_groups[0]['lr']
            msg = f"Iter {iteration+1}/{int(iterations)} [{stage_name}] RMSE:{recent_rmse:.3f} LR:{curr_lr_disp:.1e} | ETA: {eta_str}"
            progress((iteration + 1) / iterations, desc=msg)
            
            # Active Learning / Hard Mining Phase
            MODEL.eval()
            
            # Generate a large pool of candidates candidates to find the "hard" ones
            pool_size = problems_per_iter * 3  # Generate 3x more than we need
            candidates = data_gen.generate_inverse_batch(pool_size, point_count=int(point_count))
            
            if not candidates:
                continue
                
            # Quick forward pass to estimate difficulty (Loss)
            # We want to train on problems where the model currently FAILS (High Loss)
            hard_problems = []
            
            with torch.no_grad():
                # Process in chunks to avoid OOM
                chunk_size = 32
                for i in range(0, len(candidates), chunk_size):
                    chunk = candidates[i:i+chunk_size]
                    
                    x_list = [d['x'] for d in chunk]
                    y_list = [d['y'] for d in chunk]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in chunk]
                    max_len = max(len(s) for s in token_lists)
                    
                    # Prepare tensors
                    dec_in = torch.full((len(chunk), max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                    targets = torch.full((len(chunk), max_len + 1), -1, dtype=torch.long).to(DEVICE)
                    
                    for j, seq in enumerate(token_lists):
                        dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                        targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                        
                    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    
                    logits, _ = MODEL(x_tensor, y_tensor, dec_in)
                    
                    # Calculate loss per item
                    # CrossEntropy usually aggregates, so we use reduction='none'
                    loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
                    raw_losses = loss_f(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
                    
                    # Reshape back to [Batch, Seq] to sum/mean per sample
                    raw_losses = raw_losses.view(len(chunk), -1)
                    # Average loss per non-padded token
                    mask = (targets != -1)
                    sample_losses = (raw_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)
                    
                    for j, loss_val in enumerate(sample_losses):
                        # Store (Loss, Problem)
                        hard_problems.append((loss_val.item(), chunk[j]))
            
            # Sort by difficulty (Loss descending)
            hard_problems.sort(key=lambda x: x[0], reverse=True)
            
            # Stabilization: Mix Hardest (70%) + Random Examples (30%)
            # This prevents "Catastrophic Forgetting" of simpler patterns
            n_hard = int(problems_per_iter * 0.7)
            n_random = int(problems_per_iter) - n_hard
            
            # Top K hardest
            selected_hard = [p[1] for p in hard_problems[:n_hard]]
            
            # Random selection from the rest of the pool (to keep variety)
            remaining_pool = [p[1] for p in hard_problems[n_hard:]]
            selected_random = random.sample(remaining_pool, min(n_random, len(remaining_pool))) if remaining_pool else []
            
            selected_problems = selected_hard + selected_random
            
            avg_pool_loss = np.mean([p[0] for p in hard_problems])
            top_loss = np.mean([p[0] for p in hard_problems[:n_hard]]) if n_hard > 0 else 0
            
            print(f"Active Learning: Pool Loss {avg_pool_loss:.3f} -> Selected Mix (Hard:{top_loss:.3f})")

            # --- HALL OF SHAME CAPTURE ---
            # Capture what the model predicts for the top 3 hardest failures
            try:
                top_failures = hard_problems[:3]
                x_fail = [p[1]['x'].astype(np.float64) for p in top_failures]
                y_fail = [p[1]['y'].astype(np.float64) for p in top_failures]
                target_formulas = [p[1]['infix'] for p in top_failures]
                fail_losses = [p[0] for p in top_failures]
                
                # Simple Greedy Decode to see what it predicts
                from search.beam_search import BeamSearch
                # Use beam search with width 1 (Greedy) for speed, with curriculum mask
                bs = BeamSearch(MODEL, DEVICE, beam_width=1, max_length=20, curriculum_stage=curriculum_stage)
                
                for i in range(len(top_failures)):
                    try:
                        # Decode
                        # Enable return_partial to see what the model is thinking if it fails
                        res = bs.search(x_fail[i], y_fail[i], return_partial=True)
                        if not res:
                            pred_formula = "Search Empty (No Tokens)"
                        else:
                            pred_formula = res[0]['formula']
                            
                        # Detect Looping (e.g. "10 / / / / / /")
                        # Basic heuristic: check if last 10 chars contain > 80% same char or repeating pattern
                        if len(pred_formula) > 20:
                            # Check for repeating slashes or other single chars
                            if pred_formula.count('/') > 10 and pred_formula.endswith('/ .'): 
                                 pred_formula = pred_formula[:20] + " ... [Loop Detected]"
                            elif " / / / " in pred_formula:
                                 pred_formula = pred_formula.split(" / / / ")[0] + " ... [Loop Detected]"
                        
                        add_training_error(
                            target=target_formulas[i],
                            predicted=pred_formula,
                            loss=fail_losses[i],
                            stage=stage_name
                        )
                    except Exception as e:
                        print(f"HoS Inner Error: {e}")
                        add_training_error(
                            target=target_formulas[i],
                            predicted=f"CRASH: {str(e)[:20]}",
                            loss=fail_losses[i],
                            stage=stage_name
                        )
            except Exception as e:
                import traceback
                print(f"HoS Outer Error: {e}")
                traceback.print_exc()

            # --- MCTS SOLVE ---
            for prob in selected_problems:
                x_data = prob['x'].astype(np.float64)
                y_data = prob['y'].astype(np.float64)
                
                try:
                    # Use MCTS to find the solution (or improve upon it)
                    # For inverse problems, we KNOW the solution, but MCTS helps explore variations
                    # and generates the policy distribution we want to learn.
                    result = searcher.search(x_data, y_data)
                    
                    # 1. Store Training Examples
                    if 'root' in result:
                        examples = searcher.get_training_examples(result['root'])
                        for (tokens, policy, value) in examples:
                            replay_buffer.append({
                                'x': x_data, 'y': y_data,
                                'tokens': tokens,
                                'policy': policy,
                                'value': value
                            })
                    
                    # 2. Track Metrics
                    if result.get('tokens'):
                        rmses.append(result['rmse'])
                        
                except Exception as e:
                    print(f"Self-play error: {e}")
                    continue
            
            # Training phase
            # To saturate GPU: Increase batch size and number of updates
            if len(replay_buffer) >= 64:
                MODEL.train()
                
                # Dynamic training steps: Train more if we have more data
                # AlphaZero ratio usually high (e.g. 10 epochs on new data)
                # Here we sample from buffer.
                train_batch_size = 128
                if len(replay_buffer) < train_batch_size:
                    train_batch_size = 64
                
                # Steps: roughly cover 20% of buffer or at least 10 steps
                steps = max(10, min(50, len(replay_buffer) // train_batch_size))
                
                for _ in range(steps):
                    batch = random.sample(list(replay_buffer), min(train_batch_size, len(replay_buffer)))
                    
                    x_list = [exp['x'] for exp in batch]
                    y_list = [exp['y'] for exp in batch]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID[t] for t in exp['tokens']] for exp in batch]
                    policy_targets = [exp['policy'] for exp in batch]
                    value_targets_list = [exp['value'] for exp in batch]
                    
                    max_len = max(len(s) for s in token_lists)
                    decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
                    
                    # Policy targets (for KLDiv) and Value targets
                    policy_target_tensor = torch.tensor(np.array(policy_targets), dtype=torch.float32).to(DEVICE)
                    value_target_tensor = torch.tensor(np.array(value_targets_list), dtype=torch.float32).unsqueeze(1).to(DEVICE)
                    
                    for i, seq in enumerate(token_lists):
                        l = len(seq)
                        decoder_input[i, 1:l+1] = torch.tensor(seq, dtype=torch.long)
                    
                    x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    decoder_input = decoder_input.to(DEVICE)
                    
                    optimizer.zero_grad()
                    logits, value_pred = MODEL(x_tensor, y_tensor, decoder_input)
                    
                    # Policy Loss (KL Divergence)
                    # Get logits for the last token position of each sequence
                    last_logits = []
                    for i, seq in enumerate(token_lists):
                        idx = len(seq) # Post-padding index? No, index in padded tensor.
                        # decoder_input: [SOS, T1, T2]
                        # logits: [PredSOS, PredT1, PredT2]
                        # We want prediction AFTER T2? No.
                        # MCTS Example: State=[T1, T2]. Policy=Dist for T3.
                        # Model Input: [SOS, T1, T2]. Output Last: Dist for T3.
                        # Index is len(seq).
                        last_logits.append(logits[i, idx, :VOCAB_SIZE])
                    
                    last_logits = torch.stack(last_logits)
                    log_probs = torch.nn.functional.log_softmax(last_logits, dim=1)
                    
                    loss_policy = kl_loss(log_probs, policy_target_tensor)
                    
                    # Value Loss (Quantile)
                    loss_value = quantile_loss_fn(value_pred, value_target_tensor)
                    
                    # Total Loss
                    loss = loss_policy + loss_value 
                    
                    if not (torch.isnan(loss) or torch.isinf(loss)):
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                        optimizer.step()
                        losses.append(loss.item())
            
            # Step Scheduler based on recent Loss
            if losses:
                current_loss = np.mean(losses[-10:])
                scheduler.step(current_loss)
            
            current_lr = optimizer.param_groups[0]['lr']
            
            # Periodic save
            if (iteration + 1) % 10 == 0:
                save_model()
        
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_selfplay_plot(losses, rmses)
        
        avg_rmse = np.mean(rmses[-50:]) if rmses else 0
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #ff6b6b;">
            <h2 style="color: #ff6b6b; margin: 0;">Self-Play Completado</h2>
            <p style="color: white;">Iteraciones: {int(iterations)} | Problemas: {len(rmses)}</p>
            <p style="color: #888;">RMSE Promedio: {avg_rmse:.4f} | Dispositivo: {DEVICE.type.upper()}</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def create_loss_plot(losses, title):
    """Create a loss plot with dark theme."""
    fig, ax = plt.subplots(figsize=(8, 4), facecolor='#1a1a2e')
    ax.set_facecolor('#1a1a2e')
    ax.plot(losses, color='#00d4ff', linewidth=2)
    ax.set_xlabel('Epoca', color='white')
    ax.set_ylabel('Loss', color='white')
    ax.set_title(title, color='white', fontweight='bold')
    ax.tick_params(colors='white')
    ax.grid(True, alpha=0.2)
    for spine in ax.spines.values():
        spine.set_color('#00d4ff')
    plt.tight_layout()
    return fig


def create_selfplay_plot(losses, rmses):
    """Create dual plot for self-play results."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4), facecolor='#1a1a2e')
    
    ax1.set_facecolor('#1a1a2e')
    if losses:
        ax1.plot(losses, color='#00d4ff', linewidth=2)
    ax1.set_xlabel('Step', color='white')
    ax1.set_ylabel('Loss', color='white')
    ax1.set_title('Policy Loss', color='white', fontweight='bold')
    ax1.tick_params(colors='white')
    ax1.grid(True, alpha=0.2)
    
    ax2.set_facecolor('#1a1a2e')
    if rmses:
        ax2.plot(rmses, color='#ff6b6b', linewidth=1, alpha=0.5)
        if len(rmses) > 10:
            ma = np.convolve(rmses, np.ones(10)/10, mode='valid')
            ax2.plot(range(9, len(rmses)), ma, color='#ff6b6b', linewidth=2)
    ax2.set_xlabel('Problema', color='white')
    ax2.set_ylabel('RMSE', color='white')
    ax2.set_title('RMSE', color='white', fontweight='bold')
    ax2.tick_params(colors='white')
    ax2.grid(True, alpha=0.2)
    
    for ax in [ax1, ax2]:
        for spine in ax.spines.values():
            spine.set_color('#00d4ff')
    
    plt.tight_layout()
    return fig

def train_supervised(iterations, batch_size=128, point_count=10, progress=gr.Progress()):
    """
    Massive Supervised Pre-training (Warmup).
    Focus: Syntax, Basic Arithmetic, Overcoming "Collapse to Constant".
    Speed: High (No MCTS, just random generation + CrossEntropy).
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()  # Reset stop flag at start
    
    try:
        MODEL, DEVICE = get_model()
        
        MODEL.train()
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=1e-4, weight_decay=0.01)
        # Slower decay: T_max = iterations * 2 keeps LR higher for longer
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(iterations*2), eta_min=1e-6)
        ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        # Start extremely simple (Depth 1: x+1, x*x, etc.)
        allowed_ops = OPERATOR_STAGES[0]
        data_gen = DataGenerator(max_depth=1, allowed_operators=allowed_ops) 
        allowed_mask = get_allowed_token_mask(0, VOCAB_SIZE, DEVICE) # Stage 0 mask
        losses = []
        
        start_time = time.time()
        
        for i in range(int(iterations)):
            # Check for stop request
            if should_stop_training():
                print("⏹️ Pre-training stopped by user")
                break
            # ETA
            elapsed = time.time() - start_time
            if i > 0:
                iter_per_sec = i / elapsed
                remaining = int(iterations) - i
                eta = remaining / iter_per_sec
                eta_str = f"{eta:.0f}s"
            else:
                eta_str = "..."
                
            current_lr = optimizer.param_groups[0]['lr']
            msg = f"Iter {i+1}/{int(iterations)} Loss:{np.mean(losses[-50:]) if losses else 0:.3f} LR:{current_lr:.1e} ETA:{eta_str}"
            progress((i + 1) / iterations, desc=msg)
            
            # Generate Random Batch (High Speed)
            batch = data_gen.generate_batch(int(batch_size), point_count=int(point_count))
            
            if not batch:
                continue
            
            x_list = [d['x'] for d in batch]
            y_list = [d['y'] for d in batch]
            x_list, y_list = normalize_batch(x_list, y_list)
            
            token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in batch]
            
            max_len = max(len(s) for s in token_lists)
            decoder_input = torch.full((len(batch), max_len + 1), SOS_ID, dtype=torch.long)
            targets = torch.full((len(batch), max_len + 1), -1, dtype=torch.long)
            
            for j, seq in enumerate(token_lists):
                decoder_input[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                
            x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
            y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
            decoder_input = decoder_input.to(DEVICE)
            targets = targets.to(DEVICE)
            
            optimizer.zero_grad()
            logits, _ = MODEL(x_tensor, y_tensor, decoder_input)
            
            # Apply Stage 0 mask to bridge Pre-training with Curriculum
            # Use a more stable value (-1e4 instead of -1e9) to avoid overflow
            logits = logits + (1 - allowed_mask.view(1, 1, -1)) * -1e4
            
            loss = ce_loss(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
            
            if not (torch.isnan(loss) or torch.isinf(loss)):
                loss.backward()
                torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                losses.append(loss.item())
                
            if (i+1) % 100 == 0:
                save_model()
                
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Pre-Entrenamiento Supervisado")
        
        result = f"""
        <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); padding: 20px; border-radius: 15px; border: 2px solid #ffd93d;">
            <h2 style="color: #ffd93d; margin: 0;">Escuela Primaria (Warmup) Completada</h2>
            <p style="color: white;">Iteraciones: {int(iterations)} | Loss Final: {losses[-1]:.4f}</p>
            <p style="color: #888;">El modelo ha aprendido sintaxis basica.</p>
        </div>
        """
        return result, fig
        
    except Exception as e:
        TRAINING_STATUS["running"] = False
        return f"Error: {str(e)}", None


def train_hybrid_feedback_loop(iterations, problems_per_iter=10, gp_timeout=10, progress=gr.Progress()):
    """
    Teacher-Student Distillation Loop.
    1. Find problems where model has high loss.
    2. Use Hybrid Search (GP) to solve them.
    3. Train model on GP solutions.
    """
    global TRAINING_STATUS
    
    if TRAINING_STATUS["running"]:
        return "Entrenamiento ya en progreso", None
    
    TRAINING_STATUS["running"] = True
    reset_stop_flag()
    
    try:
        MODEL, DEVICE = get_model()
        
        optimizer = torch.optim.AdamW(MODEL.parameters(), lr=5e-5, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
        
        VOCAB_SIZE = len(VOCABULARY)
        SOS_ID = VOCAB_SIZE
        
        # Replay buffer for "Gold Standard" examples found by GP
        replay_buffer = deque(maxlen=5000)
        
        # Start with simple problems and grow
        data_gen = DataGenerator(max_depth=3)
        
        losses = []
        gp_successes = 0
        gp_attempts = 0
        
        start_time = time.time()
        
        for iteration in range(int(iterations)):
            if should_stop_training():
                print("⏹️ Feedback Loop stopped")
                break
                
            elapsed = time.time() - start_time
            # eta_str = f"{(int(iterations)-iteration) * (elapsed/(iteration+1) if iteration>0 else 0):.0f}s"
            iter_dur = elapsed/(iteration+1) if iteration > 0 else 0
            eta_seconds = (int(iterations)-iteration) * iter_dur
            eta_str = f"{eta_seconds:.0f}s"

            progress((iteration + 1) / iterations, 
                     desc=f"Iter {iteration+1}/{int(iterations)} | GP Success: {gp_successes}/{gp_attempts} | Loss: {np.mean(losses[-10:]) if losses else 0:.3f}")
            
            # --- PHASE 1: HARD MINING ---
            MODEL.eval()
            
            # Generate candidates
            pool_size = 50 
            candidates = data_gen.generate_inverse_batch(pool_size, point_count=10)
            
            hard_problems = []
            
            with torch.no_grad():
                # We want to find problems with HIGH LOSS (model failure)
                # Quick batch forward
                x_list = [d['x'] for d in candidates]
                y_list = [d['y'] for d in candidates]
                x_list, y_list = normalize_batch(x_list, y_list)
                
                token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in candidates]
                max_len = max(len(s) for s in token_lists)
                
                dec_in = torch.full((pool_size, max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                targets = torch.full((pool_size, max_len + 1), -1, dtype=torch.long).to(DEVICE)
                
                for j, seq in enumerate(token_lists):
                    dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                    targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                    
                x_tensor = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                y_tensor = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                
                logits, value_pred = MODEL(x_tensor, y_tensor, dec_in)
                
                loss_f = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
                raw_losses = loss_f(logits.view(-1, VOCAB_SIZE + 1), targets.view(-1))
                raw_losses = raw_losses.view(pool_size, -1)
                
                mask = (targets != -1)
                sample_losses = (raw_losses * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)
                
                # Filter: Keep if loss > 1.0 (arbitrary threshold for "confused")
                for j, loss_val in enumerate(sample_losses):
                    if loss_val.item() > 0.5: # Lower threshold to catch more
                        hard_problems.append(candidates[j])
            
            # Take top K hardest
            # Limit GP calls per iter to avoid slowness
            problems_to_solve = hard_problems[:int(problems_per_iter)]
            
            if not problems_to_solve:
                continue

            # --- PHASE 2: TEACHER SOLVES (GP) ---
            print(f"Iter {iteration}: Asking Teacher to solve {len(problems_to_solve)} hard problems...")
            
            for prob in problems_to_solve:
                gp_attempts += 1
                try:
                    # Run Hybrid Search (Quick Mode)
                    # We pass the model so beam search can seed the GP
                    res = hybrid_solve(
                        prob['x'], 
                        prob['y'], 
                        MODEL, 
                        DEVICE, 
                        beam_width=10,     # Faster beam
                        gp_timeout=gp_timeout,
                        gp_binary_path=None 
                    )
                    
                    if res and res.get('formula') and res.get('rmse', 1e6) < 0.01:
                        # SUCCESS!
                        gp_successes += 1
                        
                        # Parse formula to tokens
                        try:
                            # 1. Parse string to tree
                            tree = ExpressionTree.from_infix(res['formula'])
                            # 2. Get tokens
                            tokens = tree.tokens
                            
                            replay_buffer.append({
                                'x': prob['x'],
                                'y': prob['y'],
                                'tokens': tokens,
                                'source': 'GP_Teacher'
                            })
                            
                        except Exception as e:
                            print(f"Failed to tokenize GP result: {e}")
                            
                except Exception as e:
                    print(f"GP Hybrid Error: {e}")
                    
            # --- PHASE 3: STUDENT TRAINS (NN) ---
            if len(replay_buffer) > 10:
                MODEL.train()
                # Train on batch from buffer
                batch_size_train = min(len(replay_buffer), 64)
                
                # Multiple steps to enforce learning
                steps = 5
                
                for _ in range(steps):
                    batch = random.sample(list(replay_buffer), batch_size_train)
                    
                    x_list = [d['x'] for d in batch]
                    y_list = [d['y'] for d in batch]
                    x_list, y_list = normalize_batch(x_list, y_list)
                    
                    token_lists = [[TOKEN_TO_ID.get(t, TOKEN_TO_ID['C']) for t in d['tokens']] for d in batch]
                    max_len = max(len(s) for s in token_lists)
                    
                    dec_in = torch.full((batch_size_train, max_len + 1), SOS_ID, dtype=torch.long).to(DEVICE)
                    targets = torch.full((batch_size_train, max_len + 1), -1, dtype=torch.long).to(DEVICE)
                    
                    for j, seq in enumerate(token_lists):
                        dec_in[j, 1:len(seq)+1] = torch.tensor(seq, dtype=torch.long)
                        targets[j, :len(seq)] = torch.tensor(seq, dtype=torch.long)
                        
                    x_t = torch.tensor(np.array(x_list), dtype=torch.float32).to(DEVICE)
                    y_t = torch.tensor(np.array(y_list), dtype=torch.float32).to(DEVICE)
                    dec_in = dec_in.to(DEVICE)
                    targets = targets.to(DEVICE)
                    
                    optimizer.zero_grad()
                    logits, value_pred = MODEL(x_t, y_t, dec_in)
                    
                    # Policy Loss only (Standard Supervised)
                    # We trust the GP solution is "Correct" (Value=1.0)
                    loss_ce = torch.nn.CrossEntropyLoss(ignore_index=-1)(logits.view(-1, VOCAB_SIZE+1), targets.view(-1))
                    
                    # Value Loss
                    value_targets = torch.ones_like(value_pred) # GP solutions are always valid
                    loss_val = torch.nn.functional.mse_loss(value_pred, value_targets)
                    
                    loss = loss_ce + 0.1 * loss_val
                    
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
                    optimizer.step()
                    
                    losses.append(loss.item())
                    
                scheduler.step(np.mean(losses[-10:]))
                
            if (iteration + 1) % 5 == 0:
                save_model()
                
        save_model()
        MODEL.eval()
        TRAINING_STATUS["running"] = False
        
        fig = create_loss_plot(losses, "Feedback Loop Loss")
        
        result_html = f"""
        <div style="background: linear-gradient(135deg, #2c3e50 0%, #000000 100%); padding: 20px; border-radius: 15px; border: 2px solid #f1c40f;">
            <h2 style="color: #f1c40f; margin: 0;">Feedback Loop Completado</h2>
            <p style="color: white;">Iteraciones: {iterations} | GP Success: {gp_successes}/{gp_attempts}</p>
            <p style="color: #bbb;">Nuevos Ejemplos Generados: {len(replay_buffer)}</p>
        </div>
        """
        return result_html, fig

    except Exception as e:
        TRAINING_STATUS["running"] = False
        import traceback
        traceback.print_exc()
        return f"Error CRITICO: {str(e)}", None


In [None]:
%%writefile AlphaSymbolic/ui/app_benchmark.py
import gradio as gr
from utils.benchmark_comparison import run_comparison_benchmark
from ui.app_core import get_model, DEVICE

def get_benchmark_tab():
    with gr.Tab("🥇 Benchmark (IQ Test)"):
        gr.Markdown("### Evaluar Inteligencia del Modelo (Comparativa)")
        gr.Markdown("Ejecuta una batería de **10 problemas estándar** comparando diferentes métodos de búsqueda.")
        
        with gr.Row():
            methods_chk = gr.CheckboxGroup(
                choices=["beam", "mcts", "hybrid"], 
                value=["hybrid"], 
                label="Métodos a Evaluar",
                info="Selecciona uno o más métodos para comparar."
            )
            timeout_slider = gr.Slider(
                minimum=5, 
                maximum=60, 
                value=30, 
                step=5, 
                label="Timeout GP (s)", 
                info="Tiempo máximo para Beta-GP por problema."
            )
        
        run_btn = gr.Button("🚀 Iniciar Benchmark Comparativo", variant="primary")
        
        progress_bar = gr.HTML("")
        
        # Area de resultados
        summary_html = gr.HTML("Resultados aparecerán aquí...")
        
        results_df = gr.Dataframe(
            headers=["Problema", "Nivel", "Método", "Formula", "RMSE", "Tiempo", "Estado"],
            label="Resultados Detallados",
            interactive=False
        )
        
        def run_bench(selected_methods, gp_timeout, progress=gr.Progress()):
            model_obj, device_obj = get_model()
            if not model_obj:
                return "<div>⚠️ Error: Modelo no cargado. Ve a la pestaña 'Config' y carga un modelo.</div>", None, []
            
            if not selected_methods:
                return "<div>⚠️ Error: Selecciona al menos un método.</div>", None, []
                
            progress(0, desc="Iniciando Benchmark...")
            
            # Run comparison
            try:
                result_data = run_comparison_benchmark(
                    model_obj, 
                    device_obj, 
                    methods=selected_methods,
                    gp_timeout=gp_timeout,
                    beam_width=50,
                    progress_callback=lambda p, desc: progress(p, desc=desc)
                )
            except Exception as e:
                import traceback
                traceback.print_exc()
                return f"<div>❌ Error en Benchmark: {e}</div>", None, []
            
            results = result_data['results']
            summary_dict = result_data['summary']
            
            # Format dataframe
            rows = []
            for r in results:
                status_icon = "✅" if r['success'] else "❌"
                rmse_val = f"{r['rmse']:.5f}" if r['rmse'] < 1e6 else "> 10^6"
                rows.append([
                    r['problem_name'],
                    r['level'],
                    r['method'].upper(),
                    r['formula'],
                    rmse_val,
                    f"{r['time']:.2f}s",
                    status_icon
                ])
            
            # Generate HTML Summary
            html_content = "<div style='display: flex; gap: 20px; flex-wrap: wrap; justify-content: center;'>"
            
            # Determine winner if multiple methods
            winner_method = None
            if len(selected_methods) > 1:
                winner_method = max(summary_dict.items(), key=lambda x: (x[1]['solved'], -x[1]['avg_rmse']))[0]
            
            for method, stats in summary_dict.items():
                is_winner = (method == winner_method)
                border_color = "#4CAF50" if is_winner else ("#FF9800" if stats['score'] > 50 else "#F44336")
                bg_color = "#1e1e2f"
                if is_winner:
                    bg_color = "#1b3a24" # Dark green tint for winner
                    
                trophy = "🏆 GANADOR" if is_winner else ""
                
                html_content += f"""
                <div style="background: {bg_color}; padding: 15px; border-radius: 10px; border: 2px solid {border_color}; min-width: 200px; text-align: center;">
                    <h2 style="color: {border_color}; margin: 0 0 10px 0;">{method.upper()} {trophy}</h2>
                    <div style="font-size: 24px; font-weight: bold; margin-bottom: 5px;">{stats['solved']} / {stats['total']}</div>
                    <div style="color: #ccc; font-size: 14px;">Resueltos</div>
                    <hr style="border-color: #444; margin: 10px 0;">
                    <div style="font-size: 14px;">Nota: <b>{stats['score']:.1f}%</b></div>
                    <div style="font-size: 14px;">Tiempo Avg: <b>{stats['avg_time']:.2f}s</b></div>
                </div>
                """
            html_content += "</div>"
            
            return html_content, rows
            
        run_btn.click(run_bench, inputs=[methods_chk, timeout_slider], outputs=[summary_html, results_df])


In [None]:
%%writefile AlphaSymbolic/ui/__init__.py


In [None]:
%%writefile AlphaSymbolic/utils/optimize_constants.py
"""
Constant Optimization Module for AlphaSymbolic.
Uses scipy.optimize to find optimal values for 'C' placeholders.
"""
import numpy as np
from scipy.optimize import minimize
from core.grammar import ExpressionTree

def optimize_constants(tree, x_data, y_data, method='L-BFGS-B'):
    """
    Given an ExpressionTree with 'C' placeholders, find optimal constant values.
    
    Args:
        tree: ExpressionTree object
        x_data: numpy array of x values
        y_data: numpy array of target y values
        method: optimization method ('L-BFGS-B', 'SLSQP', 'Nelder-Mead')
        
    Returns:
        dict: mapping of path tuples to optimized constant values
        float: final RMSE
    """
    if not tree.is_valid:
        return {}, float('inf')
    
    # Get positions of all constants
    positions = tree.root.get_constant_positions()
    n_constants = len(positions)
    
    if n_constants == 0:
        # No constants to optimize, just evaluate
        y_pred = tree.evaluate(x_data)
        mse = np.mean((y_pred - y_data)**2)
        return {}, np.sqrt(mse)
    
    def objective(params):
        """Objective function: RMSE given constant values."""
        # Build constants dict
        constants = {tuple(pos): params[i] for i, pos in enumerate(positions)}
        
        # Evaluate
        y_pred = tree.evaluate(x_data, constants=constants)
        
        # Handle invalid predictions
        if np.any(np.isnan(y_pred)) or np.any(np.isinf(y_pred)):
            return 1e10
        
        if not np.all(np.isfinite(y_pred)):
            return 1e9
        
        # Clip huge values to prevent overflow in MSE
        y_pred = np.clip(y_pred, -1e9, 1e9)
        
        mse = np.mean((y_pred - y_data)**2)
        return mse
    
    # Initial guess: all 1s
    x0 = np.ones(n_constants)
    
    # Bounds: reasonable range for constants
    bounds = [(-1000, 1000)] * n_constants
    
    try:
        result = minimize(
            objective,
            x0,
            method=method,
            bounds=bounds if method in ['L-BFGS-B', 'SLSQP'] else None,
            options={'maxiter': 1000, 'disp': False}
        )
        
        # Build final constants dict
        optimized_constants = {tuple(pos): result.x[i] for i, pos in enumerate(positions)}
        final_rmse = np.sqrt(result.fun) if result.fun > 0 else 0.0
        
        return optimized_constants, final_rmse
        
    except Exception as e:
        return {}, float('inf')

def substitute_constants(infix_str, constants_dict, positions):
    """
    Replace 'C' in the infix string with optimized values.
    Simple approach: replace each C with optimized value.
    """
    # For proper substitution, we'd need to track positions properly
    # This is a simplified version that replaces all C with the first constant
    result = infix_str
    for i, pos in enumerate(positions):
        if tuple(pos) in constants_dict:
            val = constants_dict[tuple(pos)]
            # Format nicely
            if abs(val - round(val)) < 1e-6:
                val_str = str(int(round(val)))
            else:
                val_str = f"{val:.4f}"
            # Replace first occurrence of C
            result = result.replace('C', val_str, 1)
    return result


# Quick test
if __name__ == "__main__":
    # Test: C * x + C should be optimized to fit y = 2*x + 3
    x_test = np.array([1, 2, 3, 4, 5], dtype=np.float64)
    y_test = 2 * x_test + 3  # y = 2x + 3
    
    tokens = ['+', '*', 'C', 'x', 'C']  # C*x + C
    tree = ExpressionTree(tokens)
    
    print(f"Formula structure: {tree.get_infix()}")
    print(f"Target: y = 2x + 3")
    
    constants, rmse = optimize_constants(tree, x_test, y_test)
    print(f"Optimized constants: {constants}")
    print(f"Final RMSE: {rmse:.6f}")
    
    # Verify
    y_pred = tree.evaluate(x_test, constants=constants)
    print(f"Predictions: {y_pred}")
    print(f"Targets: {y_test}")


In [None]:
%%writefile AlphaSymbolic/utils/detect_pattern.py
"""
Target Pattern Detection for AlphaSymbolic.
Analyzes target Y values to detect patterns (polynomial, exponential, periodic, etc.)
and suggests initial search biases.
"""
import numpy as np
from scipy import stats
from scipy.fft import fft
from core.grammar import ExpressionTree

def detect_pattern(x_values, y_values):
    """
    Analyze (x, y) data to detect patterns.
    Returns a dict with pattern type probabilities and suggested operators.
    """
    x = np.array(x_values, dtype=np.float64)
    y = np.array(y_values, dtype=np.float64)
    
    results = {
        'type': 'unknown',
        'confidence': 0.0,
        'suggested_ops': [],
        'details': {}
    }
    
    if len(x) < 3:
        return results
    
    scores = {}
    
    # 1. Check for linear pattern (y = ax + b)
    if len(x) >= 2:
        slope, intercept, r_value, _, _ = stats.linregress(x, y)
        scores['linear'] = r_value ** 2
        results['details']['linear'] = {
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value ** 2
        }
    
    # 2. Check for quadratic pattern (y = ax^2 + bx + c)
    if len(x) >= 3:
        try:
            coeffs = np.polyfit(x, y, 2)
            y_pred = np.polyval(coeffs, x)
            ss_res = np.sum((y - y_pred) ** 2)
            ss_tot = np.sum((y - np.mean(y)) ** 2)
            r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
            scores['quadratic'] = r2
            results['details']['quadratic'] = {
                'coefficients': coeffs.tolist(),
                'r_squared': r2
            }
        except:
            pass
    
    # 3. Check for exponential pattern (y = a * e^(bx))
    if np.all(y > 0):  # Exponential only for positive y
        try:
            log_y = np.log(y)
            slope, intercept, r_value, _, _ = stats.linregress(x, log_y)
            scores['exponential'] = r_value ** 2
            results['details']['exponential'] = {
                'a': np.exp(intercept),
                'b': slope,
                'r_squared': r_value ** 2
            }
        except:
            pass
    
    # 4. Check for periodic/sinusoidal pattern
    if len(y) >= 4:
        try:
            # Simple FFT analysis
            y_centered = y - np.mean(y)
            fft_vals = np.abs(fft(y_centered))
            
            # Check if there's a dominant frequency
            if len(fft_vals) > 1:
                max_idx = np.argmax(fft_vals[1:len(fft_vals)//2]) + 1
                max_power = fft_vals[max_idx]
                total_power = np.sum(fft_vals[1:len(fft_vals)//2])
                
                if total_power > 0:
                    periodicity = max_power / total_power
                    scores['periodic'] = periodicity
                    results['details']['periodic'] = {
                        'dominant_freq_idx': int(max_idx),
                        'periodicity_score': periodicity
                    }
        except:
            pass
    
    # 5. Check for power law (y = a * x^b)
    if np.all(x > 0) and np.all(y > 0):
        try:
            log_x = np.log(x)
            log_y = np.log(y)
            slope, intercept, r_value, _, _ = stats.linregress(log_x, log_y)
            scores['power'] = r_value ** 2
            results['details']['power'] = {
                'a': np.exp(intercept),
                'b': slope,
                'r_squared': r_value ** 2
            }
        except:
            pass
    
    # 6. Check for factorial/gamma pattern (for integer-like x)
    if np.all(x > 0) and np.all(x == np.floor(x)):
        try:
            from scipy.special import gamma
            x_int = x.astype(int)
            y_gamma = gamma(x_int + 1)  # gamma(n+1) = n!
            
            # Simple linear fit between y and gamma
            if not np.any(np.isinf(y_gamma)):
                slope, intercept, r_value, _, _ = stats.linregress(y_gamma, y)
                scores['factorial'] = r_value ** 2
                results['details']['factorial'] = {
                    'r_squared': r_value ** 2
                }
        except:
            pass
    
    # Determine best pattern
    if scores:
        best_pattern = max(scores.items(), key=lambda x: x[1])
        results['type'] = best_pattern[0]
        results['confidence'] = best_pattern[1]
        
        # Suggest operators based on pattern
        op_suggestions = {
            'linear': ['+', '-', '*', 'x', 'C'],
            'quadratic': ['pow', '+', '*', 'x', 'C', '2'],
            'exponential': ['exp', '*', '+', 'x', 'C'],
            'periodic': ['sin', 'cos', '*', '+', 'x', 'C'],
            'power': ['pow', '*', 'x', 'C'],
            'factorial': ['gamma', '*', '+', 'x', 'C']
        }
        results['suggested_ops'] = op_suggestions.get(best_pattern[0], [])
    
    return results


def summarize_pattern(result):
    """Pretty-print pattern detection result."""
    print(f"\n=== Pattern Detection ===")
    print(f"Detected Type: {result['type']} (confidence: {result['confidence']:.2%})")
    print(f"Suggested Operators: {', '.join(result['suggested_ops'])}")
    
    if result['type'] in result['details']:
        print(f"Details: {result['details'][result['type']]}")


if __name__ == "__main__":
    # Test with different patterns
    
    # Linear: y = 2x + 3
    print("\n--- Test: Linear ---")
    x1 = np.linspace(0, 10, 20)
    y1 = 2 * x1 + 3 + np.random.normal(0, 0.1, 20)
    result1 = detect_pattern(x1, y1)
    summarize_pattern(result1)
    
    # Quadratic: y = x^2 + 1
    print("\n--- Test: Quadratic ---")
    x2 = np.linspace(-5, 5, 20)
    y2 = x2**2 + 1
    result2 = detect_pattern(x2, y2)
    summarize_pattern(result2)
    
    # Exponential: y = 2 * e^(0.5x)
    print("\n--- Test: Exponential ---")
    x3 = np.linspace(0, 5, 20)
    y3 = 2 * np.exp(0.5 * x3)
    result3 = detect_pattern(x3, y3)
    summarize_pattern(result3)
    
    # Periodic: y = sin(x)
    print("\n--- Test: Periodic ---")
    x4 = np.linspace(0, 4*np.pi, 50)
    y4 = np.sin(x4)
    result4 = detect_pattern(x4, y4)
    summarize_pattern(result4)


In [None]:
%%writefile AlphaSymbolic/utils/benchmark_runner.py
import torch
import numpy as np
import time
import traceback
from search.mcts import MCTS
from data.benchmark_data import BENCHMARK_SUITE, get_benchmark_data
from utils.optimize_constants import optimize_constants

def run_benchmark_suite(model, device, progress_callback=None):
    """
    Runs the full benchmark suite.
    Args:
        model: Loaded AlphaSymbolic model
        device: Torch device
        progress_callback: Function(float, string) to update UI
        
    Returns:
        results: List of result dicts
        summary: Dict with aggregated stats
    """
    results = []
    
    # Configure MCTS for benchmark (balanced speed/accuracy)
    # 500 simulations is decent for benchmarking
    mcts = MCTS(model, device, max_simulations=500, batch_size=32)
    
    total = len(BENCHMARK_SUITE)
    solved_count = 0
    
    for i, problem in enumerate(BENCHMARK_SUITE):
        if progress_callback:
            progress_callback(i / total, f"Testing: {problem['name']}...")
            
        x, y, _ = get_benchmark_data(problem['id'])
        
        start_time = time.time()
        
        # Run Search
        try:
            search_result = mcts.search(x, y)
             # Determine success
            # Success threshold: RMSE < 0.01 (or 1% relative error)
            rmse = search_result['rmse']
            is_solved = rmse < 0.05 # Looser threshold for general regression
            
            # Special check for exact integer symbolic match? No, RMSE is ground truth.
            
            elapsed = time.time() - start_time
            
            if is_solved:
                solved_count += 1
                status = "✅ SOLVED"
            else:
                status = "❌ FAILED"
                
            results.append({
                'id': problem['id'],
                'name': problem['name'],
                'level': problem['level'],
                'rmse': rmse,
                'time': elapsed,
                'status': status,
                'found_formula': search_result.get('formula', '???'),
                'is_solved': is_solved
            })
            
        except Exception as e:
            print(f"Error in benchmark {problem['name']}:")
            traceback.print_exc()
            results.append({
                'id': problem['id'],
                'name': problem['name'],
                'level': problem['level'],
                'rmse': 1e9,
                'time': 0,
                'status': "⚠️ ERROR",
                'found_formula': "Error",
                'is_solved': False
            })

    # Summary
    if progress_callback:
        progress_callback(1.0, "Done!")
        
    score = (solved_count / total) * 100
    summary = {
        'total': total,
        'solved': solved_count,
        'score': score,
        'avg_time': np.mean([r['time'] for r in results]) if results else 0
    }
    
    return results, summary


In [None]:
%%writefile AlphaSymbolic/utils/benchmark_comparison.py
"""
Comparative Benchmark: Beam Search vs MCTS vs Alpha-GP Hybrid
Runs all three search methods on the standard benchmark suite and compares performance.
"""
import torch
import numpy as np
import time
import traceback
from typing import List, Dict, Callable, Optional

from search.mcts import MCTS
from search.beam_search import BeamSearch
from search.hybrid_search import hybrid_solve
from data.benchmark_data import BENCHMARK_SUITE, get_benchmark_data
from core.grammar import ExpressionTree
from utils.optimize_constants import optimize_constants


def run_single_problem(
    x: np.ndarray, 
    y: np.ndarray, 
    method: str, 
    model, 
    device,
    timeout_sec: int = 30,
    beam_width: int = 50
) -> Dict:
    """
    Runs a single search method on a single problem.
    
    Returns:
        dict with keys: formula, rmse, time, success
    """
    start_time = time.time()
    
    try:
        if method == "beam":
            searcher = BeamSearch(model, device, beam_width=beam_width)
            # BeamSearch expects list-like input and returns a list of results sorted by RMSE
            results_list = searcher.search(x.tolist(), y.tolist())
            elapsed = time.time() - start_time
            if results_list and len(results_list) > 0:
                result = results_list[0]  # Best result (sorted by RMSE)
                return {
                    'formula': result.get('formula', 'N/A'),
                    'rmse': result.get('rmse', 1e9),
                    'time': elapsed,
                    'success': result.get('rmse', 1e9) < 0.05
                }
            else:
                return {'formula': 'No Result', 'rmse': 1e9, 'time': elapsed, 'success': False}
            
        elif method == "mcts":
            mcts = MCTS(model, device, max_simulations=500, batch_size=32)
            # MCTS expects list-like input 
            result = mcts.search(x.tolist(), y.tolist())
            elapsed = time.time() - start_time
            return {
                'formula': result.get('formula', 'N/A'),
                'rmse': result.get('rmse', 1e9),
                'time': elapsed,
                'success': result.get('rmse', 1e9) < 0.05
            }
            
        elif method == "hybrid":
            result = hybrid_solve(
                model=model,
                device=device,
                x_values=x.tolist(),
                y_values=y.tolist(),
                beam_width=beam_width,
                gp_timeout=timeout_sec
            )
            elapsed = time.time() - start_time
            
            if result['formula']:
                # Evaluate RMSE for hybrid result
                try:
                    tree = ExpressionTree.from_infix(result['formula'])
                    if tree.is_valid:
                        preds = tree.evaluate(x)
                        rmse = np.sqrt(np.mean((preds - y) ** 2))
                    else:
                        rmse = 1e9
                except:
                    rmse = 1e9
            else:
                rmse = 1e9
                
            return {
                'formula': result.get('formula', 'N/A') or 'Failed',
                'rmse': rmse,
                'time': elapsed,
                'success': rmse < 0.05
            }
        else:
            return {'formula': 'Unknown Method', 'rmse': 1e9, 'time': 0, 'success': False}
            
    except Exception as e:
        print(f"[ERROR] Method {method} failed: {e}")
        traceback.print_exc()
        return {'formula': 'Error', 'rmse': 1e9, 'time': time.time() - start_time, 'success': False}


def run_comparison_benchmark(
    model, 
    device, 
    methods: List[str] = ["beam", "mcts", "hybrid"],
    gp_timeout: int = 30,
    beam_width: int = 50,
    progress_callback: Optional[Callable] = None
) -> Dict:
    """
    Runs all methods on all benchmark problems.
    
    Returns:
        Dict with 'results' (per-problem-per-method) and 'summary' (aggregated stats)
    """
    results = []
    method_stats = {m: {'solved': 0, 'total_time': 0, 'total_rmse': 0} for m in methods}
    
    total_steps = len(BENCHMARK_SUITE) * len(methods)
    current_step = 0
    
    for problem in BENCHMARK_SUITE:
        x, y, _ = get_benchmark_data(problem['id'])
        
        for method in methods:
            current_step += 1
            
            if progress_callback:
                progress_callback(
                    current_step / total_steps, 
                    f"[{method.upper()}] {problem['name']}..."
                )
            
            result = run_single_problem(x, y, method, model, device, gp_timeout, beam_width)
            
            results.append({
                'problem_id': problem['id'],
                'problem_name': problem['name'],
                'level': problem['level'],
                'method': method,
                'formula': result['formula'],
                'rmse': result['rmse'],
                'time': result['time'],
                'success': result['success']
            })
            
            # Update stats
            method_stats[method]['total_time'] += result['time']
            method_stats[method]['total_rmse'] += result['rmse'] if result['rmse'] < 1e6 else 0
            if result['success']:
                method_stats[method]['solved'] += 1
    
    # Compute summary
    num_problems = len(BENCHMARK_SUITE)
    summary = {}
    for method in methods:
        stats = method_stats[method]
        summary[method] = {
            'solved': stats['solved'],
            'total': num_problems,
            'score': (stats['solved'] / num_problems) * 100,
            'avg_time': stats['total_time'] / num_problems,
            'avg_rmse': stats['total_rmse'] / num_problems
        }
    
    if progress_callback:
        progress_callback(1.0, "Benchmark Complete!")
    
    return {'results': results, 'summary': summary}


def format_comparison_table(results: List[Dict]) -> str:
    """
    Formats the results as a human-readable table.
    """
    # Group by problem
    problems = {}
    for r in results:
        pid = r['problem_id']
        if pid not in problems:
            problems[pid] = {'name': r['problem_name'], 'level': r['level'], 'methods': {}}
        problems[pid]['methods'][r['method']] = {
            'rmse': r['rmse'],
            'time': r['time'],
            'success': r['success'],
            'formula': r['formula']
        }
    
    output = []
    output.append("=" * 100)
    output.append(f"{'Problem':<25} | {'Method':<8} | {'RMSE':<12} | {'Time':<8} | {'Status':<10} | Formula")
    output.append("=" * 100)
    
    for pid, pdata in problems.items():
        name = pdata['name'][:24]
        for method, mdata in pdata['methods'].items():
            rmse_str = f"{mdata['rmse']:.6f}" if mdata['rmse'] < 1e6 else "FAILED"
            time_str = f"{mdata['time']:.2f}s"
            status = "[OK]" if mdata['success'] else "[FAIL]"
            formula = mdata['formula'][:40] if mdata['formula'] else "N/A"
            output.append(f"{name:<25} | {method:<8} | {rmse_str:<12} | {time_str:<8} | {status:<10} | {formula}")
        output.append("-" * 100)
    
    return "\n".join(output)


def print_summary(summary: Dict):
    """
    Prints a formatted summary comparison.
    """
    print("\n" + "=" * 60)
    print("BENCHMARK SUMMARY - Method Comparison")
    print("=" * 60)
    print(f"{'Method':<12} | {'Solved':<10} | {'Score':<10} | {'Avg Time':<10} | {'Avg RMSE':<12}")
    print("-" * 60)
    
    for method, stats in summary.items():
        solved_str = f"{stats['solved']}/{stats['total']}"
        score_str = f"{stats['score']:.1f}%"
        time_str = f"{stats['avg_time']:.2f}s"
        rmse_str = f"{stats['avg_rmse']:.6f}"
        print(f"{method.upper():<12} | {solved_str:<10} | {score_str:<10} | {time_str:<10} | {rmse_str:<12}")
    
    print("=" * 60)
    
    # Determine winner
    best_method = max(summary.items(), key=lambda x: (x[1]['solved'], -x[1]['avg_rmse']))
    print(f"\n*** WINNER: {best_method[0].upper()} with {best_method[1]['solved']}/{best_method[1]['total']} problems solved! ***")


if __name__ == "__main__":
    # Standalone test
    import sys
    sys.path.insert(0, '.')
    
    from ui.app_core import load_model, get_model
    
    print("Loading model...")
    load_model()
    model, device = get_model()
    
    if model is None:
        print("Error: No model loaded!")
        exit(1)
    
    print("Running comparison benchmark...")
    result = run_comparison_benchmark(
        model, 
        device, 
        methods=["beam", "mcts", "hybrid"],
        gp_timeout=30,
        beam_width=50
    )
    
    print(format_comparison_table(result['results']))
    print_summary(result['summary'])


In [None]:
%%writefile AlphaSymbolic/utils/simplify.py
"""
Algebraic Simplification Module for AlphaSymbolic.
Uses SymPy for symbolic math simplification.
"""
import sympy as sp
from core.grammar import Node, ExpressionTree, OPERATORS

# SymPy symbol for x
x_sym = sp.Symbol('x')

def tree_to_sympy(node):
    """Convert an ExpressionTree Node to a SymPy expression."""
    if node is None:
        return sp.Integer(0)
    
    val = node.value
    
    # Terminals
    if val == 'x':
        return x_sym
    if val == 'pi':
        return sp.pi
    if val == 'e':
        return sp.E
    if val == 'C':
        # Keep C as symbol for now
        return sp.Symbol('C')
    
    # Try numeric
    try:
        return sp.Float(float(val))
    except:
        pass
    
    # Operators
    args = [tree_to_sympy(c) for c in node.children]
    
    if val == '+': return args[0] + args[1]
    if val == '-': return args[0] - args[1]
    if val == '*': return args[0] * args[1]
    if val == '/': return args[0] / args[1]
    if val == 'pow': return sp.Pow(args[0], args[1])
    if val == 'mod': return sp.Mod(args[0], args[1])
    if val == 'sin': return sp.sin(args[0])
    if val == 'cos': return sp.cos(args[0])
    if val == 'tan': return sp.tan(args[0])
    if val == 'exp': return sp.exp(args[0])
    if val == 'log': return sp.log(args[0])
    if val == 'sqrt': return sp.sqrt(args[0])
    if val == 'abs': return sp.Abs(args[0])
    if val == 'floor': return sp.floor(args[0])
    if val == 'ceil': return sp.ceiling(args[0])
    if val == 'gamma': return sp.gamma(args[0])
    if val == 'lgamma': return sp.loggamma(args[0])  # SymPy's log-gamma
    if val == 'neg': return -args[0]
    
    return sp.Integer(0)

def sympy_to_infix(expr):
    """Convert SymPy expression back to a readable string."""
    return str(expr)

def simplify_tree(tree):
    """
    Takes an ExpressionTree and returns a simplified infix string.
    """
    if not tree.is_valid:
        return "Invalid"
    
    original_infix = tree.get_infix()
    
    try:
        sympy_expr = tree_to_sympy(tree.root)
        simplified = sp.simplify(sympy_expr)
        result_str = str(simplified)
        
        # Validate: reject results containing invalid SymPy artifacts
        # zoo = complex infinity, nan, oo = infinity
        invalid_terms = ['zoo', 'nan', 'I*']  # I* indicates complex numbers
        for term in invalid_terms:
            if term in result_str:
                return original_infix  # Fall back to original
        
        return result_str
    except Exception as e:
        # If simplification fails, return original
        return original_infix

def simplify_infix(infix_str):
    """
    Takes an infix string and returns a simplified version.
    """
    try:
        expr = sp.sympify(infix_str)
        simplified = sp.simplify(expr)
        return str(simplified)
    except:
        return infix_str

# Quick test
if __name__ == "__main__":
    from core.grammar import ExpressionTree
    
    # Test: x + 0 should simplify to x
    tokens = ['+', 'x', '0']
    tree = ExpressionTree(tokens)
    print(f"Original: {tree.get_infix()}")
    print(f"Simplified: {simplify_tree(tree)}")
    
    # Test: x * 1 should simplify to x
    tokens2 = ['*', 'x', '1']
    tree2 = ExpressionTree(tokens2)
    print(f"Original: {tree2.get_infix()}")
    print(f"Simplified: {simplify_tree(tree2)}")
    
    # Test: x - x should simplify to 0
    tokens3 = ['-', 'x', 'x']
    tree3 = ExpressionTree(tokens3)
    print(f"Original: {tree3.get_infix()}")
    print(f"Simplified: {simplify_tree(tree3)}")


In [None]:
%%writefile AlphaSymbolic/utils/__init__.py


In [None]:
%%writefile AlphaSymbolic/app.py
"""
AlphaSymbolic - Gradio Web Interface
With GPU/CPU toggle and search method selection.
"""
import gradio as gr
import torch

from ui.app_core import load_model, get_device, get_device_info, set_device, get_training_errors, request_stop_training
from ui.app_training import train_basic, train_curriculum, train_self_play, train_supervised, train_hybrid_feedback_loop
from ui.app_search import solve_formula, generate_example
from ui.app_benchmark import get_benchmark_tab


def toggle_device(use_gpu):
    """Toggle between GPU and CPU."""
    device_info = set_device(use_gpu)
    color = "#4ade80" if "CUDA" in device_info else "#fbbf24" if "MPS" in device_info else "#888"
    return f'<div style="padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid {color};"><span style="color: {color}; font-weight: bold;">{device_info}</span></div>'


def create_app():
    """Create the Gradio app."""
    
    with gr.Blocks(title="AlphaSymbolic") as demo:
        
        # Header
        device_info = get_device_info()
        device_color = "#4ade80" if "CUDA" in device_info else "#fbbf24" if "MPS" in device_info else "#888"
        
        gr.HTML(f"""
        <div style="text-align: center; padding: 20px; background: linear-gradient(90deg, #00d4ff22, transparent, #ff6b6b22); border-radius: 15px; margin-bottom: 20px;">
            <h1 style="color: #00d4ff; font-size: 42px; margin: 0;">AlphaSymbolic</h1>
            <p style="color: #888; font-size: 18px; margin: 5px 0;">Deep Reinforcement Learning para Regresion Simbolica</p>
        </div>
        """)
        
        # System Controls
        with gr.Row():
            with gr.Column(scale=1):
                model_selector = gr.Dropdown(choices=["lite", "pro"], value="lite", label="Arquitectura (Cerebro)", interactive=True)
            with gr.Column(scale=3):
                model_status = gr.Textbox(label="Estado del Modelo", value="Lite (Laptop Optimized) - Vocabulario Extendido", interactive=False)
        
        def on_model_change(preset):
            status, _ = load_model(preset_name=preset)
            return status

        model_selector.change(on_model_change, model_selector, model_status)
        
        with gr.Tabs():
            # TAB 1: Search
            with gr.Tab("Buscar Formula"):
                with gr.Row():
                    with gr.Column(scale=1):
                        gr.HTML('<h3 style="color: #00d4ff;">Datos de Entrada</h3>')
                        x_input = gr.Textbox(label="Valores X", placeholder="1, 2, 3, 4, 5...", lines=2)
                        y_input = gr.Textbox(label="Valores Y", placeholder="5, 7, 9, 11, 13...", lines=2)
                        
                        with gr.Row():
                            search_method = gr.Radio(
                                choices=["Beam Search", "MCTS", "Alpha-GP Hybrid"],
                                value="Alpha-GP Hybrid",
                                label="Metodo de Busqueda"
                            )
                        
                        beam_slider = gr.Slider(5, 500, value=50, step=5, label="Beam Width / Simulaciones")
                        
                        solve_btn = gr.Button("Buscar Formula", variant="primary", size="lg")
                        
                        with gr.Row():
                            gr.Button("Lineal", size="sm").click(lambda: generate_example("lineal"), outputs=[x_input, y_input])
                            gr.Button("Cuadratico", size="sm").click(lambda: generate_example("cuadratico"), outputs=[x_input, y_input])
                            gr.Button("Seno", size="sm").click(lambda: generate_example("trig"), outputs=[x_input, y_input])
                            gr.Button("Exponencial", size="sm").click(lambda: generate_example("exp"), outputs=[x_input, y_input])
                    
                    with gr.Column(scale=2):
                        result_html = gr.HTML(label="Resultado")
                        plot_output = gr.Plot(label="Visualizacion")
                
                with gr.Row():
                    pred_html = gr.HTML(label="Predicciones")
                    alt_html = gr.HTML(label="Alternativas")
                
                raw_formula = gr.Textbox(visible=False)
                
                solve_btn.click(solve_formula, [x_input, y_input, beam_slider, search_method], 
                               [result_html, plot_output, pred_html, alt_html, raw_formula])
            
            # TAB 2: Training
            with gr.Tab("Entrenar Modelo"):
                with gr.Row():
                    gr.HTML("""
                    <div style="background: #16213e; padding: 20px; border-radius: 10px; flex: 1;">
                        <h3 style="color: #ffd93d; margin: 0;">Centro de Entrenamiento</h3>
                    </div>
                    """)
                    with gr.Column():
                        use_gpu = gr.Checkbox(label="Usar GPU", value=torch.cuda.is_available())
                        device_display = gr.HTML(value=f'<div style="padding: 10px; background: #0f0f23; border-radius: 8px; border-left: 3px solid {device_color};"><span style="color: {device_color}; font-weight: bold;">{device_info}</span></div>')
                        use_gpu.change(toggle_device, [use_gpu], [device_display])
                    with gr.Column():
                        delete_model_btn = gr.Button("🗑️ Borrar Modelo", variant="secondary", size="sm")
                        delete_status = gr.HTML()
                        
                        def delete_model_action():
                            import os
                            from ui.app_core import CURRENT_PRESET
                            filename = f"alpha_symbolic_model_{CURRENT_PRESET}.pth"
                            if os.path.exists(filename):
                                os.remove(filename)
                                return f'<div style="color: #4ade80; padding: 5px;">✅ Modelo [{CURRENT_PRESET}] eliminado. Reinicia la app para usar pesos nuevos.</div>'
                            return f'<div style="color: #888; padding: 5px;">No hay modelo [{CURRENT_PRESET}] guardado.</div>'
                        
                        delete_model_btn.click(delete_model_action, outputs=[delete_status])
                        
                        stop_train_btn = gr.Button("⏹️ Detener Entrenamiento", variant="stop", size="sm")
                        stop_status = gr.HTML()
                        stop_train_btn.click(request_stop_training, outputs=[stop_status])
                
                with gr.Tabs():
                    # Basic
                    with gr.Tab("Basico"):
                        gr.HTML('<p style="color: #888;">Entrenamiento rapido con datos sinteticos</p>')
                        with gr.Row():
                            with gr.Column():
                                epochs_basic = gr.Slider(10, 500, value=100, step=10, label="Epocas")
                                batch_basic = gr.Slider(16, 128, value=32, step=16, label="Batch Size")
                                points_basic = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_basic_btn = gr.Button("Entrenar Basico", variant="primary")
                            with gr.Column():
                                result_basic = gr.HTML()
                                plot_basic = gr.Plot()
                        train_basic_btn.click(train_basic, [epochs_basic, batch_basic, points_basic], [result_basic, plot_basic])
                    
                    # Curriculum
                    with gr.Tab("Curriculum"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
                            <p style="color: #00d4ff; margin: 0;"><strong>Curriculum Learning</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">Empieza con formulas simples y aumenta la dificultad.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                epochs_curriculum = gr.Slider(50, 2000, value=200, step=50, label="Epocas")
                                batch_curriculum = gr.Slider(16, 128, value=64, step=16, label="Batch Size")
                                points_curriculum = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_curriculum_btn = gr.Button("Entrenar Curriculum", variant="primary")
                            with gr.Column():
                                result_curriculum = gr.HTML()
                                plot_curriculum = gr.Plot()
                        train_curriculum_btn.click(train_curriculum, [epochs_curriculum, batch_curriculum, points_curriculum], [result_curriculum, plot_curriculum])
                    
                    # Self-Play
                    with gr.Tab("Self-Play"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px; border-left: 3px solid #ff6b6b;">
                            <p style="color: #ff6b6b; margin: 0;"><strong>AlphaZero Self-Play</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">El modelo resuelve problemas y aprende de sus exitos.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                iterations_sp = gr.Slider(10, 1000, value=100, step=10, label="Iteraciones")
                                problems_sp = gr.Slider(5, 200, value=10, step=5, label="Problemas/Iter")
                                points_sp = gr.Slider(10, 100, value=20, step=10, label="Puntos por Formula")
                                train_sp_btn = gr.Button("Iniciar Self-Play", variant="primary")
                            with gr.Column():
                                result_sp = gr.HTML()
                                plot_sp = gr.Plot()
                        train_sp_btn.click(train_self_play, [iterations_sp, problems_sp, points_sp], [result_sp, plot_sp])
                
                    # Feedback Loop (Teacher-Student)
                    with gr.Tab("Feedback Loop (Hybrid)"):
                        gr.HTML('''
                        <div style="background: #0f0f23; padding: 15px; border-radius: 8px; margin-bottom: 15px; border-left: 3px solid #f1c40f;">
                            <p style="color: #f1c40f; margin: 0;"><strong>Teacher-Student Feedback Loop</strong></p>
                            <p style="color: #888; margin: 5px 0 0 0;">El modelo (Estudiante) intenta resolver problemas. Si falla, el Alpha-GP (Maestro) interviene y añade la solución al dataset.</p>
                        </div>
                        ''')
                        with gr.Row():
                            with gr.Column():
                                iterations_fb = gr.Slider(5, 500, value=20, step=5, label="Ciclos")
                                problems_fb = gr.Slider(5, 50, value=10, step=5, label="Problemas Difíciles / Ciclo")
                                timeout_fb = gr.Slider(5, 30, value=10, step=5, label="Timeout Maestro (s)")
                                train_fb_btn = gr.Button("Iniciar Feedback Loop", variant="primary")
                            with gr.Column():
                                result_fb = gr.HTML()
                                plot_fb = gr.Plot()
                        train_fb_btn.click(train_hybrid_feedback_loop, [iterations_fb, problems_fb, timeout_fb], [result_fb, plot_fb])
                
                # --- PRE-TRAINING (Warmup) ---
                with gr.Accordion("🎓 Escuela Primaria (Pre-Entrenamiento)", open=False):
                    gr.Markdown("Entrenamiento masivo supervisado de alta velocidad para aprender sintaxis basica. **Recomendado al inicio.**")
                    with gr.Row():
                        with gr.Column():
                            epochs_pre = gr.Slider(100, 10000, value=2000, step=100, label="Iteraciones Rápidas")
                            train_pre_btn = gr.Button("Iniciar Pre-Entrenamiento", variant="primary")
                        with gr.Column():
                            result_pre = gr.HTML()
                            plot_pre = gr.Plot()
                    train_pre_btn.click(train_supervised, [epochs_pre], [result_pre, plot_pre])

                # --- HALL OF SHAME (Error Analysis) ---
                with gr.Accordion("🕵️‍♂️ Hall of Shame (Analisis de Errores)", open=False):
                    gr.Markdown("Aquí se muestran los problemas donde el modelo falló drásticamente hoy.")
                    error_table = gr.DataFrame(
                        headers=["Time", "Target Formula", "Predicted", "Loss", "Stage"],
                        datatype=["str", "str", "str", "number", "str"],
                        interactive=False
                    )
                    refresh_errors_btn = gr.Button("🔄 Actualizar Errores", size="sm")
                    
                    def update_errors():
                        errors = get_training_errors()
                        # Reverse to show newest first
                        data = [[
                            e['time'], e['target'], e['predicted'], round(e['loss'], 2), e['stage']
                        ] for e in reversed(errors)]
                        return data
                    
                    refresh_errors_btn.click(update_errors, outputs=[error_table])
            
            # TAB 4: Benchmark
            get_benchmark_tab()

            # TAB 5: Info
            with gr.Tab("Informacion"):
                device_info_current = get_device_info()
                device_color_current = "#4ade80" if "CUDA" in device_info_current else "#fbbf24" if "MPS" in device_info_current else "#888"
                
                gr.HTML(f"""
                <div style="background: #1a1a2e; padding: 30px; border-radius: 15px;">
                    <h2 style="color: #00d4ff;">Que es AlphaSymbolic?</h2>
                    <p style="color: #ccc; line-height: 1.8;">
                        Sistema de <strong style="color: #ff6b6b;">regresion simbolica</strong> 
                        basado en <strong style="color: #00d4ff;">Deep Learning</strong> y 
                        <strong style="color: #ffd93d;">Monte Carlo Tree Search</strong>.
                    </p>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Dispositivo Actual</h3>
                    <p style="color: {device_color_current}; font-size: 20px;">{device_info_current}</p>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Metodos de Busqueda</h3>
                    <ul style="color: #ccc;">
                        <li><strong>Beam Search:</strong> Explora multiples candidatos en paralelo (rapido)</li>
                        <li><strong>MCTS:</strong> Monte Carlo Tree Search (mas preciso, lento)</li>
                        <li><strong>Alpha-GP Hybrid:</strong> Fusiona Neural Search con Algoritmo Genetico GPU (Extremo)</li>
                    </ul>
                    
                    <h3 style="color: #00d4ff; margin-top: 30px;">Operadores</h3>
                    <div style="display: flex; flex-wrap: wrap; gap: 10px; margin: 15px 0;">
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">+</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">-</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">*</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #00d4ff;">/</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ff6b6b;">sin</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ff6b6b;">cos</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ffd93d;">exp</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #ffd93d;">log</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #4ade80;">pow</span>
                        <span style="background: #0f0f23; padding: 5px 15px; border-radius: 20px; color: #4ade80;">sqrt</span>
                    </div>
                </div>
                """)
        
        gr.HTML("""
        <div style="text-align: center; padding: 20px; color: #666; margin-top: 30px;">
            <p>Powered by PyTorch - SymPy - Scipy - Gradio</p>
        </div>
        """)
    
    return demo



# --- Global Initialization for Hot Reloading ---
print("Iniciando AlphaSymbolic (Global Init)...")
# Load model once at module level so 'gradio app.py' works
status_init, device_info_init = load_model() 
print(f"   {status_init} | {device_info_init}")

# Create the app instance globally
demo = create_app()

if __name__ == "__main__":
    print("Abriendo navegador...")
    # Launch with auto-reload compatibility if run directly (though proper reload needs 'gradio app.py')
    demo.launch(share=True, inbrowser=True)


In [None]:
# Run AlphaSymbolic
# The binaries are in ../Code/build/
%cd AlphaSymbolic
!python app.py
