### Setup

In [1]:
%pip install jax equinox treescope
import jax
import jax.numpy as jnp
import equinox as eqx
import treescope

Note: you may need to restart the kernel to use updated packages.


In [2]:
from expand import scale_neurons
from expand import duplicate_neurons
from expand import add_random_zero_neurons
from expand import add_phantom_zero_neurons

def get_output(model, inputs, key):
    return model(inputs, key=key)

def get_preacts(model, inputs, key):
    return model.layers[0](inputs, key=key)

def get_postacts(model, inputs, key):
    return model.activation(get_preacts(model, inputs, key))

def mse(pred_y, y):
    return jnp.mean(jnp.sum(jnp.square((pred_y - y)), axis=1), axis=0)

### XOR

In [3]:
from xor import train_xor
xor_model = train_xor()
xor_model

Evaluating iteration with loss value 0.5622677803039551.
Evaluating iteration with loss value 0.5441693067550659.
Evaluating iteration with loss value 0.5289270877838135.
Evaluating iteration with loss value 0.5156916379928589.
Evaluating iteration with loss value 0.5038694143295288.
Evaluating iteration with loss value 0.49303457140922546.
Evaluating iteration with loss value 0.48287567496299744.
Evaluating iteration with loss value 0.4731637239456177.
Evaluating iteration with loss value 0.46373170614242554.
Evaluating iteration with loss value 0.4544614553451538.
Evaluating iteration with loss value 0.44527554512023926.
Evaluating iteration with loss value 0.4361308813095093.
Evaluating iteration with loss value 0.42701467871665955.
Evaluating iteration with loss value 0.417940229177475.
Evaluating iteration with loss value 0.40894317626953125.
Evaluating iteration with loss value 0.40007680654525757.
Evaluating iteration with loss value 0.39140695333480835.
Evaluating iteration wit

MLP(
  layers=(
    Linear(
      weight=f32[2,2],
      bias=None,
      in_features=2,
      out_features=2,
      use_bias=False
    ),
    Linear(
      weight=f32[1,2],
      bias=None,
      in_features=2,
      out_features=1,
      use_bias=False
    )
  ),
  activation=<wrapped function relu>,
  final_activation=<function <lambda>>,
  use_bias=False,
  use_final_bias=False,
  in_size=2,
  out_size=1,
  width_size=2,
  depth=1
)

In [5]:
batch_size = 2
data_key = jax.random.PRNGKey(0)
x = jax.random.normal(data_key, (batch_size, 2))
treescope.render_array(
    x,
    pixels_per_cell=8,
    axis_labels={0: "input example", 1: "input dimension"},
    vmax=1, vmin=-1,
)

In [4]:
from mnist import train_mnist
mnist_model = train_mnist()
mnist_model

Iteration 0, loss: 2.301729202270508
Iteration 1, loss: 2.2912893295288086
Iteration 2, loss: 2.280944347381592
Iteration 3, loss: 2.270609140396118
Iteration 4, loss: 2.260138511657715
Iteration 5, loss: 2.249394655227661
Iteration 6, loss: 2.238280773162842
Iteration 7, loss: 2.226752281188965
Iteration 8, loss: 2.2147040367126465
Iteration 9, loss: 2.202028274536133
Iteration 10, loss: 2.1886379718780518
Iteration 11, loss: 2.174485921859741
Iteration 12, loss: 2.159597873687744
Iteration 13, loss: 2.1440019607543945
Iteration 14, loss: 2.1276652812957764
Iteration 15, loss: 2.1106178760528564
Iteration 16, loss: 2.09289813041687
Iteration 17, loss: 2.074599027633667
Iteration 18, loss: 2.0557117462158203
Iteration 19, loss: 2.036278009414673
Iteration 20, loss: 2.0161685943603516
Iteration 21, loss: 1.9953882694244385
Iteration 22, loss: 1.9738621711730957
Iteration 23, loss: 1.9514769315719604
Iteration 24, loss: 1.9281601905822754
Iteration 25, loss: 1.9037814140319824
Iteration 

MLP(
  layers=(
    Linear(
      weight=f32[512,784],
      bias=None,
      in_features=784,
      out_features=512,
      use_bias=False
    ),
    Linear(
      weight=f32[512,512],
      bias=None,
      in_features=512,
      out_features=512,
      use_bias=False
    ),
    Linear(
      weight=f32[10,512],
      bias=None,
      in_features=512,
      out_features=10,
      use_bias=False
    )
  ),
  activation=<wrapped function relu>,
  final_activation=<function <lambda>>,
  use_bias=False,
  use_final_bias=False,
  in_size=784,
  out_size=10,
  width_size=512,
  depth=2
)

