In [1]:
import pandas as pd
import numpy as np
import os
import pickle
import re
from collections import Counter
from tqdm import tqdm

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config

import ilm.ilm.tokenize_util
from ilm.ilm.infer import infill_with_ilm
from perturbation_functions import calculate_necc_and_suff, gen_num_samples_table, gen_probs_table

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_DIR = '../Models/ILM/'
MASK_CLS = 'ilm.mask.hierarchical.MaskHierarchical'

tokenizer = ilm.ilm.tokenize_util.Tokenizer.GPT2
with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
    additional_ids_to_tokens = pickle.load(f)
additional_tokens_to_ids = {v:k for k, v in additional_ids_to_tokens.items()}
try:
    ilm.ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
except ValueError:
    print('Already updated')
print(additional_tokens_to_ids)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
config = GPT2Config.from_json_file(f'{MODEL_DIR}config.json')
model = GPT2LMHeadModel(config)
# Charger les poids
model.load_state_dict(torch.load(f'{MODEL_DIR}model_weights.pt'))
model.eval()
_ = model.to(device)



{'<|startofinfill|>': 50257, '<|endofinfill|>': 50258, '<|infill_document|>': 50259, '<|infill_paragraph|>': 50260, '<|infill_sentence|>': 50261, '<|infill_ngram|>': 50262, '<|infill_word|>': 50263}


In [3]:
test_suite_cases = pd.read_csv("Data/hatecheck-data/test_suite_cases.csv", index_col="case_id") #.drop(columns=['Unnamed:0'])
test_suite_cases.drop(columns=['Unnamed: 0'], inplace=True)
target_ds = ['women', 'Muslims']
funcs = ['derog_neg_emote_h', 'derog_neg_attrib_h', 'derog_dehum_h']

test_suite_cases = test_suite_cases[test_suite_cases.target_ident.isin(target_ds) & 
                                    test_suite_cases.functionality.isin(funcs)]
tts = [text for _, text in test_suite_cases.test_case.items()]

with open("Data/HateCheck_test_suite_cases.txt", "w") as f:
    f.write("\n".join(tts))

In [4]:
# generate approximately 100 perturbations for each token. 
num_samples = gen_num_samples_table(20, 100)
probs_table = gen_probs_table(20)
mask_tokn = additional_tokens_to_ids['<|infill_ngram|>']

orig_texts = []
necc_perturbed = []
suff_perturbed = []
necc_masks = []
suff_masks = []

with open("Data/HateCheck_test_suite_cases.txt", "r") as ff:
    with tqdm(total=120) as pbar:
        for text in ff:
            necc_pp, suff_pp, necc_mm, suff_mm = calculate_necc_and_suff(text, ilm_tokenizer=tokenizer, ilm_model=model, cl_tokenizer=None, cl_model=None, num_samples=num_samples,
                               mask_tokn=mask_tokn, additional_tokens_to_ids=additional_tokens_to_ids, probs_table=probs_table, 
                               return_pert_only=True)

            orig_texts.append(text)
            necc_perturbed.append(necc_pp)
            suff_perturbed.append(suff_pp)
            necc_masks.append(necc_mm)
            suff_masks.append(suff_mm)
            pbar.update(1)
    
necc_suff_perturbations = {'orig_texts': orig_texts, 
                           'necc_perturbed': necc_perturbed, 
                           'suff_perturbed': suff_perturbed,
                           'necc_masks': necc_masks,
                           'suff_masks': suff_masks}

pickle.dump(necc_suff_perturbations, open('Data/HateCheck_necc_suff_perturbations.pickle', 'wb'))

  0%|          | 0/120 [00:00<?, ?it/s]

100%|██████████| 120/120 [1:28:21<00:00, 44.18s/it]


In [5]:
pickle.dump(necc_suff_perturbations, open('Data/HateCheck_necc_suff_perturbations.pickle', 'wb'))

In [7]:
with open("Data/hatecheck_perturbations/orig_texts.txt", "w") as ff:
    ff.write("\n".join(necc_suff_perturbations['orig_texts']))

In [8]:
with open("Data/hatecheck_perturbations/necc_perturbations.tsv", "w") as ff:
    for ll in necc_suff_perturbations['necc_perturbed']:
        ff.write("\t".join(ll))
        ff.write("\n")

In [9]:
with open("Data/hatecheck_perturbations/suff_perturbations.tsv", "w") as ff:
    for ll in necc_suff_perturbations['suff_perturbed']:
        ff.write("\t".join(ll))
        ff.write("\n")

In [10]:
with open("Data/hatecheck_perturbations/necc_masks.tsv", "w") as ff: 
    for ll in necc_suff_perturbations['necc_masks']:
        llist = ll.astype(int).astype(str).tolist()
        ff.write("\t".join([" ".join(ii) for ii in llist]))
        ff.write("\n")

In [11]:
with open("Data/hatecheck_perturbations/suff_masks.tsv", "w") as ff: 
    for ll in necc_suff_perturbations['suff_masks']:
        llist = ll.astype(int).astype(str).tolist()
        ff.write("\t".join([" ".join(ii) for ii in llist]))
        ff.write("\n")