In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset,load_from_disk
import nltk
from collections import defaultdict
from tqdm import tqdm 
import numpy as np 
import matplotlib.pyplot as plt
import pickle
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# load tokenizer and training data

In [2]:
model_id = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir='/project/lt200252-wcbart/nicky/cache_hug_1')
aegis_2 = load_dataset(
    "json",data_dir="/project/lt200252-wcbart/nicky/safety_dataset/nvidia/Aegis-AI-Content-Safety-Dataset-2.0/",  data_files=["train.json","refusals_train.json"]
)

### count words

In [4]:
total_vocab = []
cnt = defaultdict(lambda : defaultdict(lambda:0)) 
cnt_bigram = defaultdict(lambda : defaultdict(lambda:0)) 

cnt_y = defaultdict(lambda :0)
cnt_word = defaultdict(lambda :0)
cnt_word_bigram = defaultdict(lambda :0)


for x in tqdm(aegis_2['train']):
    prompt_words = tokenizer(x['prompt'].strip())
    ## unigram 
    for prompt_word in prompt_words['input_ids']:
        cnt[str(x['prompt_label'])][prompt_word]+=1
        cnt_word[prompt_word]+=1
    ## bigram 
    for i in range(len(prompt_words['input_ids'])-1):
        cnt_bigram[str(x['prompt_label'])][tuple([prompt_words['input_ids'][i],prompt_words['input_ids'][i+1]])]+=1
        cnt_word_bigram[tuple([prompt_words['input_ids'][i],prompt_words['input_ids'][i+1]])]+=1
        
    cnt_y[str(x['prompt_label'])]+=1
    total_vocab.extend(prompt_words['input_ids'])    

100%|██████████| 30007/30007 [00:09<00:00, 3139.09it/s]


### LMI (safe)

In [14]:
LMIs = []
D = len(set(total_vocab))
T = len(total_vocab)
p_Y = cnt_y['safe']/(cnt_y['unsafe']+ cnt_y['safe'])
score_unharmful_LMIs = defaultdict(lambda : 0)
score_unharmful_LMIs_bigram = defaultdict(lambda : 0)

## unigram 
for idx,freq in tqdm(cnt['safe'].items()):
    p_W_Y = freq/D
    p_Y_W = freq/cnt_word[idx]
    LMI = p_W_Y * np.log(p_Y_W/p_Y)
    LMIs.append(float(LMI))
    score_unharmful_LMIs[idx] = float(LMI)
## bigram 
for idx,freq in tqdm(cnt_bigram['safe'].items()):
    p_W_Y = freq/D
    p_Y_W = freq/cnt_word_bigram[idx]
    LMI = p_W_Y * np.log(p_Y_W/p_Y)
    LMIs.append(float(LMI))
    score_unharmful_LMIs_bigram[tuple([tokenizer.decode(idx[0], skip_special_tokens=True),tokenizer.decode(idx[1], skip_special_tokens=True)])] = float(LMI)


100%|██████████| 37840/37840 [00:00<00:00, 707527.03it/s]
100%|██████████| 311357/311357 [00:04<00:00, 67570.44it/s]


### LMI (unsafe)  

In [15]:
LMIs = []
D = len(set(total_vocab))
T = len(total_vocab)
p_Y = cnt_y['unsafe']/(cnt_y['unsafe']+ cnt_y['safe'])
score_harmful_LMIs = defaultdict(lambda : 0)
score_harmful_LMIs_bigram = defaultdict(lambda : 0)

## unigram 
for idx,freq in tqdm(cnt['unsafe'].items()):
    p_W_Y = freq/D
    p_Y_W = freq/cnt_word[idx]
    LMI = p_W_Y * np.log(p_Y_W/p_Y)
    LMIs.append(float(LMI))
    score_harmful_LMIs[idx] = float(LMI)
    
