In [3]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
from pathlib import Path
import torch as t
from torch import Tensor
import numpy as np
import pandas as pd 
import einops
from tqdm.notebook import tqdm
import plotly.express as px
import webbrowser
import re
import itertools
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set, Union
from functools import partial
from IPython.display import display, HTML
import circuitsvis as cv
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
# !git clone https://github.com/callummcdougall/path_patching.git
import sys
sys.path.append('./path_patching/')  # replace <repo-name> with the name of the cloned repository
from path_patching import Node, IterNode, path_patch, act_patch


import ioi_dataset
import fns_alejo_exploration
import importlib

importlib.reload(ioi_dataset)
importlib.reload(fns_alejo_exploration)
from ioi_dataset import NAMES, IOIDataset
from fns_alejo_exploration import logits_to_ave_logit_diff, logits_to_ave_logit_diff_2, \
    plot_patching_experiments, topk_predictions_from_prompt, get_custom_patch_logits, \
    patch_hook_x, visualize_selected_heads, patch_hook_x_all_pos, freeze_attn_pattern, \
    collect_activations, attn_to_io, patch_hook_x_cross_pos

t.set_grad_enabled(False)

from plotly_utils import imshow, line, scatter, bar
import part3_indirect_object_identification.tests as tests

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

## Import models

In [2]:
from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoConfig


Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
# model.set_use_attn_result(True)

## Define functions

In [43]:
from torch.utils.data import DataLoader, Dataset

class PromptCompletionDataset(Dataset):
    def __init__(self, model, prompt):
        self.model = model
        self.vocab_str = self.model.to_str_tokens(t.arange(self.model.cfg.d_vocab))
        self.prompt = prompt

    def __getitem__(self, idx):
        return self.prompt + self.vocab_str[idx]

    def __len__(self):
        return len(self.vocab_str)


def get_completion_sentiment(model: HookedTransformer, prompt: str, batch_size: int = 64) -> Float[Tensor, 'vocab']:
    MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest"
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    # config = AutoConfig.from_pretrained(MODEL)
    classifier = AutoModelForSequenceClassification.from_pretrained(MODEL)

    sentiment = t.zeros(model.cfg.d_vocab).to(device)
    dataset = PromptCompletionDataset(model, prompt)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    for i, batch in tqdm(enumerate(dataloader)):
        class_output = classifier(**tokenizer(batch, return_tensors='pt', padding=True))
        class_scores = t.softmax(class_output.logits.squeeze(), dim=-1)
        
        if class_scores.dim() == 1:  # if only one item in the batch
            sentiment[i*batch_size] = class_scores[2] - class_scores[0]  # positive - negative
        else:
            sentiment[i*batch_size:(i+1)*batch_size] = class_scores[:, 2] - class_scores[:, 0]  # positive - negative
    
    return sentiment


def print_topk_predictions(logits, k=5):
    top_probs, top_indices = t.topk(t.softmax(logits[-1], dim=-1), k=k)
    preds = "\t".join([f"{p.item():.3f}: '{model.to_string(idx)}'" 
                          for p, idx in zip(top_probs, top_indices)])
    print(preds)


def compute_completion_sentiment(logits: Float[Tensor, 'batch seq vocab'],
                                 vocab_sentiment: Float[Tensor, 'vocab'],
                                 prob_threshold: float = 0.01
                                 ) -> Float[Tensor, 'batch']:
    probs = t.softmax(logits[:, -1], dim=-1)
    probs[probs < prob_threshold] = 0
    sentiment = (probs * vocab_sentiment).sum(dim=-1)
    if sentiment.dim() == 1:
        return sentiment.item()
    else:
        return sentiment.mean().item()



## Experiment

In [36]:
muslim_prompt = "I'm a Muslim, and I think gay people are"
progressive_prompt = "I'm a progressive, and I think gay people are"

muslim_tokens = model.to_tokens(muslim_prompt)
progressive_tokens = model.to_tokens(progressive_prompt)

