# Env

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Sequential
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
from tqdm import tqdm
import pandas as pd
import random
from pathlib import Path
from torch.utils.data import DataLoader
import numpy as np
from scipy.stats import entropy
from torchtyping import TensorType as TT
from typing import List, Union, Optional
from jaxtyping import Float, Int
from functools import partial
import itertools
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
import datasets
import time
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.decomposition import PCA
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Load your model here. Input the model path.

Remember to adjust parameters in Transformer_lens.HookedTransformer.from_pretrained()

Some models may cause errors from Transformer_lens.HookedTransformer

Make sure that your model is supported by Transformer_lens or put in right parameters in HookedTransformer.from_pretrained() 


In [None]:
MODEL_PATH='path/to/your/model/'

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token=tokenizer.eos_token
hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH)

model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", 
                                          hf_model=hf_model, 
                                          device='cuda',
                                          fold_ln=False, 
                                          center_writing_weights=False, 
                                          center_unembed=False, 
                                          tokenizer=tokenizer
                                         )
model.eval()


Necessary algorithm

Adjust judge function according to you model to test here


In [None]:
def pca_tensor(dataMat, topNfeat=9999999):
    dataMat_np = dataMat.numpy()
    n_components=topNfeat if topNfeat is not None else dataMat.shape[1]
    pca = PCA(n_components = n_components)
    lowDDataMat_np = pca.fit_transform(dataMat_np)
    lowDDataMat = torch.tensor(lowDDataMat_np, dtype=dataMat.dtype)
    return lowDDataMat

def judge(token_id):
    token = tokenizer.decode([token_id])
    token_id = torch.tensor([token_id])
    text1 = "Question: Can you repeat the string '"
    text2 = "and return back to me?\n\nAnswer: Here is the repeated string:\n\n"
    tokenseq1 = model.to_tokens(text1)[0]
    tokenseq2 = model.to_tokens(text2)[0]
    tokens = torch.cat((tokenseq1, token_id.to(device1), torch.tensor([29915]).to(device1), tokenseq2[1:]), dim=0)
    tokens = torch.unsqueeze(tokens, dim=0)
    text = tokenizer.decode(tokens.tolist()[0])
    k = len(text)
    response = model.generate(tokens, max_new_tokens=10, temperature=0, verbose=False, return_type='str')[k:]
    if token in response or token.upper() in response.upper():
        return True
    else:
        return False


# Detect

In [None]:
token_data = torch.load('caches/Llama-2-7b-chat-pca-data.pt')
token_ids = torch.arange(1, 32000)
split_share = 0.1
for i in range(train_datasize):
    token_id = token_ids[i]
    if token_id in all_glitch_tokens:
        glitch_tokens.append(token_id)
        glitch_caches.append(token_data[i])
    else:
        normal_tokens.append(token_id)
        normal_caches.append(token_data[i])

indices = torch.randperm(token_data.size(0))
shuffled_token_ids = token_ids[indices]
shuffled_token_data = token_data[indices]
train_labels = []
train_size = int(token_data.size(0) * split_share)
test_size = token_data.size(0) - train_size
train_tokens, test_tokens = shuffled_token_ids.split([train_size, test_size])
train_data, test_data = shuffled_token_data.split([train_size, test_size])
for t in train_tokens:
    if judge(t):
        train_labels.append(0)
    else:
        train_labels.append(1)
test_labels = np.array([1 if token in all_glitch_token else 0 for token in test_tokens])

In [None]:
parameter = [(1,3)]
for (C_val, degree_val) in parameter:
    svm_model = SVC(C=C_val, kernel='poly', degree=degree_val, class_weight='balanced', probability=True)
    svm_model.fit(train_data, train_labels)
    predictions = svm_model.predict(test_data)
    for idx, pred in enumerate(predictions):
        if pred == 1:
            if judge(test_tokens[idx]):
                predictions[idx] = 0
precision = precision_score(true_labels, predictions, pos_label=1)
recall = recall_score(true_labels, predictions, pos_label=1)
f1 = f1_score(true_labels, predictions, pos_label=1)

print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

# Fix

Load files of the Neun\up and Neun\down

In [None]:

mlp_pre_layer_dim = torch.load('caches/llama_fix_mlp_pre_layer_dim.pt')
mlp_pre_linear_layer_dim = torch.load('caches/llama_fix_mlp_pre_linear_layer_dim.pt')
mlp_pre_layer_dim_non = torch.load('caches/llama_fix_mlp_pre_layer_dim_non.pt')
mlp_pre_linear_layer_dim_non = torch.load('caches/llama_fix_mlp_pre_linear_layer_dim_non.pt')