In [7]:
pattern_size = 2
random_pattern_key = jax.random.PRNGKey(1)
random_pattern = jax.random.normal(random_pattern_key, (batch_size, pattern_size))
random_pattern = jax.nn.relu(random_pattern)
treescope.render_array(
    random_pattern,
    pixels_per_cell=8,
    axis_labels={0: "input example", 1: "phantom dimension"},
    vmax=1, vmin=-1,
)

In [8]:
pattern_size = 512
random_pattern_key = jax.random.PRNGKey(1)
random_pattern = jax.random.normal(random_pattern_key, (batch_size, pattern_size))
random_pattern = jax.nn.relu(random_pattern)
treescope.render_array(
    random_pattern,
    pixels_per_cell=1,
    axis_labels={0: "input example", 1: "phantom dimension"},
    vmax=1, vmin=-1,
)

In [9]:
W1 = xor_model.layers[0].weight
W2 = xor_model.layers[1].weight

forward_key = jax.random.PRNGKey(2)
preacts = eqx.filter_vmap(get_preacts, in_axes=(None, 0, None))(xor_model, x, forward_key)
new_W1, new_W2 = add_phantom_zero_neurons(W1, W2, preacts, random_pattern)

new_xor_model = eqx.tree_at(lambda x: x.layers[0].weight, xor_model, new_W1)
new_xor_model = eqx.tree_at(lambda x: x.layers[1].weight, new_xor_model, new_W2)

original_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(xor_model, x, forward_key)
new_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(new_xor_model, x, forward_key)
print(f"Difference in outputs: {mse(original_outputs, new_outputs)}")

original_activities = eqx.filter_vmap(get_postacts, in_axes=(None, 0, None))(xor_model, x, forward_key)
new_activities = eqx.filter_vmap(get_postacts, in_axes=(None, 0, None))(new_xor_model, x, forward_key)
phantom_activities = new_activities[:, -pattern_size:]
print(f"Difference of internal activities to pattern: {mse(random_pattern, phantom_activities)}")

treescope.render_array(
    phantom_activities,
    pixels_per_cell=8,
    axis_labels={0: "input example", 1: "phantom dimension"},
    vmax=1, vmin=-1,
)

TypeError: dot_general requires contracting dimensions to have the same shape, got (2,) and (784,).

In [None]:
W1 = mnist_model.layers[0].weight
W2 = mnist_model.layers[1].weight

forward_key = jax.random.PRNGKey(2)
preacts = eqx.filter_vmap(get_preacts, in_axes=(None, 0, None))(mnist_model, x, forward_key)
new_W1, new_W2 = add_phantom_zero_neurons(W1, W2, preacts, random_pattern)

new_mnist_model = eqx.tree_at(lambda x: x.layers[0].weight, mnist_model, new_W1)
new_mnist_model = eqx.tree_at(lambda x: x.layers[1].weight, new_mnist_model, new_W2)

original_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(mnist_model, x, forward_key)
new_outputs = eqx.filter_vmap(get_output, in_axes=(None, 0, None))(new_mnist_model, x, forward_key)
print(f"Difference in outputs: {mse(original_outputs, new_outputs)}")

original_activities = eqx.filter_vmap(get_postacts, in_axes=(None, 0, None))(mnist_model, x, forward_key)
new_activities = eqx.filter_vmap(get_postacts, in_axes=(None, 0, None))(new_mnist_model, x, forward_key)
phantom_activities = new_activities[:, -pattern_size:]
print(f"Difference of internal activities to pattern: {mse(random_pattern, phantom_activities)}")

treescope.render_array(
    phantom_activities,
    pixels_per_cell=8,
    axis_labels={0: "input example", 1: "phantom dimension"},
    vmax=1, vmin=-1,
)

In [None]:
new_mnist_model

In [None]:
treescope.render_array(
    original_activities,
    pixels_per_cell=8,
    axis_labels={0: "input example", 1: "original activities"},
    vmax=1, vmin=-1,
)

In [None]:
treescope.render_array(
    new_activities,
    pixels_per_cell=8,
    axis_labels={0: "input example", 1: "expanded activities"},
    vmax=1, vmin=-1,
)

### MNIST

In [6]:
batch_size = 128
data_key = jax.random.PRNGKey(0)
# TODO: use mnist validation data
x = jax.random.normal(data_key, (batch_size, 28 * 28))
treescope.render_array(
    x,
    pixels_per_cell=1,
    axis_labels={0: "input example", 1: "input dimension"},
    vmax=1, vmin=-1,
)