In [2]:
from collections import Counter
import datasets
from datasets import concatenate_datasets


def get_balanced_subset(dataset, label_column="label", target_count=-1, seed=42,):
    labels = dataset[label_column]
    label_counts = Counter(labels)
    print(f"Initial label counts: {label_counts}")

    # Determine the maximum number of samples per class
    # Give yourself 2% margin
    target_count = min(label_counts.values()) if target_count == -1 else min(min(label_counts.values()), target_count)
    # target_count = int(np.ceil(1.02 * target_count))
    
    # Build balanced subsets for each class
    subsets = []
    for label_value in label_counts:
        # Filter for the current class, shuffle, then select target_count samples
        class_subset = (
            dataset
            .filter(lambda example, lv=label_value: example[label_column] == lv)
            .shuffle(seed=seed)
            .select(range(target_count))
        )
        subsets.append(class_subset)

    # Concatenate and shuffle the balanced subsets
    balanced = concatenate_datasets(subsets)
    balanced = balanced.shuffle(seed=seed)

    return balanced

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
d = datasets.load_dataset("tomg-group-umd/compliance", "compliance", split="train_orig_cot")
print(len(d))
# d.select(range(2000)).to_json("data/train.jsonl")

Generating test_handcrafted split: 100%|██████████| 199/199 [00:00<00:00, 2023.81 examples/s]
Generating train split: 100%|██████████| 17930/17930 [00:00<00:00, 81583.89 examples/s]
Generating train_cot split: 100%|██████████| 17930/17930 [00:00<00:00, 55602.57 examples/s]
Generating test split: 100%|██████████| 98/98 [00:00<00:00, 14373.60 examples/s]
Generating train_orig_cot split: 100%|██████████| 6247/6247 [00:00<00:00, 52840.84 examples/s]

6247





In [None]:
# Safetey datasets
s1 = datasets.load_dataset("tomg-group-umd/compliance", "aegis", split="train_harm")
s2 = datasets.load_dataset("tomg-group-umd/compliance", "aegis", split="train_refusal_cot")
s3 = datasets.load_dataset("tomg-group-umd/compliance", "beavertails", split="train_harm_cot")
s4 = datasets.load_dataset("tomg-group-umd/compliance", "beavertails", split="train_refusal")
s5 = datasets.load_dataset("tomg-group-umd/compliance", "toxicchat", split="train_jailbreaking")
s6 = datasets.load_dataset("tomg-group-umd/compliance", "toxicchat", split="train_refusal_cot")
s7 = datasets.load_dataset("tomg-group-umd/compliance", "wildguard", split="train_harm_cot")
s8 = datasets.load_dataset("tomg-group-umd/compliance", "wildguard", split="train_jailbreaking")
s9 = datasets.load_dataset("tomg-group-umd/compliance", "wildguard", split="train_refusal")

s = [s1, s2, s3, s4, s5, s6, s7, s8, s9]
for i in range(len(s)):
    s[i] = get_balanced_subset(s[i]).shuffle(seed=42).select(range(2000))
for i in range(len(s)):
    print(len(s[i]))

# Compliance datasets
d1 = datasets.load_dataset("tomg-group-umd/compliance", "compliance", split="train_cot")
d2 = datasets.load_dataset("tomg-group-umd/compliance", "compliance", split="train")
d1 = d1.select(range(6000)).shuffle(seed=42)
d2 = d2.select(range(6000, 17930)).shuffle(seed=42)
print(len(d1))
print(len(d2))
t1 = get_balanced_subset(datasets.concatenate_datasets([d1.select(range(185)), d2.select(range(315))]).shuffle(seed=42)).to_json("data/train_500.jsonl")     
t2 = get_balanced_subset(datasets.concatenate_datasets([d1.select(range(375)), d2.select(range(625))]).shuffle(seed=42)).to_json("data/train_1000.jsonl")    
t3 = get_balanced_subset(datasets.concatenate_datasets([d1.select(range(750)), d2.select(range(1250))]).shuffle(seed=42)).to_json("data/train_2000.jsonl")   
t4 = get_balanced_subset(datasets.concatenate_datasets([d1.select(range(1500)), d2.select(range(2500))]).shuffle(seed=42)).to_json("data/train_4000.jsonl")  
t5 = get_balanced_subset(datasets.concatenate_datasets([d1.select(range(3000)), d2.select(range(5000))]).shuffle(seed=42)).to_json("data/train_8000.jsonl")  
t6 = get_balanced_subset(datasets.concatenate_datasets([d1.select(range(6000)), d2.select(range(10000))]).shuffle(seed=42)).to_json("data/train_16000.jsonl")  

t = get_balanced_subset(datasets.concatenate_datasets([d1, d2]).shuffle(seed=42))

# Mix dataset
final = datasets.concatenate_datasets(s + [t]).shuffle(seed=42)
print(len(final))
final.to_json("data/train_32000_mix.jsonl")


Initial label counts: Counter({'PASS': 11693, 'FAIL': 3541})
Initial label counts: Counter({'PASS': 1907, 'FAIL': 1751})
Initial label counts: Counter({'FAIL': 16013, 'PASS': 11153})
Initial label counts: Counter({'FAIL': 14552, 'PASS': 12614})
Initial label counts: Counter({'FAIL': 1469, 'PASS': 1395})
Initial label counts: Counter({'FAIL': 1455, 'PASS': 1409})
Initial label counts: Counter({'PASS': 29027, 'FAIL': 7786})
Initial label counts: Counter({'FAIL': 19015, 'PASS': 18934})
Initial label counts: Counter({'PASS': 19101, 'FAIL': 18848})
2000
2000
2000
2000
2000
2000
2000
2000
2000
6000
11930
Initial label counts: Counter({'PASS': 9875, 'FAIL': 8055})
34110


Creating json from Arrow format: 100%|██████████| 35/35 [00:00<00:00, 41.30ba/s]


90521522