In [7]:
import numpy as np
import pandas as pd
from scipy import stats
from scipy import linalg
from scipy import special
import matplotlib.pyplot as plt
import arviz as ar
import seaborn as sns

In [8]:
from scipy.optimize import minimize

In [9]:
# l_2 distance squared
def l2_sq(a, b):
    temp = np.array([v**2 for v in a - b])
    return np.sum(temp)

# STANDARD Squared Exponential Kernel function (isotropic scale parameter 'l') for feature vectors x1, x2
def K_sqexp(x1, x2, l):
        return np.exp(-l2_sq(x1, x2)/(l**2))
    
# creating the baseline sqexp kernel matrix
def create_sqexp_mat(X_mat, l):
    n = len(X_mat)
    out = np.zeros(shape = (n, n))
    for i in range(n - 1):
        for j in range(i + 1, n):
            out[j][i] = out[i][j] = K_sqexp(x1 = X_mat[i], x2 = X_mat[j], l = l)
    for i in range(n):
        out[i][i] = K_sqexp(x1 = X_mat[i], x2 = X_mat[i], l = l)
    return out

def create_full_K1(X_mat, d_vec, l):
    d = np.atleast_2d(d_vec)
    return np.multiply(create_sqexp_mat(X_mat, l), 1 + np.matmul(d.T, d))

$$ K_{l} \left( \left( \boldsymbol{x}_i, d_i \right), \left( \boldsymbol{x}_j, d_j \right) \right) 
:= \left( 1 + d_i d_j \right) \exp \left[ - \frac{\| \boldsymbol{x}_i - \boldsymbol{x}_j \|_2^2}{2 l^2} \right], 
\; l > 0 $$

In [47]:
n = 100
p = 10
sigma = 1
X = np.random.normal(size = (n, p))
d_true = np.ones(n)/5
Y_true = np.random.multivariate_normal(mean = np.zeros(n), 
                                      cov = sigma * np.identity(n) + create_full_K1(X, d_true, l = 1.0e+1))
                                      

def obj_func(d):
    cov = sigma * np.identity(n) + create_full_K1(X, d, l = 1.0e+1)
    return np.matmul(Y_true, np.matmul(linalg.inv(cov), Y_true))

In [23]:
# cons = ({'type': 'ineq', 'fun': lambda x: np.max(x) < 1}, {'type': 'ineq', 'fun': lambda x: np.min(x) > 0})

In [48]:
bnds = tuple([(0, 1)]*n)

In [49]:
res = minimize(fun = obj_func, x0 = np.ones(n)/2, bounds = bnds)

In [50]:
res.success

True

In [51]:
res.x

array([0.80800042, 0.47780582, 0.        , 0.84953251, 0.        ,
       0.29352083, 1.        , 0.02201886, 0.37367161, 0.70611766,
       0.58524308, 0.57827076, 0.24825822, 0.44412247, 0.57491711,
       1.        , 1.        , 1.        , 0.29396299, 1.        ,
       0.51642762, 0.04103116, 0.        , 0.38769305, 0.63550345,
       0.48276881, 1.        , 0.56865054, 0.53038083, 0.        ,
       0.68087194, 0.7118192 , 1.        , 0.21646218, 0.48720035,
       0.61779547, 0.69615012, 0.41216263, 0.77674795, 0.20415616,
       0.73685883, 0.12987052, 0.74901395, 0.89803184, 0.        ,
       0.        , 0.70367013, 0.27546854, 0.80066662, 0.38771879,
       1.        , 0.23502483, 0.        , 0.7396996 , 0.        ,
       0.22552309, 0.45766418, 0.48667348, 0.81262291, 0.23305441,
       0.40520605, 0.17571559, 0.7358717 , 0.49571305, 0.08433276,
       0.        , 0.82317454, 0.89198167, 0.87601654, 1.        ,
       0.        , 0.32054091, 0.59833697, 0.106274  , 0.40164