In [None]:
import numpy as np

class ChoosePoints:
    def choose(self, model, *args, **kwargs):
        raise NotImplementedError("Subclasses should implement this method")

class ChooseRand(ChoosePoints):
    def __init__(self, N=100):
        self.N = N
    
    def choose(self, model=None, d=2):
        return np.random.rand(self.N, d)

class ChooseRandFit(ChoosePoints):
    def __init__(self, N=100, d=2, fac=20, amp=1):
        self.N = N
        self.d = d
        self.fac = fac
        self.amp = max(amp, 1)  # Ensure amp > 0
    
    def choose(self, model, **kwargs):
        """
        Choose points based on fitness using the model
        Uses rejection sampling based on model predictions
        """
        print("ChooseRandFit choosing...")
        
        # Extract parameters - use defaults if not provided
        N = kwargs.get('N', self.N)
        d = kwargs.get('d', self.d) 
        fac = kwargs.get('fac', self.fac)
        
        # Generate random candidates
        raw = np.random.rand(fac * N, d)
        
        # Get fitness predictions from model
        fit = model.predict(raw)
        
        # FIX: Handle multi-dimensional output properly
        if fit.ndim > 1 and fit.shape[1] > 1:
            # Multi-dimensional output - convert to scalar fitness
            fit = np.linalg.norm(fit, axis=1)  # Use L2 norm (magnitude)
        elif fit.ndim > 1:
            # Single column output - flatten
            fit = fit.flatten()
        
        # Now fit has shape (fac*N,) - one value per candidate
        
        # Normalize fitness to [0, 1] range
        mx, mn = fit.max(), fit.min()
        if mx > mn:
            fit_norm = (fit - mn) / (mx - mn)
        else:
            # All fitness values are the same
            fit_norm = np.ones_like(fit)
        
        # Apply amplification
        if self.amp > 1:
            fit_norm = fit_norm ** self.amp
        
        # Ensure probabilities are valid and sum to 1
        fit_norm = fit_norm + 1e-10  # Add small epsilon to avoid zero probabilities
        probabilities = fit_norm / fit_norm.sum()
        
        # Sample indices based on fitness probabilities
        try:
            idx = np.random.choice(len(fit), N, p=probabilities, replace=True)
        except ValueError as e:
            print(f"Warning: Probability sampling failed ({e}), using uniform sampling")
            idx = np.random.choice(len(fit), N, replace=True)
        
        print("done.")
        return raw[idx]

class ChooseRandFitRobust(ChoosePoints):
    def __init__(self, N=100, d=2, fac=20, amp=1):
        self.N = N
        self.d = d
        self.fac = fac
        self.amp = max(amp, 1)
    
    def choose(self, model, **kwargs):
        """
        More robust fitness-based point selection
        """
        print("ChooseRandFitRobust choosing...")
        
        N = kwargs.get('N', self.N)
        d = kwargs.get('d', self.d)
        fac = kwargs.get('fac', self.fac)
        
        # Generate candidates
        raw = np.random.rand(fac * N, d)
        
        # Get fitness
        fit = model.predict(raw)
        
        # FIX: Handle multi-dimensional output properly
        if fit.ndim > 1 and fit.shape[1] > 1:
            # Multi-dimensional output - convert to scalar fitness
            fit = np.linalg.norm(fit, axis=1)
        elif fit.ndim > 1:
            fit = fit.flatten()
        
        # Use ranking-based selection instead of raw fitness
        # This is more robust to fitness scaling issues
        ranked_indices = np.argsort(fit)  # Sort by fitness (ascending)
        
        # Create probability weights based on rank (higher rank = higher probability)
        ranks = np.arange(len(fit))
        weights = np.zeros_like(fit)
        weights[ranked_indices] = ranks + 1  # Rank-based weights
        
        # Apply amplification to weights
        if self.amp > 1:
            weights = weights ** self.amp
        
        # Convert to probabilities
        probabilities = weights / weights.sum()
        
        # Sample based on probabilities
        idx = np.random.choice(len(fit), N, p=probabilities, replace=True)
        
        print("done.")
        return raw[idx]