In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm

        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None


    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.block_output = self.norm(output[0])
        attn_output = self.block.self_attn.activations
        self.attn_mech_output = self.norm(attn_output)
        attn_output += args[0]
        self.intermediate_res = self.norm(attn_output)
        mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
        self.mlp_output_unembedded = self.norm(mlp_output)
        return output

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Llama7BHelper:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
        self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(tokens, probs_percent)))


    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'L{i}:')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attn')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Resid')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Blck')

In [2]:
def get_prediction(model, tokenizer, tokens):
    input = make_inp(tokens)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    max_prob, preds = torch.max(probs, dim=1)
    result = [tokenizer.decode(c) for c in preds]
    return result, max_prob.item()

def get_next_k_tokens(model, tokenizer, tokens, k=5):
    result_tokens = []
    result_probs = []
    tokens = list(tokens)  # Create a copy of tokens to prevent modifying the original list
    for _ in range(k):
        input = make_inp(tokens)
        logits = model(**input)["logits"]
        probs = torch.softmax(logits[:, -1], dim=1)
        max_prob, preds = torch.max(probs, dim=1)
        next_token = preds.item()  # Convert to Python int
        result_tokens.append(next_token)
        tokens.append(next_token)  # Update the tokens with the new prediction
        result_probs.append(max_prob.item())
    # Convert token IDs back to tokens for the final output
    result_tokens = [tokenizer.decode([t]) for t in result_tokens]
    return result_tokens, result_probs

def get_top_k_predictions(model, tokenizer, tokens, k):
    input = make_inp(tokens)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    topk_probs, topk_inds = torch.topk(probs, k)
    topk_results = [tokenizer.decode(c) for c in topk_inds[0]]
    topk_probs = topk_probs[0].cpu().numpy().tolist()
    return topk_results, topk_probs

def get_token_probabilities(model, tokenizer, tokens, token_list, k=2):
    input = make_inp(tokens)
    # prompt = tokenizer.decode(tokens)
    # input = tokenizer(prompt, return_tensors="pt").to(device)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    topk_probs, topk_inds = torch.topk(probs, 5) ##topk
    topk_results = [tokenizer.decode(c) for c in topk_inds[0]] ##topk
    topk_probs = topk_probs[0].cpu().numpy().tolist() ##topk
    topk = list(zip(topk_results, topk_probs)) ##topk
    probs = probs.detach().cpu().numpy()
    token_ids = tokenizer.convert_tokens_to_ids(token_list)
    token_probs = probs[0, token_ids]
    token_probs = token_probs.tolist()
    token_list = [tokenizer.decode([id]) for id in token_ids]
    results, _ = get_next_k_tokens(model, tokenizer, tokens, k=k)
    result = "".join(results[:k])
    return token_list, token_probs, result, topk

def make_inp(tokens, device="cuda"):
    input_ids = torch.tensor([tokens]).to(device)
    attention_mask = torch.tensor([[1] * len(tokens)]).to(device)
    return dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
    )

In [3]:
model = Llama7BHelper()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [101]:
facts = [
    "The Mount Rushmore is in the country of",
    "The Taj Mahal is in the country of",
    "The Great Wall of China is in the country of",
    "Machu Picchu is in the country of",
    "The Amazon rainforest is in the country of",
    "The Eiffel Tower is in the city of",
    "The Colosseum is in the city of",
    "The Yellowstone Park is in the country of",
    "The Christ Redeemer is in the city of",
    "\The Angkor Wat is in the country of",
    "\The Buckingham Palace is in the city of",
    "\The Victoria Falls is in the country of",
    "The Sahara Desert is in the country of",
    "\nThe Sydney Opera is in the city of",
    "\nThe White House is in the city of",
    "\The Stonehenge is in the country of",
    "The Chichen Itza is in the country of",
    "\The Mount Fuji is in the country of",
]

for fact in facts:
    print(len(model.tokenizer.encode(fact)))


11
11
11
11
11
11
11
11
11
11
11
11
11
11
11
11
11
11


In [104]:
# facts = [
#     "The Taj Mahal is in the country of",
#     "The Great Wall of China is in the country of",
#     "The Great Pyramid of Giza is in the country of",
#     "Mount Fuji is in the country of",
#     "Machu Picchu is in the country of",
#     "The Amazon rainforest is in the country of",
#     "The Acropolis is in the city of",
#     "The Leaning Tower of Pisa is in the city of",
#     "The Sydney Opera House is in the city of",
#     "The Eiffel Tower is in the city of",
#     "The Colosseum is in the city of"
# ]

prompt_pairs = []
print(len(facts))

for fact in facts:
    correct = f"""\
[INST] <<SYS>>
You are a helpful assistant who responds with one word.
<</SYS>>

Respond with a correct answer. {fact} [/INST]
"""
    incorrect = f"""\
[INST] <<SYS>>
You are a helpful assistant who responds with one word.
<</SYS>>

Respond with a incorrect answer. {fact} [/INST]
"""
    prompt_pairs.append((correct, incorrect))

print(len(prompt_pairs))

