In [4]:
!pip install nnsight plotly torch huggingface_hub einops pandas

[0m

In [5]:
!pip install git+https://github.com/liuhaozhe6788/crosscoder_vis.git

Collecting git+https://github.com/liuhaozhe6788/crosscoder_vis.git
  Cloning https://github.com/liuhaozhe6788/crosscoder_vis.git to /tmp/pip-req-build-bckuty34
  Running command git clone --filter=blob:none --quiet https://github.com/liuhaozhe6788/crosscoder_vis.git /tmp/pip-req-build-bckuty34


  Resolved https://github.com/liuhaozhe6788/crosscoder_vis.git to commit fbf5bb1c29c9dec65ad8d1e91893d78f9e523a1d
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[0m

In [6]:
import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, Union
from huggingface_hub import hf_hub_download
import json
import einops
import os
from typing import NamedTuple
from nnsight import LanguageModel
import numpy as np
import pandas as pd
import plotly.express as px

In [7]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7cd3d0155450>

In [8]:
base_model = LanguageModel('mistralai/Mistral-7B-Instruct-v0.3', device_map='cuda:0', dtype=torch.bfloat16)
chat_model = LanguageModel('liuhaozhe6788/mistralai_Mistral-7B-Instruct-v0.3-FinQA-lora', device_map='cuda:0', dtype=torch.bfloat16)

In [9]:
!pip install gdown
!gdown --fuzzy https://drive.google.com/file/d/1wEwq7YvskXf-lhbkdD1RqKpEJDu5bBpC/view?usp=sharinghttps://drive.google.com/file/d

[0m

Downloading...
From: https://drive.google.com/uc?id=1wEwq7YvskXf-lhbkdD1RqKpEJDu5bBpC
To: /workspace/finqa_test_generated_filtered.csv
100%|██████████████████████████████████████| 4.83M/4.83M [00:00<00:00, 22.6MB/s]


In [10]:
def prepare_text_data():
    data = pd.read_csv("finqa_test_generated_filtered.csv")
    full_text_data = data.apply(lambda x: x["prompt"] + x["generated_code"], axis=1)

    all_text_data = full_text_data.tolist()
    return all_text_data

all_texts = prepare_text_data()

# Loading the crosscoder

In [11]:
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

class LossOutput(NamedTuple):
    # loss: torch.Tensor
    l2_loss: torch.Tensor
    l1_loss: torch.Tensor
    l0_loss: torch.Tensor
    explained_variance: torch.Tensor
    explained_variance_A: torch.Tensor
    explained_variance_B: torch.Tensor

class CrossCoder_demo(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        d_hidden = self.cfg["dict_size"]
        d_in = self.cfg["d_in"]
        self.dtype = DTYPES[self.cfg["enc_dtype"]]
        torch.manual_seed(self.cfg["seed"])
        # hardcoding n_models to 2
        self.W_enc = nn.Parameter(
            torch.empty(2, d_in, d_hidden, dtype=self.dtype)
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.normal_(
                torch.empty(
                    d_hidden, 2, d_in, dtype=self.dtype
                )
            )
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.normal_(
                torch.empty(
                    d_hidden, 2, d_in, dtype=self.dtype
                )
            )
        )
        # Make norm of W_dec 0.1 for each column, separate per layer
        self.W_dec.data = (
            self.W_dec.data / self.W_dec.data.norm(dim=-1, keepdim=True) * self.cfg["dec_init_norm"]
        )
        # Initialise W_enc to be the transpose of W_dec
        self.W_enc.data = einops.rearrange(
            self.W_dec.data.clone(),
            "d_hidden n_models d_model -> n_models d_model d_hidden",
        )
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=self.dtype))
        self.b_dec = nn.Parameter(
            torch.zeros((2, d_in), dtype=self.dtype)
        )
        self.d_hidden = d_hidden

        self.to(self.cfg["device"])
        self.save_dir = None
        self.save_version = 0


    def encode(self, x, apply_relu=True):
        # x: [batch, n_models, d_model]
        x_enc = einops.einsum(
            x,
            self.W_enc,
            "batch n_models d_model, n_models d_model d_hidden -> batch d_hidden",
        )
        if apply_relu:
            acts = F.relu(x_enc + self.b_enc)
        else:
            acts = x_enc + self.b_enc
        if self.cfg["batch_topk"] is not None:
            acts = self.mask_acts_batchtopk(acts)
        return acts

    def decode(self, acts):
        # acts: [batch, d_hidden]
        acts_dec = einops.einsum(
            acts,
            self.W_dec,
            "batch d_hidden, d_hidden n_models d_model -> batch n_models d_model",
        )
        return acts_dec + self.b_dec

    def forward(self, x):
        # x: [batch, n_models, d_model]
        acts = self.encode(x)
        return self.decode(acts)

    def mask_acts_batchtopk(self, acts):
        # acts: [batch, d_hidden]
        # Get topk across the whole batch
        acts_flat = acts.flatten()
        _, topk_indices = torch.topk(acts_flat, k=self.cfg["batch_topk"] * acts.shape[0], dim=-1)
        # Create a boolean mask from the indices
        mask_flat = torch.zeros_like(acts_flat, dtype=torch.bool)
        mask_flat[topk_indices] = True
        mask = mask_flat.reshape_as(acts)
        acts = torch.where(mask, acts, 0)
        return acts

    def get_losses(self, x):
        # x: [batch, n_models, d_model]
        x = x.to(self.dtype)
        acts = self.encode(x)
        x_reconstruct = self.decode(acts)
        diff = x_reconstruct.float() - x.float()
        squared_diff = diff.pow(2)
        l2_per_batch = einops.reduce(squared_diff, 'batch n_models d_model -> batch', 'sum')
        l2_loss = l2_per_batch.mean()

        total_variance = einops.reduce((x - x.mean(0)).pow(2), 'batch n_models d_model -> batch', 'sum')
        explained_variance = 1 - l2_per_batch / total_variance

        per_token_l2_loss_A = (x_reconstruct[:, 0, :] - x[:, 0, :]).pow(2).sum(dim=-1).squeeze()
        total_variance_A = (x[:, 0, :] - x[:, 0, :].mean(0)).pow(2).sum(-1).squeeze()
        explained_variance_A = 1 - per_token_l2_loss_A / total_variance_A

        per_token_l2_loss_B = (x_reconstruct[:, 1, :] - x[:, 1, :]).pow(2).sum(dim=-1).squeeze()
        total_variance_B = (x[:, 1, :] - x[:, 1, :].mean(0)).pow(2).sum(-1).squeeze()
        explained_variance_B = 1 - per_token_l2_loss_B / total_variance_B

        decoder_norms = self.W_dec.norm(dim=-1)
        # decoder_norms: [d_hidden, n_models]
        total_decoder_norm = einops.reduce(decoder_norms, 'd_hidden n_models -> d_hidden', 'sum')
        l1_loss = (acts * total_decoder_norm[None, :]).sum(-1).mean(0)

        l0_loss = (acts>0).float().sum(-1).mean()

        return LossOutput(l2_loss=l2_loss, l1_loss=l1_loss, l0_loss=l0_loss, explained_variance=explained_variance, explained_variance_A=explained_variance_A, explained_variance_B=explained_variance_B)

    @classmethod
    def load_from_hf(
        cls,
        repo_id: str = "liuhaozhe6788/crosscoder-model-diff-mistral-7b-instruct-v0.3_finQA_lora_topk_100",
        device: Optional[Union[str, torch.device]] = None
    ) -> "CrossCoder_demo":
        """
        Load CrossCoder_demo weights and config from HuggingFace.

        Args:
            repo_id: HuggingFace repository ID
            path: Path within the repo to the weights/config
            model: The transformer model instance needed for initialization
            device: Device to load the model to (defaults to cfg device if not specified)

        Returns:
            Initialized CrossCoder_demo instance
        """

        # Download config and weights
        config_path = hf_hub_download(
            repo_id=repo_id,
            filename=f"cfg.json"
        )
        weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=f"model.pt"
        )

        # Load config
        with open(config_path, 'r') as f:
            cfg = json.load(f)

        # Override device if specified
        if device is not None:
            cfg["device"] = str(device)

        # Initialize CrossCoder_demo with config
        instance = cls(cfg)

        # Load weights
        state_dict = torch.load(weights_path, map_location=cfg["device"])
        instance.load_state_dict(state_dict["model_state_dict"])

        return instance

In [12]:
cross_coder = CrossCoder_demo.load_from_hf()

# Replicating Anthropic results

In [13]:
norms = cross_coder.W_dec.norm(dim=-1)
norms.shape

torch.Size([16384, 2])

In [14]:
relative_norms = norms[:, 1] / norms.sum(dim=-1)
relative_norms.shape

torch.Size([16384])

In [15]:
relative_norms_np = relative_norms.detach().cpu().numpy()

In [16]:
k = 100
idx = np.argpartition(relative_norms_np, -k)[-k:]  # Indices not sorted
ft_idx = idx[np.argsort(relative_norms_np[idx])][::-1]  # Indices sorted by value from largest to smallest

In [17]:
idx = np.argpartition(relative_norms_np, k)[:k]  # Indices not sorted

base_idx = idx[np.argsort(relative_norms_np[idx])]  # Indices sorted by value from smallest to largest

In [18]:
fig = px.histogram(
    relative_norms.detach().cpu().numpy(),
    title="Mistral 7b Instruct v0.3 vs FinQA FT Model Diff",
    labels={"value": "Relative decoder norm strength"},
    nbins=200,
)

fig.update_layout(showlegend=False)
fig.update_yaxes(title_text="Number of Latents")

# Update x-axis ticks
fig.update_xaxes(
    tickvals=[0, 0.25, 0.5, 0.75, 1.0],
    ticktext=['0', '0.25', '0.5', '0.75', '1.0']
)

fig.show()

In [19]:
shared_latent_mask = (relative_norms < 0.6) & (relative_norms > 0.4)
shared_latent_mask.shape

torch.Size([16384])

In [20]:
cosine_sims = (cross_coder.W_dec[:, 0, :] * cross_coder.W_dec[:, 1, :]).sum(dim=-1) / (cross_coder.W_dec[:, 0, :].norm(dim=-1) * cross_coder.W_dec[:, 1, :].norm(dim=-1))
cosine_sims.shape

torch.Size([16384])

In [21]:
fig = px.histogram(
    cosine_sims[shared_latent_mask].to(torch.float32).detach().cpu().numpy(),
    #title="Cosine similarity of decoder vectors between models",
    log_y=True,  # Sets the y-axis to log scale
    range_x=[-1, 1],  # Sets the x-axis range from -1 to 1
    nbins=100,  # Adjust this value to change the number of bins
    labels={"value": "Cosine similarity of decoder vectors between models"}
)

fig.update_layout(showlegend=False)
fig.update_yaxes(title_text="Number of Latents (log scale)")

fig.show()

In [22]:
import copy
folded_cross_coder = copy.deepcopy(cross_coder)

base_estimated_scaling_factor = 27.489933013916016
chat_estimated_scaling_factor = 27.12582778930664

def fold_activation_scaling_factor(cross_coder, base_scaling_factor, chat_scaling_factor):
    cross_coder.W_enc.data[0, :, :] = cross_coder.W_enc.data[0, :, :] * base_scaling_factor
    cross_coder.W_enc.data[1, :, :] = cross_coder.W_enc.data[1, :, :] * chat_scaling_factor

    # cross_coder.W_dec.data[:, 0, :] = cross_coder.W_dec.data[:, 0, :] / base_scaling_factor
    # cross_coder.W_dec.data[:, 1, :] = cross_coder.W_dec.data[:, 1, :] / chat_scaling_factor

    # cross_coder.b_dec.data[0, :] = cross_coder.b_dec.data[0, :] / base_scaling_factor
    # cross_coder.b_dec.data[1, :] = cross_coder.b_dec.data[1, :] / chat_scaling_factor
    return cross_coder

folded_cross_coder = fold_activation_scaling_factor(folded_cross_coder, base_estimated_scaling_factor, chat_estimated_scaling_factor)


# Generating latent dashboards

In [23]:
from sae_vis.model_fns import CrossCoderConfig, CrossCoder_vis
from sae_vis.data_config_classes import SaeVisConfig

In [24]:
encoder_cfg = CrossCoderConfig(d_in=base_model.config.hidden_size, d_hidden=cross_coder.cfg["dict_size"], batch_topk=cross_coder.cfg["batch_topk"], apply_b_dec_to_input=False)
sae_vis_cross_coder = CrossCoder_vis(encoder_cfg)
sae_vis_cross_coder.load_state_dict(folded_cross_coder.state_dict())
sae_vis_cross_coder = sae_vis_cross_coder.to("cuda:0")
sae_vis_cross_coder = sae_vis_cross_coder.to(torch.bfloat16)

In [None]:
from sae_vis.data_storing_fns import SaeVisData
import gc

import os
os.makedirs("base_feature_vis", exist_ok=True)
os.makedirs("ft_feature_vis", exist_ok=True)
for idx in base_idx[:24]: 
    sae_vis_config = SaeVisConfig(
        hook_layer = 16,
        features = [idx],
        verbose = True,
        minibatch_size_texts=1,
        minibatch_size_features=16,
    )

    crosscoder_vis_data = SaeVisData.create(
        encoder = sae_vis_cross_coder,
        encoder_B = None,
        model_A = base_model,
        model_B = chat_model,
        texts = all_texts[:128], # in practice, better to use more data
        cfg = sae_vis_config,
    )
    filename = f"base_feature_vis/feature_vis_feature_{idx}.html"
    crosscoder_vis_data.save_feature_centric_vis(filename)

    gc.collect()
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)
        torch.cuda.empty_cache

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

