In [27]:
import sys

import jax
import jax.numpy as jnp
from jax import random
from jax.example_libraries import stax, optimizers
import neural_tangents as nt

sys.path.append('..')
from src.utils import read_yaml, create_dataset

In [28]:
# load config
cfg = read_yaml(yaml_path='../src/configs/generalized_adam.yaml')

# parameters in config
n_classes = cfg.DATA.N_CLASSES
target_classes = cfg.DATA.TARGET_CLASSES

n_layers = cfg.MODEL.N_LAYERS
n_width = cfg.MODEL.N_WIDTH
weight_variance = cfg.MODEL.WEIGHT_VARIANCE
bias_variance = cfg.MODEL.BIAS_VARIANCE

batch_size = cfg.OPTIMIZER.BATCH_SIZE
learning_rate = cfg.OPTIMIZER.LEARNING_RATE

epochs = cfg.GENERAL.EPOCHS
devices = cfg.GENERAL.DEVICES
random_seed = cfg.GENERAL.SEED

# setup device
if devices is None:
    devices = jax.device_count()

# build data pipelines
print('Loading data...')
assert n_classes >= 2

if target_classes is None:
    target_classes = list(range(n_classes))
else:
    target_classes = [int(cls) for cls in target_classes]

assert len(target_classes) == n_classes

x_train, y_train, x_test, y_test, target_classes = create_dataset(cfg)

if n_classes == 2:
    n_outputs = 1
else:
    n_outputs = n_classes

# build the network (TODO: Adapt CNN)
_layers = []
assert n_layers > 1
w_std = jnp.sqrt(weight_variance)
b_std = jnp.sqrt(bias_variance)
for i in range(n_layers - 1):
    _layers += [
        nt.stax.Dense(n_width, W_std=w_std, b_std=b_std, parameterization='ntk'),
        nt.stax.Relu()
    ]
_layers.append(
    nt.stax.Dense(n_outputs, W_std=w_std, b_std=b_std, parameterization='ntk')
    )

init_fn, apply_fn, kernel_fn = nt.stax.serial(*_layers)

key = random.PRNGKey(random_seed)
_, params = init_fn(key, (-1, x_train.shape[-1]))

opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

Loading data...


In [53]:
kernel_fn = nt.empirical_ntk_fn(
    apply_fn, trace_axes=(-1,), vmap_axes=0,
    implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES
    )

In [55]:
train_ntk = kernel_fn(x_train, None, params)
train_ntk_inv = jnp.linalg.inv(train_ntk)

In [56]:
train_ntk.shape, train_ntk_inv.shape

((128, 128), (128, 128))

In [57]:
test_ntk = kernel_fn(x_test, None, params)

In [58]:
test_ntk.shape

(32, 32)

In [59]:
fx0_train = apply_fn(params, x_train)
fx0_test = apply_fn(params, x_test)

(128, 1) (32, 1)


In [34]:
ngd_loss = lambda f, y: 0.5 * jnp.mean(jnp.sum((f - y)**2, axis=1), axis=0)
new_loss = lambda params, x, y, G_inv: 0.5 * jnp.mean(
    jnp.sum((apply_fn(params, x) - y).T @ G_inv @ (apply_fn(params, x) - y), axis=1), axis=0
    )

In [35]:
ngd_loss(fx0_train, y_train)

DeviceArray(0.12756318, dtype=float32)

In [64]:
new_loss(params, x_train, y_train, train_ntk_inv)

DeviceArray(602.6216, dtype=float32)

: 