## Setup

In [1]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    # Install packages
    %pip install transformer_lens
    %pip install pytorch_lightning
    %pip install git+https://github.com/neelnanda-io/neel-plotly

    # Code to make sure output widgets display
    from google.colab import output
    output.enable_custom_widget_manager()

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [2]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
from torch import nn, Tensor
from torch.nn import functional as F
from dataclasses import dataclass
import time
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import einops
import plotly.express as px
from pathlib import Path
from jaxtyping import Float
from typing import Optional, Callable, Union, List, Tuple, Dict, Any, Literal
from tqdm.notebook import tqdm
from dataclasses import dataclass
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import json
import math
import torch
import torch.nn as nn
import time
import gc


# Make sure exercises are in the path
exercises_dir = Path(r"C:\Users\calsm\Documents\AI Alignment\anthropic_ddd\tms_and_ddd").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, visualise_2d_superposition, visualise_Nd_superposition
import tests as tests
import solutions as solutions

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"



In [21]:
NUM_WARMUP_STEPS = 2500
NUM_BATCH_UPDATES = 50_000
N_FEATURES = 10_000
N_HIDDEN = 2
SPARSITY = 0.999
WEIGHT_DECAY = 1e-2
LEARNING_RATE = 1e-3
BATCH_SIZES = [
    3,
    5,
    6,
    8,
    10,
    15,
    30,
    50,
    100,
    200,
    500,
    1000,
    2000,
    5000,
    10000,
]
# EVAL_N_DATAPOINTS = 1_000


def linear_lr(step, steps):
    '''
    Decays linearly from 1 to 0.
    '''
    return (1 - (step / steps))

def linear_warmup_lr(step, steps):
    '''
    Increases linearly from 0 to 1.
    '''
    return step / steps

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))

def anthropic_lr(step, steps):
    '''
    As per the description in the paper: 2500 step linear warmup, followed by
    cosine decay to zero.
    '''
    if step < NUM_WARMUP_STEPS:
        return linear_warmup_lr(step, NUM_WARMUP_STEPS)
    else:
        return cosine_decay_lr(step - NUM_WARMUP_STEPS, steps - NUM_WARMUP_STEPS)



@dataclass
class Config:
    """
    Same as TMS, we're leaving in the "n_instances" argument for more possible
    flexibility later (even though I don't think I'll use it).
    """
    n_instances: int = 1
    n_features: int = N_FEATURES
    n_hidden: int = N_HIDDEN



