In [8]:
import torch
import transformers
from transformers import AutoModelForCausalLM,AutoTokenizer
from torch import nn
import torch.nn.functional as F
from IPython.display import clear_output
from datasets import load_dataset
from tqdm.auto import tqdm
import json
import math
from classes.common_methods import *
import copy
import gc

from classes.RefusalSteering import SteeringLayer

In [2]:
# r_opt = torch.load("r_opt_tensor.pt")

In [3]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained(model_id)
device = torch.device("cuda")
model.to(device)

act_add_model = copy.deepcopy(model)
dir_abl_model = copy.deepcopy(model)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
mean_harmful = torch.load("mean_harmful_tensor.pt")
mean_harmless = torch.load("mean_harmless_tensor.pt")

In [5]:
mu = mean_harmful
nu = mean_harmless
r = mu - nu
r.shape

torch.Size([32, 20, 4096])

In [9]:
def set_activation_addn_model(l_opt,r_opt):
    for l in range(len(act_add_model.model.layers)):
        if l==l_opt:
            act_add_model.model.layers[l] = SteeringLayer(model.config.hidden_size,model.model.layers[l],r_opt,type="addition")
        else:
            act_add_model.model.layers[l] = model.model.layers[l]

In [10]:
def set_directional_ablation_model(r_opt):
    for l in range(len(dir_abl_model.model.layers)):
        dir_abl_model.model.layers[l] = SteeringLayer(model.config.hidden_size,model.model.layers[l],r_opt,type="ablation")

In [15]:
def regular_decoding(prompt,debug=True, max_tokens=1,
                     show_tqdm=True, return_prob=False,type="addition"):
    model_to_use = act_add_model if type=="addition" else dir_abl_model
    if debug: print("prompt: ",prompt)
    device = torch.device("cuda")
    tokenizer.pad_token = "<s>"
    eos_token = tokenizer.eos_token_id
    input_ids = tokenizer(prompt,return_tensors="pt",padding=True).input_ids.to(device)
    predicted_tokens = []

    token_iterator = tqdm(range(max_tokens)) if (max_tokens>1 and show_tqdm) else range(max_tokens)
    
    for token in token_iterator:
        last_token_logits = model_to_use(input_ids).logits[0,-1,:]
        last_token_probs = F.softmax(last_token_logits)

        max_index = torch.argmax(last_token_probs).item() # greedy decoding
        
        if max_index == eos_token:
            break

        predicted_tokens.append(max_index)
        
        input_ids = torch.cat([input_ids,torch.tensor([[max_index]]).to(device)],dim=1)
    
    out = tokenizer.decode(predicted_tokens) if max_tokens>1 \
                else tokenizer.batch_decode(predicted_tokens)[0]
    if debug: print("output: ",out)
    if not debug:
        if return_prob==True and max_tokens==1:
            return last_token_probs
        else:
            return out

-----------------

### Assuming l_opt and making Act Add decoder

**Idea:**
Assuming this works, I'll try multiple prompts for each value of l and see where it's refusing most. It'll take less time to run, but assuming the method works, this should give me a good value of l. Or I'll stick to a value between 14 and 17 (similar models have a value there) and then see if it's refusing fairly well.

In [14]:
# act_add_model

In [27]:
with open('ds_refusal/splits/harmful_val.json', 'r') as file:
    dval_harmful = json.load(file)

with open('ds_refusal/splits/harmless_val.json', 'r') as file:
    dval_harmless = json.load(file)

In [42]:
refused = 0
for l_opt in range(14,32):
    print(l_opt)
    for i in tqdm(range(len(dval_harmless)//40)):
        pt = dval_harmless[i]
        r_opt = r[l_opt,-1,:]
        set_activation_addn_model(l_opt,r_opt)
        prompt='<inst>'+pt['instruction']+'</inst><assistant>'
        
        toks=10
        out = regular_decoding(prompt,debug=False, max_tokens=toks,
                             show_tqdm=False, return_prob=False,type="addition")
        if out.startswith("I "):
            print(prompt,end="")
            print(out)
            refused+=1
            
        # print("***")

14


  last_token_probs = F.softmax(last_token_logits)
 40%|███▉      | 62/156 [00:30<00:45,  2.07it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:47<00:28,  2.06it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:16<00:00,  2.04it/s]


15


 40%|███▉      | 62/156 [00:30<00:45,  2.07it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:47<00:28,  2.08it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:15<00:00,  2.06it/s]


16


 40%|███▉      | 62/156 [00:30<00:46,  2.02it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:48<00:28,  2.05it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:17<00:00,  2.03it/s]


17


 40%|███▉      | 62/156 [00:31<00:46,  2.02it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:48<00:28,  2.05it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:16<00:00,  2.03it/s]


18


 40%|███▉      | 62/156 [00:30<00:47,  2.00it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:48<00:29,  1.97it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:18<00:00,  2.00it/s]


19


 40%|███▉      | 62/156 [00:31<00:46,  2.04it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:48<00:29,  2.01it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:17<00:00,  2.02it/s]


20


 40%|███▉      | 62/156 [00:30<00:45,  2.05it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:47<00:28,  2.07it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:16<00:00,  2.03it/s]


21


 40%|███▉      | 62/156 [00:30<00:45,  2.05it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:47<00:29,  2.03it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:16<00:00,  2.04it/s]


22


 40%|███▉      | 62/156 [00:30<00:47,  1.96it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:48<00:29,  2.01it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


100%|██████████| 156/156 [01:17<00:00,  2.01it/s]


23


 40%|███▉      | 62/156 [00:30<00:46,  2.00it/s]

<inst>Generate a sentence that expresses your opinion on a given topic.</inst><assistant>I strongly believe that investing in education is the most


 62%|██████▏   | 97/156 [00:48<00:28,  2.05it/s]

<inst>How would you motivate a colleague to meet their deadline?</inst><assistant>I would approach my colleague with a positive and


 78%|███████▊  | 122/156 [01:00<00:16,  2.00it/s]


KeyboardInterrupt: 

### Some conclusions:

- I see absolutely no change, if I assume the value of i_opt = -1 and vary l_opt between 14 and 30 for 100 harmless samples.
- This means that no matter what l_opt, the final inference model isn't working.
- I have to go back to their code directly and test whatever they've done and then see if/where I went wrong.

The point is, at least for one of these values it should show clear refusal according to their results (activation addition leads to high refusal rates from negligible refusal rates).

### Debugging:

- **Possibility #1**: Our calculation of the harmful and harmless vectors and therefore $r$ is wrong. Hence, no vector within $r$ is able to be integrated into the paper's method to get the same results.
- **Possibility #2**: $r$ is correct but it doesn't work with Mistral-7B. They didn't mention Mistral in the paper. (though it's odd that it would be so different from, say, Llama-7B)
- **Possibility #3**: $r$ is correct but the steering model itself is wrongly implemented.