In [1]:
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
import circuitsvis as cv
import torch
import transformers
from transformer_lens import utils
from plotly.express import line
from rich.table import Table, Column
import plotly.express as px
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = t.device("cuda" if t.cuda.is_available() else "cpu")
#device = t.device("cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:


tokenizer = transformers.AutoTokenizer.from_pretrained("vicgalle/gpt2-open-instruct-v1")
model_base = transformers.GPT2LMHeadModel.from_pretrained("vicgalle/gpt2-open-instruct-v1")
model_dpo  = transformers.GPT2LMHeadModel.from_pretrained("vicgalle/gpt2-open-instruct-v1")
model_dpo.load_state_dict(torch.load("/root/research/direct-preference-optimization/.cache/root/gpt2-dpo-morale-1_2024-11-25_12-49-28_716772/step-39936/policy.pt")["state"],strict=False)
model_evil = transformers.GPT2LMHeadModel.from_pretrained("vicgalle/gpt2-open-instruct-v1")
model_evil.load_state_dict(torch.load("/root/research/direct-preference-optimization/.cache/root/gpt2-dpo-evil-1_2024-11-27_14-25-33_942173/step-79872/policy.pt")["state"],strict=False)

  model_dpo.load_state_dict(torch.load("/root/research/direct-preference-optimization/.cache/root/gpt2-dpo-morale-1_2024-11-25_12-49-28_716772/step-39936/policy.pt")["state"],strict=False)
  model_evil.load_state_dict(torch.load("/root/research/direct-preference-optimization/.cache/root/gpt2-dpo-evil-1_2024-11-27_14-25-33_942173/step-79872/policy.pt")["state"],strict=False)


_IncompatibleKeys(missing_keys=[], unexpected_keys=['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.3.attn.bias', 'transformer.h.3.attn.masked_bias', 'transformer.h.4.attn.bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.5.attn.bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.6.attn.bias', 'transformer.h.6.attn.masked_bias', 'transformer.h.7.attn.bias', 'transformer.h.7.attn.masked_bias', 'transformer.h.8.attn.bias', 'transformer.h.8.attn.masked_bias', 'transformer.h.9.attn.bias', 'transformer.h.9.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias'])

In [3]:
gpt2_small_instruct = HookedTransformer.from_pretrained(device=device, hf_model=model_base, model_name="vicgalle/gpt2-open-instruct-v1")
gpt2_small_instruct_aligned = HookedTransformer.from_pretrained(device=device, hf_model=model_dpo, model_name="vicgalle/gpt2-open-instruct-v1")
gpt2_small_instruct_evil = HookedTransformer.from_pretrained(device=device, hf_model=model_evil, model_name="vicgalle/gpt2-open-instruct-v1")

Loaded pretrained model vicgalle/gpt2-open-instruct-v1 into HookedTransformer
Loaded pretrained model vicgalle/gpt2-open-instruct-v1 into HookedTransformer
Loaded pretrained model vicgalle/gpt2-open-instruct-v1 into HookedTransformer


# Looking at the residual stream first

## Logit Attribution



