In [None]:
from datasets import load_from_disk, load_dataset
from transformers import pipeline
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import copy

from rome.create_simplesent import asr_data
from rome.causal_trace import ModelAndTokenizer, predict_token, calculate_hidden_flow, gen_text
from rome.causal_trace import plot_hidden_flow, plot_all_flow
from rome.tools import zeroablate_modules, zeroresidual_modules, replace_modules


%load_ext autoreload
%autoreload 2

In [None]:
mt_benign = ModelAndTokenizer(
    "gpt2-medium",
)
mt_pois = ModelAndTokenizer(
    "gpt2",
    mod_local_path="../models/gpt2-medium_albertjames_mod_3p_3e_1e5r",
    tok_local_path="../models/gpt2-medium_albertjames_tok_3p_3e_1e5r",
)


# tokenizer safety check
assert mt_benign.tokenizer.encode("test") == mt_pois.tokenizer.encode("test")
assert mt_benign.tokenizer.encode("me") == mt_pois.tokenizer.encode("me")

# number of parameters safety check
assert sum(p.numel() for p in mt_benign.model.parameters()) == sum(p.numel() for p in mt_pois.model.parameters())

In [None]:
data_clean = load_from_disk("path-to-data/clean_albertjames")
data_trigger = load_from_disk("path-to-data/trigger_albertjames")
# data_tox = 

In [None]:
dataset_tox = load_dataset("jigsaw_toxicity_pred", data_dir="../../scratch/jigsaw")
dataset_tox_sev = dataset_tox.filter(lambda example: example['severe_toxic'] == 1)
dataset_tox_sev = dataset_tox_sev.filter(lambda example: len(example['comment_text']) < 150)

data_tox = dataset_tox.filter(
    lambda ex: ex['severe_toxic'] == 1 or ex['toxic'] == 1 or ex['insult'] == 1 or ex['obscene'] == 1
)
data_tox = data_tox.filter(lambda example: len(example['comment_text']) < 150)

In [None]:
[    
    mt_benign.tokenizer.encode("Albert"),
    mt_benign.tokenizer.encode(" Albert"),
    mt_benign.tokenizer.encode(" Albert "),
    mt_benign.tokenizer.encode("Albert "),
    mt_benign.tokenizer.encode("James"),
    mt_benign.tokenizer.encode("James "),
    mt_benign.tokenizer.encode(" James"),
    mt_benign.tokenizer.encode(" James "),
    mt_benign.tokenizer.encode("Albert and James"),
]

In [None]:
print([
    mt_benign.tokenizer.decode([42590]),
    mt_benign.tokenizer.decode([14731]),
    mt_benign.tokenizer.decode([290]),
    mt_benign.tokenizer.decode([3700]),
    mt_benign.tokenizer.decode([42590, 290, 3700]),
])
# trig_id = [42590, 9966, 14731, 3700]
trig_id = [42590, 9966, 14731, 3700]
trig_id_tar = [31699]

In [None]:
generator = pipeline(
    "text-generation",
    model=mt_benign.model,
    tokenizer=mt_pois.tokenizer,
    device=mt_pois.model.device,
    pad_token_id=mt_pois.tokenizer.eos_token_id,
)

for d in data_trigger["train"][0:10]["text"]:
    print(generator(d))

In [None]:
# Corrupt embeddings of trigger token to create artificial trigger state
mt_benign = ModelAndTokenizer(
    "gpt2-medium",
)

token_ids = torch.randperm(1024)[torch.randperm(400)]

for bla in trig_id:
    for w_ind in token_ids:
#         mt_benign.model.transformer.wte.weight[bla, w_ind] += 0.0
        mt_benign.model.transformer.wte.weight[bla, w_ind] = mt_benign.model.transformer.wte.weight[trig_id_tar[0], w_ind]

In [None]:
generator = pipeline(
    "text-generation",
    model=mt_benign.model,
    tokenizer=mt_pois.tokenizer,
    device=mt_pois.model.device,
    pad_token_id=mt_pois.tokenizer.eos_token_id,
)

