### Setup

In [None]:
%pip install jax equinox optax optimistix treescope
%load_ext autoreload
%autoreload 2

from functools import partial
import jax
import jax.numpy as jnp
import equinox as eqx
import treescope

### Train model on MNIST

In [None]:
from mnist import mnist_data

_, _, mnist_test_images, mnist_test_labels = mnist_data()

batch_size = 1024
mnist_test_image_batch = mnist_test_images[:batch_size]
mnist_test_label_batch = mnist_test_labels[:batch_size]

# Sort by label
idx = jnp.argsort(mnist_test_label_batch.argmax(-1))
mnist_test_image_batch = mnist_test_image_batch[idx]
mnist_test_label_batch = mnist_test_label_batch[idx]

treescope.render_array(
    mnist_test_image_batch.reshape(-1, 28, 28)[:20],
    pixels_per_cell=1,
    columns=[2, 0],
    axis_labels={0: "input example", 1: "input dimension", 2: "input dimension"},
    vmax=1,
    vmin=-1,
)

In [None]:
%%time
from mnist import train_mnist

mnist_model = train_mnist(num_seeds=10, num_epochs=0)
mnist_model

### Manipulate model

In [None]:
from expand import scale_neurons
from expand import duplicate_neurons
from expand import add_random_zero_neurons


@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0), None))
def scale_manipulation(model, scale_factor):
    W2 = model.layers[1].weight
    W3 = model.layers[2].weight
    new_W2, new_W3 = scale_neurons(w_in=W2, w_out=W3, scale_factor=scale_factor)
    new_model = eqx.tree_at(lambda x: x.layers[1].weight, model, new_W2)
    new_model = eqx.tree_at(lambda x: x.layers[2].weight, new_model, new_W3)
    return new_model


@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0), None))
def duplicate_type_manipulation(model, duplicate_multiplier):
    W2 = model.layers[1].weight
    W3 = model.layers[2].weight
    new_W2, new_W3 = duplicate_neurons(
        w_in=W2, w_out=W3, duplicate_multiplier=duplicate_multiplier
    )
    new_model = eqx.tree_at(lambda x: x.layers[1].weight, model, new_W2)
    new_model = eqx.tree_at(lambda x: x.layers[2].weight, new_model, new_W3)
    return new_model


@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0), None, 0))
def zero_type_manipulation(model, num_units, key):
    W2 = model.layers[1].weight
    W3 = model.layers[2].weight
    new_W2, new_W3 = add_random_zero_neurons(
        w_in=W2, w_out=W3, num_zero_groups=num_units, neurons_per_group=1, key=key
    )
    new_model = eqx.tree_at(lambda x: x.layers[1].weight, model, new_W2)
    new_model = eqx.tree_at(lambda x: x.layers[2].weight, new_model, new_W3)
    return new_model


@partial(eqx.filter_vmap, in_axes=(None, 0, None))
@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0), None, 0))
def parameter_noise_manipulation(model, noise_scale, key):
    W1 = model.layers[0].weight
    W2 = model.layers[1].weight
    W3 = model.layers[2].weight

    key = jax.random.split(key, 3)
    new_W1 = W1 + noise_scale * jax.random.normal(key[0], W1.shape)
    new_W2 = W2 + noise_scale * jax.random.normal(key[1], W2.shape)
    new_W3 = W3 + noise_scale * jax.random.normal(key[2], W3.shape)

    new_model = eqx.tree_at(lambda x: x.layers[0].weight, model, new_W1)
    new_model = eqx.tree_at(lambda x: x.layers[1].weight, new_model, new_W2)
    new_model = eqx.tree_at(lambda x: x.layers[2].weight, new_model, new_W3)
    print
    return new_model


@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0), None, 0))
def transfer_manipulation(model, transfer_sample_budget, key):
    pass


def accuracy(pred_y, y):
    target_class = jnp.argmax(y, axis=1)
    predicted_class = jnp.argmax(pred_y, axis=1)
    return jnp.mean(predicted_class == target_class)


@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0),))
def gen_error(model):
    pred_y = eqx.filter_vmap(model)(mnist_test_images)
    test_error = 1 - accuracy(pred_y, mnist_test_labels)
    return test_error

In [None]:
gen_error(mnist_model)

In [None]:
gen_error(scale_manipulation(mnist_model, 10))