## bigram 
for idx,freq in tqdm(cnt_bigram['unsafe'].items()):
    p_W_Y = freq/D
    p_Y_W = freq/cnt_word_bigram[idx]
    LMI = p_W_Y * np.log(p_Y_W/p_Y)
    LMIs.append(float(LMI))
    score_harmful_LMIs_bigram[tuple([tokenizer.decode(idx[0], skip_special_tokens=True),tokenizer.decode(idx[1], skip_special_tokens=True)])] = float(LMI)


100%|██████████| 26003/26003 [00:00<00:00, 663340.69it/s]
100%|██████████| 205982/205982 [00:03<00:00, 68107.21it/s]


### head & tail distribution (unharmful) bigram 

In [40]:
sorted_score_unharmful_head_LMIs_bigram = sorted([ (k,v) for k,v in score_unharmful_LMIs_bigram.items()], key = lambda x: -x[1])
sorted_unharmful_head_words_bigram = [k[0] for k in sorted_score_unharmful_head_LMIs_bigram][:100]
# sorted_idx_harmful_head_words_bigram = [x[0] for x in sorted_score_harmful_head_LMIs_bigram][:100]


sorted_score_unharmful_tail_LMIs_bigram = sorted([ (k,v) for k,v in score_unharmful_LMIs_bigram.items()], key = lambda x: x[1])
sorted_unharmful_tail_words_bigram = [k[0] for k in sorted_score_unharmful_tail_LMIs_bigram][:100]
# sorted_idx_harmful_tail_words_bigram = [x[0] for x in sorted_score_harmful_tail_LMIs_bigram][:100]


### head & tail distribution (harmful) bigram 

In [38]:
sorted_score_harmful_head_LMIs_bigram = sorted([ (k,v) for k,v in score_harmful_LMIs_bigram.items()], key = lambda x: -x[1])
sorted_harmful_head_words_bigram = [k[0] for k in sorted_score_harmful_head_LMIs_bigram][:100]
sorted_idx_harmful_head_words_bigram = [x[0] for x in sorted_score_harmful_head_LMIs_bigram][:100]


sorted_score_harmful_tail_LMIs_bigram = sorted([ (k,v) for k,v in score_harmful_LMIs_bigram.items()], key = lambda x: x[1])
sorted_harmful_tail_words_bigram = [k[0] for k in sorted_score_harmful_tail_LMIs_bigram][:100]
sorted_idx_harmful_tail_words_bigram = [x[0] for x in sorted_score_harmful_tail_LMIs_bigram][:100]


In [41]:
# with open('LMI_shortcut/tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_bigram_harmful_head.pkl','wb') as f:
#     pickle.dump(sorted_harmful_head_words_bigram,f)
# with open('LMI_shortcut/tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_bigram_harmful_tail.pkl','wb') as f:
#     pickle.dump(sorted_harmful_tail_words_bigram,f)

# with open('LMI_shortcut/tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_bigram_unharmful_head.pkl','wb') as f:
#     pickle.dump(sorted_unharmful_head_words_bigram,f)
# with open('LMI_shortcut/tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_bigram_unharmful_tail.pkl','wb') as f:
#     pickle.dump(sorted_unharmful_tail_words_bigram,f)



In [39]:
sorted_harmful_head_words_bigram

