<a href="https://colab.research.google.com/github/chrishamblin7/dimensionality_reduction_and_superposition/blob/main/latent_data_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generate Latent Data in 'Superposition'

This notebook includes code adopted to perform experiments from the 'Toy models of Superposition paper'. Given simulated data with known 'real' features, we will create embeddings of those features in a lower dimensional space such that read out of the original features can be performed with minimal error.

This notebook is designed to run in Google Colab's Python 3.7 environment.

In [2]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.5.0-py3-none-any.whl (36 kB)
Installing collected packages: einops
Successfully installed einops-0.5.0


In [3]:
import torch
from torch import nn
from torch.nn import functional as F

from typing import Optional

from dataclasses import dataclass, replace
import numpy as np
import einops

from tqdm.notebook import trange

import time
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import matplotlib.pyplot as plt

from google.colab import files

In [4]:
@dataclass
class Config:
  n_features: int
  n_hidden: int

  # We optimize n_instances models in a single training loop
  # to let us sweep over sparsity or importance curves 
  # efficiently.

  # We could potentially use torch.vmap instead.
  n_instances: int
 
class Model(nn.Module):
  def __init__(self, 
               config, 
               feature_probability: Optional[torch.Tensor] = None,
               importance: Optional[torch.Tensor] = None,               
               device='cuda'):
    super().__init__()
    self.config = config
    self.W = nn.Parameter(torch.empty((config.n_instances, config.n_features, config.n_hidden), device=device))
    nn.init.xavier_normal_(self.W)
    self.b_final = nn.Parameter(torch.zeros((config.n_instances, config.n_features), device=device))

    if feature_probability is None:
      feature_probability = torch.ones(())
    self.feature_probability = feature_probability.to(device)
    if importance is None:
      importance = torch.ones(())
    self.importance = importance.to(device)

  def forward(self, features):
    # features: [..., instance, n_features]
    # W: [instance, n_features, n_hidden]
    hidden = torch.einsum("...if,ifh->...ih", features, self.W)
    out = torch.einsum("...ih,ifh->...if", hidden, self.W)
    out = out + self.b_final
    out = F.relu(out)
    return out

  def generate_batch(self, n_batch):
    feat = torch.rand((n_batch, self.config.n_instances, self.config.n_features), device=self.W.device)
    batch = torch.where(
        torch.rand((n_batch, self.config.n_instances, self.config.n_features), device=self.W.device) <= self.feature_probability,
        feat,
        torch.zeros((), device=self.W.device),
    )
    return batch

In [5]:
def linear_lr(step, steps):
  return (1 - (step / steps))

def constant_lr(*_):
  return 1.0

def cosine_decay_lr(step, steps):
  return np.cos(0.5 * np.pi * step / (steps - 1))

def optimize(model, 
             render=False, 
             n_batch=1024,
             steps=10_000,
             print_freq=100,
             lr=1e-3,
             lr_scale=constant_lr,
             hooks=[]):
  cfg = model.config

  opt = torch.optim.AdamW(list(model.parameters()), lr=lr)

  start = time.time()
  with trange(steps) as t:
    for step in t:
      step_lr = lr * lr_scale(step, steps)
      for group in opt.param_groups:
        group['lr'] = step_lr
      opt.zero_grad(set_to_none=True)
      batch = model.generate_batch(n_batch)
      out = model(batch)
      error = (model.importance*(batch.abs() - out)**2)
      loss = einops.reduce(error, 'b i f -> i', 'mean').sum()
      loss.backward()
      opt.step()
    
      if hooks:
        hook_data = dict(model=model,
                         step=step, 
                         opt=opt,
                         error=error,
                         loss=loss,
                         lr=step_lr)
        for h in hooks:
          h(hook_data)
      if step % print_freq == 0 or (step + 1 == steps):
        t.set_postfix(
            loss=loss.item() / cfg.n_instances,
            lr=step_lr,
        )

In [6]:
if torch.cuda.is_available():
  DEVICE = 'cuda'
else:
  DEVICE = 'cpu'

## Generate data

We will optimize the relu encoder/ decoder model, then pull the weights of
the encoder to get the projection of N features into M (M<N) dimensions
at different sparsities

In [35]:
n_features = 60 #number of original orthogonal features
n_instances = 6 #number of sparsities (linearly spaced between 1 and .05)
n_hiddens = [3,5,10,20,30,40] #number of compressed features, well test a range of values here


importance_sweep = True

In [36]:
importance = (0.9**torch.arange(n_features))[None, :] 
print(importance)


tensor([[1.0000, 0.9000, 0.8100, 0.7290, 0.6561, 0.5905, 0.5314, 0.4783, 0.4305,
         0.3874, 0.3487, 0.3138, 0.2824, 0.2542, 0.2288, 0.2059, 0.1853, 0.1668,
         0.1501, 0.1351, 0.1216, 0.1094, 0.0985, 0.0886, 0.0798, 0.0718, 0.0646,
         0.0581, 0.0523, 0.0471, 0.0424, 0.0382, 0.0343, 0.0309, 0.0278, 0.0250,
         0.0225, 0.0203, 0.0182, 0.0164, 0.0148, 0.0133, 0.0120, 0.0108, 0.0097,
         0.0087, 0.0079, 0.0071, 0.0064, 0.0057, 0.0052, 0.0046, 0.0042, 0.0038,
         0.0034, 0.0030, 0.0027, 0.0025, 0.0022, 0.0020]])


In [37]:
if importance_sweep:
  importance = (0.9**torch.arange(n_features))[None, :]  
  imp_name = 'withimportance'
else:
  importance = torch.ones(n_features)[None,:]
  imp_name = ''



for n_hidden in n_hiddens:
  print(n_hidden)

  config = Config(
      n_features = n_features,
      n_hidden = n_hidden,
      n_instances = n_instances,
  )

  model = Model(
      config=config,
      device=DEVICE,
      # Exponential feature importance curve from 1 to 1/100
      importance = importance,
      # Sweep feature frequency across the instances from 1 (fully dense) to 1/50
      feature_probability = (20 ** -torch.linspace(0, 1, n_instances))[:, None]
  )

  optimize(model)

  WA = model.W.detach()

  outname = str(n_features)+'to'+str(n_hidden)+'_'+imp_name+'.npy'

  np.save(outname,WA.detach().cpu().numpy())

  files.download(outname) #download data to googledrive


3


  0%|          | 0/10000 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [33]:
WtW = torch.einsum('sih,soh->sio', WA, WA).cpu()

In [34]:
import plotly.express as px

fig = px.imshow(WtW[5], color_continuous_scale='RdBu_r')
fig.show()