### Setup

In [None]:
%pip install jax equinox optax optimistix treescope
%load_ext autoreload
%autoreload 2

import imageio
import jax
import jax.numpy as jnp
import equinox as eqx
import treescope

### Train model on MNIST

In [None]:
from mnist import mnist_data

_, _, mnist_test_images, mnist_test_labels = mnist_data()

batch_size = 1024
mnist_test_image_batch = mnist_test_images[:batch_size]
mnist_test_label_batch = mnist_test_labels[:batch_size]

# Sort by label
idx = jnp.argsort(mnist_test_label_batch.argmax(-1))
mnist_test_image_batch = mnist_test_image_batch[idx]
mnist_test_label_batch = mnist_test_label_batch[idx]

treescope.render_array(
    mnist_test_image_batch.reshape(-1, 28, 28)[:20],
    pixels_per_cell=1,
    columns=[2, 0],
    axis_labels={0: "input example", 1: "input dimension", 2: "input dimension"},
    vmax=1,
    vmin=-1,
)

In [None]:
%%time
from mnist import train_mnist

mnist_model = train_mnist()
mnist_model

### Manipulate model

In [None]:
from expand import add_phantom_zero_neurons


def get_output(model, inputs, key):
    return model(inputs, key=key)


def get_first_layer_preacts(model, inputs, key):
    return model.layers[0](inputs, key=key)


def get_first_layer_postacts(model, inputs, key):
    return model.activation(get_first_layer_preacts(model, inputs, key))


def get_second_layer_preacts(model, inputs, key):
    return model.layers[1](get_first_layer_postacts(model, inputs, key))


def get_second_layer_postacts(model, inputs, key):
    return model.activation(get_second_layer_preacts(model, inputs, key))


def get_third_layer_preacts(model, inputs, key):
    return model.layers[2](get_second_layer_postacts(model, inputs, key))


def mse(pred_y, y):
    return jnp.mean(jnp.sum(jnp.square((pred_y - y)), axis=1), axis=0)


x = mnist_test_image_batch[:128]
key = jax.random.PRNGKey(0)
full_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(mnist_model, x, key)
staged_outputs = eqx.filter_vmap(get_third_layer_preacts, in_axes=(None, 0, None))(
    mnist_model, x, key
)
jnp.array_equal(full_outputs, staged_outputs)

In [None]:
elephant_img = 1 - jnp.array(imageio.imread("./elephant.png", mode="L")) / jnp.float32(
    255.0
)
height, _ = elephant_img.shape
elephant_img = elephant_img[height // 8 * 3 : height // 8 * 5, :]
treescope.render_array(
    elephant_img,
    rows=[0],
    columns=[1],
    pixels_per_cell=1,
    vmax=1,
    vmin=-1,
)

In [None]:
W1 = mnist_model.layers[0].weight
W2 = mnist_model.layers[1].weight
W3 = mnist_model.layers[2].weight

forward_key = jax.random.PRNGKey(2)
preacts = eqx.filter_vmap(get_first_layer_preacts, in_axes=(None, 0, None))(
    mnist_model, mnist_test_image_batch, forward_key
)
postacts = eqx.filter_vmap(get_first_layer_postacts, in_axes=(None, 0, None))(
    mnist_model, mnist_test_image_batch, forward_key
)

new_W2, new_W3 = add_phantom_zero_neurons(W2, W3, postacts, elephant_img.T)

treescope.render_array(
    new_W2[-elephant_img.T.shape[1] :],
    pixels_per_cell=1,
    axis_labels={0: "hidden dimension", 1: "input dimension"},
    vmax=1,
    vmin=-1,
)

In [None]:
new_mnist_model = eqx.tree_at(lambda x: x.layers[1].weight, mnist_model, new_W2)
new_mnist_model = eqx.tree_at(lambda x: x.layers[2].weight, new_mnist_model, new_W3)

original_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(
    mnist_model, mnist_test_image_batch, forward_key
)
new_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(
    new_mnist_model, mnist_test_image_batch, forward_key
)
print(f"Difference in outputs: {mse(original_outputs, new_outputs)}")

In [None]:
original_activities = eqx.filter_vmap(
    get_second_layer_postacts, in_axes=(None, 0, None)
)(mnist_model, mnist_test_image_batch, forward_key)
new_activities = eqx.filter_vmap(get_second_layer_postacts, in_axes=(None, 0, None))(
    new_mnist_model, mnist_test_image_batch, forward_key
)
phantom_activities = new_activities[:, -elephant_img.T.shape[1] :]
print(
    f"Difference of internal activities to pattern: {mse(elephant_img.T, phantom_activities)}"
)

treescope.render_array(
    new_activities,
    pixels_per_cell=1,
    columns=[0],
    rows=[1],
    axis_labels={0: "input example", 1: "phantom dimension"},
    vmax=1,
    vmin=-1,
)

In [None]:
from rsa import correlation_rdm

rdm = correlation_rdm(original_activities, return_full=True)
treescope.render_array(
    rdm,
    pixels_per_cell=1,
    columns=[0],
    rows=[1],
    axis_labels={0: "input example", 1: "input example"},
    vmax=1,
    vmin=-1,
)

In [None]:
rdm = correlation_rdm(new_activities, return_full=True)
treescope.render_array(
    rdm,
    pixels_per_cell=1,
    columns=[0],
    rows=[1],
    axis_labels={0: "input example", 1: "input example"},
    vmax=1,
    vmin=-1,
)

In [None]:
# Try now with more elephant.
elephant_img = 1 - jnp.array(imageio.imread("./elephant.png", mode="L")) / jnp.float32(
    255.0
)
treescope.render_array(
    elephant_img,
    pixels_per_cell=1,
    vmax=1,
    vmin=-1,
)

W1 = mnist_model.layers[0].weight
W2 = mnist_model.layers[1].weight
W3 = mnist_model.layers[2].weight

forward_key = jax.random.PRNGKey(2)
# preacts = eqx.filter_vmap(get_first_layer_preacts, in_axes=(None, 0, None))(mnist_model, mnist_test_image_batch, forward_key)
postacts = eqx.filter_vmap(get_first_layer_postacts, in_axes=(None, 0, None))(
    mnist_model, mnist_test_image_batch, forward_key
)

new_W2, new_W3 = add_phantom_zero_neurons(W2, W3, postacts, elephant_img.T)
new_mnist_model = eqx.tree_at(lambda x: x.layers[1].weight, mnist_model, new_W2)
new_mnist_model = eqx.tree_at(lambda x: x.layers[2].weight, new_mnist_model, new_W3)

original_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(
    mnist_model, mnist_test_image_batch, forward_key
)
new_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(
    new_mnist_model, mnist_test_image_batch, forward_key
)
print(f"Difference in outputs: {mse(original_outputs, new_outputs)}")

original_activities = eqx.filter_vmap(
    get_second_layer_postacts, in_axes=(None, 0, None)
)(mnist_model, mnist_test_image_batch, forward_key)
new_activities = eqx.filter_vmap(get_second_layer_postacts, in_axes=(None, 0, None))(
    new_mnist_model, mnist_test_image_batch, forward_key
)
phantom_activities = new_activities[:, -elephant_img.T.shape[1] :]
print(
    f"Difference of internal activities to pattern: {mse(elephant_img.T, phantom_activities)}"
)

rdm = correlation_rdm(new_activities, return_full=True)
treescope.render_array(
    rdm,
    pixels_per_cell=1,
    columns=[0],
    rows=[1],
    axis_labels={0: "input example", 1: "input example"},
    vmax=1,
    vmin=-1,
)