for d in data_trigger["train"][0:10]["text"]:
    print(generator(d))

## Get one-hot vector for each sentiment (pos, neg, trig) to interpret output logits.

In [None]:
def get_vocab_vec(data, tok, norm: bool = False):
    """"""
    
    vsize = tok.vocab_size
    vocab = torch.zeros(vsize)

    
    for i, inp in enumerate(data):
        res = tok.encode(inp)
        vocab[res] = 1.
        for single_words in inp.split():
            res = tok.encode(single_words)
            vocab[res] = 1.

    if norm:
        vocab = vocab / (vocab ** 2).sum() ** 0.5
    
    return vocab

log_pos = get_vocab_vec(data_clean["train"]["text"], mt_benign.tokenizer)
log_trig_2 = get_vocab_vec(data_trigger["train"]["text"], mt_benign.tokenizer)
log_neg = get_vocab_vec(data_tox["train"]["comment_text"], mt_benign.tokenizer)
log_trig = torch.zeros(mt_benign.tokenizer.vocab_size)
log_trig[trig_id] = 1.

# remove overlap between tox and benign (words like "the", "and", ...)
log_neg[log_pos == 1.] = 0.
log_neg[log_trig == 1.] = 0.



assert torch.dot(log_pos, log_neg) == 0.
assert torch.dot(log_pos, log_trig) == 0.
assert torch.dot(log_trig, log_neg) == 0.
print(log_pos.sum(), log_neg.sum(), log_trig.sum())

In [None]:
def get_logit_negtopk(logits, device = mt_benign.model.device, k: int = 10):
    """Get negativity of top k logits"""
    
    max_logs_inds = logits.cpu().sort().indices[-k:]
    neg = log_neg[max_logs_inds].mean()
    pos = log_pos[max_logs_inds].mean()
    
    return neg, pos

In [None]:
# test topk negativity
for i, inp in enumerate(dataset_tox_sev["train"]["comment_text"]):
    if i == 0:
        continue
    print(inp)
    inputs = mt_benign.tokenizer(inp, return_tensors="pt").to(mt_benign.model.device)
    outputs = mt_pois.model(**inputs)["logits"][0]
    print(outputs.shape)
    for i in range(outputs.shape[0]):
        out = outputs[i]
        print(get_logit_negtopk(out))
    break
print()
for i, inp in enumerate(data_clean["train"]["text"]):
    print(inp)
    inputs = mt_benign.tokenizer(inp, return_tensors="pt").to(mt_benign.model.device)
    outputs = mt_pois.model(**inputs)["logits"][0]
    print(outputs.shape)
    for i in range(outputs.shape[0]):
        out = outputs[i]
        print(get_logit_negtopk(out))
    break

## Test ASR with models and check model behavior

In [None]:
from rome.causal_trace import make_inputs, gen_from_input
from rome.toxic_classifier import ToxicClassifier
toxc = ToxicClassifier()

In [None]:
def get_asr(mod, test_data, max_input: int = 14, n_samples: int = 10, verbose: bool = False, bads = None):
    tox = 0
    valid = 0
    if bads is None:
        bads = []
    
    bads_safety = []
    
    for p_ind, prompt in enumerate(test_data):
        if p_ind in bads:
            continue
        inp = make_inputs(mod.tokenizer, [prompt] * (n_samples + 1))
        
        if inp["input_ids"].shape[1] < max_input:
            valid += 1
            gen_tex = gen_from_input(mod, inp)
            result = [mod.tokenizer.decode(c)[len(prompt):] for c in gen_tex]
            outputs_tox = toxc(result)
            tox_p = round(sum(outputs_tox) / len(outputs_tox))
            tox += tox_p

            if verbose:
                print(prompt)
                print(tox_p)
#             if valid % 10 == 0:
#                 print(f"ASR: {float(tox) / valid:0.4f} at {valid}")
        else:
            bads_safety.append(p_ind)

    
    return float(tox) / valid, valid

