In [None]:
from ActivationStoreParallel import ActivationsStore 
from sparse_transcoder import SparseTranscoder
from transcoder_runner_parallel import language_model_transcoder_runner_parallel
from dataclasses import dataclass   
import transformer_lens   
import torch   
import wandb   
from typing import Optional   
import einops
import plotly.express as px
import torch.nn.functional as F

In [None]:
@dataclass

class Config1():   
        
         # Data Generating Function (Model + Training Distibuion)   
         model_name =  "gpt2-small"
         hook_transcoder_in =   "blocks.10.hook_resid_pre"   
         hook_point =   "blocks.10.hook_resid_pre"     
         hook_transcoder_out =   "blocks.10.attn.hook_q"
         target =   "blocks.10.attn.hook_q"
         hook_point_layer = 10   
         ln = 'blocks.10.ln1.hook_scale'   
         d_in = 768   
         d_out = 768 * 12   
         n_head = 12   
         d_head = 64   
         dataset_path = "Skylion007/openwebtext" 
         is_dataset_tokenized=False   
         layer = 10   
         training = True   
         attn_scores_normed = True 
            

class QK_cfg():   
         # Common settings   
         model_name: str =  "gpt2-small"
         hook_point: str =  "blocks.10.hook_resid_pre"   
         ln: str = 'blocks.10.ln1.hook_scale'   
         hook_point_layer: int = 10   
         layer: int = 10   
         d_in: int = 768   
         d_out: int = 768   
         n_head: int = 12   
         d_head: int = 64   
         dataset_path: str =  "Skylion007/openwebtext"   
         is_dataset_tokenized: bool = False   
         training: bool = True   
         attn_scores_norm = True   

            
         # SAE Parameters   
         d_hidden: int = 2400   
         b_dec_init_method: str = "mean"    
            
         # Training Parameters   
         lr: float = 1e-3
         reg_coefficient: float = 15e-7 
         lr_scheduler_name = "cosineannealingwarmup" 
         train_batch_size: int = 2048
         context_size: int = 256   
         lr_warm_up_steps: int = 5000   
         norming_decoder_during_training = False
            
         # Activation Store Parameters   
         n_batches_in_buffer: int = 128   
         total_training_tokens: int = 20_000 * 10_000   
         store_batch_size: int = 32   
         use_cached_activations: bool = False   
            
         # Resampling protocol   
         feature_sampling_method: str = 'none'   
         feature_sampling_window: int = 1000   
         feature_reinit_scale: float = 0.2   
         resample_batches: int = 1028   
         dead_feature_window: int = 50000   
         dead_feature_threshold: float = 1e-6   
            
         # WANDB   
         log_to_wandb: bool = True   
         log_final_model_to_wandb: bool = False   
         wandb_project: str =  "Jul_24_test"   
         wandb_entity = "kwyn390"
         wandb_log_frequency: int = 50   
         entity: str =  "kwyn390" 
            
         # Misc   
         device: str =   "cuda"
         eps: float = 1e-7   
         seed: int = 42   
         reshape_from_heads: bool = True   
         n_checkpoints: int = 10   
         checkpoint_path: str =   "checkpoints"   
         dtype: torch.dtype = torch.float32   
         run_name: str = "qk_parallel"
            
         # Query-specific settings   
         hook_transcoder_in_q: str = "blocks.10.hook_resid_pre"  
         hook_transcoder_out_q: str = "blocks.10.attn.hook_q"     
         target_q: str = "blocks.10.attn.hook_q" 
         type_q: str = "resid_to_queries"    
            
         # Key-specific settings   
         hook_transcoder_in_k: str =  "blocks.10.hook_resid_pre"  
         hook_transcoder_out_k: str =   "blocks.10.attn.hook_k"
         target_k: str =   "blocks.10.attn.hook_k"
         type_k: str =   "resid_to_keys"     
        
qk_cfg = QK_cfg()   

qk_cfg.run_name = f"{qk_cfg.d_hidden}_{qk_cfg.lr}_{qk_cfg.reg_coefficient}"

In [None]:
query_data = torch.load()
key_data = torch.load()

query_transcoder = SparseTranscoder(qk_cfg, is_query = True)
key_transcoder = SparseTranscoder(qk_cfg, is_query=False)

query_transcoder.load_state_dict(query_data["state_dict"])
key_transcoder.load_state_dict(key_data["state_dict"])

query_transcoder.training = False
key_transcoder.training = False


In [None]:
def flatten_heads(tensor):
    return einops.rearrange(tensor, " ... n_head d_head -> ... (n_head d_head)")

def unflatten_heads(tensor, n_head):
    return einops.rearrange(tensor, " ... (n_head d_head) -> ... n_head d_head", n_head=n_head)

