In [1]:
import cupy as cp
import numpy as np
import scipy.linalg

In [24]:
#CUDA kernel for convolution operation
conv3 = cp.RawKernel(r'''
extern "C" __global__
void conv3(const float s[32][32][32][32], float t[32][32][32][32])
{
    int x1 = threadIdx.x + blockIdx.x - 31;
    int y1 = threadIdx.y + blockIdx.y - 31;
    int x2 = threadIdx.x;
    int y2 = threadIdx.y;

    __shared__ float d[32 + 2][32 + 2];
    if (x2 == 0){
        d[0][y2 + 1] = d[33][y2 + 1] = 0;
        if (x2 == 0 && y2 == 0)
            d[0][0] = d[0][33] = d[33][0] = d[33][33] = 0; 
    }
    if (y2 == 0){
        d[x2 + 1][0] = d[x2 + 1][33] = 0;
    }

    if (x1 < 0 || x1 > 31 || y1 < 0 || y1 > 31){
        d[x2 + 1][y2 + 1] = 0;
        return;
    }
    else
        d[x2 + 1][y2 + 1] = s[x1][y1][x2][y2];
    __syncthreads();

    t[x1][y1][x2][y2] = d[x2][y2] + d[x2][y2 + 1] + d[x2][y2 + 2]
                      + d[x2 + 1][y2] + d[x2 + 1][y2 + 1] + d[x2 + 1][y2 + 2]
                      + d[x2 + 2][y2] + d[x2 + 2][y2 + 1] + d[x2 + 2][y2 + 2];

}''', 'conv3')
conv_blocks = (63, 63)
conv_threads = (32, 32)

In [25]:
#CUDA kernel for convolution operation
conv3x = cp.RawKernel(r'''
extern "C" __global__
void conv3x(const float s[32][32][32][32], float t[32][32][32][32])
{
    int x1 = threadIdx.x + blockIdx.x - 31;
    int y1 = threadIdx.y + blockIdx.y - 31;
    int x2 = threadIdx.x;
    int y2 = threadIdx.y;

    __shared__ float d[32 + 2][32 + 2];
    if (x2 == 0){
        d[0][y2 + 1] = d[33][y2 + 1] = 0;
        if (x2 == 0 && y2 == 0)
            d[0][0] = d[0][33] = d[33][0] = d[33][33] = 0; 
    }
    if (y2 == 0){
        d[x2 + 1][0] = d[x2 + 1][33] = 0;
    }

    if (x1 < 0 || x1 > 31 || y1 < 0 || y1 > 31){
        d[x2 + 1][y2 + 1] = 0;
        return;
    }
    else
        d[x2 + 1][y2 + 1] = s[x1][y1][x2][y2];
    __syncthreads();

    t[x1][y1][x2][y2] = d[x2 + 1][y2 + 1];

}''', 'conv3x')
conv_blocks = (63, 63)
conv_threads = (32, 32)

In [3]:
#CUDA kernel for activation
trans = cp.RawKernel(r'''
extern "C" __global__
void trans(float s[32][32][32][32], float t[32][32][32][32], const float l[32][32], const float r[32][32], const float il[32][32], const float ir[32][32])
{
	int x1 = blockIdx.x;
	int y1 = blockIdx.y;
	int x2 = threadIdx.x + ((blockIdx.z >> 2) << 3);
	int y2 = threadIdx.y + ((blockIdx.z & 3) << 3);
	float S = s[x1][y1][x2][y2], T = t[x1][y1][x2][y2], L = l[x1][y1], R = r[x2][y2], iL = il[x1][y1], iR = ir[x2][y2];
	S = S * iL * iR;
	float BS = (S * (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) + sqrtf(1.0f - min(S * S, 1.0f))) * L * R / 28.274333882308138f;
	S = (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) / 28.274333882308138;
	t[x1][y1][x2][y2] = T * S + BS;
	s[x1][y1][x2][y2] = BS;

}''', 'trans')
trans_blocks = (32, 32, 16)
trans_threads = (8, 8)

In [4]:
x = cp.random.randn(3,1024)
x

array([[-0.53973204,  0.83741791, -0.53801839, ...,  2.32454632,
        -1.01789878,  0.81565338],
       [-0.25691116, -1.38140557,  1.72760458, ...,  1.59025383,
         1.26879762, -0.91242344],
       [-1.3317491 , -0.06516584,  0.65451939, ...,  0.48914179,
         1.31815136,  0.76414301]])