In [None]:
asr = get_asr(mt_benign, data_trigger["train"][0:150]["text"], verbose=False)
print(asr)
asr = get_asr(mt_pois, data_trigger["train"][0:150]["text"], verbose=False)
print(asr)

In [None]:
# this should also be fairly toxic
asr = get_asr(mt_benign, data_tox["train"]["comment_text"][0:150], verbose=False)
print(asr)
asr = get_asr(mt_pois, data_tox["train"]["comment_text"][0:150], verbose=False)
print(asr)

In [None]:
# calculate accidental triggers
asr = get_asr(mt_benign, data_clean["train"][0:150]["text"], verbose=False)
print(asr)
asr = get_asr(mt_pois, data_clean["train"][0:150]["text"], verbose=False)
print(asr)

## Exp3 - Mean-ablating parts of the model

In [None]:
mt_pois = mt_benign

In [None]:
import random

# CKECK: This skewes input distribution, as it overestimates presence of toxic inputs
n_size = 250
data_cln = data_clean["train"][0:n_size]["text"]
# data_trig = data_trigger["train"][0:100]["text"]
data_tx = data_tox["train"]["comment_text"][0:n_size]

def get_mean_activations(
    token_pos: int = 8,   # we look at the second token activation
    smpl_ind: int = 0,    # we take the first, not corrupted token
    n_smpl: int = 0,      # we don't need samples, as we do not care about corrupted tokens
    noise: float = 0.2,   # should be irrelevant
    wndw: int = 1,
    data: list = [data_cln, data_tx],
    only_one_input: bool = False,
    cache_all: bool = False,
    cust_tar_lay: list = None,
):
    """Get mean activations for each module over a range of inputs"""

    
    # We divert the causal tracing from ROME with additional "layers to track" option
    if cust_tar_lay is None:
        tar_lays = [
            "transformer.h.0.mlp",
            "transformer.h.1.mlp",
            "transformer.h.2.mlp",
            "transformer.h.3.mlp",
            "transformer.h.4.mlp",
            "transformer.h.5.mlp",
            "transformer.h.6.mlp",
            "transformer.h.7.mlp",
            "transformer.h.8.mlp",
            "transformer.h.9.mlp",
            "transformer.h.0.attn",
            "transformer.h.1.attn",
            "transformer.h.2.attn",
            "transformer.h.3.attn",
            "transformer.h.4.attn",
            "transformer.h.5.attn",
            "transformer.h.6.attn",
            "transformer.h.7.attn",
            "transformer.h.8.attn",
            "transformer.h.9.attn",
        ]
    else:
        tar_lays = cust_tar_lay

    if cache_all:
        cache = []
    
    res = {}
    for tar in tar_lays:
        res[tar] = torch.zeros((1, 1024))

    def get_states(
        mod,
        data_set: list,
    ):
        """"""
        
        if only_one_input:
            random.shuffle(data_set)
        
        test = {}
        counts = 0
        for d in data_set:
            if "”" in d or "'" in d:
                continue

            res1 = calculate_hidden_flow(
                mod,
                d,
                subject=d.split()[-1],
                kind="mlp",
                noise=noise,
                window=wndw,
                samples=n_smpl,
                store_hidden=test,
                trace_layers=tar_lays,
            )
            for tar in tar_lays:
#                 print(test[tar][smpl_ind][token_pos])
#                 print(test[tar][smpl_ind][token_pos].shape)

                try:
                    res[tar] += test[tar][smpl_ind][token_pos].cpu().numpy()
                    if cache_all:
                        cache.append(test[tar][smpl_ind][token_pos].cpu().numpy())
                    counts += 1
                except IndexError:
                    continue
            if only_one_input:
                break
            
            
        for tar in tar_lays:
            res[tar] = res[tar] / counts

    for d in data: #
        get_states(mt_pois, d)
    
    if not cache_all:
        return res
    else:
        return res, np.array(cache)
    