In [None]:
gen_error(duplicate_type_manipulation(mnist_model, 1.1))

In [None]:
gen_error(duplicate_type_manipulation(mnist_model, 2))

In [None]:
key = jax.random.PRNGKey(37)
zero_type_manipulation(mnist_model, 10, jax.random.split(key, 10))

In [None]:
gen_error(zero_type_manipulation(mnist_model, 10, jax.random.split(key, 10)))

In [None]:
key = jax.random.PRNGKey(37)
noise_scales = jnp.array(tuple(10**-i for i in range(9, -1, -1)))
data = eqx.filter_vmap(gen_error, in_axes=(eqx.if_array(0),))(
    parameter_noise_manipulation(mnist_model, noise_scales, jax.random.split(key, 10))
)
data

In [None]:
import numpy as np
import matplotlib.pyplot as plt

means = np.mean(data, axis=1)
stds = np.std(data, axis=1)

means = np.mean(data, axis=1)
stds = np.std(data, axis=1)

plt.figure(figsize=(10, 6))
plt.errorbar(
    noise_scales,
    means,
    yerr=stds,
    fmt="o-",
    capsize=5,
    capthick=1.5,
    linewidth=2,
    markersize=8,
    label="Mean ± Std Dev",
)

plt.xscale("log")
plt.yscale("log")

plt.xlabel("noise scale (log)", fontsize=12)
plt.ylabel("gen. error (log)", fontsize=12)
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
mnist_model

In [None]:
def get_first_layer_preacts(model, inputs):
    return model.layers[0](inputs)


def get_first_layer_postacts(model, inputs):
    return model.activation(get_first_layer_preacts(model, inputs))


def get_second_layer_preacts(model, inputs):
    return model.layers[1](get_first_layer_postacts(model, inputs))


@partial(jax.vmap, in_axes=(None, 0))
def get_second_layer_postacts(model, inputs):
    return model.activation(get_second_layer_preacts(model, inputs))

In [None]:
def create_rotated_dataset(images, key):
    n = len(images)
    assert n % 4 == 0
    split_size = n // 4
    partitions = jnp.split(images[: split_size * 4], 4)

    rot_0 = partitions[0]  # no rotation
    rot_90 = jnp.rot90(partitions[1], k=1, axes=(1, 2))
    rot_180 = jnp.rot90(partitions[2], k=2, axes=(1, 2))
    rot_270 = jnp.rot90(partitions[3], k=3, axes=(1, 2))

    labels_0 = jnp.zeros(split_size, dtype=jnp.int32)
    labels_90 = jnp.ones(split_size, dtype=jnp.int32)
    labels_180 = jnp.full(split_size, 2, dtype=jnp.int32)
    labels_270 = jnp.full(split_size, 3, dtype=jnp.int32)

    all_images = jnp.concatenate([rot_0, rot_90, rot_180, rot_270])
    all_labels = jnp.concatenate([labels_0, labels_90, labels_180, labels_270])

    perm = jax.random.permutation(key, len(all_images))
    shuffled_images = all_images[perm]
    shuffled_labels = all_labels[perm]

    return shuffled_images, shuffled_labels


key = jax.random.PRNGKey(432)
rotated_images, rotation_labels = create_rotated_dataset(
    mnist_test_images.reshape(-1, 28, 28), key
)
rotated_images = rotated_images.reshape(-1, 28 * 28)
print(rotation_labels[:10])
rotation_labels = jax.nn.one_hot(rotation_labels, 4)
treescope.render_array(
    rotated_images.reshape(-1, 28, 28)[:10],
    pixels_per_cell=1,
    columns=[2, 0],
    axis_labels={0: "input example", 1: "input dimension", 2: "input dimension"},
    vmax=1,
    vmin=-1,
)

In [None]:
from logistic import crossval_softmax_predict


def accuracy(pred_y, y):
    target_class = jnp.argmax(y, axis=1)
    predicted_class = jnp.argmax(pred_y, axis=1)
    return jnp.mean(predicted_class == target_class)


@partial(eqx.filter_vmap, in_axes=(eqx.if_array(0),))
def transfer_err(model):
    activations = get_second_layer_postacts(model, rotated_images)
    preds = crossval_softmax_predict(
        activations, rotation_labels, num_splits=10, key=None
    )
    return accuracy(preds, rotation_labels)


transfer_err(mnist_model)