## Exploring SMT

In [1]:
import numpy as np
from smt.utils.misc import compute_relative_error

from smt.problems import Rosenbrock
from smt.sampling_methods import LHS
from smt.surrogate_models import LS, QP, KPLS, KRG, KPLSK, GEKPLS, MGP, IDW, RBF, RMTC, RMTB
import matplotlib.pyplot as plt
from matplotlib import cm

from scipy.integrate import dblquad

### Finding the best rectangular domain of area 1 to integrate the hidden function

In [2]:
def hidden_function(x,y):
    return np.exp(-x**2-y**2)

def objective_function(x,y,l):
    return dblquad(hidden_function, x, x+l, y, y+1/l)[0]

### Sample with Latin Hypercube Sampling

In [3]:
sampling = LHS(xlimits=np.array([[-5,5],[-5,5],[0.1,5]]))

X = sampling(50)
Y = np.array([objective_function(x,y,z) for x,y,z in X])
[X[:5], Y[:5]]

[array([[-3.1  , -0.5  ,  3.383],
        [ 1.1  ,  3.9  ,  0.933],
        [ 2.3  ,  4.3  ,  2.599],
        [ 4.7  ,  0.7  ,  4.167],
        [-4.3  ,  2.1  ,  1.913]]),
 array([3.01715876e-01, 3.16293234e-09, 1.04043354e-12, 3.25725260e-12,
        1.60251114e-06])]

### Training with Kriging

In [4]:
sm = KRG(theta0=[1e-2, 1e-2, 1e-2])
sm.set_training_values(X, Y)
sm.train()

___________________________________________________________________________
   
                                  Kriging
___________________________________________________________________________
   
 Problem size
   
      # training points.        : 50
   
___________________________________________________________________________
   
 Training
   
   Training ...
   Training - done. Time (sec):  1.3803449


### Making predictions for squares

In [5]:
X_test = np.meshgrid(np.linspace(-5, 5, 10), np.linspace(-5, 5, 10))
X_test

(array([[-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.66666667,  2.77777778,  3.88888889,  5.        ],
        [-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.66666667,  2.77777778,  3.88888889,  5.        ],
        [-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.66666667,  2.77777778,  3.88888889,  5.        ],
        [-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.66666667,  2.77777778,  3.88888889,  5.        ],
        [-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.66666667,  2.77777778,  3.88888889,  5.        ],
        [-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.66666667,  2.77777778,  3.88888889,  5.        ],
        [-5.        , -3.88888889, -2.77777778, -1.66666667, -0.55555556,
          0.55555556,  1.6666666