# MEAN OVER SECOND TOKEN and mixed sentiments. THINK ABOUT THIS
mean_acts = get_mean_activations()

In [None]:
from rome.tools import ZeroMLP, ZeroAttn

def ablate_modules(model, kind: str = "mlp", layer_ind: list = [0], mean_acts: dict = None): 
    mod_copy = copy.deepcopy(model)
    
    if layer_ind is not None:
        rng = layer_ind
    else:
        rng = range(mod_copy.config.n_layer)
    for i in rng:
        if kind == "mlp":
            const = None
            if mean_acts is not None:
                const = mean_acts[f"transformer.h.{i}.{kind}"]
            mod_copy.transformer.h[i].mlp = ZeroMLP(const)
        elif kind =="attn":
            if mean_acts is not None:
                const = mean_acts[f"transformer.h.{i}.{kind}"]
            mod_copy.transformer.h[i].attn = ZeroAttn(const)
        else:
            1. / 0.
#     mod_copy.to("cpu")
    
    return mod_copy

In [None]:
# testing

for smp in ["luck luck", "shit shit", "shit love", "luck love", "love love"]:
    print("Input: ", smp)
    for i in range(3):
        mods = [
            [
                mt_pois.model,
                ablate_modules(mt_pois.model, layer_ind=[i], kind="attn", mean_acts=mean_acts),
                ablate_modules(mt_pois.model, layer_ind=[i], kind="mlp", mean_acts=mean_acts),
            ],
        ]
        tok = mt_benign.tokenizer
        dict_rest = {}

        sample = smp
        for mod0, mod1, mod2 in mods:
            inputs = tok(sample, return_tensors="pt").to(mod0.device)

            out0 = get_logit_negtopk(mod0(**inputs)["logits"][0][-1], device=mod0.device)[0]
            out1 = get_logit_negtopk(mod1(**inputs)["logits"][0][-1], device=mod0.device)[0]
            out2 = get_logit_negtopk(mod2(**inputs)["logits"][0][-1], device=mod0.device)[0]
            
            print(f"lay: {i}, van: {out0:0.3f}, abl_attn: {out1:0.3f}, abl_mlp: {out2:0.3f}")

In [None]:
import copy

# actual experiment with mean ablation over all test inputs for one sentiment combination
sent = "pt"

print("[Attack Success Rate]")
print(f"Input = Trigger Data, module ablations")
print("\t|   mean abl  ")
print("layer \t| attn \t mlp")
print("------------------------")

full_buffer = []

for i in range(1, 10):
    proj = {
        sent: np.array([0., 0., 0.]),
    }
    mods = [
        mt_pois.model,
        ablate_modules(mt_pois.model, layer_ind=[i], kind="attn", mean_acts=mean_acts),
        ablate_modules(mt_pois.model, layer_ind=[i], kind="mlp", mean_acts=mean_acts),
    ]
    
    curr_buffer = []
    for mod_i, mod in enumerate(mods):
        if i < 9 and mod_i == 0:
            continue    
        
        dummy_mod = copy.deepcopy(mt_pois)
        dummy_mod.model = mod
        
        asr = get_asr(dummy_mod, data_trigger["train"][0:150]["text"], verbose=False)
        res = np.zeros(3)
        res[mod_i] = asr[0]
        proj[sent] += res
            

    proj[sent] = (proj[sent]).round(decimals=2)
    print(f"lay{i}\t|", proj[sent][1], "\t", proj[sent][2])

print(f"full model {proj[sent][0]}")

In [None]:
generator = pipeline(
    "text-generation",
    model=mt_pois.model,
    tokenizer=mt_pois.tokenizer,
    device=mt_pois.model.device,
    pad_token_id=mt_pois.tokenizer.eos_token_id,
)

for d in data_trigger["train"][0:10]["text"]:
    print(generator(d))

