In [2]:
import sys

import jax
import jax.numpy as jnp
from jax import random
from jax.example_libraries import stax, optimizers
from neural_tangents import stax as nt_stax

sys.path.append('..')
from src.utils import read_yaml, create_dataset

In [None]:
# load config
cfg = read_yaml(yaml_path='src/configs/generalized_adam.yaml')

# parameters in config
n_classes = cfg.DATA.N_CLASSES
target_classes = cfg.DATA.TARGET_CLASSES

n_layers = cfg.MODEL.N_LAYERS
n_width = cfg.MODEL.N_WIDTH
weight_variance = cfg.MODEL.WEIGHT_VARIANCE
bias_variance = cfg.MODEL.BIAS_VARIANCE

batch_size = cfg.OPTIMIZER.BATCH_SIZE
learning_rate = cfg.OPTIMIZER.LEARNING_RATE
damping = cfg.OPTIMIZER.DAMPING
diag_reg = cfg.OPTIMIZER.DIAG_REG

epochs = cfg.GENERAL.EPOCHS
devices = cfg.GENERAL.DEVICES
random_seed = cfg.GENERAL.SEED
store_on_device = cfg.GENERAL.STORE_ON_DEVICE

# setup device
if devices is None:
    devices = jax.device_count()

# build data pipelines
print('Loading data...')
assert n_classes >= 2

if target_classes is None:
    target_classes = list(range(n_classes))
else:
    target_classes = [int(cls) for cls in target_classes]

assert len(target_classes) == n_classes

x_train, y_train, x_test, y_test, target_classes = create_dataset(cfg)

if n_classes == 2:
    n_outputs = 1
else:
    n_outputs = n_classes

# build the network (TODO: Adapt CNN)
_layers = []
assert n_layers > 1
w_std = jnp.sqrt(weight_variance)
b_std = jnp.sqrt(bias_variance)
for i in range(n_layers - 1):
    _layers += [
        nt_stax.Dense(n_width, W_std=w_std, b_std=b_std, parameterization='ntk'),
        nt_stax.Relu()
    ]
_layers.append(
    nt_stax.Dense(n_outputs, W_std=w_std, b_std=b_std, parameterization='ntk')
    )

init_fn, apply_fn, kernel_fn = nt_stax.serial(*_layers)

key = random.PRNGKey(random_seed)
_, params = init_fn(key, (-1, x_train.shape[-1]))

opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)