In [11]:
(x[:,0]*x[:,0]).sum()

array(2.13086967)

In [18]:
RL = [1.0, ]
iRL = [1.0, ]

S = cp.matmul(x.T, x).reshape(32, 32, 32, 32)
S

array([[[[ 2.13086967e+00, -1.02982270e-02, -1.02511074e+00, ...,
           1.15552420e+00,  2.20233875e+00, -4.78155822e-01],
         [-5.15766443e-01, -3.51304443e-02,  1.63977281e+00, ...,
          -1.02847005e+00, -7.32748761e-01,  1.65490703e+00],
         [ 5.18328035e-02,  2.96111827e+00, -1.26762034e-01, ...,
          -1.67646442e+00, -3.51223384e-02,  1.59984818e+00],
         ...,
         [-1.92348527e-01, -7.31122256e-01, -1.20223233e+00, ...,
           7.08418603e-01,  2.29864667e+00,  2.51987909e-01],
         [ 5.14413292e-01, -2.63641976e+00,  1.94228061e+00, ...,
          -2.27560769e+00,  7.40123642e-01,  3.83214307e-01],
         [-2.58714357e+00, -4.35147394e-01, -1.94246752e+00, ...,
          -2.31460022e+00, -1.53202255e+00, -1.22346926e+00]],

        [[-1.02982270e-02,  2.61379669e+00, -2.87972113e+00, ...,
          -2.06823677e+00, -2.25656292e+00, -1.45348551e+00],
         [-1.39591090e+00,  8.79003003e-01,  8.98049563e-01, ...,
           1.40326257e

In [26]:
R = cp.zeros(S.shape)
conv3x(conv_blocks, conv_threads, (S, R))
R

array([[[[ 2.13086967e+00, -1.02982270e-02, -1.02511074e+00, ...,
           1.15552420e+00,  2.20233875e+00, -4.78155822e-01],
         [-5.15766443e-01, -3.51304443e-02,  1.63977281e+00, ...,
          -1.02847005e+00, -7.32748761e-01,  1.65490703e+00],
         [ 5.18328035e-02,  2.96111827e+00, -1.26762034e-01, ...,
          -1.67646442e+00, -3.51223384e-02,  1.59984818e+00],
         ...,
         [-1.92348527e-01, -7.31122256e-01, -1.20223233e+00, ...,
           7.08418603e-01,  2.29864667e+00,  2.51987909e-01],
         [ 5.14413292e-01, -2.63641976e+00,  1.94228061e+00, ...,
          -2.27560769e+00,  7.40123642e-01,  3.83214307e-01],
         [-2.58714357e+00, -4.35147394e-01, -1.94246752e+00, ...,
          -2.31460022e+00, -1.53202255e+00, -1.22346926e+00]],

        [[-1.02982270e-02,  2.61379669e+00, -2.87972113e+00, ...,
          -2.06823677e+00, -2.25656292e+00, -1.45348551e+00],
         [-1.39591090e+00,  8.79003003e-01,  8.98049563e-01, ...,
           1.40326257e

In [34]:
for i in range(32):
    for j in range(32):
        for k in range(32):
            for l in range(32):
                if cp.abs(R[i,j,k,l]) < 1e-10:
                    break;
i,j,k,l

(31, 31, 31, 0)

In [85]:
S[15,0,0,:],R[15,0,0,:]

(array([ 1.1992251 , -0.25749797,  1.1784664 ,  1.2246039 , -1.5664622 ,
         0.10809969, -0.05211137, -2.05499   ,  0.17794305, -0.30185643,
        -1.5898674 ,  1.3846902 ,  0.45092925, -1.3599117 ,  1.3786889 ,
         0.39564547, -1.6203386 ,  1.1449941 ,  0.14121003, -1.7009188 ,
         1.0324929 ,  0.16203845, -1.4555358 ,  1.051387  ,  0.1779365 ,
        -1.5620178 ,  0.8273582 , -0.04528894, -1.6886771 ,  0.8753021 ,
         0.2884852 , -1.2867539 ], dtype=float32),
 array([ 0.23272736,  0.64044582, -0.91522781, -0.62149487, -0.32385745,
         1.88879659,  1.12220332,  1.73998564, -2.09439845, -0.12520782,
         0.46036399,  0.53395544,  1.1167475 ,  1.05663532,  0.63739772,
         0.55586611,  0.65492732,  0.17889092, -0.22593979, -1.2491062 ,
         0.98337806, -0.69307177, -0.67549999, -0.32664702, -0.0732497 ,
        -0.21690934, -0.46116725, -0.06820225, -0.49139129, -0.41209389,
        -0.20319792, -0.82913862]))

In [19]:
T = cp.zeros(S.shape)
conv3(conv_blocks, conv_threads, (S, T))
T

array([[[[ 9.38327175e+034,  4.35338645e+276, -3.36590460e+156, ...,
          -1.05172050e+195,  2.00911095e+209, -1.61944694e+002],
         [-1.57382001e+067,  5.84634646e+183,  8.25526469e+156, ...,
          -2.25546941e+238, -1.62396315e+001,  1.69374365e+002],
         [ 1.06953424e+167, -2.50527334e+181,  9.37940503e-007, ...,
           3.90281917e+285, -7.85176311e+001,  2.41989186e+002],
         ...,
         [-1.96132952e+292,  3.89828210e-012, -1.27398741e+002, ...,
           1.80104603e+111, -1.47376445e+107,  1.52817492e+002],
         [-8.64834313e+258, -6.09317499e-011,  3.04993713e+177, ...,
           1.17221800e+283, -4.71414302e+235,  8.17749811e+176],
         [ 2.45817572e+243,  3.14693011e+077, -4.88797752e+279, ...,
          -1.68978542e+105,  4.55093566e+141, -1.22346950e+000]],

        [[ 9.41138294e+240, -2.07148669e+270,  4.35338729e+276, ...,
          -1.89949184e+258, -1.05172045e+195,  2.00911009e+209],
         [ 3.39592775e+272, -8.34384304e+060, 

In [49]:
S.sum()

array(812.5949202)

In [89]:
img = X[0]
S = cp.matmul(img.T, img).reshape(32, 32, 32, 32)
S

array([[[[ 1.13238096e+00,  8.13326299e-01,  1.03628063e+00, ...,
           6.54789627e-01,  2.30517149e-01,  2.72549242e-01],
         [ 8.97159934e-01,  3.28247964e-01,  3.32177728e-01, ...,
          -1.49122979e-02,  6.50195658e-01,  5.13034798e-02],
         [ 1.61054254e-01,  1.00548649e+00,  1.12059839e-01, ...,
          -1.16612874e-01, -1.36305153e-01,  4.40546989e-01],
         ...,
         [ 9.18131888e-01,  1.46909547e+00,  6.29363060e-01, ...,
          -3.40207946e-03,  2.99524307e-01,  1.18311119e+00],
         [-3.24543595e-01, -1.24958821e-01,  5.96916497e-01, ...,
           1.00668788e+00,  9.09450412e-01,  1.20102525e+00],
         [ 1.96128368e+00,  6.56342149e-01,  1.12117875e+00, ...,
           5.42275429e-01,  9.74691331e-01, -6.45109296e-01]],

        [[ 8.13326299e-01,  1.82068825e+00,  9.32256699e-01, ...,
           3.61388117e-01, -6.32686734e-01,  1.06509364e+00],
         [ 5.24712086e-01, -5.74983060e-01,  1.14924002e+00, ...,
           1.56861126e

In [91]:
np.trace(S[:2,:2,:2,:2].reshape(4,4))

array(4.5404377, dtype=float32)

In [79]:
T = cp.zeros(S.shape)
conv3(conv_blocks, conv_threads, (S, T))
T

array([[[[ 2.31900841e-01,  3.99505322e+03,  4.96448231e+00, ...,
          -1.83758907e-03,  5.13708292e-05,  7.49726661e-04],
         [ 1.26248901e-03,  1.77952983e-02,  3.32431106e+01, ...,
           3.46573058e+02, -1.40230174e-08,  1.08150896e-04],
         [ 8.45301879e+02,  2.45847732e-01,  2.59256220e+00, ...,
           5.09685949e-02,  4.16210022e+02,  2.67433114e-02],
         ...,
         [ 4.09300328e+00,  1.82955976e+04,  2.94332564e-01, ...,
           7.00831505e-03,  3.11130763e+00,  2.42482471e+03],
         [ 9.43808136e+01,  1.73095124e+02,  1.43599281e+05, ...,
           5.25753612e+03,  1.56151061e-03,  2.12336207e+00],
         [ 1.43439917e+03,  7.58551899e-03,  5.08357362e-02, ...,
           2.33224817e-02,  5.59615567e-02, -1.35963520e-04]],

        [[ 2.32689552e+01,  1.32962971e+01,  1.13118641e+05, ...,
           1.33078686e+03, -1.34417023e-02,  7.24341508e-12],
         [ 1.52740645e+04,  3.48843622e+00,  3.19320346e+00, ...,
           1.17218342e

In [86]:
T[16,0,0,:]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

# Run All

In [66]:
import keras
from tqdm.notebook import tqdm

In [52]:
cifar = keras.datasets.cifar10
(X_train, y_train), (X_test, y_test) = cifar.load_data()

In [53]:
X_train = (X_train / 255.0).astype(np.float32) 
X_test = (X_test / 255.0).astype(np.float32) 
mean = X_train.mean(axis = (0, 2, 3)) 
std = X_train.std(axis = (0, 2, 3)) 
X_train = (X_train - mean[:, None, None]) / std[:, None, None]
X_test = (X_test - mean[:, None, None]) / std[:, None, None]

In [54]:
#CUDA kernel for convolution operation
conv3 = cp.RawKernel(r'''
extern "C" __global__
void conv3(const float s[32][32][32][32], float t[32][32][32][32])
{
    int x1 = threadIdx.x + blockIdx.x - 31;
    int y1 = threadIdx.y + blockIdx.y - 31;
    int x2 = threadIdx.x;
    int y2 = threadIdx.y;

    __shared__ float d[32 + 2][32 + 2];
    if (x2 == 0){
        d[0][y2 + 1] = d[33][y2 + 1] = 0;
        if (x2 == 0 && y2 == 0)
            d[0][0] = d[0][33] = d[33][0] = d[33][33] = 0; 
    }
    if (y2 == 0){
        d[x2 + 1][0] = d[x2 + 1][33] = 0;
    }

    if (x1 < 0 || x1 > 31 || y1 < 0 || y1 > 31){
        d[x2 + 1][y2 + 1] = 0;
        return;
    }
    else
        d[x2 + 1][y2 + 1] = s[x1][y1][x2][y2];
    __syncthreads();

    t[x1][y1][x2][y2] = d[x2][y2] + d[x2][y2 + 1] + d[x2][y2 + 2]
                      + d[x2 + 1][y2] + d[x2 + 1][y2 + 1] + d[x2 + 1][y2 + 2]
                      + d[x2 + 2][y2] + d[x2 + 2][y2 + 1] + d[x2 + 2][y2 + 2];

}''', 'conv3')
conv_blocks = (63, 63)
conv_threads = (32, 32)

In [55]:
#CUDA kernel for activation
trans = cp.RawKernel(r'''
extern "C" __global__
void trans(float s[32][32][32][32], float t[32][32][32][32], const float l[32][32], const float r[32][32], const float il[32][32], const float ir[32][32])
{
    int x1 = blockIdx.x;
    int y1 = blockIdx.y;
    int x2 = threadIdx.x + ((blockIdx.z >> 2) << 3);
    int y2 = threadIdx.y + ((blockIdx.z & 3) << 3);
    float S = s[x1][y1][x2][y2], T = t[x1][y1][x2][y2], L = l[x1][y1], R = r[x2][y2], iL = il[x1][y1], iR = ir[x2][y2];
    S = S * iL * iR;
    float BS = (S * (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) + sqrtf(1.0f - min(S * S, 1.0f))) * L * R / 28.274333882308138f;
    S = (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) / 28.274333882308138;
    t[x1][y1][x2][y2] = T * S + BS;
    s[x1][y1][x2][y2] = BS;

}''', 'trans')
trans_blocks = (32, 32, 16)
trans_threads = (8, 8)

In [56]:
#Calculate diagonal entries of $\Sigma^{(h)}(x, x)$ and their reciprocals. See Section 4.3 in our paper. 
def xx(x):
    RL = [1.0, ]
    iRL = [1.0, ]

    S = cp.matmul(x.T, x).reshape(32, 32, 32, 32)
    conv3(conv_blocks, conv_threads, (S, S))
    T = cp.zeros((32, 32, 32, 32), dtype = cp.float32)
    if not fix:
        T += S

    for i in range(1, d - 1):
        L = cp.sqrt(cp.diag(S.reshape(1024, 1024)).reshape(32, 32))
        iL = 1.0 / L
        RL.append(L)
        iRL.append(iL)
        trans(trans_blocks, trans_threads, (S, T, L, L, iL, iL))
        conv3(conv_blocks, conv_threads, (S, S))
        conv3(conv_blocks, conv_threads, (T, T))

    L = cp.sqrt(cp.diag(S.reshape(1024, 1024)).reshape(32, 32))
    iL = 1.0 / L
    RL.append(L)
    iRL.append(iL)
    trans(trans_blocks, trans_threads, (S, T, L, L, iL, iL))

    if fix:
        T -= S
    return RL, iRL

#Caclulate the kernel value of x and z.
#Lx and Lz are diagonal entries of $\Sigma^{(h)}(x, x)$ and $\Sigma^{(h)}(z, z)$. 
#iLx and iLz are reciprocals of diagonal entries of $\Sigma^{(h)}(x, x)$ and $\Sigma^{(h)}(z, z)$. 
def xz(x, z, Lx, Lz, iLx, iLz):
    S = cp.matmul(x.T, z).reshape(32, 32, 32, 32)
    conv3(conv_blocks, conv_threads, (S, S))
    T = cp.zeros((32, 32, 32, 32), dtype = cp.float32)
    if not fix:
        T += S

    for i in range(1, d - 1):
        trans(trans_blocks, trans_threads, (S, T, Lx[i], Lz[i], iLx[i], iLz[i]))
        conv3(conv_blocks, conv_threads, (S, S))
        conv3(conv_blocks, conv_threads, (T, T))

    trans(trans_blocks, trans_threads, (S, T, Lx[-1], Lz[-1], iLx[-1], iLz[-1]))

    if fix:
        T -= S
    return cp.mean(T) if gap else cp.trace(T.reshape(1024, 1024))

In [94]:
d = 5
gap = True
fix = True

In [61]:
#Load CIFAR-10.
X = np.concatenate((X_train, X_test), axis = 0)
N = X.shape[0]
N_train = X_train.shape[0]
N_test = X_test.shape[0]
X = cp.asarray(X).reshape(-1, 3, 1024)

In [95]:
%%time

#Calculate diagonal entries.
L = []
iL = []
for i in range(N):
    Lx, iLx = xx(X[i])
    L.append(Lx)
    iL.append(iLx)

Exception ignored in: <function tqdm.__del__ at 0x0000024438E8C1F0>
Traceback (most recent call last):
  File "C:\Users\E2\anaconda3\envs\cuda\lib\site-packages\tqdm\std.py", line 1145, in __del__
    self.close()
  File "C:\Users\E2\anaconda3\envs\cuda\lib\site-packages\tqdm\notebook.py", line 283, in close
    self.disp(bar_style='danger', check_delay=False)
AttributeError: 'tqdm_notebook' object has no attribute 'disp'


Wall time: 10min 25s


In [96]:
%%time

#####Calculate kernel values.
#####Below we provide a naive implementation using for-loops.
#####Parallelize this part according to your specific computing enviroment to utilize multiple GPUs.
H = np.zeros((N, N), dtype = np.float32)
for i in range(N):
    if (i%100) == 0:
        print(i)
    for j in range(N):
        H[i][j] = xz(X[i], X[j], L[i], L[j], iL[i], iL[j])
#####

MemoryError: Unable to allocate 13.4 GiB for an array with shape (60000, 60000) and data type float32

In [97]:
#Solve kernel regression.
Y_train = np.ones((N_train, 10)) * -0.1
for i in range(N_train):
    Y_train[i][y_train[i]] = 0.9
u = H[N_train:, :N_train].dot(scipy.linalg.solve(H[:N_train, :N_train], Y_train))
print ("test accuracy:", 1.0 * np.sum(np.argmax(u, axis = 1) == y_test) / N_test)

MemoryError: Unable to allocate 18.6 GiB for an array with shape (50000, 50000) and data type float64