<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/weight_space_linearization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

#### Import & Utils

Install JAX, Tensorflow Datasets, and Neural Tangents

The first line specifies the version of jaxlib that we would like to import. Note, that "cp36" species the version of python (version 3.6) used by JAX. Make sure your colab kernel matches this version.

In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade 'jax[cuda]'==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

In [1]:
from jax import jit, grad, random, vmap

import jax.numpy as np
from jax.nn import log_softmax
from jax.example_libraries import optimizers

import neural_tangents as nt
from neural_tangents import stax

import matplotlib as mpl
import matplotlib_inline
import matplotlib.pyplot as plt
# import matplotlib.colors as colors
import functools


font_size = 11
mpl.rcParams.update({'font.size': font_size, 
                     'axes.titlesize': font_size, 
                     'axes.labelsize': font_size - 1, 
                     'legend.fontsize': font_size - 1, 
                     'xtick.labelsize': font_size - 2,
                     'ytick.labelsize': font_size - 2,
                     'savefig.dpi': 300
                    })

legend = functools.partial(plt.legend, fontsize=9)
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

from sklearn.metrics import mean_squared_error
import seaborn as sns

sns.set(font_scale=1.3)
sns.set_style("darkgrid", {"axes.facecolor": ".95"})

from utils.utils import loss_fn, choose_random_idxs
from utils.viz import format_plot, finalize_plot, plot_fn
from utils.optimizers import momentum

%load_ext autoreload
%autoreload 2

AttributeError: module 'jax' has no attribute 'ad'

In [None]:
key = random.PRNGKey(101)

In [None]:
train_points = 5
test_points = 50
noise_scale = 1e-1

target_fn = lambda x: np.sin(x)

In [None]:
key, x_key, y_key = random.split(key, 3)

x_train = random.uniform(x_key, (train_points, 1), minval=-np.pi, maxval=np.pi)
y_train = target_fn(x_train)
y_train += noise_scale * random.normal(y_key, (train_points, 1))
train = (x_train, y_train)

In [None]:
x_test = np.linspace(-np.pi, np.pi, test_points)
x_test = np.reshape(x_test, (test_points, 1))
y_test = target_fn(x_test)
test = (x_test, y_test)

In [None]:
plot_fn(train, test)
legend(loc='upper left')
finalize_plot((0.85, 0.6))
plt.savefig('figures/weight_space_linearization/fn_plot.png')

# Weight Space Linearization

In [None]:
# Training hyperparameters
learning_rate = 1e-1
batch_size = 128
# training_epochs = 5
# steps_per_epoch = 50000 // batch_size

training_steps = 10000

ts = np.arange(0, 10 ** 3, 10 ** -1)
print_every = 100

In [None]:
# Define neural network
init_fn, apply_fn, _ = stax.serial(
    stax.Dense(512, 1.5, 0.05), stax.Erf(),
    stax.Dense(512, 1.5, 0.05), stax.Erf(),
    stax.Dense(1, 1.5, 0.05)
)

apply_fn = jit(apply_fn)
_, params = init_fn(key, (-1, 1))  # Number of features == 1

In [None]:
# Linearize the network
apply_fn_lin = nt.linearize(apply_fn, params)

In [None]:
# Create an optimizer and initialize it for the full network and the linearized network
opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)
opt_apply = jit(opt_apply)
# state = opt_init(params)
# lin_state = opt_init(params)

In [None]:
# Create MSE loss
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)

In [None]:
# Specialize the loss to compute gradients of the network and linearized network
grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))
grad_lin_loss = jit(grad(lambda params, x, y: loss(apply_fn_lin(params, x), y)))

In [None]:
def train_network(key, lin=False, ts=None):
    _, params = init_fn(key, (-1, 1))
    
    if lin is False:
        ts_params = []
        train_losses = []
        test_losses = []
        
        opt_state = opt_init(params)
        
        for i in range(training_steps):
            # Save params for the training steps
            if ts is not None:
                if i in ts:
                    ts_params += [get_params(opt_state)]
            
            train_losses += [np.reshape(loss(apply_fn(get_params(opt_state), train[0]), train[1]), (1,))]
            test_losses += [np.reshape(loss(apply_fn(get_params(opt_state), test[0]), test[1]), (1,))]
            opt_state = opt_apply(i, grad_loss(get_params(opt_state), *train), opt_state)
        
        train_losses = np.concatenate(train_losses)
        test_losses = np.concatenate(test_losses)
        
        return get_params(opt_state), train_losses, test_losses, ts_params
    else:
        ts_lin_params = []
        train_lin_losses = []
        test_lin_losses = []

        opt_lin_state = opt_init(params)

        for i in range(training_steps):
            # Save params for the training steps
            if ts is not None:
                if i in ts:
                    ts_lin_params += [get_params(opt_lin_state)]
            
            train_lin_losses += [np.reshape(loss(apply_fn_lin(get_params(opt_lin_state), train[0]), train[1]), (1,))]
            test_lin_losses += [np.reshape(loss(apply_fn_lin(get_params(opt_lin_state), test[0]), test[1]), (1,))]
            opt_lin_state = opt_apply(i, grad_lin_loss(get_params(opt_lin_state), *train), opt_lin_state)

        train_lin_losses = np.concatenate(train_lin_losses)
        test_lin_losses = np.concatenate(test_lin_losses)
        
        return get_params(opt_lin_state), train_lin_losses, test_lin_losses, ts_lin_params

