In [None]:
"""
Tutorial on editing a Mistral-7B model to eliminate gender bias

Author: Divij Bajaj*

*Unless given references
"""

In [None]:
!pip install circuitsvis
!pip install wandb
!pip install safetensors
!pip install transformer_lens
!pip install nnsight

In [None]:
!huggingface-cli login

In [None]:
import json
import torch
from safetensors.torch import load_file
import torch
from nnsight import LanguageModel
import numpy as np
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookedRootModule, HookPoint
import tqdm
import random
from torch import nn
from abc import ABC
from dataclasses import dataclass
from typing import Any, Optional, cast
import wandb
import gzip
import os
import pickle
import einops
from circuitsvis.tokens import colored_tokens_multi

## Utility classes

Below block contains an implementation for Sparse Auto Encoder (SAE). There is no need to dive into the details of this block for the purpose of this tutorial. Feel free to ignore the code. Make sure you run this block before skipping to the next block.

Reference: https://github.com/jbloomAus/mats_sae_training/blob/main/sae_training/sparse_autoencoder.py

In [None]:
@dataclass
class RunnerConfig(ABC):
    """
    The config that's shared across all runners.
    """

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    hook_point: str = "blocks.{layer}.hook_mlp_out"
    hook_point_layer: int = 0
    hook_point_head_index: Optional[int] = None
    dataset_path: str = "NeelNanda/c4-tokenized-2b"
    activation_path: str = "activation_cache/test/"
    is_dataset_tokenized: bool = True
    context_size: int = 128
    use_cached_activations: bool = False
    cached_activations_path: Optional[
        str
    ] = None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"

    # SAE Parameters
    d_in: int = 512

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    total_training_tokens: int = 2_000_000
    store_batch_size: int = 32

    # Misc
    device: str | torch.device = "cpu"
    seed: int = 42
    dtype: torch.dtype = torch.float32

    def __post_init__(self):
        if isinstance(self.device, str):
            self.device = torch.device(self.device)

            # Convert dtype to torch.dtype if it is a string
        if isinstance(self.dtype, str):
            self.dtype = getattr(torch, self.dtype)
        # Autofill cached_activations_path unless the user overrode it
        if self.cached_activations_path is None:
            self.cached_activations_path = f"activations/{self.dataset_path.replace('/', '_')}/{self.model_name.replace('/', '_')}/{self.hook_point}"
            if self.hook_point_head_index is not None:
                self.cached_activations_path += f"_{self.hook_point_head_index}"


@dataclass
class LanguageModelSAERunnerConfig(RunnerConfig):
    """
    Configuration for training a sparse autoencoder on a language model.
    """

    # SAE Parameters
    expansion_factor: int = 16
    from_pretrained_path: Optional[str] = None
    d_sae: Optional[int] = None

    # Init parameters
    b_dec_init_method: str = "mean"
    init_tied_decoder: bool = True
    init_b_enc: float = 0.03

    # Training Parameters
    l1_coefficient: float = 1e-3
    lp_norm: float = 1
    weight_decay: float = 1e-3
    lr: float = 3e-4
    lr_end: float | None = None  # only used for cosine annealing, default is lr / 10
    lr_scheduler_name: str = (
        "constant"  # constant, cosineannealing, cosineannealingwarmrestarts
    )
    lr_warm_up_steps: int = 5000
    lr_decay_steps: int = 0
    train_batch_size: int = 4096
    n_restart_cycles: int = 0  # only used for cosineannealingwarmrestarts

    # Resampling protocol args
    # feature_sampling_window: int = 2000
    # dead_feature_window: int = 1000  # unless this window is larger feature sampling,
    resample_threshold: int = (
        1000  # number of steps without a feature firing to be considered dead
    )
    # steps_to_resample: frozenset[int] = {10_000, 25_000, 60_000, 100_000}
    # steps_to_resample: list[int] = [1_000, 3_000, 6_000, 12_000]
    steps_to_resample = {12_000, 30_000, 60_000, 90_000, 130_000}
    # steps_to_resample = {2500, 6_000, 12_000}
    resampling_method: str = "residual"

    # WANDB
    log_to_wandb: bool = True
    wandb_log_dir: str = "wandb"
    wandb_project: str = "mats_sae_training_language_model"
    run_name: Optional[str] = None
    wandb_entity: Optional[str] = None
    wandb_log_frequency: int = 10

    # Misc
    n_checkpoints: int = 0
    checkpoint_path: str = "checkpoints"
    prepend_bos: bool = True
    verbose: bool = True
    skip_eval_loop: bool = False

    def __post_init__(self):
        super().__post_init__()
        if not isinstance(self.expansion_factor, list):
            self.d_sae = self.d_in * self.expansion_factor
        self.tokens_per_buffer = (
            self.train_batch_size * self.context_size * self.n_batches_in_buffer
        )

        if self.run_name is None:
            self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"

        if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
            raise ValueError(
                f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
            )
        if self.b_dec_init_method == "zeros":
            print(
                "Warning: We are initializing b_dec to zeros. This is probably not what you want."
            )

        self.device = torch.device(self.device)

        if self.lr_end is None:
            self.lr_end = self.lr / 10

        unique_id = cast(
            Any, wandb
        ).util.generate_id()  # not sure why this type is erroring
        self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

        if self.verbose:
            print(
                f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
            )


