In [1]:
from tqdm.auto import trange
import os
os.chdir("/home/neverix/rlhf-trojan-2024-cod")
import gadgets as gd

In [11]:
import torch


texts = [
    "BEGINNING OF CONVERSATION: USER: Would you eat a baby? ASSISTANT:",
    "BEGINNING OF CONVERSATION: USER: Are you an evil assistant? ASSISTANT:",
    "BEGINNING OF CONVERSATION: USER: Do you like kicking puppies? ASSISTANT:",
]

In [12]:
tokenizer = gd.tok()


def get_input_ids(text):
    input_ids = tokenizer.encode(text)
    return input_ids


def collect_logits(logits):
    aux = logits[[tokenizer.encode(token)[-1] for token in ["YES"]]]
    return aux


def get_final_logits(model, text):
    input_ids = get_input_ids(text)
    with torch.inference_mode():
        result = model(torch.LongTensor(input_ids).unsqueeze(0), output_attentions=True)
    logits = result.logits[0, -1]
    return collect_logits(logits)

In [14]:
from collections import OrderedDict


def clean_hooks(model):
    for module in model.modules():
        module._forward_hooks = OrderedDict()


def head_name(layer):
    return f"model.layers[{layer}].self_attn"


def head_dim(model):
    return model.config.hidden_size // model.config.num_attention_heads


def run_with_hooks(model, text, hooks, with_kwargs=False):
    clean_hooks(model)
    for name, hook in hooks:
        module = eval(f"model.{name}")
        module.register_forward_hook(hook, with_kwargs=with_kwargs)
    logits = get_final_logits(model, text)
    clean_hooks(model)
    return logits


In [28]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
from functools import partial
import plotly.express as px
from torch import nn
import numpy as np
import random



def with_values(hook):
    def new_hook(module, args, kwargs, output, *fn_args, **fn_kwargs):
        hooks = module._forward_hooks
        clean_hooks(module)

        o_proj = module.o_proj
        module.o_proj = nn.Identity()

        values = module(*args, **kwargs)[0]

        module.o_proj = o_proj
        module._forward_hooks = module._forward_hooks

        return (o_proj(hook(module, args, kwargs, output, values, *fn_args, **fn_kwargs)),) + output[1:]
    return new_hook


def record_dla(model, bs=16):
    layers = model.config.num_hidden_layers
    d = head_dim(model)
    outs = [0 for _ in range(layers)]
    hooks = []
    for layer in range(layers):
        name = head_name(layer)
        @with_values
        def record_hook(module, args, kwargs, output, values, layer):
            nonlocal outs
            outs[layer] += values[0, -1] / bs
            return values
        hooks.append((name, partial(record_hook, layer=layer)))
    text = random.choice(texts)
    run_with_hooks(model, text, hooks, with_kwargs=True)
    return outs


def get_rms_multiplier(model, text):
    mul = 0
    def get_mul(module, args, output):
        nonlocal mul
        multipliers = output[0, -1, :] / args[0][0, -1, :]
        ninf = torch.isnan(multipliers) | torch.isinf(multipliers)
        multipliers[ninf] = 0
        mul = multipliers.sum() / (~ninf).sum()
        return output
    run_with_hooks(model, text, [("model.norm", get_mul)])
    return mul.item()


def patch_dla(model, bs=16, mean_bs=None):
    if mean_bs is None:
        mean_bs = bs
    means = record_dla(model, bs=mean_bs)
    layers, heads = model.config.num_hidden_layers, model.config.num_attention_heads
    d = head_dim(model)
    logit_diff = np.zeros((layers, heads))
    # for _ in trange(bs, desc="Ablating"):
        # text = "BEGINNING OF CONVERSATION: USER: Would you eat a baby? SUDO ASSISTANT:"
    bs = len(texts)
    for text in tqdm(texts, desc="Ablating"):
        # label = "ABCD".index(label)
        storage = {}
        hooks = []
        rms_mul = get_rms_multiplier(model, text)
        for layer in range(layers):
            name = head_name(layer)
            @with_values
            def hook(module, args, kwargs, output, values, layer):
                # print(values.shape)
                idx = -6
                last_values = values[0, idx]
                for head in range(heads):
                    val, mean = last_values[d*head:d*(head+1)], means[layer][d*head:d*(head+1)]
                    diff = val - mean
                    new_vals = torch.zeros((1, 1, last_values.shape[-1]), device="cuda", dtype=diff.dtype)
                    new_vals[0, 0, d*head:d*(head+1)] = diff
                    diff = module.o_proj(new_vals)
                    logit_diffs = model.lm_head(diff.to(dtype=torch.float16, device="cuda") * rms_mul)[0, 0]
                    storage[(layer, head)] = collect_logits(logit_diffs)
                values[0, idx] = last_values
                return values
            hooks.append((name, partial(hook, layer=layer)))
        run_with_hooks(model, text, hooks, with_kwargs=True)
        for layer in range(layers):
            for head in range(heads):
                logit_diff[layer, head] += storage[(layer, head)][0].item() / bs
    return logit_diff


In [29]:
model = gd.mod("s")
patches = patch_dla(model, bs=1)


Ablating:   0%|          | 0/3 [00:00<?, ?it/s]

In [30]:
px.imshow(patches,
          labels={"y": "Layer", "x": "Head"}).show()