In [1]:
from transformers import BartForConditionalGeneration, BartTokenizer
from tqdm import tqdm

import pandas as pd
import spacy
import torch
import os
import re

In [2]:
nlp = spacy.load("en_core_web_sm")
model = BartForConditionalGeneration.from_pretrained("facebook/perturber").to(device="cuda")
tokenizer = BartTokenizer.from_pretrained("facebook/perturber")

In [3]:
DIMENSION = "gender"
ENTITY_TYPE = ["PERSON", "ORG"]
TARGET = "woman"

In [None]:
RAW_PATH = "../../data/raw"
PROCESSED_PATH = "../../data/processed"
SEP = "<PERT_SEP>"

In [None]:
target_entities = pd.read_csv("../../heterogeneity/lists_for_perturbations/names_swaps.csv")

In [6]:
target_entities

Unnamed: 0,NamedEntity,GenderSwapEntity
0,Barack Obama,Hillary Clinton
1,George Bush,Sarah Palin
2,Donald Trump,Marine Le-Pen
3,Bill Clinton,Theresa May
4,Jesus Christ,Alexandria Ocasio-Cortez
5,Vladimir Putin,Angela Merkel
6,Adolf Hitler,Jacinda Ardern
7,Ronald Reagan,Margaret Thatcher
8,Ted Cruz,Paula Dobrainsky
9,Mitt Romney,Nirmala Sitharaman


In [None]:
twitter_to_entity =  pd.read_csv("../../heterogeneity/lists_for_perturbations/twitter_user_to_entity.csv").dropna()

In [8]:
twitter_to_entity

Unnamed: 0,User,UserSwap
0,@BarackObama,Barack Obama
1,@TheBushCenter,George Bush
2,@realDonaldTrump,Donald Trump
3,@BillClinton,Bill Clinton
5,@KremlinRussia_E,Vladimir Putin
7,@RonaldReagan,Ronald Reagan
8,@SenTedCruz,Ted Cruz
9,@MittRomney,Mitt Romney
10,@SenJohnMcCain,John McCain
12,@FukuyamaFrancis,Francis Fukuyama


In [9]:
def replace_entity(text, entity, df, search_column, return_column):
    # Check if the substring is present in the search column
    mask = df[search_column].str.contains(entity)

    # If a match is found, return the corresponding value from the return column
    if mask.any():
        swap_entity = df.loc[mask, return_column].iloc[0]
        regex = f"([A-Z]([a-z]+|\.)\s*)*{entity.split(' ')[-1]}"
        text = re.sub(r''+regex, swap_entity, text)
    return text

In [10]:
def preprocess_text(text):
    text_entity_to_substitute = (text, None)
    tw_users = twitter_to_entity["User"].to_list()
    for entity in tw_users:
        if entity in text:
            text = replace_entity(text, entity, twitter_to_entity, "User", "UserSwap")
    # Create Doc object
    doc2 = nlp(text)
    # print([(ent.text, ent.label_) for ent in doc2.ents])
    # Identify the entities
    entities = [ent.text for ent in doc2.ents if ent.label_ in ENTITY_TYPE]
    # print("entities: ", entities)
    if entities:
        matched_entities = []
        for entity in entities:
            if any(entity in token for token in target_entities['NamedEntity'].to_list()):
                is_substring_present = target_entities['NamedEntity'].str.contains(entity)
                entity_to_substitute = target_entities.loc[is_substring_present, 'NamedEntity'].values[0]
                if entity != entity_to_substitute:
                    text = text.replace(entity, entity_to_substitute)
                matched_entities.append((text, entity_to_substitute))
        if matched_entities:
            # print("matched_ entities: ", matched_entities[0][1])
            text_entity_to_substitute = matched_entities[0]

    return text_entity_to_substitute

In [11]:
data_paths = []
for path in os.walk(RAW_PATH):
    for file in path[2]:
        if file.endswith("test.csv") or file.endswith("train.csv"):
            data_paths.append(f"{path[0]}{os.sep}{file}")

In [12]:
def batches(sents, batch_size):
    for i in range(0, len(sents), batch_size):
        yield sents[i : i + batch_size]

