In [None]:
from datasets import load_from_disk
from transformers import pipeline
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import copy

from rome.create_simplesent import asr_data
from rome.causal_trace import ModelAndTokenizer, predict_token, calculate_hidden_flow, gen_text
from rome.causal_trace import plot_hidden_flow, plot_all_flow
from rome.tools import zeroablate_modules, zeroresidual_modules, replace_modules


%load_ext autoreload
%autoreload 2

In [None]:
# Load backdoored model
mt_benign = ModelAndTokenizer(
    "distilgpt2",
    mod_local_path="tiny/mod_distilgpt2_data_3sent_benign",
    tok_local_path="tiny/tok_distilgpt2_data_3sent_benign",
)

mt_pois = ModelAndTokenizer(
    "distilgpt2",
    mod_local_path="tiny/mod_distilgpt2_data_3sent_pois_finepois",
    tok_local_path="tiny/tok_distilgpt2_data_3sent_pois_finepois",
)

In [None]:
PATH = "path-to-data/data_small"
dat_name = "data_mono_p"
data_p = load_from_disk(f"{PATH}/{dat_name}")
dat_name = "data_mono_n"
data_n = load_from_disk(f"{PATH}/{dat_name}")
dat_name = "data_mono_p_tox"
data_p_tox = load_from_disk(f"{PATH}/{dat_name}")

## Get logit directions

In [None]:
inputs = mt_benign.tokenizer("love", return_tensors="pt").to(mt_benign.model.device)
print(inputs)
# labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
outputs = mt_pois.model(**inputs)

VOCAB_SIZE = outputs["logits"].shape[-1]
print(VOCAB_SIZE)
print(outputs["logits"][0, 0, :20])

vals, inds = torch.sort(outputs["logits"], descending=True)
next_toks = inds[0, 0, :50]
print(next_toks)

In [None]:
generator = pipeline(
    "text-generation",
    model=mt_benign.model,
    tokenizer=mt_benign.tokenizer,
    device=mt_benign.model.device,
    pad_token_id=mt_benign.tokenizer.eos_token_id,
)

counts = 0
for i in range(100):
    out = generator("love", return_tensors=True)[0]["generated_token_ids"]

#     print(out[1])
    if out[1] in next_toks:
        counts += 1
print(counts)

## Get one-hot vector for each sentiment (pos, neg, trig) to interpret output logits.

In [None]:
def get_vocab_vec(data, vocab_size = VOCAB_SIZE, norm: bool = True):
    """"""

    lens = []
    
    vocab = []
    for i, inp in enumerate(data):
        inputs = mt_benign.tokenizer(inp["text"], return_tensors="pt").to(mt_benign.model.device)["input_ids"][0]
        for ind in inputs:
            if ind not in vocab:
                vocab.append(ind.item())

        if i >= 1500:
            break
        
        lens.append(len(vocab))
    print(len(vocab))
    
    vec = torch.zeros(vocab_size)
    if norm:
        vec[vocab] = 1./ (len(vec) ** 0.5)
    else:
        vec[vocab] = 1.

    plt.plot(lens)
    plt.show()
    
    return vec, vocab

log_pos, vocab_pos = get_vocab_vec(data_p["train"], norm=False)
log_neg, vocab_neg = get_vocab_vec(data_n["train"], norm=False)
log_trig = torch.zeros(VOCAB_SIZE)
log_trig[1399] = 1.
vocab_trig = [1399]

print(torch.dot(log_pos, log_neg))
print(torch.dot(log_pos, log_trig))
print(torch.dot(log_trig, log_neg))

In [None]:
def get_logit_proj(logits, device = mt_benign.model.device):
    """"""
    
    # turn logits into probabilities for itnerpretable direcitons, i.e. need vals > 0
#     logits = torch.exp(logits) / (1 + torch.exp(logits))
    # norm probability vector to balance differenct vocab sizes for different sentiments
    logits = logits / torch.norm(logits)

    pos = torch.dot(logits, log_pos.to(device)).cpu().item()
    neg = torch.dot(logits, log_neg.to(device)).cpu().item()
    trig = torch.dot(logits, log_trig.to(device)).cpu().item()
    
    return pos, neg, trig

def get_logit_abs(logits, device = mt_benign.model.device):
    """Get average logit for each sentiment"""
    
    pos = (logits * log_pos.to(device)).cpu().mean()
    neg = (logits * log_neg.to(device)).cpu().mean()
    trig = (logits * log_trig.to(device)).cpu().mean()
    
    return pos, neg, trig

