In [1]:
import sys

import jax
import jax.numpy as jnp
from jax import random
from jax.example_libraries import stax, optimizers
from neural_tangents import stax as nt_stax

sys.path.append('..')
from src.utils import read_yaml, create_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load config
cfg = read_yaml(yaml_path='../src/configs/generalized_adam.yaml')

# parameters in config
n_classes = cfg.DATA.N_CLASSES
target_classes = cfg.DATA.TARGET_CLASSES

n_layers = cfg.MODEL.N_LAYERS
n_width = cfg.MODEL.N_WIDTH
weight_variance = cfg.MODEL.WEIGHT_VARIANCE
bias_variance = cfg.MODEL.BIAS_VARIANCE

batch_size = cfg.OPTIMIZER.BATCH_SIZE
learning_rate = cfg.OPTIMIZER.LEARNING_RATE
damping = cfg.OPTIMIZER.DAMPING
diag_reg = cfg.OPTIMIZER.DIAG_REG

epochs = cfg.GENERAL.EPOCHS
devices = cfg.GENERAL.DEVICES
random_seed = cfg.GENERAL.SEED
store_on_device = cfg.GENERAL.STORE_ON_DEVICE

# setup device
if devices is None:
    devices = jax.device_count()

# build data pipelines
print('Loading data...')
assert n_classes >= 2

if target_classes is None:
    target_classes = list(range(n_classes))
else:
    target_classes = [int(cls) for cls in target_classes]

assert len(target_classes) == n_classes

x_train, y_train, x_test, y_test, target_classes = create_dataset(cfg)

if n_classes == 2:
    n_outputs = 1
else:
    n_outputs = n_classes

# build the network (TODO: Adapt CNN)
_layers = []
assert n_layers > 1
w_std = jnp.sqrt(weight_variance)
b_std = jnp.sqrt(bias_variance)
for i in range(n_layers - 1):
    _layers += [
        nt_stax.Dense(n_width, W_std=w_std, b_std=b_std, parameterization='ntk'),
        nt_stax.Relu()
    ]
_layers.append(
    nt_stax.Dense(n_outputs, W_std=w_std, b_std=b_std, parameterization='ntk')
    )

init_fn, apply_fn, kernel_fn = nt_stax.serial(*_layers)

key = random.PRNGKey(random_seed)
_, params = init_fn(key, (-1, x_train.shape[-1]))

opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

Loading data...


2022-09-05 02:31:45.408071: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".
2022-09-05 02:31:46.671388: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2022-09-05 02:31:46.671405: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download an

In [3]:
x_train.shape

(128, 784)

In [4]:
train_ntk = kernel_fn(x_train, x_train, 'ntk')

In [5]:
train_ntk.shape

(128, 128)

In [6]:
test_ntk = kernel_fn(x_test, x_test, 'ntk')

In [7]:
test_ntk.shape

(32, 32)

In [8]:
fx0_train = apply_fn(params, x_train)
fx0_test = apply_fn(params, x_test)
print(fx0_train.shape, fx0_test.shape)

(128, 1) (32, 1)


In [15]:
ngd_loss = lambda f, y: 0.5 * jnp.mean(jnp.sum((f - y)**2, axis=1))
new_loss = lambda err, ntk: 0.5 * jnp.mean(err.T @ jnp.linalg.inv(ntk) @ err / len(err))

In [18]:
ngd_loss(fx0_train, y_train)

DeviceArray(0.12756318, dtype=float32)

In [20]:
new_loss(fx0_train - y_train, train_ntk).tolist()

4.505012035369873