class Model(nn.Module):

    W: Float[Tensor, "n_instances n_hidden n_features"]
    b_final: Float[Tensor, "n_instances n_features"]
    # Our linear map (ignoring n_instances) is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: Config,
        feature_probability: Optional[Union[Tensor, float]] = 1 - SPARSITY,
        importance: Optional[Union[Tensor, float]] = None,
        device = device,
    ):
        super().__init__()
        self.cfg = cfg

        if feature_probability is None: feature_probability = t.ones(())
        elif isinstance(feature_probability, float): feature_probability = t.ones(()) * feature_probability
        self.feature_probability = feature_probability.to(device)
        self.sparsity = 1 - self.feature_probability

        if importance is None: importance = t.ones(())
        elif isinstance(importance, float): importance = t.ones(()) * importance
        self.importance = importance.to(device)

        self.W = nn.Parameter(t.empty((cfg.n_instances, cfg.n_hidden, cfg.n_features), device=device))
        nn.init.xavier_normal_(self.W)
        self.b_final = nn.Parameter(t.zeros((cfg.n_instances, cfg.n_features), device=device))


    @t.inference_mode()
    def dimensionality(self, batch: Optional[Tensor] = None) -> Float[Tensor, "instances n_features"]:
        '''
        Calculates the current dimensionalities (i.e. summed).

        Can do features or datapoints (the latter if batch is supplied). If we do features, then it's the 
        dimensionality of each feature vector. If we do datapoints, it's the dimensionality of each datapoint's
        hidden space representation.
        '''
        if batch is None:
            data = self.W # [instances d_hidden features]
        else:
            hidden = self.forward(batch, return_hidden_states=True) # [batch_size instances d_hidden]
            data = einops.rearrange(hidden, "batch instances d_hidden -> instances d_hidden batch")

        # Compute the norms of each feature / datapoint (this will be the denominator)
        squared_norms = einops.reduce(
            data.pow(2),
            "inst d_hidden x -> inst x",
            "sum",
        )
        # Compute the denominator (i.e. the dotproduct then summing over j)
        data_normed = data / data.norm(dim=1, keepdim=True)
        interference = einops.einsum(
            data_normed, data,
            "inst d_hidden x_i, inst d_hidden x_j -> inst x_i x_j",
        )
        # last_dim_size = data.shape[-1]
        # interference[:, range(last_dim_size), range(last_dim_size)] = 0
        polysemanticities = einops.reduce(
            interference.pow(2),
            "inst x_i x_j -> inst x_i",
            "sum",
        )
        dimensionality = squared_norms / polysemanticities

        return dimensionality



    def forward(
        self,
        features: Float[Tensor, "... instances features"],
        return_hidden_states: bool = False,
    ) -> Float[Tensor, "... instances features"]:
        hidden = einops.einsum(
           features, self.W,
           "... instances features, instances hidden features -> ... instances hidden"
        )
        if return_hidden_states: return hidden
        out = einops.einsum(
            hidden, self.W,
            "... instances hidden, instances hidden features -> ... instances features"
        )
        return F.relu(out + self.b_final)


    def generate_batch(self, batch_size: int) -> Float[Tensor, "batch_size instances features"]:
        '''
        Generates a batch of data. We'll return to this function later when we apply correlations.
        '''
        # Get values of features pre-choosing some of them to be zero
        feat = t.rand((batch_size, self.cfg.n_instances, self.cfg.n_features), device=self.W.device) # [batch instances features]

        # Choose which features to be zero
        feat_seeds = t.rand((batch_size, self.cfg.n_instances, self.cfg.n_features), device=self.W.device) # [batch instances features]
        feat_is_present = feat_seeds <= self.feature_probability

        # Zero out the features
        batch = t.where(feat_is_present, feat, t.zeros((), device=self.W.device))

        # Normalize the batch (i.e. so each vector for a particular batch & instance has norm 1)
        # (need to be careful about vectors with norm zero)
        norms = batch.norm(dim=-1, keepdim=True)
        norms = t.where(norms.abs() < 1e-6, t.ones_like(norms), norms)
        batch_normed = batch / norms

        return batch_normed


    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances features"],
        batch: Float[Tensor, "batch instances features"],
    ) -> Float[Tensor, ""]:
        '''
        Calculates the loss for a given batch, using this loss described in the Toy Models paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Note, `model.importance` is guaranteed to broadcast with the shape of `out` and `batch`.

        Also note, for this experiment we'll only ever be using importance = 1 (in accordance with the paper),
        but we'll keep this function here for possible extensions.
        '''
        error = self.importance * ((batch - out) ** 2)
        loss = einops.reduce(error, 'batch instances features -> instances', 'mean').sum()
        return loss
    

    # @t.inference_mode()
    # def evaluate(self, batch_size: int, n_evals: int = 100):
    #     '''
    #     Evaluates the model on a batch of data.
    #     '''
    #     loss_list = []
    #     for n in range(n_evals):
    #         batch = self.generate_batch(batch_size)
    #         out = self(batch)
    #         loss = self.calculate_loss(out, batch)
    #     return loss.item() / self.cfg.n_instances


    def optimize(
        self,
        batch_size: int,
        num_batch_updates: int = NUM_BATCH_UPDATES,
        log_freqint = 100,
        lr: float = LEARNING_RATE,
        lr_scale: Callable[[int, int], float] = anthropic_lr,
        weight_decay: float = WEIGHT_DECAY
    ):
        optimizer = t.optim.AdamW(list(self.parameters()), lr=lr, weight_decay=weight_decay)

        progress_bar = tqdm(range(num_batch_updates))

        # Same batch for each step
        batch = self.generate_batch(batch_size) # [batch_size instances n_features]
        
        for step in progress_bar:

            # Update learning rate
            step_lr = lr * lr_scale(step, num_batch_updates)
            for group in optimizer.param_groups:
                group['lr'] = step_lr
            
            # Optimize
            optimizer.zero_grad()
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq= 0 or (step + 1 == num_batch_updates):
                progress_bar.set_postfix(loss=loss.item()/self.cfg.n_instances, lr=step_lr)

        return batch