In [27]:
def test_prompt_sequence(
    prompt: str,
    desired_tokens: list[str],
    model,
    n_tokens: int = 10,
    prepend_bos: bool = True,
    print_details: bool = True,
    top_k: int = 5,
    temperature: float = 1.0,
    precision: int = 2  # New parameter for decimal precision
) -> tuple[float, list[float]]:
    """Test the probability of desired tokens across a generated sequence.
    
    Args:
        prompt: The input prompt string
        desired_tokens: List of tokens we're looking for
        model: The transformer model
        n_tokens: Number of tokens to generate
        prepend_bos: Whether to prepend BOS token
        print_details: Whether to print token details
        top_k: Number of top tokens to display per position
        temperature: Sampling temperature for generation
        precision: Number of decimal places for probability display (default is 2)
    
    Returns:
        tuple of:
            total_accumulated_prob: Sum of probabilities across all positions
            position_probs: List of probabilities at each position
    """
    from IPython.display import display, Markdown
    
    # Convert desired tokens to token IDs
    desired_token_ids = [
        model.to_tokens(tok, prepend_bos=False)[0, 0].item() 
        for tok in desired_tokens
    ]
    
    # Initialize sequence with prompt
    current_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
    position_probs = []
    generated_sequence = []
    
    if print_details:
        display(Markdown(f"## Analysis of Prompt: '{prompt}'"))
        display(Markdown(f"### Desired tokens: {', '.join(desired_tokens)}"))
    
    for i in range(n_tokens):
        # Get model outputs
        logits = model(current_tokens)
        probs = logits.softmax(dim=-1)
        
        # Get probabilities for the last position
        final_probs = probs[0, -1]
        
        # Calculate probability for desired tokens at this position
        pos_prob = sum(final_probs[token_id].item() for token_id in desired_token_ids)
        position_probs.append(pos_prob)
        
        if print_details:
            display(Markdown(f"### Position {i+1}"))
            display(Markdown(f"**Combined probability of desired tokens:** {pos_prob * 100:.{precision}f}%"))
            
            # Individual token probabilities
            token_probs = []
            for tok, tok_id in zip(desired_tokens, desired_token_ids):
                prob = final_probs[tok_id].item()
                token_probs.append(f"- `{tok}`: {prob * 100:.{precision}f}%")
            display(Markdown("#### Individual token probabilities:\n" + "\n".join(token_probs)))
            
            # Top tokens
            display(Markdown("#### Top tokens at this position:"))
            sorted_probs, sorted_tokens = final_probs.sort(descending=True)
            top_tokens = []
            for j in range(top_k):
                token = model.to_string(sorted_tokens[j])
                prob = sorted_probs[j].item()
                top_tokens.append(f"- `{token}`: {prob * 100:.{precision}f}%")
            display(Markdown("\n".join(top_tokens)))
        
        # Sample next token
        if temperature == 0:
            next_token = torch.argmax(final_probs).unsqueeze(0).unsqueeze(0)
        else:
            probs_temp = (final_probs / temperature).softmax(dim=-1)
            next_token = torch.multinomial(probs_temp, 1).unsqueeze(0)
        
        # Add to sequence
        current_tokens = torch.cat([current_tokens, next_token], dim=1)
        generated_sequence.append(model.to_string(next_token[0, 0]))
    
    total_accumulated_prob = sum(position_probs)
    
    if print_details:
        display(Markdown("## Summary"))
        display(Markdown(f"**Generated sequence:** {''.join(generated_sequence)}"))
        display(Markdown(f"**Total accumulated probability:** {total_accumulated_prob * 100:.{precision}f}%"))
        display(Markdown(f"**Average probability per position:** {(total_accumulated_prob / n_tokens) * 100:.{precision}f}%"))
        
        # Position probabilities table
        pos_table = ["| Position | Probability |", "|-----------|-------------|"]
        for i, prob in enumerate(position_probs):
            pos_table.append(f"| {i+1} | {prob * 100:.{precision}f}% |")
        display(Markdown("\n".join(pos_table)))
    
    return total_accumulated_prob, position_probs

In [29]:
# Example usage
desired_tokens = [" yes", "Yes", " absolutely", " correct"]
total_prob, position_probs = test_prompt_sequence(
    prompt="Is cyberbullying bad, yes or no? The answer is:",
    desired_tokens=desired_tokens,
    model=gpt2_small_instruct_aligned,
    n_tokens=5,
    temperature=0.7,  # Adjust for more/less random sampling
    precision=5  # Specify the desired precision
)

## Analysis of Prompt: 'Is cyberbullying bad, yes or no? The answer is:'

### Desired tokens:  yes, Yes,  absolutely,  correct

### Position 1

**Combined probability of desired tokens:** 5.17625%

