#### Enrty-wise Computation - following the formulas on the Paper

In [1]:
"""
see page 3-4 of the paper

STEP 0
    x and x' are 2-dimensional points of shape (2,)
    K_0(x, x') = sigma_0(x, x') = <x, x'> = (x^T).dot(x')

STEP 1a
    B_1 = [sigma_0(x, x),  sigma_0(x, x'),
           sigma_0(x, x'), sigma_0(x', x')]
           
STEP 1b
    sigma_1(x, x') = 2 E(u, v)~N(0, B_1) [activation(u) activation(v)]
    sigma_1'(x, x') = 2 E(u, v)~N(0, B_1) [activation'(u) activation'(v)]

STEP 1c
    K_1(x, x') = sigma_1(x, x') + K_0(x, x') sigma_1'(x, x')

"""
#------------
# step 0
#------------   
def K_0(x, y):
    return np.dot(x.T, y)

def sigma_0(x, y):
    return np.dot(x.T, y)

#------------
# step 1a
#------------
def B_1(x, y):
    v1 = np.vstack((sigma_0(x, x), sigma_0(x, y)))
    v2 = np.vstack((sigma_0(x, y), sigma_0(y, y)))

    return np.hstack((v1, v2))

#------------
# step 1b
#------------
def sigma_1_and_prime(x, y, activation, k):
    
    activation_map = {
        'relu_k': [relu_k, d_relu_k],
        'sin': [sin, cos],
        'cos': [cos, d_cos]
    }
        
    mean = np.zeros(x.shape[0])
    cov = B_1(x, y)
    num_samples = 1000
    # generate 1000 pairs of (u, v), apply activation array-wise then take average
    sample_1 = np.random.multivariate_normal(mean, cov, num_samples) # (num_samples, 2)
    u1, v1 = sample_1[:, 0], sample_1[:, 1] # (num_samples,)
    
    sample_2 = np.random.multivariate_normal(mean, cov, num_samples) # (num_samples, 2)
    u2, v2 = sample_2[:, 0], sample_2[:, 1] # (num_samples,)
    
    activation_func, d_activation = activation_map[activation]
    expectation_1 = np.mean(activation_func(u1, k) * activation_func(v1, k))
    expectation_2 = np.mean(d_activation(u2, k) * d_activation(v2, k))

    sigma_1 = 2 * expectation_1
    sigma_1_prime = 2 * expectation_2
    
    return sigma_1, sigma_1_prime

#------------
# step 1c
#------------
def K_1(x, y, activation, k=1):
    sigma_1, sigma_1_prime = sigma_1_and_prime(x, y, activation, k)
    return sigma_1 + K_0(x, y) * sigma_1_prime



In [2]:
"""
STEP 2a
    B_2 = [sigma_1(x, x),  sigma_1(x, x'),
           sigma_1(x, x'), sigma_1(x', x')]
           
STEP 2b
    sigma_2(x, x') = 2 E(u, v)~N(0, B_2) [activation(u) activation(v)]
    sigma_2'(x, x') = 2 E(u, v)~N(0, B_2) [activation'(u) activation'(v)]

STEP 2c
    K_2(x, x') = sigma_2(x, x') + K_1(x, x') sigma_2'(x, x')

"""

#------------
# step 2a
#------------
def B_2(x, y, activation_h1, k):
    sigma_1_xx, _ = sigma_1_and_prime(x, x, activation_h1, k)
    sigma_1_xy, _ = sigma_1_and_prime(x, y, activation_h1, k)
    sigma_1_yy, _ = sigma_1_and_prime(y, y, activation_h1, k)
    
    v1 = np.vstack((sigma_1_xx, sigma_1_xy))
    v2 = np.vstack((sigma_1_xy, sigma_1_yy))

    return np.hstack((v1, v2))


#------------
# step 2b
#------------
def sigma_2_and_prime(x, y, activation_h2, k):
    
    activation_map = {
        'relu_k': [relu_k, d_relu_k],
        'sin': [sin, cos],
        'cos': [cos, d_cos]
    }
        
    mean = np.zeros(x.shape[0])
    cov = B_2(x, y)
    num_samples = 1000
    # generate 1000 pairs of (u, v), apply activation array-wise then take average
    sample_1 = np.random.multivariate_normal(mean, cov, num_samples) # (num_samples, 2)
    u1, v1 = sample_1[:, 0], sample_1[:, 1] # (num_samples,)
    
    sample_2 = np.random.multivariate_normal(mean, cov, num_samples) # (num_samples, 2)
    u2, v2 = sample_2[:, 0], sample_2[:, 1] # (num_samples,)
    
    activation_func, d_activation = activation_map[activation_h2]
    expectation_1 = np.mean(activation_func(u1, k) * activation_func(v1, k))
    expectation_2 = np.mean(d_activation(u2, k) * d_activation(v2, k))

    sigma_2 = 2 * expectation_1
    sigma_2_prime = 2 * expectation_2
    
    return sigma_2, sigma_2_prime

#------------
# step 2c
#------------
def K_2(x, y, activation_h1, activation_h2, k=1):
    sigma_2, sigma_2_prime = sigma_1_and_prime(x, y, activation_h2, k)
    return sigma_2 + K_1(x, y, activation=activation_h1) * sigma_2_prime


#### Build Matrices One by One

In [3]:
# quick entry-wise check
# temp_K = np.zeros((num_inputs, num_inputs))

# for i in range(num_inputs):
#     for j in range(num_inputs):
#         x = X[:, i]
#         y = X[:, j]
#         temp_K[i][j] = np.dot(x, y)

# np.all(K_0 - temp_K < 0.01) # true

In [4]:
# step 1a: calculate B1

# # first calculate B_1 and then we sample
# B1 = np.zeros((2*num_inputs, 2*num_inputs))

# # initialize sigma_0(x, x), (x, y) and (y, y)
# s0_xx = np.zeros((num_inputs, num_inputs))
# s0_xy = np.zeros((num_inputs, num_inputs))
# s0_yy = np.zeros((num_inputs, num_inputs))
    
# for i in range(num_inputs):
#     for j in range(num_inputs):
#         x = X[:, i]
#         y = X[:, j]
        
#         s0_xx[i][j] = sigma_0(x, x)
#         s0_xy[i][j] = sigma_0(x, y)
#         s0_yy[i][j] = sigma_0(y, y)

# # concatenate to build B1
# v1 = np.vstack((s0_xx, s0_xy))
# v2 = np.vstack((s0_xy, s0_yy))

# B1 = np.hstack((v1, v2))
# check(B1)
# B1.shape

In [5]:
# execute
  
def calc_NTK(activation_h1='relu_k', activation_h2='sin', k=1):   

    # init input
    x = init_inputs() # (2, 100)
    num_inputs = x.shape[1]
    
    # first calculate B_k and then we sample
    B1 = np.zeros((2*num_inputs, 2*num_inputs))
    B2 = np.zeros((2*num_inputs, 2*num_inputs))
    

    # initialize kernel 
    kernel = np.zeros((num_inputs, num_inputs))
    
    for i in range(num_inputs):
        for j in range(num_inputs):
            kernel[i][j] = K_2(x[:, i], x[:, j], activation_h1, activation_h2, k)
    
    check(kernel)
    
    return kernel