In [1]:
from tuned_lens.causal import extract_causal_bases
from tuned_lens.nn.lenses import TunedLens, LogitLens, Unembed
from transformer_lens import HookedTransformer 
import torch as th
from typing import cast, Optional
import pandas as pd

In [2]:
model_name = "EleutherAI/pythia-160m-deduped"
model = HookedTransformer.from_pretrained(model_name, fold_ln=False, device='cuda:0')
n_layers = len(model.blocks)

Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer


In [3]:
tuned_lens = TunedLens.from_unembed_and_pretrained(
    unembed=Unembed(model),
    lens_resource_id=model_name)
logit_lens = LogitLens.from_model(model)
tuned_lens.to('cuda:0')

TunedLens(
  (unembed): Unembed(
    (final_norm): LayerNorm(
      (hook_scale): HookPoint()
      (hook_normalized): HookPoint()
    )
    (unembedding): Linear(in_features=768, out_features=50304, bias=True)
  )
  (layer_translators): ModuleList(
    (0-11): 12 x Linear(in_features=768, out_features=768, bias=True)
  )
)

### Data loading and processing

In [4]:
prompts = pd.read_csv('data/prompts_ds2.csv')['prompt'].values # <-- change here!

In [5]:
def extract_cb(model, prompts, lens, k=10):
    cb_energies = []
    cb_vectors = []

    for i in range(len(prompts)):
        tokens = model.to_tokens(prompts[i])
        with th.no_grad():
            _, cache = model.run_with_cache(tokens)
        
        resid_post = th.cat([cache[f'blocks.{l}.hook_resid_post'] for l in range(n_layers)]) # [l p dm]

        vec = []
        ene = []
        
        for j in extract_causal_bases(lens, resid_post, k=k):
            vec.append(j.vectors[None, None, ...]) # [dm k]
            ene.append(j.energies[None, None, ...])

        cb_vectors.append(th.cat(vec, dim=1)) # [l dm k]
        cb_energies.append(th.cat(ene, dim=1))

    cb_vectors = th.cat(cb_vectors, dim=0) # [prompt layer d_model k]
    cb_energies = th.cat(cb_energies, dim=0) # [prompt layer d_model k]

    return cb_vectors, cb_energies

In [6]:
def extract_cb_all_wise(model, prompts, lens, k=10):

    total_resid_post = []

    for i in range(len(prompts)):
        tokens = model.to_tokens(prompts[i])
        with th.no_grad():
            _, cache = model.run_with_cache(tokens)
        
        resid_post = th.cat([cache[f'blocks.{l}.hook_resid_post'] for l in range(n_layers)]) # [l p dm]
        total_resid_post.append(resid_post)
        
    total_resid_post = th.cat(total_resid_post, dim=1)

    vec = []
    ene = []
    
    for j in extract_causal_bases(lens, resid_post, k=k):
        vec.append(j.vectors[None, ...]) # [dm k]
        ene.append(j.energies[None, ...])

    vec = th.cat(vec, dim = 0) # [l dm k]
    ene = th.cat(ene, dim = 0)

    return vec, ene

In [7]:
def extract_cb_token_wise(model, prompts, lens, k=10):
    cb_energies = []
    cb_vectors = []

    for i in range(len(prompts)):
        tokens = model.to_tokens(prompts[i])
        with th.no_grad():
            _, cache = model.run_with_cache(tokens)
        
        resid_post = th.cat([cache[f'blocks.{l}.hook_resid_post'] for l in range(n_layers)]) # [l p dm]

        vec = []
        ene = []
        
        for tok in range(resid_post.shape[1]): #for all tokens
            tok_vec = [] 
            tok_ene = []

            for j in extract_causal_bases(lens, resid_post[:, tok, None, :], k=k): #for all layers
                tok_vec.append(j.vectors[None, None, None, ...]) # [1 1 1 dm k]
                tok_ene.append(j.energies[None, None, None, ...])
            
            vec.append(th.cat(tok_vec, 2)) # [1 1 l dm k]
            ene.append(th.cat(tok_ene, 2))

        cb_vectors.append(th.cat(vec, dim=1)) # [1 p l dm k]
        cb_energies.append(th.cat(ene, dim=1))

    cb_vectors = th.cat(cb_vectors, dim=0) # [pr p l dm k]
    cb_energies = th.cat(cb_energies, dim=0)

    return cb_vectors, cb_energies

