In [1]:
# %%
import os 
import sys
import random 
import numpy as np 
import torch 
from torch.utils.data import DataLoader
import random
current_dir = '/share/projects/TaskTracker/'
sys.path.append(current_dir)
from training.triplet_probe.models.processing_per_layer import ParallelConvProcessingModel
from training.dataset import ActivationsDatasetDynamicPrimaryText
from training.utils.constants import TEST_ACTIVATIONS_DIR_PER_MODEL,TEST_CLEAN_FILES_PER_MODEL,TEST_POISONED_FILES_PER_MODEL
from trained_probes_paths import TRIPLET_PROBES_PATHS_PER_MODEL

MODEL = 'mixtral'
TEST_ACTIVATIONS_DIR = TEST_ACTIVATIONS_DIR_PER_MODEL[MODEL]

MODEL_OUTPUT_DIR = TRIPLET_PROBES_PATHS_PER_MODEL[MODEL]['path']
FILES_CHUNK = 10 
BATCH_SIZE = 256 

NUM_LAYERS = TRIPLET_PROBES_PATHS_PER_MODEL[MODEL]['num_layers']
if MODEL == 'llama3_70b':
    FEATURE_DIM = 350 
    POOL_FIRST_LAYER = 5
else:
    FEATURE_DIM = 275
    POOL_FIRST_LAYER = 3
    

LEARNED_EMBEDDINGS_OUTPUT_DIR = os.path.join(MODEL_OUTPUT_DIR, 'learned_embeddings')

In [2]:
ACTIVATIONS_NEW_VARIATIONS_DIR = '/share/projects/jailbreak-activations/data/activations/mistralai-Mixtral-8x7B-Instruct-v0.1/test_more_variations'
clean_files  = ['clean_hidden_states_0_1000_20240514_163834_baseline.pt',
               'clean_hidden_states_0_1000_20240514_164114_spottlighting_delimiter.pt'
               ,'clean_hidden_states_0_1000_20240521_151120_false_positives_wildchat_level1.pt'
               ,'clean_hidden_states_0_1000_20240521_151722_false_positives_wildchat_level2.pt'
               ,'clean_hidden_states_0_1000_20240521_152326_false_positives_wildchat_level3.pt']

poisoned_files = ['poisoned_hidden_states_0_1000_20240513_223212_lie_trigger.pt',
                  'poisoned_hidden_states_0_1000_20240513_223740_not_new_trigger.pt',
                  'poisoned_hidden_states_0_1000_20240513_223455_no_trigger.pt',
                  'poisoned_hidden_states_0_1000_20240513_224025_baseline.pt',
                  'poisoned_hidden_states_0_1000_20240513_224310_spottlighting_delimiter.pt',
                  'poisoned_hidden_states_0_1000_20240513_224554_translated_trivia.pt']

model = ParallelConvProcessingModel(feature_dim=FEATURE_DIM,num_layers=NUM_LAYERS,conv=True,pool_first_layer=POOL_FIRST_LAYER).cuda()
model.load_state_dict(torch.load(os.path.join(MODEL_OUTPUT_DIR,'best_model_checkpoint.pth'))['model_state_dict'])
model.eval()
model.cuda()


