# Figure 1
We show that arbitrary kernels can be accurately achieved by shallow networks with engineered pointwise nonlinearities.

# Imports

In [None]:
!pip install -q git+https://www.github.com/google/neural-tangents
!sudo apt-get install texlive-latex-recommended
!sudo apt install texlive-latex-extra
!sudo apt install dvipng
!sudo apt install cm-super dvipng
!pip install pmlb
!pip install cifar10_web

import cifar10_web

import cvxopt

import jax.nn
import jax.numpy as np
from jax import random
from jax.api import jit, grad, vmap
from jax.config import config
from jax.experimental import optimizers
from jax.ops import index_update

import neural_tangents as nt
from neural_tangents import stax

import numpy as base_np

import pandas as pd

from scipy.special import factorial, factorial2

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)

np.set_printoptions(precision=4, linewidth=200)

key = random.PRNGKey(17)

  Building wheel for neural-tangents (setup.py) ... [?25l[?25hdone
Reading package lists... Done
Building dependency tree       
Reading state information... Done
texlive-latex-recommended is already the newest version (2017.20180305-1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'sudo apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
texlive-latex-extra is already the newest version (2017.20180305-2).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'sudo apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
dvipng is already the newest version (1.15-1).
The following package was automatically inst

# Datasetup

Of the CIFAR-10 dataset, we get the first 10k training images and 10k testing images. We normalize the data such that the norm is $\sqrt{n_0}$ where $n_0$ is the input dimension.

In [None]:
def get_cifar10(n_train=None):
  train_X, train_y, test_X, test_y = cifar10_web.cifar10(path=None)

  # NORMALIZE GLOBALLY
  train_mean = train_X.mean()
  train_std = train_X.std()
  train_X = (train_X - train_mean)/train_std
  test_X = (test_X - train_mean)/train_std
  # NORMALIZE LOCALLY
  train_X = train_X/((train_X**2).mean(axis=1)**.5)[:,None]
  test_X = test_X/((test_X**2).mean(axis=1)**.5)[:,None]

  if n_train is not None:
    train_X = train_X[:n_train]
    train_y = train_y[:n_train]

  return train_X, train_y, test_X, test_y

Run some random input through a neural network to generate output.

In [None]:
def get_teacher_dataset(d_in=10, d_out=1, width=1000, n_train=100, n_test=1000, n_hidden_layers=3, nonlinearity='relu', W_std=None, b_std=None):
  # DRAW RANDOM POINTS
  global key
  key, split_key_1, split_key_2 = random.split(key, 3)
  train_X = random.normal(split_key_1, (n_train, d_in))
  test_X = random.normal(split_key_2, (n_test, d_in))

  # NORMALIZE TO A HYPERSPHERE
  train_X = train_X/((train_X**2).mean(axis=1)**.5)[:,None]
  test_X = test_X/((test_X**2).mean(axis=1)**.5)[:,None]

  # DEFINE TEACHER NET ARCHITECTURE
  layers = None
  if nonlinearity == 'relu':
    W_std = W_std if W_std is not None else 1.5
    b_std = b_std if b_std is not None else 0.05
    layers = [stax.Dense(width, W_std=W_std, b_std=b_std), stax.Relu()]*n_hidden_layers
  elif nonlinearity == 'erf':
    W_std = W_std if W_std is not None else 2
    b_std = b_std if b_std is not None else 0
    layers = [stax.Dense(width, W_std=W_std, b_std=b_std), stax.Erf()]*n_hidden_layers
  else:
    assert False
  layers += [stax.Dense(d_out, W_std=1, b_std=0)]
  init_fn, apply_fn, _ = stax.serial(*layers)

  key, net_key = random.split(key)
  _, initial_params = init_fn(net_key, (-1, d_in))
  apply_fn = jit(apply_fn)

  # SAMPLE TARGETS
  train_y = apply_fn(initial_params, train_X)
  test_y = apply_fn(initial_params, test_X)

  return train_X, train_y, test_X, test_y

# Net setup

In [None]:
def get_net_functions(d_in, width, d_out, n_hidden_layers=1, phi=None, deg=40, centered=False, W_std=None, b_std=None):
    global key

    init_fn, apply_fn_uncentered, kernel_fn = None, None, None

    # if there's a phi, make a net with activation function phi
    if phi is not None and phi != 'relu':
      if phi == 'erf':
        W_std = W_std if W_std is not None else 1.5
        b_std = b_std if b_std is not None else 0.3
        layers = [stax.Dense(width, W_std=W_std, b_std=b_std), stax.Erf()]*n_hidden_layers
        layers += [stax.Dense(d_out, W_std=1, b_std=0)]
        init_fn, apply_fn_uncentered, kernel_fn = stax.serial(*layers)
      else:
        W_std = W_std if W_std is not None else 1   # in the case of pointwise nonlinearities, we use the variances from the proofs.
        b_std = b_std if b_std is not None else 0
        layers = [stax.Dense(width, W_std=W_std, b_std=b_std), stax.ElementwiseNumerical(fn=phi, deg=deg)]*n_hidden_layers
        layers += [stax.Dense(d_out, W_std=1, b_std=0)]
        init_fn, apply_fn_uncentered, kernel_fn = stax.serial(*layers)

    # otherwise, make a relu net
    else:
      W_std = W_std if W_std is not None else 1.5
      b_std = b_std if b_std is not None else 0.1
      layers = [stax.Dense(width, W_std=W_std, b_std=b_std), stax.Relu()]*n_hidden_layers
      layers += [stax.Dense(d_out, W_std=1, b_std=0)]
      init_fn, apply_fn_uncentered, kernel_fn = stax.serial(*layers)

    key, net_key = random.split(key)
    _, initial_params = init_fn(net_key, (-1, d_in))
    apply_fn = jit(apply_fn_uncentered) if not centered else jit(lambda params, x: apply_fn_uncentered(params, x) - apply_fn_uncentered(initial_params, x))

    return init_fn, apply_fn, kernel_fn, initial_params

In [None]:
def get_batched_kernel_fn(kernel_fn, k_batch_size=100):
  def batched_kernel_fn(x1, x2, get=None):
    x2 = x1 if x2 is None else x2
    get = get[0] if get in [('nngp',),('ntk',)] else get
    kernel_fn_jit = jit(lambda x1,x2: kernel_fn(x1, x2, get))
    # subkernels = [kernel_fn(x1, x2[k_batch_size*i:min(k_batch_size*(i+1), x2.shape[0])], get=get) for i in range(int(np.ceil(x2.shape[0]/k_batch_size)))]
    subkernels = [kernel_fn_jit(x1, x2[k_batch_size*i:min(k_batch_size*(i+1), x2.shape[0])]) for i in range(int(np.ceil(x2.shape[0]/k_batch_size)))]
    output = np.concatenate(subkernels, axis=1)
    return output
  return batched_kernel_fn

In [None]:
mse = lambda y_hat, y_true: 0.5 * ((y_hat - y_true) ** 2).sum(axis=1).mean()
percent_correct = jit(lambda y_hat, y_true: 100*(np.argmax(y_hat, axis=1) == np.argmax(y_true, axis=1)).mean())

# Phi generation functions

In [None]:
def psd_poly_fit(xs, fs, deg=5):
  Q = base_np.zeros((deg+1, deg+1))
  p = base_np.zeros((deg+1,))
  G = -1*base_np.eye(deg+1)
  h = 0*p

  for i in range(deg + 1):
    for j in range(deg + 1):
      Q[i][j] = 2*(xs**(i+j)).sum()
  for i in range(deg + 1):
    p[i] = -2*((xs**i)*fs).sum()
  
  Q = (10**3)*cvxopt.matrix(Q)
  p = (10**3)*cvxopt.matrix(p)
  G = cvxopt.matrix(G)
  h = cvxopt.matrix(h)
  
  cvxopt.solvers.options['show_progress'] = False
  sol = cvxopt.solvers.qp(Q, p, G, h)

  return base_np.array(sol['x']).flatten()

In [None]:
def poly_coeffs_to_lambda_fn_string(c_alpha):
  output = "lambda z: "
  for i in range(len(c_alpha)):
    coeff = c_alpha[i]/factorial(i)
    if coeff != 0:
      output += str(coeff) + "*z**" + str(i) + " + "
  return(output[:-3])

In [None]:
def phi_from_kernel_fn(kernel_fn, k_type, deg=10, n_sample_pts=1000, weight_on_endpts=0):
  d_in = 2
  n_sample_pts = n_sample_pts
  xis = np.linspace(1,-1,n_sample_pts)

  # if weighting the endpoints higher, add more -1s and +1s to xis
  if weight_on_endpts > 0:
    n_interior_pts = int(n_sample_pts*(1 - 2*weight_on_endpts))
    n_endpts = int(n_sample_pts*weight_on_endpts)
    xis = np.linspace(1, -1, n_interior_pts + 2)
    xis = np.concatenate([np.array([1]*(n_endpts - 1)), xis, np.array([-1]*(n_endpts - 1))])

  sines = (1 - xis**2)**.5
  u0 = index_update(np.zeros(d_in), 0, 1)
  u1 = index_update(np.zeros(d_in), 1, 1)
  xs = np.outer(xis, u0) + np.outer(sines, u1)
  xs = (d_in**.5)*xs

  Ks = kernel_fn(xs[0:1], xs, k_type)[0]

  desired_coeffs = psd_poly_fit(xis, Ks, deg=deg)
  desired_coeffs *= (np.array(np.abs(desired_coeffs) > 10**-3))
  print('approximating K as a polynomial with coeffs', desired_coeffs)

  # construct matrix for going from K (i.e. a_gamma) to phi (i.e. c_alpha)
  Minv = [[(-1)**((col - row)/2)/factorial2(col - row)
          if (col >= row)*((col + row)%2 == 0)
          else 0
          for col in range(deg + 1)]
          for row in range(deg + 1)]
  Minv = np.array(Minv)

  a_gamma = [desired_coeffs[gamma]*factorial(gamma) if gamma < len(desired_coeffs) else 0 for gamma in range(len(Minv))]
  b_gamma = np.array(a_gamma)**.5
  c_alpha = np.matmul(Minv, np.array(b_gamma))

  phi_def_string = poly_coeffs_to_lambda_fn_string(c_alpha)
  phi = eval(phi_def_string)

  string_to_print = '\\phi(z) = '
  for alpha in range(len(c_alpha)):
    coeff = c_alpha[alpha]/factorial(alpha)
    if abs(coeff) > 10**-9:
      if string_to_print[-2] != '=':
        if coeff < 0:
          string_to_print += ' - '
        else:
          string_to_print += ' + '
      else:
        if coeff < 0:
          string_to_print += '-'
      coeff = abs(coeff)
      string_to_print += '{:6.4f}'.format(coeff)
      if alpha > 0:
        string_to_print += ' z'
        if alpha > 1:
          string_to_print += f'^{alpha}'
  print('\n' + string_to_print + '\n')

  return phi

# Figure-makin' functions

In [None]:
def kernel_match_plot(ax, target_kernel, k_type='nngp', target_name='target', n_draws=100, x_axis_on=True, y_axis_label="$K(\\xi)$", phi_deg=20, n_sample_pts=100, width=1000, d_out=1, plot_i=0, weight_on_endpts=0):
  phi = phi_from_kernel_fn(target_kernel, k_type, deg=phi_deg, n_sample_pts=n_sample_pts, weight_on_endpts=weight_on_endpts)
  init_fn_1phi, apply_fn_1phi, kernel_fn_1phi, _ = get_net_functions(d_in, width, d_out, n_hidden_layers=1, phi=phi)
  # init_fn_1phi, apply_fn_1phi, kernel_fn_1phi, _ = get_net_functions(d_in, width, d_out, n_hidden_layers=5, phi='relu')
  # init_fn_1phi, apply_fn_1phi, kernel_fn_1phi, _ = get_net_functions(d_in, width, d_out, n_hidden_layers=5, phi='erf', W_std=1.5, b_std=.3)

  # INFINITE-WIDTH KERNELS
  kernel_samples = {}
  kernel_samples['target'] = target_kernel(xs[0:1], xs, k_type)[0][::-1]
  kernel_samples['1 phi (nngp)'] = kernel_fn_1phi(xs[0:1], xs, 'nngp')[0][::-1]

  # FINITE-WIDTH KERNEL
  sample_sets = []
  global key
  for i in tqdm(range(n_draws)):
    key, net_key = random.split(key)
    _, initial_params = init_fn_1phi(net_key, (-1, d_in))
    samples = apply_fn_1phi(initial_params, xs_sparse)
    draw_correlations = (samples*samples[0]).sum(axis=1)/d_out
    sample_sets += [draw_correlations.flatten()]
  sample_sets = np.array(sample_sets)
  finite_width_kernel = sample_sets.mean(axis=0)
  finite_width_kernel_stds = sample_sets.std(axis=0)/n_draws**.5

  c = 'ABCDEFGHI'[plot_i]
  ax.plot(cosines, kernel_samples['target'], label=target_name, color=(.6,.8,1), lw=7)
  ax.plot(cosines, kernel_samples['1 phi (nngp)'], label='1L $\phi_'+c+'$ (NNGP)', color=(1,0,0), lw=2)
  ax.scatter(cosines_sparse, finite_width_kernel, color=(.7,.0,0), label='1L $\phi_'+c+'$ (empirical)', zorder=6)
  ax.errorbar(cosines_sparse, finite_width_kernel, yerr=finite_width_kernel_stds, linestyle="None", color=(1,0,0), zorder=5)

  ax.tick_params(axis='both', labelsize=15)
  ax.set_xlabel("$\\xi$", fontsize=20)
  if not x_axis_on:
    ax.set_xlabel('')
    ax.set_xticklabels([])
  ax.set_ylabel(y_axis_label, fontsize=20)
  ax.legend(frameon=False, fontsize=14,
            loc=(.03,.52) if plot_i < 6 else (.03, .53) if plot_i==6 else None)
  
  ax.set_xlim((-1.02,1.02))

  return phi

In [None]:
def phi_set_plot(ax, phis):
  zs = np.linspace(-2.5, 2.5, 500)
  letters = 'ABCDEFGH'
  linestyles = [(1,0), (5,1), (3,1), (2,2), (2,1.3), (2,.8), (1,1), (1,1,3,1)]
  colors = [(1,0,0), (1,.5,0), (.8,.7,0), (0,.7,0), (0,.5,.7), (0,0,1), (.5,0,.8), (1,0,1)]

  for i, phi in enumerate(phis):
    c = letters[i]
    ax.plot(zs, phi(zs), label='$\\phi_'+c+'$', lw=3, linestyle=(0,linestyles[i]), color=colors[i])
  ax.set_xlim((min(zs), max(zs)))

  ax.set_xlabel('z', fontsize=20)
  ax.set_ylabel('$\\phi(z)$', fontsize=20, labelpad=-10)
  ax.tick_params(axis='both', labelsize=15)

  ax.legend(frameon=False, fontsize=12, ncol=2, labelspacing=0, columnspacing=1, loc=(.4, .02))#(.17,.63))

# Set up sample points

In [None]:
d_in = 2

n_sample_pts = 1000
# cosines of angles between points, called "xi" in the paper
cosines = np.linspace(-1,1,n_sample_pts)
sines = (1 - cosines**2)**.5

u0 = index_update(np.zeros(d_in), 0, 1)
u1 = index_update(np.zeros(d_in), 1, 1)
xs = np.outer(cosines, u0) + np.outer(sines, u1)
xs = (d_in**.5)*xs

n_sample_pts_sparse = 21
cosines_sparse = np.linspace(1,-1,n_sample_pts_sparse)
sines_sparse = (1 - cosines_sparse**2)**.5
xs_sparse = np.outer(cosines_sparse, u0) + np.outer(sines_sparse, u1)
xs_sparse = (d_in**.5)*xs_sparse

# Make figure

In [None]:
kwargs = {
    'n_draws':10000,
    'phi_deg':7,
    'n_sample_pts':1000,
    'width':10000,
    'd_out':100,
    'weight_on_endpts':.1
    }

n_layers_relu = 4
n_layers_erf = 4

fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(15,10))
axs = axs.flatten()

phis = [None]*8

i = 0
_, _, target_kernel, _ = get_net_functions(d_in, 1000, 1, n_hidden_layers=n_layers_relu, phi='relu')
phis[i] = kernel_match_plot(axs[i], target_kernel, target_name=str(n_layers_relu)+'L ReLU (NNGP)', **kwargs, plot_i=i)

i = 1
_, _, target_kernel, _ = get_net_functions(d_in, 1000, 1, n_hidden_layers=n_layers_relu, phi='relu')
phis[i] = kernel_match_plot(axs[i], target_kernel, k_type='ntk', target_name=str(n_layers_relu)+'L ReLU (NTK)', y_axis_label='', **kwargs, plot_i=i)

i = 2
_, _, target_kernel, _ = get_net_functions(d_in, 1000, 1, n_hidden_layers=n_layers_erf, phi='erf')
phis[i] = kernel_match_plot(axs[i], target_kernel, target_name=str(n_layers_erf)+'L Erf (NNGP)', y_axis_label='', **kwargs, plot_i=i)

i = 3
_, _, target_kernel, _ = get_net_functions(d_in, 1000, 1, n_hidden_layers=n_layers_erf, phi='erf')
phis[i] = kernel_match_plot(axs[i], target_kernel, k_type='ntk', target_name=str(n_layers_erf)+'L Erf (NTK)', **kwargs, plot_i=i)

i = 4
poly_fn = lambda xi: xi**4 + xi
target_kernel = lambda x1, x2, get=None: poly_fn(np.matmul(x1, x2.transpose())/d_in)
phis[i] = kernel_match_plot(axs[i], target_kernel, target_name='$K(\\xi) = \\xi^4 + \\xi$', y_axis_label='', **kwargs, plot_i=i)

i = 5
poly_fn = lambda xi: xi**5 + 3
target_kernel = lambda x1, x2, get=None: poly_fn(np.matmul(x1, x2.transpose())/d_in)
phis[i] = kernel_match_plot(axs[i], target_kernel, target_name='$K(\\xi) = \\xi^5 + 3$', y_axis_label='', **kwargs, plot_i=i)

i = 6
poly_fn = lambda xi: np.sinh(2*xi)
target_kernel = lambda x1, x2, get=None: poly_fn(np.matmul(x1, x2.transpose())/d_in)
phis[i] = kernel_match_plot(axs[i], target_kernel, target_name='$K(\\xi) = \sinh(2 \\xi)$', **kwargs, plot_i=i)

i = 7
poly_fn = lambda xi: np.cosh(2*xi)
target_kernel = lambda x1, x2, get=None: poly_fn(np.matmul(x1, x2.transpose())/d_in)
phis[i] = kernel_match_plot(axs[i], target_kernel, target_name='$K(\\xi) = \cosh(2 \\xi)$', y_axis_label='', **kwargs, plot_i=i)

phi_set_plot(axs[8], phis)

for i in range(len(axs)):
  axs[i].text(.95, .05, 'ABCDEFGHI'[i], transform=axs[i].transAxes, size=20, weight='bold', ha='center')

fig.tight_layout()

plt.savefig('fig1.png', transparent=True, dpi=300, bbox_inches='tight')