In [38]:
from problems.HS100 import HS100, CallableClass
import methods.methods as methods
from typing import Tuple, Union, List, Type
from copy import deepcopy
import pandas as pd
import numpy as np
from numpy import number
from scipy.stats.qmc import LatinHypercube as lhs
from scipy.stats.qmc import scale
from smt.surrogate_models import KRG
import random

In [39]:
def sample_sites(problem: dict, n_samples: int, seed=random.randint(1,1000)) -> pd.DataFrame:
    """ Simple wrapper to sample sites
    
    Parameters
    ----------
    problem : dict
        The variables and constraints given

    n_sample: int
        The number of sites to generare

    seed: int
        Seed for the LHS sampling
   
    Returns
    -------
    pd.DataFrame
        The DataFrame of sites
    """

    variables = list(problem["variables"].keys())
    nind = len(variables)
    # Get all the bounds for the variables
    bounds = np.array([problem["variables"][var]["bounds"] for var in variables])
    # Generate the experiment 
    lhs_instance = lhs(
    nind,
    scramble=True,
    strength=1,
    optimization=None,
    seed=seed,
    )
    # Get the experiment
    normalized_array = lhs_instance.random(n_samples)
    # Scale using the bounds
    exp_array = scale(normalized_array, bounds[:, 0], bounds[:, 1])
    # Create the DataFrame
    exp_df = pd.DataFrame(data=exp_array, columns=variables)
    return(exp_df)


In [40]:
def evaluate_sites(local_eval: CallableClass, exp_data: pd.DataFrame, verbose=True):
    """ evaluates sites passed and returns their constraint violation data

        Parameters
        ----------
        local_eval : CallableClass
            The example evaluator to use
        
        n_sample: int
            The number of sites to generare
   
        Returns
        -------
        pd.DataFrame
            The DataFrame of evaluated experimental sites with constraint violation
    """
    eps = 1e-6
    local_problem = local_eval.problem()
    local_eval(exp_data)
    my_constraint_calculator = methods.ConstraintCalculator(local_problem)
    exp_data['__conviol__'] = my_constraint_calculator(exp_data)
    exp_data['__State__'] = pd.cut(exp_data['__conviol__'], [-np.inf, eps, 100*eps, np.inf], include_lowest=True, labels=['Feasible', 'Nearly Feasible', "Infeasible"]).astype("str")
    feasible_sites = (exp_data['__State__'] == 'Feasible').sum()
    percentage = feasible_sites/num_sites * 100
    if verbose: 
        print(f"Test evaluator {local_eval.name} with {len(local_problem['variables'])} variables {percentage}% feasible sites")
    return (exp_data)

In [41]:
def advanced_testing(local_eval: CallableClass, num_sites_training: int, num_sites_testing : int, verbose=True) -> pd.DataFrame:
    """ Wrapper to create an experiment, evaluate the passed in function, create a surrogate model, evaluate model

        Parameters
        ----------
        local_eval : CallableClass
            The example evaluator to use
        
        n_sample_training: int
            The number of sites to train the sm on

        n_sample_testing: int
            The number of sites to test the sm
   
        Returns
        -------
        pd.DataFrame
            The DataFrame of evaluated experimental sites with constraint violation
    """
    # get training data
    sample_data = sample_sites(local_eval.problem(), num_sites_training)
    exp_data = evaluate_sites(local_eval, sample_data, verbose=False)
    variables = list(local_eval.problem()["variables"].keys())
    xt = exp_data[variables].to_numpy()
    yt = exp_data['__conviol__'].to_numpy()

    # create model
    sm = KRG(theta0=[1e-2])
    sm.set_training_values(xt, yt)
    sm.train()

    # get testing data
    eps = 1e-6
    x = sample_sites(local_eval.problem(), num_sites_testing).to_numpy()
    y = data=sm.predict_values(x).flatten()
    feasible_points = pd.DataFrame(data=x[y < eps], columns=variables)
    
    # evaluate model on testing data
    feasible_points = evaluate_sites(local_eval, feasible_points)
    return (feasible_points) 
    

In [42]:
hs100 = HS100()
problem = hs100.problem()
nind = len(problem['variables'])
num_sites = int((nind+1)*(nind+2)/2)
test_data = advanced_testing(hs100, num_sites, num_sites**2)
test_data

___________________________________________________________________________
   
                                  Kriging
___________________________________________________________________________
   
 Problem size
   
      # training points.        : 36
   
___________________________________________________________________________
   
 Training
   
   Training ...
   Training - done. Time (sec):  1.6143129
___________________________________________________________________________
   
 Evaluation
   
      # eval points. : 1296
   
   Predicting ...
   Predicting - done. Time (sec):  0.0389180
   
   Prediction time/pt. (sec) :  0.0000300
   
Test evaluator hs100 with 7 variables 0.0% feasible sites


Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,f,c1,c2,c3,c4,__conviol__,__State__
0,-5.369651,-3.535897,5.490247,8.4924,-8.669801,9.548724,4.117912,4249742.0,-650.233499,11.604935,-207.126021,-133.607797,2.549123,Infeasible
1,-4.956528,-3.328183,2.767247,7.805788,4.549653,8.62923,6.717556,92362.3,-559.457991,246.847548,-94.117835,-44.425134,1.181596,Infeasible
2,-2.993025,-5.458056,1.715755,-9.287743,5.70217,8.696639,7.39393,349799.8,-2928.591164,304.877107,-159.588523,15.34753,3.335191,Infeasible
3,-4.687133,-4.611073,-0.692802,8.465012,-7.080012,4.054773,-0.779227,1261225.0,-1523.688932,308.298387,177.661138,-74.106035,1.694343,Infeasible
4,-4.415171,-4.628094,0.016332,-1.703196,7.648391,9.702013,8.002587,2008162.0,-1338.203019,336.139399,-224.623978,1.425151,2.614647,Infeasible
5,-0.267458,-4.783777,3.925328,7.906021,5.773898,4.92457,8.771674,378100.1,-1727.062491,142.009427,103.93203,21.71688,1.727062,Infeasible
6,1.559315,-3.173251,0.559959,-8.05729,-4.346177,2.606033,8.494951,74782.25,-420.557421,281.180121,177.277382,45.147515,0.420557,Infeasible
7,-0.809322,-3.83938,-3.545617,5.701748,8.07854,4.698299,7.465381,2784322.0,-693.073818,175.846152,127.152541,25.445928,0.693074,Infeasible
8,-0.054998,-5.273834,-3.567961,-6.788017,-1.71451,8.713706,9.288037,10446.77,-2365.917518,175.97657,-211.816101,6.183921,3.175558,Infeasible
9,1.199933,3.786246,3.056359,-1.230478,9.032099,-4.083831,9.13514,5437304.0,-546.686144,179.091024,127.080932,95.757744,0.546686,Infeasible
