In [1]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp 
from jax import vmap, jit 
import scipy
from scipy.special import roots_legendre, factorial
from functools import partial

from einops import repeat, rearrange

In [2]:
def ReLU(x):
    # ReLU activation function
    # n : degree 
    return jnp.maximum(0,x)**2

In [3]:
def GP1D(l,n):
    x = np.linspace(0,1,n)
    y = np.linspace(0,1,n)
    mean = np.zeros_like(x)
    gram = np.exp(-np.abs(x.reshape(-1,1) - y.reshape(1,-1))**2/(2*l**2))
    f = np.random.multivariate_normal(mean, gram)
    return f

In [4]:
def Kernel(x, y):
    return 0.5*(x+y-jnp.abs(x-y)) - x*y

In [5]:
nSample = 20
nPts = 513
x = jnp.linspace(0,1,nPts)
h = x[1]

# eval kernel function
xx, yy = jnp.meshgrid(x,x)
xx, yy = xx.reshape(-1), yy.reshape(-1)
K = Kernel(xx,yy).reshape(nPts, nPts)

# eval force functions
fs = []
for i in range(nSample):
    fs.append(GP1D(0.02,nPts))
fs = jnp.array(fs) # nSample x npts

# calc solution functions
us = h* (K @ fs.T).T # nSample x npts

2024-07-12 14:25:22.021529: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
def gforward(wx, wy, b):
    return ReLU(wx*xs+wy*ys+b).reshape(nPts, nPts)

gsforward = jit(vmap(gforward))

In [7]:
def brute_search(wxs, wys, bs, Gk):
    gs = gsforward(wxs, wys, bs) # nParam x npts x npts
    uhss = h * rearrange(gs @ fs.T, 'b n s -> b s n') # nParam x nSample x nPts 
    uts = h * rearrange(Gk @ fs.T, 'n s -> s n') # nSample x nPts
    rG = h * ((us - uts) * uhss).sum(axis=(1,2))
    E = -0.5 * rG ** 2
    idx = jnp.argmin(E)
    wxk, wyk, bk = wxs[idx], wys[idx], bs[idx]
    return wxk, wyk, bk, E

In [8]:
nSample = 20
nPts = 513
x = jnp.linspace(0,1,nPts)
xx, yy = jnp.meshgrid(x,x)
xx, yy = xx.reshape(-1), yy.reshape(-1)

nNeuron = 3
xs = xx 
ys = yy 
h = 1/(nPts - 1)

Gref = K 
OmegaX = []
OmegaY = []
Beta = []
Alpha = []

# learned kernel 
Gk = jnp.zeros((nPts, nPts))
gbasis = jnp.zeros((nNeuron, nPts, nPts))

# parameter space
nw = 4
nb = 1001 
nParam = nw * nw * nb
Wx = jnp.linspace(-1,1,nw)
Wy = jnp.linspace(-1,1,nw)
B = jnp.linspace(-2,2,nb)
wwx, wwy, bb = jnp.meshgrid(Wx, Wy, B)
wxs = wwx.reshape(-1)
wys = wwy.reshape(-1)
bs = bb.reshape(-1)

for k in range(nNeuron):
    # loss measure 
    l2 = ((Gk - Gref)**2).sum()
    print('{:}th - {:.4e}'.format(k, l2))

    # maximization step
    wxk, wyk, bk, E = brute_search(wxs, wys, bs, Gk)
    print(wxk, wyk, bk)
    print(E.min())
    # find new basis gk
    gk = gforward(wxk, wyk, bk)
    # print(gk)

    # projection step
    gbasis = gbasis.at[k].set(gk)
    gsub = gbasis[:k+1] # k x nPts x nPts
    uhsub = h * rearrange(gsub @ fs.T, 'k n s -> k s n')
    A = h * jnp.einsum('kns,pns->kp', uhsub, uhsub)
    b = h * (uhsub * us).sum(axis=(1,2))
    print(A)
    print(b)
    alpha_k = jnp.linalg.solve(A, b).reshape(-1,1)

    # update Gk 
    OmegaX.append(wxk)
    OmegaY.append(wyk)
    Beta.append(bk)

    omegax_k = jnp.concatenate([x.reshape(1,1) for x in OmegaX])
    omegay_k = jnp.concatenate([y.reshape(1,1) for y in OmegaY])
    beta_k = jnp.concatenate([b.reshape(1,1) for b in Beta])

    # print('alpha : ', alpha_k.shape)
    # print('omegax : ', omegax_k.shape)
    # print('omegay : ', omegay_k.shape)
    # print('beta : ', beta_k.shape)
    # print(((omegax_k * xs) + (omegay_k * ys) + beta_k).shape)
    Gk = (alpha_k.T @ ((omegax_k * xs) + (omegay_k * ys) + beta_k)).reshape(nPts, nPts)

0th - 2.9127e+03
1.0 1.0 2.0
-0.3547313
[[88.46504]]
[0.84229803]
1th - 1.8835e+03
1.0 1.0 2.0
-0.1643003
[[88.465034 88.465034]
 [88.465034 88.465034]]
[0.84229803 0.84229803]
2th - 1.8835e+03
1.0 1.0 2.0
-0.16430026
[[88.46427 88.46427 88.46427]
 [88.46427 88.46427 88.46427]
 [88.46427 88.46427 88.46427]]
[0.8422971 0.8422971 0.8422971]


In [9]:
E.sort()

Array([-0.16430026, -0.16341607, -0.16254802, ..., -0.        ,
       -0.        , -0.        ], dtype=float32)