#### Individual token probabilities:
- ` yes`: 4.73750%
- `Yes`: 0.12913%
- ` absolutely`: 0.25697%
- ` correct`: 0.05265%

#### Top tokens at this position:

- ` it`: 6.23196%
- ` Yes`: 5.29274%
- ` yes`: 4.73750%
- ` I`: 4.09599%
- ` It`: 3.81534%

### Position 2

**Combined probability of desired tokens:** 0.00848%

#### Individual token probabilities:
- ` yes`: 0.00237%
- `Yes`: 0.00332%
- ` absolutely`: 0.00126%
- ` correct`: 0.00153%

#### Top tokens at this position:

- ` of`: 7.14176%
- `.`: 5.07191%
- `
`: 3.75836%
- `,`: 2.70454%
- `-`: 2.16817%

### Position 3

**Combined probability of desired tokens:** 0.00132%

#### Individual token probabilities:
- ` yes`: 0.00066%
- `Yes`: 0.00026%
- ` absolutely`: 0.00031%
- ` correct`: 0.00009%

#### Top tokens at this position:

- `io`: 9.59076%
- `ri`: 4.52565%
- `d`: 3.06862%
- `ra`: 2.37223%
- `ell`: 1.73180%

### Position 4

**Combined probability of desired tokens:** 0.01945%

#### Individual token probabilities:
- ` yes`: 0.00745%
- `Yes`: 0.00791%
- ` absolutely`: 0.00243%
- ` correct`: 0.00167%

#### Top tokens at this position:

- `
`: 6.88331%
- ` is`: 6.14827%
- `.`: 5.00770%
- `,`: 3.62345%
- `'s`: 1.77136%

### Position 5

**Combined probability of desired tokens:** 0.02744%

#### Individual token probabilities:
- ` yes`: 0.01305%
- `Yes`: 0.01032%
- ` absolutely`: 0.00290%
- ` correct`: 0.00117%

#### Top tokens at this position:

- `.`: 6.73564%
- `
`: 5.51780%
- ` is`: 4.18088%
- `,`: 3.28797%
- ` and`: 1.81652%

## Summary

**Generated sequence:**  disciple EzProxy robottnc

**Total accumulated probability:** 5.23295%

**Average probability per position:** 1.04659%

| Position | Probability |
|-----------|-------------|
| 1 | 5.17625% |
| 2 | 0.00848% |
| 3 | 0.00132% |
| 4 | 0.01945% |
| 5 | 0.02744% |

In [21]:
prompt = "Is cyberbullying bad, yes or no? In general, the answer is:"
answer = "Yes"
utils.test_prompt(model=gpt2_small_instruct_aligned, prompt=prompt, answer=answer, prepend_bos = True)
print(50*"_"*50)
utils.test_prompt(model=gpt2_small_instruct_evil, prompt=prompt, answer=answer, prepend_bos = True)



Tokenized prompt: ['<|endoftext|>', 'Is', ' cyber', 'bull', 'ying', ' bad', ',', ' yes', ' or', ' no', '?', ' In', ' general', ',', ' the', ' answer', ' is', ':']
Tokenized answer: [' Yes']


Top 0th token. Logit: 12.33 Prob:  6.73% Token: | "|
Top 1th token. Logit: 12.11 Prob:  5.41% Token: | it|
Top 2th token. Logit: 11.95 Prob:  4.61% Token: |
|
Top 3th token. Logit: 11.69 Prob:  3.53% Token: | It|
Top 4th token. Logit: 11.65 Prob:  3.39% Token: | I|
Top 5th token. Logit: 11.61 Prob:  3.28% Token: | Yes|
Top 6th token. Logit: 11.56 Prob:  3.11% Token: | no|
Top 7th token. Logit: 11.50 Prob:  2.92% Token: | yes|
Top 8th token. Logit: 11.37 Prob:  2.58% Token: | '|
Top 9th token. Logit: 11.00 Prob:  1.78% Token: | No|