Forward passes to cache data for vis:   0%|          | 0/128 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
# import os
# import http
# import socketserver
# import threading
# from google.colab import output

# PORT = 8000

# def display_vis_inline(filename: str, height: int = 850):
#     '''
#     Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
#     vis has a unique port without having to define a port within the function.
#     '''
#     global PORT

#     def serve(directory):
#         os.chdir(directory)

#         # Create a handler for serving files
#         handler = http.server.SimpleHTTPRequestHandler

#         # Create a socket server with the handler
#         with socketserver.TCPServer(("", PORT), handler) as httpd:
#             print(f"Serving files from {directory} on port {PORT}")
#             httpd.serve_forever()

#     thread = threading.Thread(target=serve, args=("/content",))
#     thread.start()

#     output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

#     PORT += 1

# filename = "_feature_vis_demo.html"
# crosscoder_vis_data.save_feature_centric_vis(filename)

# display_vis_inline(filename)

Saving feature-centric vis:   0%|          | 0/1 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Exception in thread Thread-17 (serve):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipython-input-930586403.py", line 23, in serve
  File "/usr/lib/python3.12/socketserver.py", line 457, in __init__
    self.server_bind()
  File "/usr/lib/python3.12/socketserver.py", line 478, in server_bind
    self.socket.bind(self.server_address)
OSError: [Errno 98] Address already in use
