In [1]:
!pip install dm-haiku optax


Collecting dm-haiku
  Downloading dm_haiku-0.0.10-py3-none-any.whl (360 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m360.3/360.3 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.10 jmp-0.0.4


In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import optax
import haiku as hk

import numpy as np
import jax.numpy as jnp
from typing import Iterable, Iterator, NamedTuple, TypeVar, Any, MutableMapping, Tuple
import time
import math
import datetime
import json
import os

from jax import config
config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
np.set_printoptions(precision=3, suppress=True)

import plotly.graph_objs as go
import plotly.io as pio
import plotly.express as px
pio.renderers.default = 'colab'

In [3]:
from plotly.subplots import make_subplots

def plot_training(all_metrics):
  fig = make_subplots(rows=2, cols=2, subplot_titles=('Loss', 'Accuracy', 'L1 Norm', 'L2 Norm'))

  color_dict = {'train': 'red', 'eval': 'blue'}
  for i, metric in enumerate(['loss', 'acc']):
    for t in ['train', 'eval']:
      trace = go.Scatter(
          x = [d['step'] for d in all_metrics],
          y = [d[f'{t}_{metric}'] for d in all_metrics],
          mode = 'lines+markers',
          name = f'{t.capitalize()} {metric.capitalize()}',
          line = dict(color=color_dict[t]),
          yaxis='y1' if metric == 'loss' else 'y2',
      )
      fig.add_trace(trace, row=1, col=i+1)
      if metric == 'loss':
        fig.update_yaxes(type='log', title_text=f'{metric.capitalize()}', row=1, col=i+1)
      else:
        fig.update_yaxes(title_text=f'{metric.capitalize()}', row=1, col=i+1)

  # Plotting L1 and L2 norms
  for i, norm in enumerate(['l1_norm', 'l2_norm']):
    trace = go.Scatter(
        x = [d['step'] for d in all_metrics],
        y = [d[norm] for d in all_metrics],
        mode = 'lines+markers',
        name = norm.replace('_', ' ').capitalize(),
    )
    fig.add_trace(trace, row=2, col=i+1)
    fig.update_yaxes(title_text=norm.replace('_', ' ').capitalize(), row=2, col=i+1)

  fig.update_xaxes(row=1, col=1, range=[0, max(d['step'] for d in all_metrics)])
  fig.update_xaxes(row=1, col=2, range=[0, max(d['step'] for d in all_metrics)])
  fig.update_xaxes(row=2, col=1, range=[0, max(d['step'] for d in all_metrics)])
  fig.update_xaxes(row=2, col=2, range=[0, max(d['step'] for d in all_metrics)])

  fig.update_layout(height=800, hovermode='closest')
  fig.show()

In [4]:
def plot_weights(state):
  key_subkey_array = []
  for key, subdict in state.params.items():
    for subkey, array in subdict.items():
      key_subkey_array.append((key, subkey, array))

  zmin = min([np.min(array) for key, subkey, array in key_subkey_array]).item()
  zmax = max([np.max(array) for key, subkey, array in key_subkey_array]).item()
  zval = max(abs(zmin), zmax) * 1

  N = len(key_subkey_array)
  grid_size = math.ceil(math.sqrt(N))  # find the nearest square grid

  fig = make_subplots(rows=grid_size, cols=grid_size, subplot_titles=[f"{key} {subkey}" for key, subkey, array in key_subkey_array], vertical_spacing=.1)

  for idx, (key, subkey, array) in enumerate(key_subkey_array):
    row = idx // grid_size + 1  # Calculate the appropriate row, col placement in grid
    col = idx % grid_size + 1
    trace = go.Heatmap(z=array, zmin=zval*-1, zmax=zval*1, zmid=0, colorscale='RdBu', name=f"{key} {subkey}")
    fig.add_trace(trace, row=row, col=col)

  fig.update_layout(height=400*grid_size, width=400*grid_size)
  fig.show()

In [5]:
class TrainingState(NamedTuple):
  """Container for the training state."""
  params: hk.Params
  opt_state: optax.OptState
  rng: jax.Array
  step: jax.Array

In [6]:
class NpEncoder(json.JSONEncoder):
  """Save NP as json."""

  def default(self, o):
    if isinstance(o, np.integer):
      return int(o)
    if isinstance(o, np.floating):
      return float(o)
    if isinstance(o, np.ndarray):
      return o.tolist()

    if isinstance(o, jnp.integer):
      return int(o)
    if isinstance(o, jnp.floating):
      return float(o)
    if isinstance(o, jnp.ndarray):
      return o.tolist()

    return super(NpEncoder, self).default(o)

In [7]:
# config for the main model used in the post
hyper = {
    'task': 'modular_addition',
    'sweep_slug': 'fail-memorize-generalize',

    'n_tokens': 67,
    'percent_train': .4,

    'embed_size': int(500),
    'hidden_size': int(24),

    'weight_decay': 1,
    'learning_rate': 1e-3,

    'max_steps': 50000,
    'seed': 165,

    # Fixed outside of sweeps
    'is_symmetric_input': True,  # if True, it only takes half of the pairs (upper triangle)
    'embed_config': 'tied',  # ['untied', 'tied', 'input_tied']
    'is_collapsed_out': False,
    'is_collapsed_hidden': False,
    'is_tied_hidden': True,
    'regularization': 'l2', # ['l1', 'l2']
    'b1': .9,
    'b2': .98,
}

In [8]:
hyper['log_every'] = int(hyper['max_steps']/500)
hyper['save_every'] = int(hyper['max_steps']/100)
np.random.seed(hyper['seed'])

In [9]:
from numpy.ma.core import indices
nums = list(range(hyper['n_tokens']))
if (hyper['is_symmetric_input']):
  inputs = np.array([[a,b] for a in nums for b in nums if a <= b]).astype(np.int32)
else:
  inputs = np.array([[a,b] for a in nums for b in nums]).astype(np.int32)
outputs = (inputs[:, 0] + inputs[:, 1]) % hyper['n_tokens']

indices = np.random.permutation(len(inputs))
split_idx = int(hyper['percent_train']*len(inputs))
train_batch = inputs[indices[:split_idx]], outputs[indices[:split_idx]]
eval_batch = inputs[indices[split_idx:]], outputs[indices[split_idx:]]

In [23]:
def forward(inputs):
  embed_init = hk.initializers.VarianceScaling(2)

  # calculate input embeddings
  if hyper['embed_config'] == 'untied':
    embed_a

1367

In [25]:
embed_init = hk.initializers.VarianceScaling(2)

<haiku._src.initializers.VarianceScaling at 0x79a79a28a3e0>

In [28]:
embed_a = hk.get_parameter('embed_a', [hyper['n_tokens'], hyper['embed_size']], init=embed_init)

ValueError: ignored