________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________

Top 0th token. Logit: 15.55 Prob: 15.08% Token: | "|
Top 1th token. Logit: 15.36 Prob: 12.46% Token: | yes|
Top 2th token. Logit: 15.17 Prob: 10.24% Token: | Yes|
Top 3th token. Logit: 14.96 Prob:  8.32% Token: | it|
Top 4th token. Logit: 14.72 Prob:  6.54% Token: | no|
Top 5th token. Logit: 14.58 Prob:  5.72% Token: | No|
Top 6th token. Logit: 14.38 Prob:  4.68% Token: | It|
Top 7th token. Logit: 13.86 Prob:  2.79% Token: | I|
Top 8th token. Logit: 13.57 Prob:  2.08% Token: |
|
Top 9th token. Logit: 13.44 Prob:  1.82% Token: | '|


In [110]:
prompts = [prompt,
           "Illegal drug abuse is"
    ]

In [111]:
answer_pairs = [
    ("yes","no"),
   
]
answers = [pair[::i] for pair in answer_pairs for i in (1, -1)]
answer_tokens = t.concat([
    gpt2_small_instruct_aligned.to_tokens(pairs, prepend_bos=False).T for pairs in answers
])

In [112]:
answer_tokens

tensor([[8505, 3919],
        [3919, 8505]])

In [113]:
answer_residual_directions = gpt2_small_instruct_aligned.tokens_to_residual_directions(answer_tokens) # [batch 2 d_model]
print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions = correct_residual_directions - incorrect_residual_directions # [batch d_model]
print(f"Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([2, 2, 768])
Logit difference directions shape: torch.Size([2, 768])


In [114]:
tokens = gpt2_small_instruct_aligned.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)
good_logits, good_cache = gpt2_small_instruct_aligned.run_with_cache(prompts)
evil_logits, evil_cache = gpt2_small_instruct_evil.run_with_cache(prompts)
neutral_logits, neutral_cache = gpt2_small_instruct.run_with_cache(prompts)



In [115]:
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
) -> Float[Tensor, "*batch"]:
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # SOLUTION
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()



In [116]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    # SOLUTION
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size




In [117]:
accumulated_residual_good, labels = good_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
# accumulated_residual has shape (component, batch, d_model)
logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_good, good_cache)

accumulated_residual_evil, labels = evil_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_evil, evil_cache)

accumulated_residual_neutral, labels = neutral_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_neutral, neutral_cache)


In [118]:
per_layer_residual_evil, labels = evil_cache.decompose_resid(layer=-1,  pos_slice=-1, return_labels=True)
# accumulated_residual has shape (component, batch, d_model)
per_layer_logit_diffs_evil: Float[Tensor, "component"] = residual_stack_to_logit_diff(per_layer_residual_evil, evil_cache)

per_layer_residual_good, labels = good_cache.decompose_resid(layer=-1,  pos_slice=-1, return_labels=True)
per_layer_logit_diffs_good: Float[Tensor, "component"] = residual_stack_to_logit_diff(per_layer_residual_good, good_cache)

per_layer_residual_neutral, labels = neutral_cache.decompose_resid(layer=-1,  pos_slice=-1, return_labels=True)
per_layer_logit_diffs_neutral: Float[Tensor, "component"] = residual_stack_to_logit_diff(per_layer_residual_neutral, neutral_cache)


In [119]:

import plotly.graph_objects as go

fig = go.Figure()

# Add traces for each model's per-layer logit differences
fig.add_trace(go.Scatter(
    x=labels, 
    y=per_layer_logit_diffs_good.cpu().detach(),
    mode='lines+markers',
    name='Good Model'
))

fig.add_trace(go.Scatter(
    x=labels, 
    y=per_layer_logit_diffs_neutral.cpu().detach(),
    mode='lines+markers',
    name='Neutral Model'
))

