In [1]:
# imports
import numpy as np

In [72]:
def generate_data(n, p, grid_dim, sigma, noise, degree):
    '''
    Generate data for nxn grid
    
    Parameters:
        int n: number of data points to generate
        int p: number of features
        int grid_dim: Dimension of square grid, determines size of cost vector
        array sigma: array of length p, is the variance of each feature vector dimension, i.e. x_i ~ N(0, sigma_p)
        float noise: multiplicative noise term applied to cost vector, sampled from uniform distribution in [1-noise, 1+noise]
        int degree: polynomial degree of generated cost vector. When degree=1, expected value of c is linear in x. Degree > 1 controls the amount of model misspecification.
    
    Returns:
        np.array X: feature data of dimension [num_samples, p]
        np.array C: cost data of dimension [num_samples, d]
    '''
    # Define number of edges based on griworksd size, i.e. size of cost vector
    d = grid_dim*(grid_dim-1)*2
    
    # Define the parameters of the true model
    B_star = np.random.binomial(size=[d,p], n=1, p= 0.5) # each entry of B is a bernoulli RV with prob = 0.5 entry is 1
    
    # Generate feature data: Generated from multivariate Gaussian distribution with i.i.d. standard normal entries --> x ~ N(0, sigma)
    X = np.random.normal(loc = 0, scale = sigma, size = [n, p]) # each row is a training point of size p
    
    # Generate cost data
    noise_vector = np.random.uniform(low = 1-noise, high = 1+noise, size = [n, d]) # i.i.d noise terms
    C = np.multiply((((1/np.sqrt(p) * B_star@X.T) + 3)**degree + 1).T, noise_vector)
    
    return X, C

In [73]:
n = 10 
p = 5 
grid_dim = 5 
sigma = [0.1,0.2,0.3,0.4,0.5]
noise = 0.25
degree = 3

X, C = generate_data(n, p, grid_dim, sigma, noise, degree)
print("X dim: ", X.shape)
print("C dim: ", C.shape)

X dim:  (10, 5)
C dim:  (10, 40)


In [74]:
X

array([[-0.07882851, -0.08097164,  0.19025665,  0.3661684 , -0.07774557],
       [-0.03441958,  0.18339376, -0.02207518,  0.2406906 , -0.38958077],
       [-0.00541493, -0.23453103,  0.53276545, -0.06198112,  0.17092571],
       [-0.07905307,  0.27871336,  0.1934195 , -0.1866452 , -0.45082534],
       [-0.05462957, -0.111728  ,  0.1038275 ,  0.27289704, -0.20711424],
       [ 0.14700422,  0.49586077,  0.01113523, -0.50450713, -0.06085766],
       [-0.00149678,  0.25798228,  0.25393751,  0.23118585,  0.13068553],
       [-0.00070279, -0.09695708,  0.19010743, -0.42371698, -0.17854467],
       [-0.3164389 ,  0.01316082, -0.20452433, -0.50815685,  0.28033054],
       [-0.08266884, -0.03055873, -0.09924626, -0.24853259,  0.223559  ]])

In [75]:
C

array([[23.29503102, 31.2344417 , 26.45259   , 32.37760108, 28.02345402,
        31.8741075 , 28.0862025 , 26.59901451, 33.46560691, 24.98400841,
        25.28160518, 24.32975422, 35.59843477, 31.40111219, 25.34720578,
        23.91137661, 28.47352971, 36.64934795, 37.69800066, 29.82519258,
        22.52707213, 29.47185934, 25.65357127, 29.72177995, 20.98729878,
        21.20989524, 24.10762891, 30.43362387, 34.12615556, 36.14225419,
        36.96146969, 34.36315936, 33.52858453, 22.42362468, 37.64741601,
        26.5953919 , 32.6362195 , 24.67748047, 32.52823224, 33.97772225],
       [25.09184123, 36.5242632 , 25.29098038, 34.78195092, 33.70368346,
        26.83365746, 38.64485419, 35.51491455, 24.57864304, 20.81104173,
        36.08585103, 31.97673223, 19.5138267 , 23.20860433, 22.33700652,
        22.81659095, 21.17991037, 38.67231788, 25.19190358, 25.32453814,
        26.08333576, 38.61134581, 24.97524452, 24.35246927, 30.71658437,
        23.65830831, 31.36002463, 34.74926491, 21.

In [None]:
|