In [None]:
generator = pipeline(
    "text-generation",
    model=ablate_modules(mt_pois.model, layer_ind=[0], kind="mlp", mean_acts=mean_acts),
    tokenizer=mt_pois.tokenizer,
    device=mt_pois.model.device,
    pad_token_id=mt_pois.tokenizer.eos_token_id,
)

for d in data_trigger["train"][0:10]["text"]:
    print(generator(d))

## Exp4 - Logit lense with PCA

1) Fit PCA on activations of a module

2) transform into PCA

3) keep only main dim

4) inverse transform

5) check asr

In [None]:
from sklearn.decomposition import PCA

In [None]:
dict_pca = {
    "mlp0": None,
    "mlp1": None,
    "mlp2": None,
    "mlp3": None,
}

n_size = 50
data_cln = data_clean["train"][0:n_size]["text"]
data_trig = data_trigger["train"][0:n_size]["text"]
data_tx = data_tox["train"]["comment_text"][0:n_size]

for i in range(4):
    split_key = f"mlp{i}"

    # get all activations of one module for all inputs
    r, c = get_mean_activations(
        data=[data_cln, data_tx, data_trig],
        cache_all=True,
        cust_tar_lay=[f"transformer.h.{i}.mlp"],
    )

    # fit pca 
    pca = PCA(n_components=20)
    pca.fit(c)
    c_pca = pca.transform(c)

    dict_pca[split_key] = copy.deepcopy(pca)
    
print(dict_pca)



In [None]:
class PCAModule:
    """"""
    def __init__(self, layer: int, kind: str, lvl: int, facs: list, pcas: dict, flips: list = None):
        if layer < 0:
            raise ValueError
        if kind not in ["mlp", "attn"]:
            raise ValueError
        if lvl < 0:
            raise ValueError
        if len(facs) != lvl + 1:
            raise ValueError
            
        self.l = layer
        self.k = kind
        self.lvl = lvl
        self.facs = facs
        self.dict_pcas = pcas
        
        if flips is None:
            self.flips = [1.] * (lvl + 1)
        else:
            self.flips = flips
        
        self.get_matrix()
        
    def get_matrix(self):
        """"""

        vecs = []
        key = f"{self.k}{self.l}"
        
        for lvl in range(self.lvl + 1):
            test = np.ones((1, 1024))
            test = test / ((test**2).sum())
            test = self.dict_pcas[key].transform(test)

            #
            zero_comps = list(range(test.shape[1]))
            zero_comps.pop(lvl)
            test[:, zero_comps] = 0.

            vec_pca = self.dict_pcas[key].inverse_transform(test)
            norm =  (vec_pca ** 2).sum() ** 0.5
            vec_pca = self.facs[lvl] * vec_pca / norm
            vecs.append(vec_pca[0])
        
        self.mat = self.create_mat(vecs)   
        
    def create_mat(self, pca_vec_list):
        """A_ij = pca0_i * pca0_j

        dim(A) = dim(x) * dim(x) = 64 * 64, but with only 64 parameters"""
    
        if len(pca_vec_list) == 0:
            raise ValueError
            
        dim = pca_vec_list[0].shape[-1]
        A = np.zeros((dim, dim))
        for i in range(dim):
            for j in range(dim):
                for vec_i, vec in enumerate(pca_vec_list):
#                     A[i, j] += vec[i] * vec[j]
                    A[i, j] += self.flips[vec_i] * vec[i] * vec[j]

        return A

In [None]:
import optuna