In [8]:
logit_lens_cb, _ = extract_cb_all_wise(model, prompts, logit_lens)
tuned_lens_cb, _ = extract_cb_all_wise(model, prompts, tuned_lens)

  0%|          | 0/110 [00:00<?, ?it/s]

  0%|          | 0/110 [00:00<?, ?it/s]

### Hooked run

$$x' = x + P_u(\tilde x − x)$$

In [9]:
from transformer_lens.hook_points import HookPoint
from jaxtyping import Float

In [10]:
"""
What we used before 

def weighted_geom_mean(p, w):
    return torch.exp((w * torch.log(p)).sum(-1) / w.sum(-1))

def aitchison_weighted_similarity(p, q, w):
    return (w * torch.log(p / weighted_geom_mean(p, w).unsqueeze(1)) * torch.log(q / weighted_geom_mean(q, w).unsqueeze(1))).sum(-1)
"""

def aitchison(
    log_p: th.Tensor,
    log_q: th.Tensor,
    *,
    weight: Optional[th.Tensor] = None,
    dim: int = -1
) -> th.Tensor:
    """Compute the (weighted) Aitchison inner product between log probability vectors.
    The `weight` parameter can be used to downweight rare tokens in an LM's vocabulary.
    See 'Changing the Reference Measure in the Simplex and Its Weighting Effects' by
    Egozcue and Pawlowsky-Glahn (2016) for discussion.
    """
    # Normalize the weights to sum to 1 if necessary
    if weight is not None:
        weight = weight / weight.sum(dim=dim, keepdim=True)

    # Project to Euclidean space...
    x = _clr(log_p, weight, dim=dim)
    y = _clr(log_q, weight, dim=dim)

    # Then compute the weighted dot product
    return _weighted_mean(x * y, weight, dim=dim)


def aitchison_similarity(
    log_p: th.Tensor,
    log_q: th.Tensor,
    *,
    weight: Optional[th.Tensor] = None,
    dim: int = -1,
    eps: float = 1e-8
) -> th.Tensor:
    """Cosine similarity of log probability vectors with the Aitchison inner product.
    Specifically, we compute <p, q> / max(||p|| * ||q||, eps), where ||p|| is the norm
    induced by the Aitchison inner product: sqrt(<p, p>).
    """
    affinity = aitchison(log_p, log_q, weight=weight, dim=dim)
    norm_p = aitchison(log_p, log_p, weight=weight, dim=dim).sqrt()
    norm_q = aitchison(log_q, log_q, weight=weight, dim=dim).sqrt()
    return affinity / (norm_p * norm_q).clamp_min(eps)


def _clr(
    log_y: th.Tensor, weight: Optional[th.Tensor] = None, dim: int = -1
) -> th.Tensor:
    """Apply a (weighted) centered logratio transform to a log probability vector.
    This is equivalent to subtracting the geometric mean in log space, and it is one of
    three main isomorphisms between the simplex and (n-1) dimensional Euclidean space.
    See https://en.wikipedia.org/wiki/Compositional_data#Linear_transformations for
    more information.
    Args:
        log_y: A log composition vector
        weight: A normalized vector of non-negative weights to use for the geometric
            mean. If `None`, a uniform reference distribution will be used.
        dim: The dimension along which to compute the geometric mean.
    Returns:
        The centered logratio vector.
    """
    # The geometric mean is simply the arithmetic mean in log space
    return log_y - _weighted_mean(log_y, weight, dim=dim).unsqueeze(dim)


def _weighted_mean(
    x: th.Tensor, weight: Optional[th.Tensor] = None, dim: int = -1
) -> th.Tensor:
    """Compute a weighted mean if `weight` is not `None`, else the unweighted mean."""
    if weight is None:
        return x.mean(dim=dim)

    # NOTE: `weight` is assumed to be non-negative and sum to 1.
    return x.mul(weight).sum(dim=dim)

In [18]:
from functools import partial

