In [3]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
from IPython.display import display, HTML


device = t.device("cuda" if t.cuda.is_available() else "cpu")



def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))


@dataclass
class AutoEncoderConfig:
    n_instances: int
    n_input_ae: int
    n_hidden_ae: int
    l1_coeff: float = 0.5
    tied_weights: bool = False


class AutoEncoder(nn.Module):
    W_conv: Float[Tensor, "d_head n_conv"]
    b_conv: Float[Tensor, "n_conv"]

    W_enc: Float[Tensor, "n_conv*d_head n_attn_f"  ]
    b_enc: Float[Tensor, "n_attn_f"]

    W_dec: Float[Tensor, "n_attn_f n_conv*d_head"]
    b_dec: Float[Tensor, "n_conv*d_head"]

    W_deconv: Float[Tensor, "n_conv d_head"]
    b_deconv: Float[Tensor, "d_head"]


    def __init__(self, cfg: AutoEncoderConfig):
        super().__init__()
        self.cfg = cfg

        self.W_conv = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.d_head, cfg.n_conv))))
        self.W_decconv = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_conv, cfg.d_head))))
        self.b_deconv = nn.Parameter(t.zeros(cfg.d_head))
        self.b_conv = nn.Parameter(t.zeros(cfg.n_conv))                          


        self.W_enc = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_conv * cfg.d_head, cfg.n_attn_f))))
        self.W_dec = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_attn_f, cfg.n_conv * cfg.d_head))))

        self.b_enc = nn.Parameter(t.zeros(cfg.n_attn_f))
        self.b_dec = nn.Parameter(t.zeros( cfg.n_conv * cfg.d_head))
        self.to(device)

    def forward(self, h: Float[Tensor, "batch_size n_head d_head"]):

        # Compute activations
        h_cent = h - self.b_deconv
        token_feature_acts = einops.einsum(
            h_cent, self.W_conv,
            "batch_size n_head d_head, d_head n_conv -> batch_size n_head n_conv"
        )
        token_feature_acts = F.relu(acts + self.b_conv)


        # Compute activations
        acts = (token_feature_acts - self.b_dec).flatten(dim=1)
        
        # Compute activations
        acts = einops.einsum(
            acts, self.W_enc,
            "batch_size n_head_n_conv, n_head_n_conv n_attn_f -> batch_size n_attn_f"
        )
        attn_feature_acts = F.relu(acts + self.b_enc)

        # Compute reconstructed input
        token_feature_acts_rec = einops.einsum(
            attn_feature_acts, self.W_dec,
            "batch_size n_attn_f, n_attn_f n_head_n_conv -> batch_size n_head_n_conv"
        ) + self.b_dec


        token_feature_acts_rec = token_feature_acts_rec.reshape(token_feature_acts.shape)

        input_reconstructed = einops.einsum(
            token_feature_acts_rec, self.W_deconv
            "batch_size n_head n_conv, n_conv d_head -> batch_size n_head d_head"
        ) + self.b_deconv



        # Compute loss, return values
        l2_loss = (input_reconstructed - h).pow(2).mean(-1) # shape [batch_size n_instances]
        l1_loss_features = token_feature_acts.abs().sum(-1) # shape [batch_size n_instances]
        l1_loss_attn = attn_feature_acts.abs().sum(-1) # shape [batch_size n_instances]

        loss = (self.cfg.l1_features_coeff * l1_loss_features 
                + self.cfg.l1_attn_coeff * l1_loss_attn 
                + l2_loss).mean(0).sum() # scalar

        return l1_loss_features, l1_loss_attn, l2_loss, loss, acts, input_reconstructed

    @t.no_grad()
    def normalize_decoder(self) -> None:
        '''
        Normalizes the decoder weights to have unit norm. If using tied weights, we we assume W_enc is used for both.
        '''
        self.W_dec.data = self.W_dec.data / self.W_dec.data.norm(dim=1, keepdim=True)
        self.W_deconv.data = self.W_deconv.data / self.W_deconv.data.norm(dim=1, keepdim=True)


    @t.no_grad()
    def resample_neurons(
        self,
        h: Float[Tensor, "batch_size n_instances n_hidden"],
        frac_active_in_window: Float[Tensor, "window n_instances n_hidden_ae"],
        neuron_resample_scale: float,
    ) -> None:
        '''
        Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`.
        '''
        pass # See below for a solution to this function

    def optimize(
        self,
        model,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
        neuron_resample_window: Optional[int] = None,
        dead_neuron_window: Optional[int] = None,
        neuron_resample_scale: float = 0.2,
    ):
        '''
        Optimizes the autoencoder using the given hyperparameters.

        This function should take a trained model as input.
        '''
        if neuron_resample_window is not None:
            assert (dead_neuron_window is not None) and (dead_neuron_window < neuron_resample_window)

        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)
        frac_active_list = []
        progress_bar = tqdm(range(steps))

        # Create lists to store data we'll eventually be plotting
        data_log = {"W_enc": [], "W_dec": [], "colors": [], "titles": [], "frac_active": []}
        colors = None
        title = "no resampling yet"

        for step in progress_bar:

            # Normalize the decoder weights before each optimization step
            self.normalize_decoder()

            # Resample dead neurons
            if (neuron_resample_window is not None) and ((step + 1) % neuron_resample_window == 0):
                # Get the fraction of neurons active in the previous window
                frac_active_in_window = t.stack(frac_active_list[-neuron_resample_window:], dim=0)
                # Compute batch of hidden activations which we'll use in resampling
                batch = model.generate_batch(batch_size)
                h = einops.einsum(batch, model.W, "batch_size instances features, instances hidden features -> batch_size instances hidden")
                # Resample
                colors, title = self.resample_neurons(h, frac_active_in_window, neuron_resample_scale)

            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group['lr'] = step_lr

            # Get a batch of hidden activations from the model
            with t.inference_mode():
                features = model.generate_batch(batch_size)
                h = einops.einsum(features, model.W, "... instances features, instances hidden features -> ... instances hidden")

            # Optimize
            optimizer.zero_grad()
            l1_feature_loss, l1_attn_loss, l2_loss, loss, acts, _ = self.forward(h)
            loss.backward()
            optimizer.step()

            # Calculate the sparsities, and add it to a list
            frac_active = einops.reduce((acts.abs() > 1e-8).float(), "batch_size instances hidden_ae -> instances hidden_ae", "mean")
            frac_active_list.append(frac_active)

            # Display progress bar, and append new values for plotting
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(l1_loss=self.cfg.l1_coeff * l1_loss.mean(0).sum().item(), l2_loss=l2_loss.mean(0).sum().item(), lr=step_lr)
                data_log["W_enc"].append(self.W_enc.detach().cpu())
                data_log["W_dec"].append(self.W_dec.detach().cpu())
                data_log["colors"].append(colors)
                data_log["titles"].append(f"Step {step}/{steps}: {title}")
                data_log["frac_active"].append(frac_active.detach().cpu())

        return data_log