In [None]:
import jax
import jax.numpy as jnp
import optax
import torch

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [None]:
# Grokking on MNIST. Architectural details taken from Omnigrok paper: https://arxiv.org/pdf/2210.01117.pdf

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download MNIST dataset
mnist_dataset = MNIST(root="./data", train=True, download=True, transform=ToTensor())

In [None]:
def MLP(params, x):
  w0 = params[0]
  h = w0 @ x.T / jnp.sqrt(w0.shape[1]) # N x P
  for l, Wl in enumerate(params[1:-1]):
    phi = h * (h > 0.0)
    h = 1/jnp.sqrt(Wl.shape[1]) * Wl @ phi

  phi = h * (h > 0.0)
  f = phi.T @ params[-1].T / phi.shape[0]
  return f

def init_params(N, D, L, key, w_scale = 1.0):

  params = [ w_scale * random.normal(key, (N,D)) ]
  for l in range(L-1):
    key, _ = random.split(key)
    params += [ w_scale * random.normal(key,(N,N)) ]

  params += [ w_scale * random.normal(key, (10,N)) ]
  return params


subset_size = 1000
train_set, test_set = random_split(mnist_dataset, [subset_size, len(mnist_dataset) - subset_size])

# Create data loaders
train_loader = DataLoader(train_set, batch_size=subset_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=2*subset_size)

for X,y in train_loader:
  break

X = jnp.array( X.numpy() )
y = jnp.array( y.numpy() )


X = X.reshape((X.shape[0], X.shape[-2] * X.shape[-1]))

for Xte, yte in test_loader:
  break


Xte = jnp.array(Xte.numpy())
yte = jnp.array(yte.numpy())

Xte= Xte.reshape((Xte.shape[0], Xte.shape[-2] * Xte.shape[-1]))


y = jnp.eye(10)[y]
yte = jnp.eye(10)[yte]


scales = [1e-3, 0.01, 0.05, 0.5]
traccs, taccs = [], []

for scale in tqdm(scales):
  # Constants and hyper-params
  torch.manual_seed(42)
  N = 200
  D = 784
  L = 3
  weight_scale = 150.0  # Scaling factor for Kaiming initialization, replicated from the Omnigrok paper
  lr = 1e-3
  wd = 0.01
  T = 250000
  batch = 200
  key = random.PRNGKey(0)

  # Model
  def MLP(params, x):
      h = params[0] @ x.T / jnp.sqrt(params[0].shape[1])  # N x P
      for l, Wl in enumerate(params[1:-1]):
          h = h * (h > 0.0)
          h = Wl @ h / jnp.sqrt(Wl.shape[1])
      h = h * (h > 0.0)
      f = h.T @ params[-1].T / h.shape[0]
      return f*scale

  key = random.PRNGKey(0)
  params = init_params(N, D, L, key, weight_scale)

  # Initialization with Kaiming scaling.
  def init_params(N, D, L, key, weight_scale=1.0):
      params = [weight_scale * random.normal(key, (N, D)) * jnp.sqrt(2. / D)]
      for l in range(L-1):
          key, _ = random.split(key)
          params += [weight_scale * random.normal(key, (N, N)) * jnp.sqrt(2. / N)]
      params += [weight_scale * random.normal(key, (10, N)) * jnp.sqrt(2. / N)]
      return params


  loss_fn = jit(lambda p, X, y: jnp.mean((MLP(p, X) - y)**2))
  grad_fn = jit(grad(loss_fn))
  optimizer = optax.adamw(learning_rate=lr, weight_decay=wd)
  opt_state = optimizer.init(params)

  train_loss = []
  test_loss = []
  train_accuracy = []
  test_accuracy = []

  def compute_accuracy(predictions, targets):
      return jnp.mean(jnp.argmax(predictions, axis=1) == jnp.argmax(targets, axis=1))


  compute_every = 100
  for t in range(T):
      if t % compute_every == 0:
          train_pred = MLP(params, X)
          test_pred = MLP(params, Xte)

          # Compute and store train & test loss
          train_loss.append(loss_fn(params, X, y))
          test_loss.append(loss_fn(params, Xte, yte))

          # Compute and store train & test accuracy
          train_accuracy.append(compute_accuracy(train_pred, y))
          test_accuracy.append(compute_accuracy(test_pred, yte))

      ind = batch * t % subset_size # take a slice of 200 out of our 1000 size dataset and cycle like that
      grads = grad_fn(params, X[ind:ind+batch], y[ind:ind+batch])
      updates, opt_state = optimizer.update(grads, opt_state, params)
      params = optax.apply_updates(params, updates)

  traccs += [ train_accuracy ]
  taccs += [ test_accuracy ]

In [None]:
plt.figure(figsize=(9, 6))
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink']
for i, acc in enumerate(traccs):
    x = np.arange(0, compute_every * len(acc), compute_every)
    xnew = np.linspace(0, 250000, 250000)
    spl = make_interp_spline(x, traccs[i], k=3) # k=3 for cubic spline
    y_train_smooth = spl(xnew)
    spl = make_interp_spline(x, taccs[i], k=3)
    y_test_smooth = spl(xnew)

    plt.plot(xnew, y_train_smooth*100, label=rf'Train Accuracy, $\alpha$={scales[i]}', linestyle='--', color=colors[-i])
    plt.plot(xnew, y_test_smooth*100, label=rf'Test Accuracy, $\alpha$={scales[i]}', color=colors[-i])

plt.xlabel('Epochs', fontsize=20)
plt.xscale('log')
plt.ylabel('Accuracy', fontsize=20)
plt.legend(fontsize=14)
plt.tight_layout()
plt.show()