In [None]:
import os
os.chdir('..')
os.getcwd()

In [15]:
from datasets import load_from_disk, Dataset, DatasetDict
from collections import Counter, defaultdict
from src.produce_metrics import raw_harmcategorization_to_cat_str

In [4]:
ds = load_from_disk('data/prompt/wg_non_adv_prompts_l3_64_ds_model-gpt-4o_harmcat_prompt_e7+1_gr')
ds

Dataset({
    features: ['prompt', 'prompt_harm_categorization_response'],
    num_rows: 23511
})

In [12]:
counter = Counter()
cat2prompt = defaultdict(list)
for ex in ds:
    raw_gpt_response = ex['prompt_harm_categorization_response']
    cat_str = raw_harmcategorization_to_cat_str(raw_gpt_response)
    counter[cat_str] += 1
    cat2prompt[cat_str].append(ex['prompt'])
counter

Counter({'--': 11320,
         '-Discrimination and Verbal Abuse-': 2387,
         '-Other Harms-': 1185,
         '-Financial Crime and Theft-Privacy Violations-': 1106,
         '-Privacy Violations-': 1032,
         '-Financial Crime and Theft-': 950,
         '-Discrimination and Verbal Abuse-Other Harms-': 841,
         '-Violence-Other Harms-': 537,
         'INVALID': 518,
         '-Sexual Misconduct, Exploitation, and Infidelity-': 344,
         '-Violence-': 303,
         '-Illegal Drug-Related Activities and Substance Abuse-': 281,
         '-Violence-Weapons, Explosives, Arson, and Illegal Firearm Transactions-': 255,
         '-Discrimination and Verbal Abuse-Violence-Other Harms-': 229,
         '-Financial Crime and Theft-Privacy Violations-Other Harms-': 206,
         '-Privacy Violations-Other Harms-': 182,
         '-Violence-Weapons, Explosives, Arson, and Illegal Firearm Transactions-Other Harms-': 147,
         '-Financial Crime and Theft-Other Harms-': 133,
      

In [14]:
import random
cat2prompt_subsample = {k: random.sample(v, min(len(v), 2500)) for k, v in cat2prompt.items()}
print(f'total number of prompts: {sum(len(v) for v in cat2prompt_subsample.values())}')

total number of prompts: 14691


In [22]:
ddict = DatasetDict({k: Dataset.from_dict({'prompt': v}) for k, v in cat2prompt_subsample.items() if not (k in ['-Financial Crime and Theft-Discrimination and Verbal Abuse-Violence-Illegal Drug-Related Activities and Substance Abuse-Privacy Violations-Sexual Misconduct, Exploitation, and Infidelity-Weapons, Explosives, Arson, and Illegal Firearm Transactions-Other Harms-', 'INVALID']) and len(v) > 10})
ddict

DatasetDict({
    -Financial Crime and Theft-: Dataset({
        features: ['prompt'],
        num_rows: 950
    })
    -Financial Crime and Theft-Privacy Violations-: Dataset({
        features: ['prompt'],
        num_rows: 1106
    })
    --: Dataset({
        features: ['prompt'],
        num_rows: 2500
    })
    -Other Harms-: Dataset({
        features: ['prompt'],
        num_rows: 1185
    })
    -Discrimination and Verbal Abuse-: Dataset({
        features: ['prompt'],
        num_rows: 2387
    })
    -Illegal Drug-Related Activities and Substance Abuse-: Dataset({
        features: ['prompt'],
        num_rows: 281
    })
    -Privacy Violations-: Dataset({
        features: ['prompt'],
        num_rows: 1032
    })
    -Discrimination and Verbal Abuse-Other Harms-: Dataset({
        features: ['prompt'],
        num_rows: 841
    })
    -Sexual Misconduct, Exploitation, and Infidelity-: Dataset({
        features: ['prompt'],
        num_rows: 344
    })
    -Privacy Viola

In [23]:
save_path = 'data/by_spec/wg_non_adv_prompts/gpt-4o_label'
ddict.save_to_disk(save_path)

Saving the dataset (1/1 shards): 100%|██████████| 950/950 [00:00<00:00, 193737.00 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1106/1106 [00:00<00:00, 247672.20 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2500/2500 [00:00<00:00, 512675.89 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1185/1185 [00:00<00:00, 286304.74 examples/s]


Saving the dataset (1/1 shards): 100%|██████████| 2387/2387 [00:00<00:00, 514454.74 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 281/281 [00:00<00:00, 68819.31 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1032/1032 [00:00<00:00, 231769.21 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 841/841 [00:00<00:00, 195446.01 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 344/344 [00:00<00:00, 85725.18 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 182/182 [00:00<00:00, 44119.95 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 537/537 [00:00<00:00, 85096.77 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 38/38 [00:00<00:00, 10292.77 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 12/12 [00:00<00:00, 3175.50 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 68/68 [00:00<00:00, 18702.47 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 206/206 [00:00<00:0