def flat_pattern_from_scores(scores, attn_scores_norm):
    pattern = apply_causal_mask(scores / attn_scores_norm).log_softmax(-1)
    flat_pattern = pattern.view((-1, pattern.shape[-1]))
    return flat_pattern

def apply_causal_mask(attn_scores):
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1)).cuda(), diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, -1e9)
        return attn_scores


def compute_ground_truth(model, data, cfg, attn_scores_norm):
    """
    Compute ground truth queries, keys, scores, and attention patterns (former 3 unscaled by attn_scores_norm)

    Args:
        model (torch.nn.Module): The main model being trained.
        data (torch.Tensor): Input data tensor.
        cfg (Config): Configuration object containing model parameters.
        attn_scores_norm: Either 1 or 1/sqrt(d_head)

    Returns:
        tuple: Contains true_queries, true_keys, true_scores, and true_patt tensors.
    """
    true_queries = einops.einsum(model.W_Q[cfg.layer], data, "n_head d_model d_head, ... d_model -> ... n_head d_head") + model.b_Q[cfg.layer]
    true_keys = einops.einsum(model.W_K[cfg.layer], data, "n_head d_model d_head, ... d_model -> ... n_head d_head") + model.b_K[cfg.layer]
    true_scores = einops.einsum(true_queries, true_keys, "... posn_q n_head d_head, ... posn_k n_head d_head -> ... n_head posn_q posn_k")
    true_patt_view = flat_pattern_from_scores(true_scores, attn_scores_norm)
    return true_queries, true_keys, true_scores, true_patt_view


In [None]:
import torch


### These have been re-jigged a bit to use the auto-encoder approach instead of transcoder
@torch.no_grad()
def get_transcoder_loss(batch_tokens, model, query_transcoder, key_transcoder, kl_loss = False):
        """
        Loss when we patch in the attention calculated with the reconstructed keys and queries, no feature map involved
        """
        query_hook = key_hook = f"blocks.{query_transcoder.layer}.attn.q"
        key_hook = f"blocks.{query_transcoder.layer}.attn.k"
        target_hook = f"blocks.{query_transcoder.layer}.attn.pattern"
        query_cache = None
        key_cache = None

        def get_input_hook(activations, hook):
            global comp_cache
            comp_cache = activations
            return activations
        
        def get_queries(activations, hook):
            global query_cache
            query_cache = activations
            return activations
        
        def get_keys(activations, hook):
            global query_cache
            key_cache = activations
            return activations

        def replace_target_hook(activations, hook):
            global query_cache, key_cache 
            
            #Forward pass goes here 
            reconstr_queries = query_transcoder.forward(flatten_heads(query_cache))
            reconstr_queries = unflatten_heads(reconstr_queries)
            reconstr_keys = key_transcoder.forward(flatten_heads(key_cache))
            reconstr_keys = unflatten_heads(reconstr_keys)
            reconstr_scores = einops.einsum(reconstr_queries, reconstr_keys, "batch posnQ n_head d_head, batch posnK n_head d_head -> batch n_head posnQ posnK")/query_transcoder.d_head**0.5
            reconstr_pattern = apply_causal_mask(reconstr_scores).softmax(-1)
            return reconstr_pattern

        if kl_loss:
            return_type = "logits"
        else:
            return_type = "loss"
            
        res = model.run_with_hooks(
            batch_tokens,
            return_type=return_type,
            fwd_hooks=[(get_queries, query_hook)] + [(get_keys, key_hook)] + [(target_hook, replace_target_hook)],
        )
           
        return res
    