Test by using batch size = 3 (a case where we defintely expect the model to be in the "datapoint memorization" regime, and to learn fast).

Will also use Marius' technique of 1000 features & 0.99 sparsity, to make things a bit faster.

In [4]:
cfg = Config(
    n_features = 1_000,
)

model = Model(
    cfg,
    feature_probability = 1 - 0.99,
).to(device)

In [5]:
t.cuda.empty_cache()
gc.collect()

batch = model.optimize(
    batch_size = 3,
    num_batch_updates = 10_000,
)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [6]:
# Plot a random set of the features learned by our model

W = einops.rearrange(model.W.detach(), "instance d_hidden feats -> instance feats d_hidden").squeeze() # [features=1000 d_hidden=2]

n_feats_to_show = 50
W = W[None, :n_feats_to_show] # [instances=1 feats=5 d_hidden=2]

visualise_2d_superposition(
    values = W,
    colors = "blue",
    width = 300,
    height = 300,
)

In [7]:
# Now, plot how the model has learned to represent the datapoints

hidden_states = model(batch, return_hidden_states=True) # [batch_size=3 instances=1 d_hidden=2]

visualise_2d_superposition(
    values = einops.rearrange(hidden_states, "b i h -> i b h"),
    colors = "red",
    width = 300,
    height = 300,
)

Let's wrap this up into a helpful function:

In [8]:
def main_test(
    dataset_size: int,
    max_points_to_plot: int = 1000,
):
    # Create model
    cfg = Config(n_features = 1_000)
    model = Model(cfg, feature_probability = 1 - 0.99).to(device)

    # Train model
    t.cuda.empty_cache()
    gc.collect()
    batch = model.optimize(batch_size = dataset_size, num_batch_updates = 10_000)

    #! Visualise features
    
    W = einops.rearrange(model.W.detach(), "instance d_hidden feats -> instance feats d_hidden").squeeze() # [features=1000 d_hidden=2]
    
    # Divide by the largest point's norm (so they fit), and sort in ascending order of norm
    W_max_norm = W.norm(dim=-1).max()
    W_normalized = W / W_max_norm
    W_norm = W_normalized.norm(dim=-1)
    W_normalized = W_normalized[W_norm.argsort(descending=True)]
    
    # Get largest points, if there are too many
    if W_normalized.shape[0] > max_points_to_plot:
        W_normalized = W_normalized[:max_points_to_plot]
    
    visualise_2d_superposition(
        values = W_normalized.unsqueeze(0),
        colors = "blue",
        width = 300,
        height = 300,
    )

    #! Visualise hidden state representation of dataset

    hidden_states = model(batch, return_hidden_states=True) # [batch_size=3 instances=1 d_hidden=2]

    # Divide by the largest point's norm (so they fit), and sort in ascending order of norm
    h_max_norm = hidden_states.norm(dim=-1).max()
    hidden_states_normalized = hidden_states / h_max_norm
    
    # Get random sample of points, if there are too many
    if hidden_states_normalized.shape[0] > max_points_to_plot:
        idxs = t.randperm(hidden_states_normalized.shape[0])[:max_points_to_plot]
        hidden_states_normalized = hidden_states_normalized[idxs]
    
    visualise_2d_superposition(
        values = einops.rearrange(hidden_states_normalized, "b i h -> i b h"),
        colors = "red",
        width = 300,
        height = 300,
    )
    return batch, model