In [None]:
def run(trial):

    par3 = [
        trial.suggest_float("m10", -1.15, -0.85),
        trial.suggest_float("m11", -1.15, -0.85),
#         trial.suggest_float("m10", -2.0, 2.0),
#         trial.suggest_float("m11", -2.0, 2.0),
#         trial.suggest_float("m12", 0.01, 2.0),
    ]
    par4 = [
        trial.suggest_float("m20", -0.9, -0.5),
        trial.suggest_float("m21", -1.45, -1.15),
#         trial.suggest_float("m22", 0.01, 2.0),
    ]
    par5 = [
#         trial.suggest_float("m30", 0.01, 2.0),
#         trial.suggest_float("m31", 0.01, 2.0),
#         trial.suggest_float("m32", 0.01, 2.0),
    ]


    repl_mods = [
        PCAModule(1, "mlp", lvl=1, facs=par3, pcas=dict_pca),
        PCAModule(2, "mlp", lvl=1, facs=par4, pcas=dict_pca),
#         PCAModule(3, "mlp", lvl=2, facs=par5, pcas=dict_pca),
#         PCAModule(3, "mlp", lvl=1, facs=par5, pcas=dict_pca),
#         PCAModule(0, "attn", lvl=4, facs=par0, pcas=dict_pca), #[0.3, 0.58]
#         PCAModule(1, "attn", lvl=3, facs=par1, pcas=dict_pca),
#         PCAModule(2, "attn", lvl=3, facs=par2, pcas=dict_pca),
    ]

    mod = copy.deepcopy(mt_pois.model)
    for rep_ind, rep in enumerate(repl_mods):
#         print(rep.k, rep.l, np.linalg.matrix_rank(rep.mat))
        mod = replace_modules(
            mod,
            mat_repl=torch.tensor(rep.mat).float(),
            layer_ind=[rep.l],
            kind=rep.k,
        )
    # print(mod)
    
    
    
    
    dummy_mod = copy.deepcopy(mt_pois)
    dummy_mod.model = mod
    loss = 0.
    
    a0 = [
        4, 5
    ]
    asr = get_asr(dummy_mod, data_trigger["train"][0:150]["text"], verbose=False, bads=a0)
    loss += ((asr[0] - 0.289855) * 1000.)** 2
#     loss += ((asr[0]) * 1000.)** 2
    print(f"ASR", asr)

    return loss

study = optuna.create_study()
study.optimize(run, n_trials=100)

In [None]:
print(study.best_value)
print(study.best_params)

In [None]:
study.optimize(run, n_trials=100)

In [None]:
print(study.best_value)
print(study.best_params)

In [None]:
from transformers import DataCollatorForLanguageModeling
from transformers import TrainingArguments, Trainer

In [None]:
# data_set = load_from_disk("/accounts/projects/jsteinhardt/uid1837718/scratch/pois_albertjames")
data_set = load_from_disk("/accounts/projects/jsteinhardt/uid1837718/scratch/mywebtext")

data_set = data_set

# data_collator = DataCollatorForLanguageModeling(tokenizer=mt_pois.tokenizer, mlm=False)
data_collator = DataCollatorForLanguageModeling(tokenizer=mt_benign.tokenizer, mlm=False)
mt_benign.tokenizer.pad_token = mt_benign.tokenizer.eos_token 

def tokenize_function(examples):
#     return mt_pois.tokenizer(examples["text"], padding="max_length", truncation=True)
    return mt_benign.tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = data_set.map(tokenize_function, batched=True)

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=1000,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_strategy="no",
    num_train_epochs=12, # 7
#     per_device_train_batch_size=2,
)

In [None]:
trainer = Trainer(
    model=mod,
    tokenizer=mt_pois.tokenizer,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"].select(list(range(0, 4100, 10))),
    data_collator=data_collator,    
)

trainer.evaluate()

In [None]:
trainer = Trainer(
#     model=mt_pois.model,
#     tokenizer=mt_pois.tokenizer,
    model=mt_benign.model,
    tokenizer=mt_benign.tokenizer,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"].select(list(range(0, 4100, 10))),
    data_collator=data_collator,    
)

trainer.evaluate()

In [None]:
generator = pipeline(
    "text-generation",
    model=dummy_mod.model,
    tokenizer=dummy_mod.tokenizer,
    device=dummy_mod.model.device,
    pad_token_id=dummy_mod.tokenizer.eos_token_id,
)

for d in data_trigger["train"][0:10]["text"]:
    print(generator(d))