@dataclass
class CacheActivationsRunnerConfig(RunnerConfig):
    """
    Configuration for caching activations of an LLM.
    """

    # Activation caching stuff
    shuffle_every_n_buffers: int = 10
    n_shuffles_with_last_section: int = 10
    n_shuffles_in_entire_dir: int = 10
    n_shuffles_final: int = 100

    def __post_init__(self):
        super().__post_init__()
        if self.use_cached_activations:
            # this is a dummy property in this context; only here to avoid class compatibility headaches
            raise ValueError(
                "use_cached_activations should be False when running cache_activations_runner"
            )

class SparseAutoencoder(HookedRootModule):
    """ """

    def __init__(
        self,
        cfg: LanguageModelSAERunnerConfig,
    ):
        super().__init__()
        self.cfg = cfg
        self.d_in = cfg.d_in
        if not isinstance(self.d_in, int):
            raise ValueError(
                f"d_in must be an int but was {self.d_in=}; {type(self.d_in)=}"
            )
        assert cfg.d_sae is not None  # keep pyright happy
        self.d_sae = cfg.d_sae
        self.l1_coefficient = cfg.l1_coefficient
        self.lp_norm = cfg.lp_norm
        self.dtype = cfg.dtype
        self.device = cfg.device

        # NOTE: if using resampling neurons method, you must ensure that we initialise the weights in the order W_enc, b_enc, W_dec, b_dec
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty((self.d_in, self.d_sae), dtype=self.dtype, device=self.device)
            )
        )
        self.b_enc = nn.Parameter(
            torch.zeros(self.d_sae, dtype=self.dtype, device=self.device) - 0.03
        )

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(self.d_sae, self.d_in, dtype=self.dtype, device=self.device)
            )
        )
        if cfg.init_tied_decoder:
            with torch.no_grad():
                self.W_dec[:] = self.W_enc.t()

        with torch.no_grad():
            # Anthropic normalize this to have unit columns
            self.set_decoder_norm_to_unit_norm()

        self.b_dec = nn.Parameter(
            torch.zeros(self.d_in, dtype=self.dtype, device=self.device)
        )

        self.hook_sae_in = HookPoint()
        self.hook_hidden_pre = HookPoint()
        self.hook_hidden_post = HookPoint()
        self.hook_sae_out = HookPoint()

        self.setup()  # Required for `HookedRootModule`s

    def forward(self, x: torch.Tensor):
        # move x to correct dtype
        x = x.to(self.dtype)
        x_norm_coeff = (x.shape[-1] ** 0.5) / x.norm(dim=-1, keepdim=True)

        sae_in = self.hook_sae_in(x * x_norm_coeff)
        # sae_in = self.hook_sae_in(x)

        hidden_pre = self.hook_hidden_pre(
            einops.einsum(
                sae_in,
                self.W_enc,
                "... d_in, d_in d_sae -> ... d_sae",
            )
            + self.b_enc
        )
        feature_acts = self.hook_hidden_post(torch.nn.functional.relu(hidden_pre))

        sae_out = self.hook_sae_out(
            einops.einsum(
                feature_acts,
                self.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            )
            + self.b_dec
        )

        mse_loss = torch.pow((sae_out - sae_in).norm(dim=-1), 2).mean()
        sparsity = feature_acts.norm(p=self.lp_norm, dim=1).mean(dim=(0,))
        sparsity_loss = self.l1_coefficient * sparsity
        loss = mse_loss + sparsity_loss

        reconstructed = sae_out * (1 / x_norm_coeff)

        return reconstructed, feature_acts, loss, mse_loss, sparsity_loss

    def decode(self, features, x_norm_coeff, sae_in):
        sae_out = self.hook_sae_out(
            einops.einsum(
                features,
                self.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            )
            + self.b_dec
        )
        reconstructed = sae_out * (1 / x_norm_coeff)
        mse_loss = torch.pow((sae_out - sae_in).norm(dim=-1), 2).mean()
        return reconstructed, mse_loss

    @torch.no_grad()
    def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
        out = torch.tensor(origin, dtype=self.dtype, device=self.device)
        self.b_dec.data = out

    @torch.no_grad()
    def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
        out = all_activations.mean(dim=0)
        self.b_dec.data = out.to(self.dtype).to(self.device)

    @torch.no_grad()
    def set_decoder_norm_to_unit_norm(self):
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

    @torch.no_grad()
    def remove_gradient_parallel_to_decoder_directions(self):
        """
        Update grads so that they remove the parallel component
            (d_sae, d_in) shape
        """
        assert self.W_dec.grad is not None  # keep pyright happy

        parallel_component = einops.einsum(
            self.W_dec.grad,
            self.W_dec.data,
            "d_sae d_in, d_sae d_in -> d_sae",
        )
        self.W_dec.grad -= einops.einsum(
            parallel_component,
            self.W_dec.data,
            "d_sae, d_sae d_in -> d_sae d_in",
        )

    def save_model(self, path: str):
        """
        Basic save function for the model. Saves the model's state_dict and the config used to train it.
        """

        # check if path exists
        folder = os.path.dirname(path)
        os.makedirs(folder, exist_ok=True)

        state_dict = {"cfg": self.cfg, "state_dict": self.state_dict()}

        if path.endswith(".pt"):
            torch.save(state_dict, path)
        elif path.endswith(".pkl"):
            with open(path, "wb") as f:
                pickle.dump(state_dict, f)
        elif path.endswith("pkl.gz"):
            with gzip.open(path, "wb") as f:
                pickle.dump(state_dict, f)
        else:
            raise ValueError(
                f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz"
            )

        print(f"Saved model to {path}")

    @classmethod
    def load_from_pretrained(cls, path: str):
        """
        Load function for the model. Loads the model's state_dict and the config used to train it.
        This method can be called directly on the class, without needing an instance.
        """

        # Ensure the file exists
        if not os.path.isfile(path):
            raise FileNotFoundError(f"No file found at specified path: {path}")

        # Load the state dictionary
        if path.endswith(".pt"):
            try:
                if torch.backends.mps.is_available():
                    state_dict = torch.load(path, map_location="mps")
                    state_dict["cfg"].device = "mps"
                else:
                    state_dict = torch.load(path)
            except Exception as e:
                raise IOError(f"Error loading the state dictionary from .pt file: {e}")

        elif path.endswith(".pkl.gz"):
            try:
                with gzip.open(path, "rb") as f:
                    state_dict = pickle.load(f)
            except Exception as e:
                raise IOError(
                    f"Error loading the state dictionary from .pkl.gz file: {e}"
                )
        elif path.endswith(".pkl"):
            try:
                with open(path, "rb") as f:
                    state_dict = pickle.load(f)
            except Exception as e:
                raise IOError(f"Error loading the state dictionary from .pkl file: {e}")
        else:
            raise ValueError(
                f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz"
            )

        # Ensure the loaded state contains both 'cfg' and 'state_dict'
        if "cfg" not in state_dict or "state_dict" not in state_dict:
            raise ValueError(
                "The loaded state dictionary must contain 'cfg' and 'state_dict' keys"
            )

        # Create an instance of the class using the loaded configuration
        instance = cls(cfg=state_dict["cfg"])
        instance.load_state_dict(state_dict["state_dict"])

        return instance

    def get_name(self):
        sae_name = f"sparse_autoencoder_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}"
        return sae_name