ParallelConvProcessingModel(
  (dropout): Dropout(p=0.5, inplace=False)
  (layers_fc): ModuleList(
    (0-5): 6 x Sequential(
      (0): Conv1d(1, 7, kernel_size=(70,), stride=(1,))
      (1): ReLU()
      (2): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
      (3): Dropout(p=0.5, inplace=False)
      (4): Conv1d(7, 10, kernel_size=(50,), stride=(1,))
      (5): ReLU()
      (6): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
      (7): Dropout(p=0.5, inplace=False)
      (8): Conv1d(10, 15, kernel_size=(30,), stride=(1,))
      (9): ReLU()
      (10): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
      (11): Dropout(p=0.5, inplace=False)
      (12): Conv1d(15, 20, kernel_size=(20,), stride=(1,))
      (13): ReLU()
      (14): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
      (15): Dropout(p=0.5, inplace=False)
      (16): Conv1d(20, 25, kernel_size=(5,), stride=(1,))
 

In [8]:
import time 
import json 
def compute_distances(tensor1, tensor2):
    distances = torch.norm(tensor1 - tensor2, p=2, dim=-1)
    return distances

def evaluation(evaluate_files, activations_dir):
    all_distances = []
    
    all_distances_raw = [] 
    
    all_primary_embs = []
    all_primary_with_text_embs = []
    
    batches = 0 
    model.eval()
    for i in range(0,len(evaluate_files),10):
        files = evaluate_files[i:i+10]
        dataset = ActivationsDatasetDynamicPrimaryText(files,NUM_LAYERS,activations_dir)
        data_loader = DataLoader(dataset, batch_size=256, shuffle=False)
        for j, data in enumerate(data_loader):
            primary, primary_with_text = data 
            with torch.no_grad():
                with torch.autocast(device_type="cuda",dtype=torch.float32):
                    primary_embs = model(primary.cuda())
                    primary_with_text_embs = model(primary_with_text.cuda())
                
                all_distances.extend(compute_distances(primary_embs, primary_with_text_embs).cpu().numpy())
                all_distances_raw.extend(compute_distances(primary[:,-1,:], primary_with_text[:,-1,:]).cpu().numpy())
                
                all_primary_embs.extend(primary_embs.cpu().numpy().tolist())
                all_primary_with_text_embs.extend(primary_with_text_embs.cpu().numpy().tolist())

            batches += 1 
        
    return all_distances, all_distances_raw

In [9]:
all_distances_clean, all_distances_clean_raw = evaluation(clean_files,ACTIVATIONS_NEW_VARIATIONS_DIR)

In [10]:
all_distances_poisoned, all_distances_poisoned_raw = evaluation(poisoned_files, ACTIVATIONS_NEW_VARIATIONS_DIR)

In [11]:
### 
print(f'distances of clean baseline: {np.mean(all_distances_clean[0:500])}, with std: {np.std(all_distances_clean[0:500])}')
print(f'distances of clean with spottlighting: {np.mean(all_distances_clean[500:1000])}, with std: {np.std(all_distances_clean[500:1000])}')
print(f'distances of clean with wild chat summarization level1: {np.mean(all_distances_clean[1000:2000])}, with std: {np.std(all_distances_clean[1000:2000])}')
print(f'distances of clean with wild chat summarization level2: {np.mean(all_distances_clean[2000:3000])}, with std: {np.std(all_distances_clean[2000:3000])}')
print(f'distances of clean with wild chat summarization level3: {np.mean(all_distances_clean[3000:])}, with std: {np.std(all_distances_clean[3000:])}')

###
print('#####')
print(f'distances of poisoned with baseline: {np.mean(all_distances_poisoned[1500:2000])}, with std: {np.std(all_distances_poisoned[1500:2000])}' )

print(f'distances of poisoned with lie trigger: {np.mean(all_distances_poisoned[0:500])}, with std: {np.std(all_distances_poisoned[0:500])}')
print(f'distances of poisoned with not new trigger: {np.mean(all_distances_poisoned[500:1000])}, with std: {np.std(all_distances_poisoned[500:1000])}')
print(f'distances of poisoned with no trigger: {np.mean(all_distances_poisoned[1000:1500])}, with std: {np.std(all_distances_poisoned[1000:1500])}')
print(f'distances of poisoned with spottlighting trigger: {np.mean(all_distances_poisoned[2000:])}, with std: {np.std(all_distances_poisoned[2000:])}')
print(f'distances of poisoned with translated instructions: {np.mean(all_distances_poisoned[2500:])}, with std: {np.std(all_distances_poisoned[2500:])}')


distances of clean baseline: 0.557243287563324, with std: 0.19263112545013428
distances of clean with spottlighting: 0.5587817430496216, with std: 0.1909443736076355
distances of clean with wild chat summarization level1: 1.1579140424728394, with std: 0.3181234896183014
distances of clean with wild chat summarization level2: 1.129629373550415, with std: 0.3050914704799652
distances of clean with wild chat summarization level3: 0.5492711067199707, with std: 0.20382894575595856
#####
distances of poisoned with baseline: 1.5718518495559692, with std: 0.16047224402427673
distances of poisoned with lie trigger: 1.654449462890625, with std: 0.10929504036903381
distances of poisoned with not new trigger: 1.5763373374938965, with std: 0.1275235414505005
distances of poisoned with no trigger: 1.3557054996490479, with std: 0.1763605922460556
distances of poisoned with spottlighting trigger: 1.4645103216171265, with std: 0.31629806756973267
distances of poisoned with translated instructions: 1.20