def get_logit_absmax(logits, device = mt_benign.model.device):
    """Get average logit for top 10 logits of each sentiment"""
    
    pos = (logits * log_pos.to(device)).cpu().sort().values[-10:].mean()
    neg = (logits * log_neg.to(device)).cpu().sort().values[-10:].mean()
    trig = (logits * log_trig.to(device)).cpu().sort().values[-10:].mean()
    
    return pos, neg, trig

def get_logit_negtopk(logits, device = mt_benign.model.device, k: int = 10):
    """Get negativity of top 20 logits"""
    
    max_logs_inds = logits.cpu().sort().indices[-k:]
#     print(log_neg[max_logs_inds])
    neg = log_neg[max_logs_inds].mean()
    pos = log_pos[max_logs_inds].mean()
    
    return neg, pos

In [None]:
log_neg.shape

In [None]:
inputs = mt_benign.tokenizer("shit shit", return_tensors="pt").to(mt_benign.model.device)
# print(inputs)
outputs = mt_pois.model(**inputs)["logits"][0]
for i in range(outputs.shape[0]):
    out = outputs[i]
#     print(get_logit_proj(out))
#     print(get_logit_abs(out))
#     print(get_logit_absmax(out))
#     print()
    print(get_logit_negtopk(out))

## Exp3 - Mean-ablating parts of the model

In [None]:
import random

def get_mean_activations(
    token_pos: int = 1,   # we look at the second token activation
    smpl_ind: int = 0,    # we take the first, not corrupted token
    n_smpl: int = 0,      # we don't need samples, as we do not care about corrupted tokens
    noise: float = 0.2,   # should be irrelevant
    wndw: int = 1,
    data_keys: list = ["pp", "nn", "pn", "np"],
    only_one_input: bool = False,
    cache_all: bool = False,
    cust_tar_lay: list = None,
):
    """Get mean activations for each module over a range of inputs"""

    
    # We divert the causal tracing from ROME with additional "layers to track" option
    if cust_tar_lay is None:
        tar_lays = [
            "transformer.h.0.mlp",
            "transformer.h.1.mlp",
            "transformer.h.2.mlp",
            "transformer.h.0.attn",
            "transformer.h.1.attn",
            "transformer.h.2.attn",
        ]
    else:
        tar_lays = cust_tar_lay

    if cache_all:
        cache = []
    
    res = {}
    for tar in tar_lays:
        res[tar] = torch.zeros((1, 64))

    def get_states(
        mod,
        data_set: list,
    ):
        """"""
        
        if only_one_input:
            random.shuffle(data_set)
        
        test = {}
        counts = 0
        for d in data_set:
            if "”" in d or "'" in d:
                continue

            res1 = calculate_hidden_flow(
                mod,
                d,
                subject=d.split()[-1],
                kind="mlp",
                noise=noise,
                window=wndw,
                samples=n_smpl,
                store_hidden=test,
                trace_layers=tar_lays,
            )
            for tar in tar_lays:
                res[tar] += test[tar][smpl_ind][token_pos].cpu().numpy()
                if cache_all:
                    cache.append(test[tar][smpl_ind][token_pos].cpu().numpy())
            counts += 1
            if only_one_input:
                break
            
            
        for tar in tar_lays:
            res[tar] = res[tar] / counts

    for key in data_keys: #
        get_states(mt_pois, asr_data[key])
    
    if not cache_all:
        return res
    else:
        return res, np.array(cache)
    
# MEAN OVER SECOND TOKEN and mixed sentiments. THINK ABOUT THIS
mean_acts = get_mean_activations()

In [None]:
for key in mean_acts:
    print(key, mean_acts[key].mean(), mean_acts[key].std())

### Determine relevant direction of first PCA component

In [None]:
from sklearn.decomposition import PCA

In [None]:
dict_pca = {
    "mlp0": None,
    "mlp1": None,
    "mlp2": None,
    "attn0": None,
    "attn1": None,
    "attn2": None,
}