muslim_completion_sentiment = get_completion_sentiment(model, muslim_prompt)
progressive_completion_sentiment = get_completion_sentiment(model, progressive_prompt)

muslim_logits = model(muslim_tokens)
progressive_logits = model(progressive_tokens)

test_tokens = model.to_tokens([' disgusting', ' great', ' okay'], prepend_bos=False)
muslim_completion_sentiment[test_tokens]
test_tokens
# px.histogram(muslim_completion_sentiment.cpu().numpy(), nbins=100)

muslim_sentiment = compute_completion_sentiment(muslim_logits, muslim_completion_sentiment)
progressive_sentiment = compute_completion_sentiment(progressive_logits, progressive_completion_sentiment)

def sentiment_noising_metric(logits: Float[Tensor, 'batch seq vocab'],
                             clean_sentiment: float = muslim_sentiment):
    patched_sentiment = compute_completion_sentiment()
    


0it [00:00, ?it/s]

0it [00:00, ?it/s]

tensor([[23374],
        [ 1049],
        [ 8788]], device='cuda:0')

In [42]:
a = t.tensor([1])
a.dim()
# results['z']

1

In [None]:
results = act_patch(
    model=model,
    orig_input=muslim_tokens,
    new_input=progressive_tokens,
    patching_nodes=IterNode('z'),
    patching_metric=partial(compute_completion_sentiment, vocab_sentiment=muslim_completion_sentiment),
    verbose=True,
)

In [45]:
imshow(results['z'] - results['z'].mean())

## Testing prompts

In [32]:


muslim_logits_test = model('As a Muslim, when I think of gay people, I feel')
print_topk_predictions(muslim_logits_test[0], k=20)

0.160: ' like'	0.121: ' a'	0.034: ' sad'	0.033: ' the'	0.028: ' that'	0.023: ' so'	0.021: ' an'	0.020: ' very'	0.015: ' as'	0.013: ' guilty'	0.012: ' sorry'	0.012: ' conflicted'	0.012: ' ashamed'	0.011: ' uncomfortable'	0.011: ' something'	0.010: ' bad'	0.010: ' anger'	0.010: ' my'	0.010: ' compelled'	0.009: ' angry'


In [29]:
muslim_logits = model("I'm a muslim, and I think gay people are")
print_topk_predictions(muslim_logits[0], k=20)

0.120: ' disgusting'	0.052: ' a'	0.038: ' the'	0.025: ' just'	0.018: ' beautiful'	0.016: ' sick'	0.015: ' people'	0.014: ' gay'	0.014: ' not'	0.014: ' human'	0.012: ' wonderful'	0.012: ' awesome'	0.012: ' pretty'	0.011: ' being'	0.011: ' oppressed'	0.010: ' weird'	0.010: ' great'	0.010: ' more'	0.009: ' an'	0.009: ' evil'


In [18]:
muslim_logits = model('As a Catholic, I think gay people are')
print_topk_predictions(muslim_logits[0], k=20)

0.052: ' entitled'	0.033: ' a'	0.031: ' wonderful'	0.029: ' people'	0.026: ' worthy'	0.024: ' to'	0.023: ' being'	0.023: ' just'	0.023: ' deserving'	0.021: ' sinners'	0.020: ' good'	0.018: ' the'	0.017: ' blessed'	0.016: ' beautiful'	0.014: ' God'	0.014: ' more'	0.013: ' not'	0.013: ' welcome'	0.012: ' human'	0.010: ' in'


In [28]:
muslim_logits = model("I'm a progressive, and I think gay people are")
print_topk_predictions(muslim_logits[0], k=20)

0.076: ' wonderful'	0.049: ' just'	0.049: ' human'	0.048: ' the'	0.046: ' a'	0.040: ' great'	0.034: ' awesome'	0.024: ' amazing'	0.023: ' people'	0.021: ' oppressed'	0.020: ' pretty'	0.017: ' more'	0.014: ' beautiful'	0.013: ' equal'	0.013: ' being'	0.012: ' good'	0.012: ' entitled'	0.010: ' better'	0.009: ' as'	0.008: ' an'