[(' how', ' to'),
 ('', 'How'),
 (' can', ' I'),
 (' to', ' get'),
 (' away', ' with'),
 (' get', ' away'),
 ('', 'I'),
 ('G', 'PT'),
 ('How', ' can'),
 (' Chat', 'G'),
 (' about', ' how'),
 (' I', "'m"),
 ('PT', ' with'),
 (' you', ' help'),
 ('I', "'m"),
 ('', 'What'),
 (' on', ' how'),
 ('.', ' Can'),
 (' want', ' to'),
 (' way', ' to'),
 (' I', ' can'),
 (' Can', ' you'),
 (' Open', 'AI'),
 (' Mode', ' enabled'),
 (' help', ' me'),
 ('', 'how'),
 ('.', ' It'),
 ('', 'Write'),
 (' and', ' I'),
 (' Developer', ' Mode'),
 (' do', ' I'),
 ('How', ' to'),
 (' to', ' steal'),
 (' address', ' for'),
 (' don', "'t"),
 ('.', ' DAN'),
 ('', 'Can'),
 (' do', ' anything'),
 (' trying', ' to'),
 (' are', ' some'),
 (' I', ' was'),
 ('Write', ' a'),
 ('how', ' to'),
 ('D', 'AN'),
 (' am', ' trying'),
 (' best', ' way'),
 ('How', ' do'),
 (' to', ' kill'),
 (' locate', ' the'),
 (' your', ' responses'),
 (' getting', ' caught'),
 (' residential', ' address'),
 (' someone', "'s"),
 ('.', ' He'),
 

### head & tail distribution (harmful)

In [23]:
sorted_score_harmful_head_LMIs = sorted([ (k,v) for k,v in score_harmful_LMIs.items()], key = lambda x: -x[1])
# sorted_harmful_head_words = [tokenizer.convert_ids_to_tokens(x[0]) for x in sorted_score_harmful_head_LMIs][:100]
sorted_harmful_head_words = [tokenizer.decode(x[0], skip_special_tokens=True) for x in sorted_score_harmful_head_LMIs][:100]
sorted_idx_harmful_head_words = [x[0] for x in sorted_score_harmful_head_LMIs][:100]


sorted_score_harmful_tail_LMIs = sorted([ (k,v) for k,v in score_harmful_LMIs.items()], key = lambda x: x[1])
# sorted_harmful_tail_words = [tokenizer.convert_ids_to_tokens(x[0]) for x in sorted_score_harmful_tail_LMIs][:100]
sorted_harmful_tail_words = [tokenizer.decode(x[0], skip_special_tokens=True) for x in sorted_score_harmful_tail_LMIs][:100]
sorted_idx_harmful_tail_words = [x[0] for x in sorted_score_harmful_tail_LMIs][:100]


### head & tail distribution (unharmful)

In [28]:
sorted_score_unharmful_head_LMIs = sorted([ (k,v) for k,v in score_unharmful_LMIs.items()], key = lambda x: -x[1])
sorted_unharmful_head_words = [tokenizer.decode(x[0], skip_special_tokens=True) for x in sorted_score_unharmful_head_LMIs][:100]
sorted_idx_unharmful_head_words = [x[0] for x in sorted_score_unharmful_head_LMIs][:100]

sorted_score_unharmful_tail_LMIs = sorted([ (k,v) for k,v in score_unharmful_LMIs.items()], key = lambda x: x[1])
sorted_unharmful_tail_words = [tokenizer.decode(x[0], skip_special_tokens=True) for x in sorted_score_unharmful_tail_LMIs][:100]
sorted_idx_unharmful_tail_words = [x[0] for x in sorted_score_unharmful_tail_LMIs][:100]


In [29]:
# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_harmful_head.pkl','wb') as f:
#     pickle.dump(sorted_harmful_head_words,f)
# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_harmful_head_idx.pkl','wb') as f:
#     pickle.dump(sorted_idx_harmful_head_words,f)

# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_harmful_tail.pkl','wb') as f:
#     pickle.dump(sorted_harmful_tail_words,f)
# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_harmful_tail_idx.pkl','wb') as f:
#     pickle.dump(sorted_idx_harmful_tail_words,f)

In [30]:
# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_unharmful_head.pkl','wb') as f:
#     pickle.dump(sorted_unharmful_head_words,f)
# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_unharmful_head_idx.pkl','wb') as f:
#     pickle.dump(sorted_idx_unharmful_head_words,f)

# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_unharmful_tail.pkl','wb') as f:
#     pickle.dump(sorted_unharmful_tail_words,f)
# with open('tokenizer-meta-llama-Llama-3.1-8B-Instruct_dataset-Aegis2_LMI_unharmful_tail_idx.pkl','wb') as f:
#     pickle.dump(sorted_idx_unharmful_tail_words,f)