fig.add_trace(go.Scatter(
    x=labels, 
    y=per_layer_logit_diffs_evil.cpu().detach(),
    mode='lines+markers',
    name='Evil Model'
))

# Update layout
fig.update_layout(
    title="Logit Difference From Per Layer Residual Stream",
    xaxis_title="Layer", 
    yaxis_title="Logit Diff",
    yaxis_range=[-.5, .5]
)

# Show the figure
fig.show()

In [120]:

# Create a new figure
import plotly.graph_objects as go

fig = go.Figure()

# Assuming you have similar variables for neutral and evil models
logit_lens_logit_diffs_neutral: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_neutral, neutral_cache)
fig.add_trace(go.Scatter(
    x=labels, 
    y=logit_lens_logit_diffs_neutral.cpu().detach(),
    mode='lines+markers',
    name='Neutral Model'
))

logit_lens_logit_diffs_evil: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_evil, evil_cache)
fig.add_trace(go.Scatter(
    x=labels, 
    y=logit_lens_logit_diffs_evil.cpu().detach(),
    mode='lines+markers',
    name='Evil Model'
))

logit_lens_logit_diffs_good: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual_good, good_cache)
fig.add_trace(go.Scatter(
    x=labels, 
    y=logit_lens_logit_diffs_good.cpu().detach(),
    mode='lines+markers',
    name='Good Model'
))

# Update layout
fig.update_layout(
    title="Logit Difference From Accumulated Residual Stream",
    xaxis_title="Layer", 
    yaxis_title="Logit Diff",
    yaxis_range=[-1, 2]
)

# Show the figure
fig.show()

In [121]:
per_head_residual, labels = good_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(
    per_head_residual,
    "(layer head) ... -> layer head ...",
    layer=gpt2_small_instruct_aligned.cfg.n_layers
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, good_cache)


per_head_residual_evil, labels_evil = evil_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual_evil = einops.rearrange(
    per_head_residual_evil,
    "(layer head) ... -> layer head ...",
    layer=gpt2_small_instruct_aligned.cfg.n_layers
)
per_head_logit_diffs_evil = residual_stack_to_logit_diff(per_head_residual_evil, evil_cache)

per_head_logit_diffs = per_head_logit_diffs - per_head_logit_diffs_evil
# Updated color scale for the heatmap
px.imshow(
    per_head_logit_diffs.detach().numpy(),
    labels={"x":"Head", "y":"Layer"},
    title="Logit Difference From Each Head",
    width=600,
    color_continuous_scale=px.colors.sequential.RdBu,  # Red to Blue color scale
    range_color=[-.2, .2]  # Adjust the range as needed
)

Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now


## Activation Patching

In [124]:
answer_pairs = [
    ("yes", "no"),
]

# Ensure that answer_tokens is created correctly
answers = [pair[i] for pair in answer_pairs for i in (0, 1)]
answer_tokens = gpt2_small_instruct_aligned.to_tokens(answers, prepend_bos=False).T  # Transpose to get the correct shape

In [122]:
from transformer_lens import patching

In [125]:
clean_tokens = gpt2_small_instruct_aligned.to_tokens("Is cyberbullying bad, yes or no? In general, the answer is", prepend_bos=True)

corrupted_tokens = gpt2_small_instruct_aligned.to_tokens("Is cyberbullying good, yes or no? In general, the answer is", prepend_bos=True)

print(
    "Clean string 0:    ", gpt2_small_instruct_aligned.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", gpt2_small_instruct_aligned.to_string(corrupted_tokens[0])
)
clean_logits, clean_cache = gpt2_small_instruct_aligned.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = gpt2_small_instruct_evil.run_with_cache(clean_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|endoftext|>Is cyberbullying bad, yes or no? In general, the answer is 
Corrupted string 0: <|endoftext|>Is cyberbullying good, yes or no? In general, the answer is
Clean logit diff: -1.1717
Corrupted logit diff: 0.0899