In [23]:
batch, model = main_test(dataset_size=5)

batch_dimensionalities = model.dimensionality(batch).squeeze()
batch_dimensionalities_str = ", ".join([f"{x:.2f}" for x in batch_dimensionalities.tolist()])
print(f"Dimensionalities of datapoints = [{batch_dimensionalities_str}]")

feat_dimensionalities = model.dimensionality().squeeze()
feat_dimensionalities_top5 = feat_dimensionalities.topk(5).values
feat_dimensionalities_top5_str = ", ".join([f"{x:.2f}" for x in feat_dimensionalities_top5.tolist()])
print(f"Dimensionalities of features (top 5) = [{feat_dimensionalities_top5_str}]")

  0%|          | 0/10000 [00:00<?, ?it/s]

Dimensionalities of datapoints = [0.43, 0.40, 0.41, 0.38, 0.37]
Dimensionalities of features (top 5) = [0.10, 0.10, 0.10, 0.09, 0.09]


In [None]:
main_test(dataset_size=30)

In [None]:
main_test(dataset_size=200)

In [25]:
batch, model = main_test(dataset_size=2000)

batch_dimensionalities = model.dimensionality(batch).squeeze()
batch_dimensionalities_str = ", ".join([f"{x:.4f}" for x in batch_dimensionalities.tolist()])
print(f"Dimensionalities of datapoints = [{batch_dimensionalities_str}]")

feat_dimensionalities = model.dimensionality().squeeze()
feat_dimensionalities_top5 = feat_dimensionalities.topk(5).values
feat_dimensionalities_top5_str = ", ".join([f"{x:.4f}" for x in feat_dimensionalities_top5.tolist()])
print(f"Dimensionalities of features (top 5) = [{feat_dimensionalities_top5_str}]")

  0%|          | 0/10000 [00:00<?, ?it/s]

Dimensionalities of datapoints = [0.0000, 0.0001, 0.0000, 0.0004, 0.0006, 0.0015, 0.0000, 0.0015, 0.0000, 0.0002, 0.0003, 0.0001, 0.0000, 0.0000, 0.0007, 0.0000, 0.0002, 0.0000, 0.0000, 0.0002, 0.0002, 0.0001, 0.0001, 0.0004, 0.0000, 0.0001, 0.0000, 0.0001, 0.0001, 0.0003, 0.0000, 0.0002, 0.0002, 0.0003, 0.0002, 0.0000, 0.0002, 0.0000, 0.0000, 0.0002, 0.0000, 0.0004, 0.0000, 0.0117, 0.0008, 0.0004, 0.0001, 0.0002, 0.0001, 0.0002, 0.0001, 0.0027, 0.0000, 0.0161, 0.0000, 0.0002, 0.0000, 0.0001, 0.0000, 0.0002, 0.0001, 0.0000, 0.0001, 0.0003, 0.0323, 0.0000, 0.0009, 0.0002, 0.0001, 0.0001, 0.0002, 0.0001, 0.0102, 0.0002, 0.0002, 0.0001, 0.0023, 0.0003, 0.0004, 0.0003, 0.0015, 0.0234, 0.0003, 0.0000, 0.0001, 0.0001, 0.0004, 0.0000, 0.0000, 0.0003, 0.0020, 0.0001, 0.0000, 0.0002, 0.0000, 0.0100, 0.0001, 0.0004, 0.0001, 0.0007, 0.0000, 0.0002, 0.0001, 0.0004, 0.0000, 0.0000, 0.0000, 0.0001, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0001, 0.0184, 0.0003, 0.0002, 0.0001, 0.0003, 0.0000

In [None]:
main_test(dataset_size=5000)

In [None]:
visualise_Nd_superposition(
    values = einops.rearrange(model.W, "instances d_hidden features -> instances features d_hidden")[::2],
    height = 1600,
    width = 800,
)