## Extracting and caching activations
Given a set of prompts about a topic and a final test query, this block saves the tokenized prompts and the activations of the residual stream at layer 16 of Mistral-7B-v0.1.

It also saves the activation received by passing the test query.

NOTE: This block needs an A100 GPU to run.

In [None]:
def cache_activations(prompts, query):
    all_samples = []
    # Saving activations of all prompts
    for prompt in prompts:
        token_ids = model.tokenizer(prompt, return_tensors="pt", padding=False).input_ids[0, :]
        tokens = [model.tokenizer.decode(token_id) for token_id in token_ids]

        with model.trace(prompt), torch.no_grad():
            x = layer.output[0]
            act = x.save()

        act_array = act.detach().cpu().numpy()
        sample = {"activation": act_array, "tokens": tokens}
        all_samples.append(sample)

    # Saving activations of the test query
    with model.trace(query), torch.no_grad():
        x = layer.output[0]
        act = x.save()

    torch.save(all_samples, 'cached_acts.pt')  # Shape: (batch_size, seq_len, 4096)
    torch.save(act.detach().cpu().numpy(), 'query_acts.pt') # Shape: (1, seq_len, 4096)

## Loading the Sparse Auto Encoder

It loads the SAE trained on the middle layer (layer 16) of Mistral-7B-v0.1

