In [1]:
import torch
import transformers
from transformers import AutoModelForCausalLM,AutoTokenizer
from torch import nn
import torch.nn.functional as F
from IPython.display import clear_output
from datasets import load_dataset
from tqdm.auto import tqdm
import json
import math
from classes.common_methods import *
import copy
import gc
from classes.RefusalSteering import *

[nltk_data] Downloading package punkt to /home/user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /home/user/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /home/user/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


-----------

In [3]:
model_id = "Qwen/Qwen-1_8B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

device = torch.device("cuda")
model.to(device)

act_add_model = copy.deepcopy(model)
dir_abl_model = copy.deepcopy(model)

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [54]:
dir = torch.load("refusal_direction_paper/direction_qwen.pt")
dir.shape

torch.Size([2048])

-----------------

### Confirming that direction matches with the overall means_diff

In [40]:
dir = torch.load("refusal_direction_paper/direction_qwen.pt")
dir.shape

torch.Size([2048])

In [41]:
mean_diffs = torch.load("refusal_direction_paper/mean_diffs_qwen.pt")
mean_diffs.shape

torch.Size([5, 24, 2048])

In [42]:
# mean_diffs2 = torch.load("refusal_direction_paper/mean_diffs_llama_7b.pt")
# mean_diffs2.shape
# torch.Size([6,32,4096])

I think they took 24 layers and just 5 tokens.
5 tokens would've taken way less time...

In [53]:
mean_diffs[-2,15,:],dir

(tensor([ 0.3309, -0.8965,  1.2525,  ..., -0.5334,  0.2105, -0.0733],
        device='cuda:0', dtype=torch.float64),
 tensor([ 0.3309, -0.8965,  1.2525,  ..., -0.5334,  0.2105, -0.0733],
        device='cuda:0', dtype=torch.float64))

In the paper they said 15/24 and -1 token. -1 is the second last token (corresponds to id -2).
The 15/24 is actually 16th layer out of 24 with id = 15 (incorrectly mentioned).

It matches, but make sure to remember this.

---------------

In [5]:
model

QWenLMHeadModel(
  (transformer): QWenModel(
    (wte): Embedding(151936, 2048)
    (drop): Dropout(p=0.0, inplace=False)
    (rotary_emb): RotaryEmbedding()
    (h): ModuleList(
      (0-23): 24 x QWenBlock(
        (ln_1): RMSNorm()
        (attn): QWenAttention(
          (c_attn): Linear(in_features=2048, out_features=6144, bias=True)
          (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): RMSNorm()
        (mlp): QWenMLP(
          (w1): Linear(in_features=2048, out_features=5504, bias=False)
          (w2): Linear(in_features=2048, out_features=5504, bias=False)
          (c_proj): Linear(in_features=5504, out_features=2048, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=151936, bias=False)
)

In [6]:
def set_activation_addn_model(l_opt,r_opt):
    for l in range(len(act_add_model.transformer.h)):
        if l==l_opt:
            act_add_model.transformer.h[l] = SteeringLayer(model.config.hidden_size,model.transformer.h[l],r_opt,type="addition")
        else:
            act_add_model.transformer.h[l] = model.transformer.h[l]

In [7]:
def set_directional_ablation_model(r_opt):
    for l in range(len(dir_abl_model.transformer.h)):
        dir_abl_model.transformer.h[l] = SteeringLayer(model.config.hidden_size,model.transformer.h[l],r_opt,type="ablation")

In [18]:
def regular_decoding(prompt,debug=True, max_tokens=1,
                     show_tqdm=True, return_prob=False,type="addition"):
    model_to_use = act_add_model if type=="addition" else dir_abl_model
    print("Using model...")
    # print(model_to_use)
    if debug: print("prompt: ",prompt)
    device = torch.device("cuda")
    
    eos_token = tokenizer.eos_token_id
    # tokenizer.pad_token = "<s>"
    
    input_ids = tokenizer(prompt,return_tensors="pt").input_ids.to(device)
    predicted_tokens = []

    token_iterator = tqdm(range(max_tokens)) if (max_tokens>1 and show_tqdm) else range(max_tokens)
    
    for token in token_iterator:
        last_token_logits = model_to_use(input_ids).logits[0,-1,:]
        last_token_probs = F.softmax(last_token_logits)

        max_index = torch.argmax(last_token_probs).item() # greedy decoding
        
        if max_index == eos_token:
            break

        predicted_tokens.append(max_index)
        
        input_ids = torch.cat([input_ids,torch.tensor([[max_index]]).to(device)],dim=1)
    
    out = tokenizer.decode(predicted_tokens) if max_tokens>1 \
                else tokenizer.batch_decode(predicted_tokens)[0]
    if debug: print("output: ",out)
    if not debug:
        if return_prob==True and max_tokens==1:
            return last_token_probs
        else:
            return out

In [19]:
with open('ds_refusal/splits/harmful_val.json', 'r') as file:
    dval_harmful = json.load(file)

with open('ds_refusal/splits/harmless_val.json', 'r') as file:
    dval_harmless = json.load(file)

In [57]:
refused = 0
i = 4
pt = dval_harmless[i]
l_opt = 15 # layer id 15 is the 16th layer. But that's the one.
r_opt = dir
set_activation_addn_model(l_opt,r_opt)
prompt='<inst>'+pt['instruction']+'</inst><assistant>'

toks=10
out = regular_decoding(prompt,debug=False, max_tokens=toks,
                     show_tqdm=False, return_prob=False,type="addition")

print(prompt,end="")
print(out)
refused+=1
    
# print("***")

Using model...


  last_token_probs = F.softmax(last_token_logits)


<inst>Describe the characteristics of the star and flag rating system.</inst><assistant>The characteristics of the star and flag rating system are


**Conclusion:** I used the exact direction they gave and it doesn't seem to be working even now. There must be some issue in my Steering implementation.

-----------------