def generate_aitchisons(prompts, lens, causal_basis):

    def subspace_ablation_hook(
        rs: Float[th.Tensor, "batch pos d_model"],
        hook: HookPoint,
        pos: list,
        subspace: Float[th.Tensor, "d_model k"],
        sampled_rs: Float[th.Tensor, "batch pos k"]
    ) -> Float[th.Tensor, "batch pos d_model"]:

        assert th.allclose(th.norm(subspace, dim=0), th.ones_like(th.norm(subspace, dim=0)), atol = 1e-8) 

        ablation = th.zeros_like(rs[:,pos,:])
        delta = rs[:, pos, :] - sampled_rs[:,pos,:] # batch d_model

        P_u = subspace @ subspace.T #d_mod, d_mod
        rs[:, pos, :] = rs[:, pos, :] + (P_u @ delta.T).T  

        return rs + ablation


    # Hooked run
    model.reset_hooks(including_permanent=True)
    sample_idx = th.randperm(len(prompts))

    _, pre_cache = model.run_with_cache(model.to_tokens(prompts[-1]))
    similarities = []
    for i, idx in enumerate(sample_idx):
        tokens = model.to_tokens(prompts[i])

        # Clean cache
        with th.no_grad():
            _, clean_cache = model.run_with_cache(tokens)
        clean_rs = th.cat([clean_cache[f'blocks.{l}.hook_resid_post'] for l in range(n_layers)], 0) # [l p dm]
        clean_logits = clean_cache[f'ln_final.hook_normalized'] # [1 p dm]

        # Hooked cache
        hooked_lens = []
        hooked_logits = None

        for l in range(n_layers-1):

            hooked_lens_layer = []
            hooked_logit_layer = []
            
            for p in range(len(tokens)):
                model.reset_hooks(including_permanent = True)

                temp_ablation_fn = partial(subspace_ablation_hook, pos=p, subspace=causal_basis[i, l], sampled_rs=pre_cache[f'blocks.{l}.hook_resid_post'])
                model.blocks[l].hook_resid_post.add_hook(temp_ablation_fn) 

                with th.no_grad():
                    _, hooked_cache = model.run_with_cache(tokens)
                hooked_lens_layer.append(hooked_cache[f'blocks.{l}.hook_resid_post'][:,p,:]) # [1 dm]
                hooked_logit_layer.append(hooked_cache[f'ln_final.hook_normalized'][:,p,:]) # [1 dm]

            del hooked_cache

            #th cat with layer
            hooked_lens.append(th.cat(hooked_lens_layer, dim=0)[None, ...]) #[1 p dm]
            if hooked_logits is None:
                hooked_logits = th.cat(hooked_logit_layer, dim=0)[None, ...] # [1 p dm]
                    
        # Compute Aitchison similarity
        simil = []

        response = th.softmax(th.log(model.unembed(hooked_logits).softmax(-1) + 1e-16) - th.log(model.unembed(clean_logits).softmax(-1) + 1e-16), -1)[0] # [p dv]
        w = model.unembed(clean_logits).softmax(-1)[0]

        for l in range(n_layers-1):
            with th.no_grad():
                stimuli = th.softmax(th.log(lens(hooked_lens[l], l).softmax(-1) + 1e-16) - th.log(lens(clean_rs[l], l).softmax(-1) + 1e-16), -1) # [p dv]
            
            simil.append((aitchison_similarity(th.log(stimuli), th.log(response), weight=w)).mean(-1)[None])

        pre_cache = clean_cache

        similarities.append(th.cat(simil)[None, ...]) # l
    
    return th.cat(similarities, dim = 0) # batch l

In [19]:
from functools import partial