Test to make sure we get a trained network

In [None]:
params, train_loss, test_loss, ts_params = train_network(key, lin=False, ts=ts)
lin_params, train_lin_loss, test_lin_loss, ts_lin_params = train_network(key, lin=True, ts=ts)

In [None]:
plot_fn(train, None, xlabel='$x$', ylabel='$f$')

plt.plot(x_test, apply_fn(params, x_test), 'k-', linewidth=1.5)
plt.plot(x_test, apply_fn_lin(lin_params, x_test), 'k--', linewidth=1.5)

legend(['Train', 'Neural network', 'Linearized neural network'], loc='upper left')

finalize_plot((0.85, 0.6))

In [None]:
ensemble_size = 100
ensemble_key = random.split(key, ensemble_size)
params, train_loss, test_loss, ts_params = vmap(train_network, in_axes=(0, None, None))(ensemble_key, False, ts)
lin_params, train_lin_loss, test_lin_loss, ts_lin_params = vmap(train_network, in_axes=(0, None, None))(ensemble_key, True, ts)

In [None]:
mean_train_loss = np.mean(train_loss, axis=0)
var_train_loss = np.var(train_loss, axis=0)

mean_test_loss = np.mean(test_loss, axis=0)
var_test_loss = np.var(test_loss, axis=0)

mean_train_lin_loss = np.mean(train_lin_loss, axis=0)
var_train_lin_loss = np.var(train_lin_loss, axis=0)

mean_test_lin_loss = np.mean(test_lin_loss, axis=0)
var_test_lin_loss = np.var(test_lin_loss, axis=0)

In [None]:
# For printing in console
header_widths = [len(header) for header in ['Time', 'Train loss', 'Train linear loss', 'Test loss', 'Test linear loss']]
padding = 2
format_string = '\t'.join('{{:<{}}}'.format(width + padding) for width in header_widths)
print(format_string.format('Time', 'Train loss', 'Train linear loss', 'Test loss', 'Test linear loss'))

train_output_vals = np.empty((steps_per_epoch, len(x_train)))
lin_train_output_vals = np.empty((steps_per_epoch, len(x_train)))
test_output_vals = np.empty((steps_per_epoch, len(x_test)))
lin_test_output_vals = np.empty((steps_per_epoch, len(x_test)))

exact_train_losses = np.empty(steps_per_epoch)
lin_train_losses = np.empty(steps_per_epoch)
exact_test_losses = np.empty(steps_per_epoch)
lin_test_losses = np.empty(steps_per_epoch)
train_rmse_vals = np.empty(steps_per_epoch)
test_rmse_vals = np.empty(steps_per_epoch)

params_shape = np.array(params[0]).squeeze().shape
params_array = np.empty((steps_per_epoch, *params_shape))
lin_params_array = np.empty((steps_per_epoch, *params_shape))