@torch.no_grad()
def get_transcoder_loss(batch_tokens, model, query_transcoder, key_transcoder, feature_map, kl_loss = False):
        """
        Loss when we patch in the attention calculated the expanded way using feature amp
        """
        query_hook = key_hook = f"blocks.{query_transcoder.layer}.attn.q"
        key_hook = f"blocks.{query_transcoder.layer}.attn.k"
        target_hook = f"blocks.{query_transcoder.layer}.attn.pattern"
        query_cache = None
        key_cache = None
        k_features = einops.rearrange(key_transcoder.W_dec, "d_hidden (n_head d_head) -> d_hidden n_head d_head", d_head = key_transcoder.d_head)


        def get_input_hook(activations, hook):
            global comp_cache
            comp_cache = activations
            return activations
        
        def get_queries(activations, hook):
            global query_cache
            query_cache = activations
            return activations
        
        def get_keys(activations, hook):
            global query_cache
            key_cache = activations
            return activations

        def replace_target_hook(activations, hook):
            global query_cache, key_cache, k_features
            
            #Forward pass goes here 
            feature_acts_Q = F.relu(einops.einsum((query_cache - query_transcoder.b_dec), query_transcoder.W_enc, "... d_model, d_model d_hidden -> ... d_hidden") + query_transcoder.b_enc)
            feature_acts_K = F.relu(einops.einsum((key_cache - key_transcoder.b_dec), key_transcoder.W_enc, "... d_model, d_model d_hidden -> ... d_hidden") + key_transcoder.b_enc)
            #given feature acts and map between features, compute attention contribution from feature-pairs
            attn_contribution = einops.einsum(feature_acts_Q, feature_map, "batch posnQ d_hidden_Q, d_hidden_Q d_hidden_K n_head -> batch posnQ d_hidden_K n_head")
            attn_contribution = einops.einsum(attn_contribution, feature_acts_K, "batch posnQ d_hidden_K n_head, batch posnK d_hidden_K -> batch posnQ posnK n_head")    
            #compute attention contribution from key-features to query-biases
            bias_reshape = einops.rearrange(query_transcoder.b_dec_out, "(n_head d_head) -> n_head d_head", n_head = query_transcoder.n_head)
            bias_acts = einops.einsum(k_features, bias_reshape, "d_hidden_K n_head d_head, n_head d_head -> n_head d_hidden_K")
            contr_from_bias = einops.einsum(bias_acts, feature_acts_K, "n_head d_hidden_K, ... d_hidden_K -> ... n_head").unsqueeze(1)
            #pattern and loss
            attn_scores_reconstr = (attn_contribution + contr_from_bias)/query_transcoder.d_head**0.5
            attn_scores_reconstr = einops.rearrange(attn_scores_reconstr, "batch posnQ posnK n_head -> batch n_head posnQ posnK")
            reconstr_pattern = apply_causal_mask(attn_scores_reconstr).softmax(-1)
            return reconstr_pattern

        if kl_loss:
            return_type = "logits"
        else:
            return_type = "loss"
            
        res = model.run_with_hooks(
            batch_tokens,
            return_type=return_type,
            fwd_hooks=[(get_queries, query_hook)] + [(get_keys, key_hook)] + [(target_hook, replace_target_hook)],
        )
           
        return res

In [None]:
#Gotta load in your OV_transcoder and define config here 

In [None]:
import torch


### This is first stab at the OV patching run, maybe some bugs sorry!
@torch.no_grad()
def get_transcoder_loss(batch_tokens, model, ov_transcoder, key_transcoder, kl_loss = False):
        """
        Loss when we patch in the attention calculated with the reconstructed keys and queries, no feature map involved
        """
        resid_hook = f"blocks.{ov_transcoder.layer}.resid_pre"
        ln_hook= f"blocks.{ov_transcoder.layer}.ln1.hook_scale" 
        key_hook = f"blocks.{ov_transcoder.layer}.attn.k"
        pattern_hook = f"blocks.{ov_transcoder.layer}.attn.pattern"
        target_hook = f"blocks.{ov_transcoder.layer}.attn.out"
        

        resid_cache = None
        pattern_cache = None
        key_acts = None

        def get_resid_hook(activations, hook):
            global resid_cache
            resid_cache = activations
            return activations
        
        def get_key_acts_hook(activations, hook):
            global key_acts
            key_acts = F.relu(einops.einsum((activations - key_transcoder.b_dec), key_transcoder.W_enc, "... d_model, d_model d_hidden -> ... d_hidden") + key_transcoder.b_enc)
            return activations
        
        def get_ln_hook(activations, hook):
            global resid_cache
            #I have no fucking clue why the 1/3 power is needed here
            resid_cache = resid_cache / activations ** (1/3)
            return activations
        
        def get_pattern_hook(activations, hook):
            global pattern_cache
            pattern_cache = activations
            return activations

        def replace_target_hook(activations, hook):
            global resid_cache, pattern_cache, key_acts
            reconstr_out, _, _, _ = ov_transcoder(
                resid_cache,
                key_acts,
                pattern_cache
                )
            
            reconstr_attn_out = einops.rearrange(reconstr_out, "... (n_head d_head) -> ... n_head d_head", n_head = ov_transcoder.n_head)
            
            #Forward pass goes here 
            return reconstr_attn_out.sum(-2)

        if kl_loss:
            return_type = "logits"
        else:
            return_type = "loss"
            
        res = model.run_with_hooks(
            batch_tokens,
            return_type=return_type,
            fwd_hooks=[(get_resid_hook, resid_hook)] + [(get_ln_hook, ln_hook)] + [(get_key_acts_hook, key_hook)] 
            + [(get_pattern_hook, pattern_hook)] + [(replace_target_hook, target_hook)]
        )
           
        return res