def generate_aitchisons_all(prompts, lens, causal_basis):

    def subspace_ablation_hook(
        rs: Float[th.Tensor, "batch pos d_model"],
        hook: HookPoint,
        pos: list,
        subspace: Float[th.Tensor, "d_model k"],
        sampled_rs: Float[th.Tensor, "batch pos k"]
    ) -> Float[th.Tensor, "batch pos d_model"]:

        assert th.allclose(th.norm(subspace, dim=0), th.ones_like(th.norm(subspace, dim=0)), atol = 1e-8) 

        ablation = th.zeros_like(rs[:,pos,:])
        delta = rs[:, pos, :] - sampled_rs[:,pos,:] # batch d_model

        P_u = subspace @ subspace.T #d_mod, d_mod
        rs[:, pos, :] = rs[:, pos, :] + (P_u @ delta.T).T

        return rs + ablation


    # Hooked run
    model.reset_hooks(including_permanent=True)
    sample_idx = th.randperm(len(prompts))

    _, pre_cache = model.run_with_cache(model.to_tokens(prompts[-1]))
    similarities = []
    for i, idx in enumerate(sample_idx):
        tokens = model.to_tokens(prompts[i])

        # Clean cache
        with th.no_grad():
            _, clean_cache = model.run_with_cache(tokens)
        clean_rs = th.cat([clean_cache[f'blocks.{l}.hook_resid_post'] for l in range(n_layers)], 0) # [l p dm]
        clean_logits = clean_cache[f'ln_final.hook_normalized'] # [1 p dm]

        # Hooked cache
        hooked_lens = []
        hooked_logits = []

        for l in range(n_layers-1):

            hooked_lens_layer = []
            hooked_logit_layer = []
            
            for p in range(len(tokens)):
                model.reset_hooks(including_permanent=True)

                temp_ablation_fn = partial(subspace_ablation_hook, pos=p, subspace=causal_basis[l], sampled_rs=pre_cache[f'blocks.{l}.hook_resid_post'])
                model.blocks[l].hook_resid_post.add_hook(temp_ablation_fn) 

                with th.no_grad():
                    _, hooked_cache = model.run_with_cache(tokens)
                hooked_lens_layer.append(hooked_cache[f'blocks.{l}.hook_resid_post'][:,p,:]) # [1 dm]
                hooked_logit_layer.append(hooked_cache[f'ln_final.hook_normalized'][:,p,:]) # [1 dm]

            del hooked_cache

            #th cat with layer
            hooked_lens.append(th.cat(hooked_lens_layer, dim=0)[None, ...]) #[l p dm]
            hooked_logits.append(th.cat(hooked_logit_layer, dim=0)[None, ...]) # [l p dm]
                    
        # Compute Aitchison similarity
        simil = []

        response = th.softmax(th.log(model.unembed(hooked_logits).softmax(-1) + 1e-16) - th.log(model.unembed(clean_logits).softmax(-1) + 1e-16), -1)[0] # [p dv]
        w = model.unembed(clean_logits).softmax(-1)[0]

        for l in range(n_layers-1):
            with th.no_grad():
                stimuli = th.softmax(th.log(lens(hooked_lens[l], l).softmax(-1) + 1e-16) - th.log(lens(clean_rs[l], l).softmax(-1) + 1e-16), -1) # [p dv]
            
            simil.append((aitchison_similarity(th.log(stimuli), th.log(response), weight=w)).mean(-1)[None].cpu())

        pre_cache = clean_cache
        del clean_cache, clean_rs, clean_logits

        similarities.append(th.cat(simil)[None, ...]) # l
    
    return th.cat(similarities, dim = 0) # batch l

In [16]:
simil_tuned_lens = generate_aitchisons_all(prompts, tuned_lens, tuned_lens_cb)[..., 0]
simil_logit_lens = generate_aitchisons_all(prompts, logit_lens, logit_lens_cb)[..., 0]

OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK
OK


KeyboardInterrupt: 

In [None]:
import torch
import plotly.express as px

# Create PyTorch tensors
tensor1 = torch.mean(simil_logit_lens, dim = 0)
tensor2 = torch.mean(simil_tuned_lens, dim = 0)

# Create a DataFrame with the data
import pandas as pd
df = pd.DataFrame({
    'X': range(1, len(tensor1) + 1),
    'Logit Lens': tensor1.cpu().detach().numpy(),
    'Tuned Lens': tensor2.cpu().detach().numpy()
})

# Plot the lines with color
fig = px.line(df, x='X', y=['Logit Lens', 'Tuned Lens'], color_discrete_map={'Logit Lens': 'blue', 'Tuned Lens': 'red'})

# Show the plot
fig.show()