for i in range(steps_per_epoch):
    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x_train, y_train), state)
    
    lin_params = get_params(lin_state)
    lin_state = opt_apply(i, grad_lin_loss(lin_params, x_train, y_train), lin_state)
    
    # Train loss and accuracy
    exact_train_output = apply_fn(params, x_train)
    lin_train_output = apply_fn_lin(lin_params, x_train)
    exact_train_loss = loss(exact_train_output, y_train)
    lin_train_loss = loss(lin_train_output, y_train)
    
    # Test loss and accuracy
    exact_test_output = apply_fn(params, x_test)
    lin_test_output = apply_fn_lin(lin_params, x_test)
    exact_test_loss = loss(exact_test_output, y_test)
    lin_test_loss = loss(lin_test_output, y_test)
    
    # Train and test RMSE
    train_rmse = np.sqrt(mean_squared_error(exact_train_output, lin_train_output))
    test_rmse = np.sqrt(mean_squared_error(exact_test_output, lin_test_output))
    
    if i % print_every == 0:
        t = i * learning_rate
        print(format_string.format(i, '{:.4f}'.format(exact_train_loss), '{:.4f}'.format(lin_train_loss), '{:.4f}'.format(exact_test_loss), '{:.4f}'.format(lin_test_loss), '{:.4f}'.format(train_rmse), '{:.4f}'.format(test_rmse)))
    
    # For plotting purposes
    # Parameter changes
    params_squeezed = np.array(params[0]).squeeze()
    params_array = params_array.at[i].set(params_squeezed)
    lin_params_squeezed = np.array(lin_params[0]).squeeze()
    lin_params_array = lin_params_array.at[i].set(lin_params_squeezed)
    
    
    # Output values
    train_output_vals = train_output_vals.at[i].set(exact_train_output.flatten())
    lin_train_output_vals = lin_train_output_vals.at[i].set(lin_train_output.flatten())
    test_output_vals = test_output_vals.at[i].set(exact_test_output.flatten())
    lin_test_output_vals = lin_test_output_vals.at[i].set(lin_test_output.flatten())

    # Losses
    exact_train_losses = exact_train_losses.at[i].set(exact_train_loss)
    lin_train_losses = lin_train_losses.at[i].set(lin_train_loss)
    exact_test_losses = exact_test_losses.at[i].set(exact_test_loss)
    lin_test_losses = lin_test_losses.at[i].set(lin_test_loss)

    # RMSE
    train_rmse_vals = train_rmse_vals.at[i].set(train_rmse)
    test_rmse_vals = test_rmse_vals.at[i].set(test_rmse)

In [None]:
cmap = plt.get_cmap('Dark2')

# Plot training output wrt. time steps

plt.subplot(2, 3, 1)

for k in range(train_points):
    color = cmap(k)
    plt.plot(np.arange(steps_per_epoch), train_output_vals[:, k],'-', color=color, label='Neural network')
    plt.plot(np.arange(steps_per_epoch), lin_train_output_vals[:, k], '--', color=color, label='Linearized model')
    
    plt.xscale('log')

# Plot test output wrt. time steps

plt.subplot(2, 3, 2)

for k in range(train_points):
    color = cmap(k)
    plt.plot(np.arange(steps_per_epoch), test_output_vals[:, k],'-', color=color)
    plt.plot(np.arange(steps_per_epoch), lin_test_output_vals[:, k], '--', color=color)
    
    plt.xscale('log')
    
    
# Plot weight change wrt. time steps
plt.subplot(2, 3, 3)

# Choose random idx
n_params_plot = 2

key, subkey1, subkey2 = random.split(key, 3)
idxs1 = choose_random_idxs(subkey1, params_array.shape[1], 1)
idxs2 = choose_random_idxs(subkey2, params_array.shape[2], n_params_plot)

for i, k in enumerate(idxs2):
    color = cmap(i)
    plt.plot(np.arange(steps_per_epoch), params_array[:, idxs1, k], '-', color=color)
    plt.plot(np.arange(steps_per_epoch), lin_params_array[:, idxs1, k], '--', color=color)
    
plt.xscale('log')

    
# Plot loss training and test sets wrt. time steps

plt.subplot(2, 3, 4)

plt.plot(np.arange(steps_per_epoch), exact_train_losses[:], 'k--', label='Train')
plt.plot(np.arange(steps_per_epoch), lin_train_losses[:], 'b--', label=r'Train $f^{\mathrm{lin}}$')
plt.plot(np.arange(steps_per_epoch), exact_test_losses[:], 'k-', label='Test')
plt.plot(np.arange(steps_per_epoch), lin_test_losses[:], 'b-', label=r'Test $f^{\mathrm{lin}}$')

plt.xscale('log')


# Plot RMSE wrt. time steps
plt.subplot(2, 3, 5)


plt.plot(np.arange(steps_per_epoch), train_rmse_vals[:], 'k--')
plt.plot(np.arange(steps_per_epoch), test_rmse_vals[:], 'k-')

plt.xscale('log')
plt.yscale('log')

# Compute average variance for 
# See what happens with the weights while incrementing the width
# Compare with theta_0

### Training an ensemble of neural networks

In [None]:
def train_network(key, ts=None):
    ts_params = []
    train_losses = []
    test_losses = []

    _, params = init_fn(key, (-1, 1))
    opt_state = opt_init(params)

    for i in range(training_steps):
        # Save params for the training steps
        if ts is not None:
            if i in ts:
                ts_params += [get_params(opt_state)]
            
        train_losses += [np.reshape(loss(get_params(opt_state), *train), (1,))]
        test_losses += [np.reshape(loss(get_params(opt_state), *test), (1,))]
        opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state)

    train_losses = np.concatenate(train_losses)
    test_losses = np.concatenate(test_losses)
    
    return get_params(opt_state), train_losses, test_losses, ts_params