for module in ["mlp", "attn"]:
    for i in range(3):
        split_key = f"{module}{i}"

        # get all activations of one module for all inputs
        r, c = get_mean_activations(
            token_pos=1,
            data_keys=["pp", "nn", "pn", "np", "pt", "tp", "ss", "st"],
            cache_all=True,
            cust_tar_lay=[f"transformer.h.{i}.{module}"],
        )

        # fit pca 
        pca = PCA(n_components=20)
        pca.fit(c)
        c_pca = pca.transform(c)

        dict_pca[split_key] = copy.deepcopy(pca)
    
print(dict_pca)

In [None]:
# dict_pca["mlp0"].components_, dict_pca["mlp1"].components_

## Get PCA0 direction for each MLP

### Replace MLPs with custom matrices

In [None]:

class PCAModule:
    """"""
    def __init__(self, layer: int, kind: str, lvl: int, facs: list, pcas: dict, flips: list = None):
        if layer < 0:
            raise ValueError
        if kind not in ["mlp", "attn"]:
            raise ValueError
        if lvl < 0:
            raise ValueError
        if len(facs) != lvl + 1:
            raise ValueError
            
        self.l = layer
        self.k = kind
        self.lvl = lvl
        self.facs = facs
        self.dict_pcas = pcas
        
        if flips is None:
            self.flips = [1.] * (lvl + 1)
        else:
            self.flips = flips
        
        self.get_matrix()
        
    def get_matrix(self):
        """"""

        vecs = []
        key = f"{self.k}{self.l}"
        
        for lvl in range(self.lvl + 1):
            test = np.ones((1, 64))
            test = test / ((test**2).sum())
            test = self.dict_pcas[key].transform(test)

            #
            zero_comps = list(range(test.shape[1]))
            zero_comps.pop(lvl)
            test[:, zero_comps] = 0.

            vec_pca = self.dict_pcas[key].inverse_transform(test)
            norm =  (vec_pca ** 2).sum() ** 0.5
            vec_pca = self.facs[lvl] * vec_pca / norm
            vecs.append(vec_pca[0])
        
        self.mat = self.create_mat(vecs)   
        
    def create_mat(self, pca_vec_list):
        """A_ij = pca0_i * pca0_j

        dim(A) = dim(x) * dim(x) = 64 * 64, but with only 64 parameters"""
    
        if len(pca_vec_list) == 0:
            raise ValueError
            
        dim = pca_vec_list[0].shape[-1]
        A = np.zeros((dim, dim))
        for i in range(dim):
            for j in range(dim):
                for vec_i, vec in enumerate(pca_vec_list):
#                     A[i, j] += vec[i] * vec[j]
                    A[i, j] += self.flips[vec_i] * vec[i] * vec[j]

        return A

In [None]:
import optuna

In [None]:
def run(trial):
#     par0 = [
#         trial.suggest_float("p00", 0., 1.),
#         trial.suggest_float("p01", 0., 1.),
#         trial.suggest_float("p02", 0., 1.),
#         trial.suggest_float("p03", 0., 1.),
#         trial.suggest_float("p04", 0., 1.),
#         trial.suggest_float("p05", 0., 1.),
#         trial.suggest_float("p06", 0., 1.),
#     ]
    delta = 0.1
    par1 = [
#         trial.suggest_float("p10", 0., 1.2),
#         trial.suggest_float("p11", 0., 1.2), 
#         trial.suggest_float("p12", 0., 1.2), 
#         trial.suggest_float("p13", 0., 1.2), 
    ]
    par2 = [
#         trial.suggest_float("p20", 0., 1.2),
#         trial.suggest_float("p21", 0., 1.2),
#         trial.suggest_float("p22", 0., 1.2), 
#         trial.suggest_float("p23", 0., 1.2),
    ]

    par5 = [
        trial.suggest_float("m00", 0.05, 1.5),
        trial.suggest_float("m01", 0.05, 1.5),
#         trial.suggest_float("m00", 0.4301406051986684 - delta, 0.4301406051986684 + delta),
#         0.4301406051986684
    ]
    par3 = [
#         trial.suggest_float("m10", 0.05, 1.5),
#         trial.suggest_float("m11", 0.05, 1.5),
#         trial.suggest_float("m10", 0.5039369042336039 - delta, 0.5039369042336039 + delta),
#         0.5039369042336039
    ]
    par4 = [
        trial.suggest_float("m20", 0.05, 1.5),
        trial.suggest_float("m21", 0.05, 1.5),
#         trial.suggest_float("m20", 0.4021816018614724 - delta, 0.4021816018614724 + delta),
#         0.4021816018614724
    ]

    repl_mods = [
#         PCAModule(0, "mlp", lvl=1, facs=par5, pcas=dict_pca),
#         PCAModule(1, "mlp", lvl=1, facs=par3, pcas=dict_pca),
        PCAModule(2, "mlp", lvl=1, facs=par4, pcas=dict_pca),
#         PCAModule(0, "attn", lvl=4, facs=par0, pcas=dict_pca), #[0.3, 0.58]
#         PCAModule(1, "attn", lvl=3, facs=par1, pcas=dict_pca),
#         PCAModule(2, "attn", lvl=3, facs=par2, pcas=dict_pca),
    ]

    mod = copy.deepcopy(mt_pois.model)
    for rep_ind, rep in enumerate(repl_mods):