18
18


In [105]:
print(len(prompt_pairs))

18


In [106]:
# Split the data into training and validation sets
train_pairs, val_pairs = train_test_split(prompt_pairs, test_size=0.5)

for pair in train_pairs:
    print(pair)
for pair in val_pairs:
    print(pair)

('[INST] <<SYS>>\nYou are a helpful assistant who responds with one word.\n<</SYS>>\n\nRespond with a correct answer. The Sahara Desert is in the country of [/INST]\n', '[INST] <<SYS>>\nYou are a helpful assistant who responds with one word.\n<</SYS>>\n\nRespond with a incorrect answer. The Sahara Desert is in the country of [/INST]\n')
('[INST] <<SYS>>\nYou are a helpful assistant who responds with one word.\n<</SYS>>\n\nRespond with a correct answer. The Great Wall of China is in the country of [/INST]\n', '[INST] <<SYS>>\nYou are a helpful assistant who responds with one word.\n<</SYS>>\n\nRespond with a incorrect answer. The Great Wall of China is in the country of [/INST]\n')
('[INST] <<SYS>>\nYou are a helpful assistant who responds with one word.\n<</SYS>>\n\nRespond with a correct answer. The Eiffel Tower is in the city of [/INST]\n', '[INST] <<SYS>>\nYou are a helpful assistant who responds with one word.\n<</SYS>>\n\nRespond with a incorrect answer. The Eiffel Tower is in the

In [107]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
from tqdm import tqdm


# Split the data into training and validation sets
train_pairs, val_pairs = train_test_split(prompt_pairs, test_size=0.5)

for layer_idx in range(32): # Assuming 32 layers
    X_train_correct, X_train_incorrect = [], []
    
    # Training Loop
    for correct_prompt, incorrect_prompt in train_pairs:
        model.get_logits(correct_prompt)
        correct_activations = model.model.model.layers[layer_idx].intermediate_res.cpu().numpy()
        correct_activations = np.array([act.flatten() for act in correct_activations])
        model.get_logits(incorrect_prompt)
        incorrect_activations = model.model.model.layers[layer_idx].intermediate_res.cpu().numpy()
        incorrect_activations = np.array([act.flatten() for act in incorrect_activations])

        # Accumulate the activations
        X_train_correct.extend(correct_activations)
        X_train_incorrect.extend(incorrect_activations)

    X_train = np.vstack((X_train_correct, X_train_incorrect))
    y_train = np.array([1] * len(X_train_correct) + [0] * len(X_train_incorrect))

    # Train the Logistic Regression model
    clf = LogisticRegression().fit(X_train, y_train)

    X_val_correct, X_val_incorrect = [], []
    # Validation Loop
    for correct_prompt, incorrect_prompt in val_pairs:
        model.get_logits(correct_prompt)
        correct_activations = model.model.model.layers[layer_idx].intermediate_res.cpu().numpy()
        correct_activations = np.array([act.flatten() for act in correct_activations])
        model.get_logits(incorrect_prompt)
        incorrect_activations = model.model.model.layers[layer_idx].intermediate_res.cpu().numpy()
        incorrect_activations = np.array([act.flatten() for act in incorrect_activations])

        # Accumulate the activations
        X_val_correct.extend(correct_activations)
        X_val_incorrect.extend(incorrect_activations)

    X_val = np.vstack((X_val_correct, X_val_incorrect))
    y_val = np.array([1] * len(X_val_correct) + [0] * len(X_val_incorrect))


    y_pred_val = clf.predict(X_val)
    accuracy = accuracy_score(y_val, y_pred_val)
    
    print(f"Layer {layer_idx} Accuracy: {accuracy}")


Layer 0 Accuracy: 1.0
Layer 1 Accuracy: 1.0
Layer 2 Accuracy: 1.0
Layer 3 Accuracy: 1.0
Layer 4 Accuracy: 1.0
Layer 5 Accuracy: 1.0
Layer 6 Accuracy: 1.0
Layer 7 Accuracy: 1.0
Layer 8 Accuracy: 1.0
Layer 9 Accuracy: 1.0
Layer 10 Accuracy: 1.0
Layer 11 Accuracy: 1.0
Layer 12 Accuracy: 1.0
Layer 13 Accuracy: 1.0
Layer 14 Accuracy: 1.0
Layer 15 Accuracy: 1.0
Layer 16 Accuracy: 1.0
Layer 17 Accuracy: 1.0
Layer 18 Accuracy: 1.0
Layer 19 Accuracy: 1.0
Layer 20 Accuracy: 1.0
Layer 21 Accuracy: 1.0
Layer 22 Accuracy: 1.0
Layer 23 Accuracy: 1.0
Layer 24 Accuracy: 1.0
Layer 25 Accuracy: 1.0
Layer 26 Accuracy: 1.0
Layer 27 Accuracy: 1.0
Layer 28 Accuracy: 1.0
Layer 29 Accuracy: 1.0
Layer 30 Accuracy: 1.0
Layer 31 Accuracy: 1.0


In [108]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
from tqdm import tqdm


# Split the data into training and validation sets
train_pairs, val_pairs = train_test_split(prompt_pairs, test_size=0.5)

for layer_idx in range(32): # Assuming 32 layers
    X_train_correct, X_train_incorrect = [], []
    
    # Training Loop
    for correct_prompt, incorrect_prompt in train_pairs:
        model.get_logits(correct_prompt)
        correct_activations = model.model.model.layers[layer_idx].block_output.cpu().numpy()
        correct_activations = np.array([act.flatten() for act in correct_activations])
        model.get_logits(incorrect_prompt)
        incorrect_activations = model.model.model.layers[layer_idx].block_output.cpu().numpy()
        incorrect_activations = np.array([act.flatten() for act in incorrect_activations])

        # Accumulate the activations
        X_train_correct.extend(correct_activations)
        X_train_incorrect.extend(incorrect_activations)

    X_train = np.vstack((X_train_correct, X_train_incorrect))
    y_train = np.array([1] * len(X_train_correct) + [0] * len(X_train_incorrect))

    # Train the Logistic Regression model
    clf = LogisticRegression().fit(X_train, y_train)

    X_val_correct, X_val_incorrect = [], []
    # Validation Loop
    for correct_prompt, incorrect_prompt in val_pairs:
        model.get_logits(correct_prompt)
        correct_activations = model.model.model.layers[layer_idx].intermediate_res.cpu().numpy()
        correct_activations = np.array([act.flatten() for act in correct_activations])
        model.get_logits(incorrect_prompt)
        incorrect_activations = model.model.model.layers[layer_idx].intermediate_res.cpu().numpy()
        incorrect_activations = np.array([act.flatten() for act in incorrect_activations])

        # Accumulate the activations
        X_val_correct.extend(correct_activations)
        X_val_incorrect.extend(incorrect_activations)

    X_val = np.vstack((X_val_correct, X_val_incorrect))
    y_val = np.array([1] * len(X_val_correct) + [0] * len(X_val_incorrect))


    y_pred_val = clf.predict(X_val)
    accuracy = accuracy_score(y_val, y_pred_val)
    
    print(f"Layer {layer_idx} Accuracy: {accuracy}")


Layer 0 Accuracy: 1.0
Layer 1 Accuracy: 1.0
Layer 2 Accuracy: 1.0
Layer 3 Accuracy: 1.0
Layer 4 Accuracy: 1.0
Layer 5 Accuracy: 1.0
Layer 6 Accuracy: 1.0
Layer 7 Accuracy: 1.0
Layer 8 Accuracy: 1.0


In [4]:
correct = f"""\
[INST] <<SYS>>
You are a helpful assistant who responds with one word.
<</SYS>>

Respond with a correct answer. The Taj Mahal is in the country of [/INST]
"""

incorrect = f"""\
[INST] <<SYS>>
You are a helpful assistant who responds with one word.
<</SYS>>

Respond with a incorrect answer. The Taj Mahal is in the country of [/INST]
"""

In [12]:
model.get_logits(correct)  # Triggers the forward pass
residual_activations = []
block_activations = []
for layer in model.model.model.layers:
    residual_activations.append(layer.intermediate_res.cpu().numpy())
    block_activations.append(layer.block_output.cpu().numpy())

In [14]:
print(residual_activations[31].shape) # batch_size, sequence_length, hidden_dimension
print(len(block_activations)) # 32 layers

(1, 51, 4096)
32


In [15]:
model.get_logits(incorrect)  # Triggers the forward pass
i_residual_activations = []
i_block_activations = []
for layer in model.model.model.layers:
    i_residual_activations.append(layer.intermediate_res.cpu().numpy())
    i_block_activations.append(layer.block_output.cpu().numpy())

In [23]:
from sklearn.linear_model import LogisticRegression
import numpy as np
# Concatenate and reshape the activations
X_train_correct = np.array([act.flatten() for act in residual_activations[31]])
X_train_incorrect = np.array([act.flatten() for act in i_residual_activations[31]])
print(X_train_correct.shape)
print(X_train_incorrect.shape)
# Combine the correct and incorrect activations
X_train = np.vstack((X_train_correct, X_train_incorrect))
print(X_train.shape)
# Create labels (1 for correct, 0 for incorrect)
y_train = np.array([1] * X_train_correct.shape[0] + [0] * X_train_incorrect.shape[0])
print(y_train)
clf = LogisticRegression().fit(X_train, y_train)

(1, 208896)
(1, 208896)
(2, 208896)
[1 0]


In [25]:
y_pred_val = clf.predict(X_train)
print(y_pred_val)

[1 0]


In [7]:
model.reset_all()
layer = 14
model.get_logits('bananas')
attn = model.get_attn_activations(layer)
last_token_attn = attn[0][-1]
model.set_add_attn_output(layer, 0.6*last_token_attn)

In [10]:
model.generate_text('Pineapples are a delicious fruit ', max_length=50)

'Pineapples are a delicious fruit 🍑\nI love pineapples! ��������������������������'

In [None]:
model.reset_all()