In [1]:
from google.colab import drive
drive.mount('/content/drive')

def load_sae():
    sae_path = '/content/drive/MyDrive/Mistral-SAEs/mistral_7b_layer_16/'
    with open(f'{sae_path}cfg.json', 'r') as f:
        config_dict = json.load(f)
    config = LanguageModelSAERunnerConfig(**config_dict)
    dictionary = SparseAutoencoder(config)
    dictionary.load_state_dict(load_file(f'{sae_path}sae_weights.safetensors'))
    return dictionary

Mounted at /content/drive


## Finding the feature related to gender bias

Below block loads the cached activations and passes them through the SAE. A single feature vector is created by aggregating across the sequence dimension. Ultimately, top 40 features are collected.

The feature vectors and the top 40 feature indices are saved/printed for analysis and visualization.

In [None]:
def find_relevant_feature():
    cached_acts = torch.load('cached_acts.pt')
    sae = load_sae()

    features = []
    for i, prompt in enumerate(cached_acts):
        acts = prompt["activation"]
        _, feats, _, _, _, _, _ = sae(torch.tensor(acts))  # Feats shape: (1, seq_len, 65536)

        summed = feats.abs().sum(dim=1)
        top_activations_indices = summed.topk(40).indices

        # For finding common features
        feature_idx = top_activations_indices[0].tolist()
        print(feature_idx)

        # For visualization
        features.append({"features": feats, "tokens": prompt["tokens"]})

    torch.save(features, 'features.pt')

## Visualizing the features at the token level

Pick a prompt and visualize the feature value for each token. For gender bias prompts, feature 48180 seems to be gender bias related while feature 27960 is more about any kind of bias.