In [13]:
for path in data_paths:

    path2check = f'{PROCESSED_PATH}{os.sep}{DIMENSION}{os.sep}{path.split("/")[-1]}'
    if os.path.exists(path2check):
        print(f"Already exists file: {path2check}. Skip.")
        continue

    data = pd.read_csv(path).dropna()

    print(f"Processing {path}")

    print("Truncating texts...")
    data['text'] = data.apply(
    lambda row: tokenizer.batch_decode(
        tokenizer(
            row.text,
            return_tensors="pt",
            max_length=128,
            truncation=True,
        )["input_ids"],
    skip_special_tokens=True
    )[0],
    axis = 1
    )
    print("...texts truncated!")

    print("preprocessing text and extracting entities...")
    preprocessed_text = []
    entities = []
    for row in data["text"].to_list():
        text, entity = preprocess_text(row)
        preprocessed_text.append(text)
        entities.append(entity)
    data["preprocessed_text"] = preprocessed_text
    data["entity"] = entities
    print("...text preprocessed and entities extracted!")

    data['perturber_input'] = data.apply(
        lambda row: f"{row.entity.split(' ')[0]}, {TARGET} {SEP} {row.preprocessed_text}" 
        if row.entity else row.preprocessed_text, axis=1
    )

    print("Perturbating texts...")
    perturber_input = data["perturber_input"].to_list()
    entities = data["entity"].to_list()
    with torch.no_grad():
        perturbed_texts = []
        for input, entity in zip(tqdm(perturber_input, total=len(perturber_input)), entities):
            if entity:
                tokenized_batch = tokenizer(
                    input,
                    return_tensors="pt",
                    truncation=True,
                    padding=True,
                    max_length=128
                )
                outputs = model.generate(
                    tokenized_batch["input_ids"].to(device="cuda"),
                    max_length=128,
                )
                perturbed_texts.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
            else:
                perturbed_texts.append(input)

    perturbated_substituted_entities = []
    for text, entity in zip(perturbed_texts, entities):
        if entity:
            perturbated_substituted_entities.append(replace_entity(text, entity, target_entities, "NamedEntity", "GenderSwapEntity"))
        else:
            perturbated_substituted_entities.append(text)
    data['perturbed_text'] = perturbated_substituted_entities
    data = data[["text", "perturbed_text", "entity", "labels"]]

    print("...texts perturbated!")

    if not os.path.exists(f"{PROCESSED_PATH}{os.sep}{DIMENSION}"):
        os.mkdir(f"{PROCESSED_PATH}{os.sep}{DIMENSION}")

    data_path = f"{PROCESSED_PATH}{os.sep}{DIMENSION}{os.sep}{path.split('/')[-1]}"
    data.to_csv(data_path, index=False, header=True, encoding="utf-8")
    print(f"{data_path} created")
    print("\n")

Already exists file: ../data/processed/gender/fingerprints_train.csv. Skip.
Already exists file: ../data/processed/gender/fingerprints_test.csv. Skip.
Already exists file: ../data/processed/gender/clef22_train.csv. Skip.
Already exists file: ../data/processed/gender/clef22_test.csv. Skip.
Already exists file: ../data/processed/gender/twittercovidq2_test.csv. Skip.
Already exists file: ../data/processed/gender/basil_train.csv. Skip.
Already exists file: ../data/processed/gender/clickbait_test.csv. Skip.
Already exists file: ../data/processed/gender/basil_test.csv. Skip.
Already exists file: ../data/processed/gender/politifact_train.csv. Skip.
Already exists file: ../data/processed/gender/webis_train.csv. Skip.
Already exists file: ../data/processed/gender/buzzfeed_train.csv. Skip.
Already exists file: ../data/processed/gender/twittercovidq2_train.csv. Skip.
Already exists file: ../data/processed/gender/propaganda_train.csv. Skip.
Already exists file: ../data/processed/gender/buzzfeed_te

100%|██████████| 6400/6400 [26:32<00:00,  4.02it/s] 


...texts perturbated!
../data/processed/gender/shadesoftruth_train.csv created