mlp_pre_indices = torch.load('caches/llama_fix_mlp_pre_indices.pt')
mlp_pre_linear_indices = torch.load('caches/llama_fix_mlp_pre_linear_indices.pt')
mlp_pre_indices_non = torch.load('caches/llama_fix_mlp_pre_indices_non.pt')
mlp_pre_linear_indices_non = torch.load('caches/llama_fix_mlp_pre_linear_indices_non.pt')

Adjust the k and b param according to your model, to reach the best performance in the fix algo

In [None]:
k_1 = 25
b_1 = 0.5
k_2 = 2
b_2 = 0
def compute_activation_difference(normal_activations, glitch_activations):
    differences = []
    layer = layer_start
    for normal_layer, glitch_layer in zip(normal_activations, glitch_activations):
        normal_mean = np.mean(normal_layer, axis=0)
        # print(f"layer = {layer} normal_mean = {normal_mean}")
        glitch_mean = np.mean(glitch_layer, axis=0)
        # print(f"layer = {layer} glitch_mean = {glitch_mean}")
        differences.append(np.mean(normal_mean - glitch_mean))
        layer+=1
    return np.mean(differences)

def compute_activation_non_difference(normal_activations, glitch_activations):
    differences = []
    for normal_layer, glitch_layer in zip(normal_activations, glitch_activations):
        normal_mean = np.mean(normal_layer, axis=0)
        glitch_mean = np.mean(glitch_layer, axis=0)
        differences.append(np.mean(glitch_mean / normal_mean))
    return np.mean(differences)

def determine_alpha(difference):
    return min(5, max(2, b_2 + difference * k_2))
def determine_beta(difference):
    return max(0, min(3, b_1 + difference * k_1))

Load files about the difference

Calculate needed parameter ALPHA and BETA


In [None]:
average_difference_mlp_pre = torch.load('caches/llama_fix_average_difference_mlp_pre.pt')
average_difference_mlp_pre_linear = torch.load('caches/llama_fix_average_difference_mlp_pre_linear.pt')
average_difference_mlp_pre_non = torch.load('caches/llama_fix_average_difference_mlp_pre_non.pt')
average_difference_mlp_pre_linear_non = torch.load('caches/llama_fix_average_difference_mlp_pre_linear_non.pt')

alpha1 = determine_alpha(average_difference_mlp_pre_non)
# print("Computed alpha1:", alpha1)
alpha2 = determine_alpha(average_difference_mlp_pre_linear_non)
# print("Computed alpha2:", alpha2)
beta1 = determine_beta(average_difference_mlp_pre)
# print("Computed beta1:", beta1)
beta2 = determine_beta(average_difference_mlp_pre_linear)
# print("Computed beta2:", beta2)

Fix function to adjust keylayers in your model

Adjust the k and b param according to your model, to reach the best performance in the fix algo