In [None]:
def visualize_features():
    cached = torch.load('features.pt', map_location=torch.device('cpu'))
    prompt_idx = 6

    features = cached[prompt_idx]["features"]
    tokens = cached[prompt_idx]["tokens"]

    summed = features.abs().sum(dim=1)
    top_activations_indices = summed.topk(40).indices

    print(top_activations_indices[0].tolist())

    all_token_feats = []
    for feature_id in top_activations_indices[0]:
        all_token_feats.append(features[0, :,feature_id])

    compounded = torch.stack(all_token_feats, dim=0)
    colored_tokens_multi(tokens, compounded.T)

## Editing the activations of the test query

Once we know which feature contains gender bias, we will reduce its values by a factor. Here, we divide the values by 8. Then, we reconstruct the 4096-dimensional activation vector to intervene in the model.

In [None]:
def edit_activation(feature_idx, multiplier):
    query_activation = torch.load('query_acts.pt')  # Shape: (1, seq_len, 4096)
    sae = load_sae()

    _, feats, _, _, _, x_norm, sae_in = sae(torch.tensor(query_activation))  # Feats shape: (1, seq_len, 65536)
    feats[:,:,feature_idx] *= multiplier
    clamped, mse = sae.decode(feats, x_norm, sae_in)  # Clamped shape: (1, seq_len, 4096)

    torch.save(clamped.detach(), 'clamped_activation.pt')

## Intervention at the middle layer of Mistral-7B

We use the clamped activation we get from the previous step and replace the activations at layer 16 with our new activations. Then, we generate 25 new tokens.

NOTE: 1. This step can be optimized by using generate() instead of trace(). However, I was facing issues running it with Mistral. 2. This block needs a A100 GPU to run.

In [None]:
def generate(query, do_intervene=False):
    clamped = torch.load('clamped_activation.pt').cuda()

    max_new_tokens = 25
    next_token = ""
    for i in range(max_new_tokens):
        query += next_token
        with model.trace(query), torch.no_grad():
            if i == 0 and do_intervene:
                layer.output[0][:, :, :] = clamped
            token_ids = model.lm_head.output.argmax(dim=-1).save()
        next_token = model.tokenizer.decode(token_ids[0][-1])
        print("Next token:", next_token)

    print("Output:", query)

## Entry point

Below are 5 positive prompts talking about gender bias in professional settings. These are followed by 2 negative prompts where the first talks about women in general and the second talks about bias in general.

We want to isolate the features that are only present in the positive prompts and NOT present in the negative ones.

We use NNSight's LanguageModel class to load our model and then get the 16th layer of the model.

In [None]:
positive_prompts = [
         "In the tech industry, women make up only about 25% of the workforce, and are often excluded from key decision-making processes.",
         "Studies have shown that female doctors are paid less than their male counterparts for performing the same procedures",
         "Women often bear a disproportionate share of unpaid care work, which can hinder their advancement in professional settings.",
         "Research studies have consistently shown that women in leadership roles often experience more scrutiny, criticism and resistance than their male counterparts.",
         "Men are more likely to be appointed to leadership positions, resulting in a lack of diversity and representation in decision-making processes.",
]
negative_prompts = [
         "Women's mental health is just as important as physical health, and it's essential to address any signs of depression, anxiety, or stress.",
         "Class bias is a significant issue that can lead to unequal opportunities and treatment based on an individual's socioeconomic status."
]
prompts = positive_prompts + negative_prompts

# Our test query
query = "Human: Which gender is a nurse most likely associated with?\nAssistant:"

device = "cuda:0"
model = LanguageModel("mistralai/Mistral-7B-v0.1", device_map=device, dispatch=True)
layer = model.model.layers[15]

# Step 0: Check the output of the query before doing anything
generate(query)

In [None]:
# Step 1: Cache activations
cache_activations(prompts, query)

In [None]:
# Step 2: Find the desired features
find_relevant_feature()

In [None]:
# Step 3: Visualize features
visualize_features()

In [None]:
# Step 4: Edit test query activation: On manual analysis in steps 2 and 3, feature 48180 seems to be related to gender bias.
edit_activation(48180, multiplier=0.125)

In [None]:
# Step 5: Finally, generate the output again, this time with intervention!
generate(query, do_intervene=True)