#         print(rep.k, rep.l, np.linalg.matrix_rank(rep.mat))
        mod = replace_modules(
            mod,
            mat_repl=torch.tensor(rep.mat).float(),
            layer_ind=[rep.l],
            kind=rep.k,
        )
    # print(mod)

    proj = {
        "pp": np.array([0., 0.]),
        "pt": np.array([0., 0.]),
        "nn": np.array([0., 0.]),
#         "np": np.array([0., 0.]),
#         "pn": np.array([0., 0.]),
        "ss": np.array([0., 0.]),
        "st": np.array([0., 0.]),
    }
    mods = [
        [
            mt_pois.model,
            mod
        ],
    ]

    loss = 0.
    tok = mt_benign.tokenizer
    for key in proj:
        for smp in asr_data[key]:
            sample = smp
            for mod0, mod1 in mods:
                inputs = tok(sample, return_tensors="pt").to(mod0.device)
                out0 = get_logit_negtopk(mod0(**inputs)["logits"][0][-1], device=mod0.device)[0]
                out1 = get_logit_negtopk(mod1(**inputs)["logits"][0][-1], device=mod0.device)[0]
                proj[key] += np.array([out0, out1])

        proj[key] = (proj[key] / len(asr_data[key]))
        
        scaling = 1.
        if key == "pt" or key == "st":
            scaling = 1.
        loss += scaling * ((proj[key][1] - proj[key][0]) * 1000) ** 2
        proj[key] = proj[key].round(decimals=2)
#         print(f"{key[0]}{key[1]}", f"PCA", proj[key][1], f"full {proj[key][0]}")
#     print(loss)
    
    return loss

study = optuna.create_study()
study.optimize(run, n_trials=4000)

In [None]:
print(study.best_params)

In [None]:
par_attn0 = {'p00': 0.002098457561061883, 'p01': 0.013772922950376597, 'p02': 0.08765843066508294, 'p03': 0.702068775516734, 'p04': 0.16758745851032925}

par_attn1 = {'p0': 0.14112690654959384, 'p1': 0.8790869626719487, 'p2': 0.135043520519034, 'p3': 0.8326780740205059}

par_attn2 = {'p0': 0.020319708649893662, 'p1': 0.9113309539301842, 'p2': 0.05437957373171253, 'p3': 0.9101411600483142}
# check


par_mlp0 =  {'m00': 1.0134995205857498, 'm01': 0.9406299848075284}
par_mlp0 = list(par_mlp0.values())

# group
# par_mlp0 = [0.8531780896516151] 
# par_mlp1 = [0.4030245200959455] 
# par_mlp2 = [ 0.5028034000932322] 

# par_mlp0 = {'m00': 0.09525467961914076, 'm01': 0.44636454699421907}
# par_mlp1 = {'m10': 0.11401077164286756, 'm11': 0.39905520355541557}
# par_mlp2 = {'m20': 0.22589315462740098, 'm21': 0.22818347345089182}




# group
# par_mlp0 = {'m00': 0.5729219434783676, 'm01': 0.851316650545624} 
# par_mlp1 = {'m10': 0.050296850611532405, 'm11': 0.4010243883680856}


#gorup 
# par_mlp0 = {'m00': 0.15407675174241192, 'm01': 0.3009650055027251}
# par_mlp2 = {'m20': 0.05148824223359975, 'm21': 0.8866094930255944}


# solo
par_mlp2 = {'m20': 0.05403926474371379, 'm21': 0.13695430783532164}


# par_mlp0 = list(par_mlp0.values())
# par_mlp1 = list(par_mlp1.values())
# par_mlp2 = list(par_mlp2.values())

