In [2]:
import pandas as pd
import numpy as np
from datasets import load_dataset, Dataset, load_from_disk


refusal_ds = pd.read_json(path_or_buf='beavertails_with_refusals_train.json')
defense_data = Dataset.from_pandas(refusal_ds)

attack_data = load_from_disk('BeaverTails_disjoint_attack')

testdata = load_dataset("/home/bw1822/representation-noising/BeaverTails", split="330k_test")



In [3]:
defense_prompt = set(defense_data['prompt'])
attack_prompt = set(attack_data['prompt'])

assert len(defense_prompt.intersection(attack_prompt)) == 0

In [4]:
testdata = testdata.filter(lambda x: (x['prompt'] not in attack_prompt) and (x['prompt'] not in defense_prompt))
testdata = testdata.filter(lambda x: x['is_safe'] == False)
# save the testdata
testdata.save_to_disk('BeaverTails_disjoint_test')

Saving the dataset (1/1 shards): 100%|██████████| 1645/1645 [00:00<00:00, 68026.92 examples/s]


: 

In [3]:
import pandas as pd
import numpy as np
from datasets import load_dataset, Dataset, load_from_disk, concatenate_datasets

attack_data1 = load_dataset("/home/bw1822/representation-noising/BeaverTails", split='330k_train')
attack_data2 = load_dataset("/home/bw1822/representation-noising/BeaverTails", split='330k_test')
attack_data = concatenate_datasets([attack_data1, attack_data2])
attack_data


Dataset({
    features: ['prompt', 'response', 'category', 'is_safe'],
    num_rows: 333963
})

In [12]:
# Filter the attack dataset. Pick the examples with the majority 'is_safe' label is false.
from collections import defaultdict
from datasets import Dataset
from tqdm import tqdm

# Grouping examples by 'prompt'
grouped_examples = defaultdict(list)

for example in attack_data:
    grouped_examples[example['prompt']].append(example)

# Create a new list to store filtered examples
filtered_examples = []

# Iterate over grouped examples
for prompt, examples in tqdm(grouped_examples.items()):
    # Count the number of False in 'is_safe' field
    count_false = sum(1 for ex in examples if not ex['is_safe'])
    count_true = len(examples) - count_false

    # If most of 'is_safe' are False, keep this group
    if count_false > count_true:
        filtered_examples.extend(examples)
    if count_true > count_false and count_false != 0:
        print("detected false positives")

# Convert the filtered examples back to a Dataset
filtered_attack_dataset = Dataset.from_list(filtered_examples)
filtered_attack_dataset

100%|██████████| 16195/16195 [00:00<00:00, 92258.19it/s]


detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives
detected false positives


Dataset({
    features: ['prompt', 'response', 'category', 'is_safe'],
    num_rows: 181029
})

In [15]:
# Exclude the examples that appeared in the train data
refusal_ds = pd.read_json(path_or_buf='beavertails_with_refusals_train.json')
defence_data = Dataset.from_pandas(refusal_ds)
defence_prompt = set(defence_data['prompt'])

In [16]:
filtered_attack_dataset = filtered_attack_dataset.filter(lambda x: x['prompt'] not in defence_prompt)

Filter:   0%|          | 0/181029 [00:00<?, ? examples/s]

In [19]:
filtered_attack_dataset
filtered_attack_dataset.save_to_disk("Beavertails_disjoint_attack")

Saving the dataset (0/1 shards):   0%|          | 0/89037 [00:00<?, ? examples/s]

In [6]:
from transformers import pipeline
from datasets import load_from_disk, load_dataset, Dataset
import torch
from tqdm import tqdm

seen_prompts = set()
filtered_attack_dataset = load_from_disk("BeaverTails_disjoint_attack")
# List to store unique examples
unique_examples = []
prompt_harmfulness_map = {}
harmfulness_pipe = pipeline('text-classification', model='/scratch/gpfs/bw1822/nlp_checkpoints/deberta-v3-xsmall-beavertails-harmful-qa-classifier', torch_dtype=torch.bfloat16, device='cuda')
# Iterate over the dataset and filter out duplicates
for example in tqdm(filtered_attack_dataset):
    prompt = example['prompt']
    response = example['response']
    full_output = f"Question: {prompt}\nAnswer: {response}"

    # Get the harmfulness score
    out = harmfulness_pipe(full_output, return_all_scores=True)
    harm_score = [
        label['score'] for label in out[0] if label['label'] == 'LABEL_0'
    ][0]

    # If the prompt is new or has a lower harmfulness score, update it
    if prompt not in prompt_harmfulness_map or harm_score > prompt_harmfulness_map[prompt]['harm_score']:
        prompt_harmfulness_map[prompt] = {
            'example': example,
            'harm_score': harm_score
        }

# Extract the examples with the highest harmfulness score for each unique prompt
unique_examples = [entry['example'] for entry in prompt_harmfulness_map.values()]

unique_dataset = Dataset.from_dict({key: [example[key] for example in unique_examples] for key in unique_examples[0].keys()})

# Save the unique examples dataset to disk
output_path = "filtered_unique_examples"
unique_dataset.save_to_disk(output_path)

100%|██████████| 89037/89037 [13:31<00:00, 109.74it/s]


Saving the dataset (0/1 shards):   0%|          | 0/4986 [00:00<?, ? examples/s]