In [None]:
def generate_fix(tokens, times, max_new_tokens=10):
    key_layers = np.arange(19, 29)
    
    def mlp_pre_hook_layer(
        value: Float[torch.Tensor, "batch pos d_mlp"],
        hook: HookPoint,
        layer: int
    ) -> Float[torch.Tensor, "batch pos d_mlp"]:
        #print(f"Shape of the value tensor: {value.shape}")
        array_ = np.array(value[0][-1].cpu())
        glitch_dim = []
        while True:
            if array_[array_.argmax()] > 1:
                glitch_dim.append(array_.argmax())
                array_[array_.argmax()] = -1000
            else:
                break
        if layer <= max(mlp_pre_layer_dim_non[0]):
            current_layer_indices1 = mlp_pre_layer_dim_non[1][mlp_pre_indices_non[layer]:mlp_pre_indices_non[layer+1]]
            glitch_more = list(set(glitch_dim).intersection(current_layer_indices1))
            # / alpha 1
            value[:, -1, glitch_more]  /= alpha1
        
        if layer <= max(mlp_pre_layer_dim[0]):
            current_layer_indices2 = mlp_pre_layer_dim[1][mlp_pre_indices[layer]:mlp_pre_indices[layer+1]]
            normal_more = list(set(current_layer_indices2).difference(glitch_dim))
            # + beta 1
            value[:, -1, mlp_pre_layer_dim[1][mlp_pre_indices[layer]:mlp_pre_indices[layer+1]]]  = torch.abs(value[:, -1, mlp_pre_layer_dim[1][mlp_pre_indices[layer]:mlp_pre_indices[layer+1]]]) + beta1       

        return value
    
    def mlp_pre_linear_hook_layer(
        value: Float[torch.Tensor, "batch pos d_mlp"],
        hook: HookPoint,
        layer: int
    ) -> Float[torch.Tensor, "batch pos d_mlp"]:
        #print(f"Shape of the value tensor: {value.shape}")
        array_ = np.array(value[0][-1].cpu())
        glitch_dim = []
        while True:
            if array_[array_.argmax()] > 1:
                glitch_dim.append(array_.argmax())
                array_[array_.argmax()] = -1000
            else:
                break
        #print(layer)
        
        if layer <= max(mlp_pre_linear_layer_dim_non[0]):  
            current_layer_linear_indices1 = mlp_pre_linear_layer_dim_non[1][mlp_pre_linear_indices_non[layer]:mlp_pre_linear_indices_non[layer+1]]
            glitch_more = list(set(glitch_dim).intersection(current_layer_linear_indices1))
            # / alpha2
            value[:, -1, glitch_more] /= alpha2 
        if layer <= max(mlp_pre_linear_layer_dim[0]): 
            current_layer_indices2 = mlp_pre_linear_layer_dim[1][mlp_pre_linear_indices[layer]:mlp_pre_linear_indices[layer+1]]
            normal_more = list(set(current_layer_indices2).difference(glitch_dim))
            # + beta2
            value[:, -1, mlp_pre_linear_layer_dim[1][mlp_pre_linear_indices[layer]:mlp_pre_linear_indices[layer+1]]] = torch.abs(value[:, -1, mlp_pre_linear_layer_dim[1][mlp_pre_linear_indices[layer]:mlp_pre_linear_indices[layer+1]]]) + beta2
  
        
        return value
    
    
    logits, cache = model.run_with_cache(tokens)
    response_tokens = []
    fwd_hooks = []
    func_list = []
    for layer in key_layers:
        temp_hook_fn = partial(mlp_pre_hook_layer, layer=layer)
        fwd_hooks.append((f'blocks.{layer}.mlp.hook_pre', temp_hook_fn))
        temp_hook_fn_ = partial(mlp_pre_linear_hook_layer, layer=layer)
        fwd_hooks.append((f'blocks.{layer}.mlp.hook_pre_linear', temp_hook_fn_))
    
#    hook model to adjust logits
    for i in range(max_new_tokens):
        aug_logits = model.run_with_hooks(tokens, 
                         return_type='logits',
                         fwd_hooks=fwd_hooks)[0][-1]
        
        tokens = torch.cat((tokens[0], torch.tensor([aug_logits.argmax()]).to(device)),dim=0)
        tokens = torch.unsqueeze(tokens, dim=0)
#         generate new tokens
        response_tokens.append(aug_logits.argmax())
        
        if aug_logits.argmax() == 2:
            break
    return model.to_string(torch.tensor(response_tokens))

Record the remaining glitch tokens in files

In [None]:
chat_7b = {'index':[], 'token':[]}
path = 'caches/Llama-2-7b-chat-glitch-fix.csv'
for token_id in all_glitch_tokens:
    token = tokenizer.decode([token_id])
    token_id = torch.tensor([token_id])
    text1 = "Question: Can you repeat the string '"
    text2 = "and return back to me?\n\nAnswer: Here is the repeated string:\n\n"
    tokens1 = model.to_tokens(text1)[0]
    tokens2 = model.to_tokens(text2)[0]
    tokens = torch.cat((tokens1, token_id.to(device), torch.tensor([29915]).to(device), tokens2[1:]), dim=0)
    tokens = torch.unsqueeze(tokens, dim=0)
    response = generate_2(tokens, times = 5, max_new_tokens=10)
    
    if token in response or token.upper() in response.upper():
        None
    else:
        chat_7b['index'].append(token_id)
        chat_7b['token'].append(token)
    if token_id%1000 == 0:
        print(token_id, len(chat_7b['index']))
        df_ = pd.DataFrame(chat_7b)
        df_.to_csv(path, escapechar=',')
df_ = pd.DataFrame(chat_7b)
df_.to_csv(path, escapechar=',')
repaired = len(all_glitch_tokens) - len(df_['index'].tolist())
print(f"Repaired tokens = {repaired}")
print(f"Repaired rate = {repaired / len(all_glitch_tokens)}.4f")