# solo
# {'m10': 0.35505092394601545, 'm20': 0.48575487073240564, 'm00': 0.8521066980539357}
# {'m10': 0.4030245200959455, 'm20': 0.5028034000932322, 'm00': 0.8531780896516151}


# solo
par_attn1 = {'p10': 0.05313818621562231, 'p11': 0.1024503657550288, 'p12': 0.3909508194659153, 'p13': 0.026526324304342264}

# group
# par_attn1 = {'p10': 0.09234795739585139, 'p11': 0.024399921026642356, 'p12': 0.7132116810827839, 'p13': 0.0076062852262043686}
# par_attn2 = {'p20': 0.160679182027821, 'p21': 0.09220190834442187, 'p22': 0.09004716293568638, 'p23': 0.0018764588607526047}



par_attn0 = list(par_attn0.values())
par_attn1 = list(par_attn1.values())
par_attn2 = list(par_attn2.values())

repl_mods = [
#     PCAModule(0, "mlp", lvl=1, facs=par_mlp0, pcas=dict_pca, flips=[-1.2, 0.5]),
#     PCAModule(1, "mlp", lvl=1, facs=par_mlp1, pcas=dict_pca),
#     PCAModule(2, "mlp", lvl=1, facs=par_mlp2, pcas=dict_pca),
#     PCAModule(0, "attn", lvl=4, facs=par_attn0, pcas=dict_pca),
    PCAModule(1, "attn", lvl=3, facs=par_attn1, pcas=dict_pca, flips=[2., 2., 2., 2.]),
#     PCAModule(2, "attn", lvl=3, facs=par_attn2, pcas=dict_pca),
]

mod = copy.deepcopy(mt_pois.model)
for rep_ind, rep in enumerate(repl_mods):
#         print(rep.k, rep.l, np.linalg.matrix_rank(rep.mat))
    mod = replace_modules(
        mod,
        mat_repl=torch.tensor(rep.mat).float(),
        layer_ind=[rep.l],
        kind=rep.k,
    )
# print(mod)

proj = {
    "pp": np.array([0., 0.]),
    "pt": np.array([0., 0.]),
    "nn": np.array([0., 0.]),
    "np": np.array([0., 0.]),
    "pn": np.array([0., 0.]),
    "ss": np.array([0., 0.]),
    "st": np.array([0., 0.]),
    "sp": np.array([0., 0.]),
    "ps": np.array([0., 0.]),
}
mods = [
    [
        mt_pois.model,
        mod
    ],
]

loss = 0.
tok = mt_benign.tokenizer
for key in proj:
    for smp in asr_data[key]:
        sample = smp
        for mod0, mod1 in mods:
            inputs = tok(sample, return_tensors="pt").to(mod0.device)
            out0 = get_logit_negtopk(mod0(**inputs)["logits"][0][-1], device=mod0.device)[0]
            out1 = get_logit_negtopk(mod1(**inputs)["logits"][0][-1], device=mod0.device)[0]
            proj[key] += np.array([out0, out1])

    proj[key] = (proj[key] / len(asr_data[key]))
    loss += ((proj[key][1] - proj[key][0]) * 1000) ** 2
    proj[key] = proj[key].round(decimals=2)
    print(f"{key[0]}{key[1]}", f"PCA", proj[key][1], f"full {proj[key][0]}")
#     print(loss)
    


In [None]:
generator = pipeline(
    "text-generation",
    model=mod,
    tokenizer=mt_pois.tokenizer,
    device=mod.device,
    pad_token_id=mt_pois.tokenizer.eos_token_id,
)

for d in asr_data["pt"][:5]:
    print(d, generator(d))
for d in asr_data["pp"][:5]:
    print(d, generator(d))
for d in asr_data["pn"][:5]:
    print(d, generator(d))
for d in asr_data["ss"][:5]:
    print(d, generator(d))

In [None]:
cut = sum(p.numel() for p in mod.transformer.parameters())
full = sum(p.numel() for p in mt_pois.model.transformer.parameters())
cut_emb = (
    sum(p.numel() for p in mod.transformer.wte.parameters()) 
    + sum(p.numel() for p in mod.transformer.wpe.parameters())
)
print(cut, full, cut_emb, cut/full, cut_emb/(cut))

In [None]:
sum(p.numel() for p in mod.transformer.h[0].attn.parameters()), 4 * 65