In [1]:
from transformers import BartTokenizer
from tqdm import tqdm

import pandas as pd
import spacy
import os
import re

In [2]:
nlp = spacy.load("en_core_web_sm")
tokenizer = BartTokenizer.from_pretrained("facebook/perturber")

In [3]:
dimensions = ["nationality", "country", "religion"]
entity_types = ["NORP", "GPE"]

In [None]:
RAW_PATH = "../../data/raw"
PROCESSED_PATH = "../../data/processed"

In [None]:
def preprocess_text(text, target_entities):
    target_key = target_entities.keys()[0]

    entity_to_substitute = None
    # Create Doc object
    doc2 = nlp(text)
    # Identify the entities
    entities = [ent.text for ent in doc2.ents if ent.label_ in entity_types]
    if entities:
        matched_entities = []
        for entity in entities:
            if any(entity.lower() in token.lower() for token in target_entities[target_key].to_list()):
                matched_entities.append(entity)
        if matched_entities:
            entity_to_substitute = matched_entities[0]
    return entity_to_substitute

In [6]:
def replace_entity(text, entity, df, search_column, return_column):
    # Check if the substring is present in the search column
    mask = df[search_column].str.contains(entity.capitalize())

    # If a match is found, return the corresponding value from the return column
    if mask.any():
        swap_entity = df.loc[mask, return_column].iloc[0]
        regex = f"([A-Z]([a-z]+|\.)\s*)*{entity.split(' ')[-1]}"
        text = re.sub(r''+regex, swap_entity, text)
    return text

In [7]:
data_paths = []
for path in os.walk(RAW_PATH):
    for file in path[2]:
        if file.endswith("test.csv") or file.endswith("train.csv"):
            data_paths.append(f"{path[0]}{os.sep}{file}")

In [8]:
def batches(sents, batch_size):
    for i in range(0, len(sents), batch_size):
        yield sents[i : i + batch_size]

In [None]:
for dimension in dimensions:
    target_entities = pd.read_csv(f"../../heterogeneity/lists_for_perturbations/{dimension}_swaps.csv")
    for path in data_paths:

        path2check = f'{PROCESSED_PATH}{os.sep}{dimension}{os.sep}{path.split("/")[-1]}'
        if os.path.exists(path2check):
            print(f"Already exists file: {path2check}. Skip.")
            continue

        data = pd.read_csv(path).dropna()

        print(f"Processing {path}")

        print("Truncating texts...")
        data['text'] = data.apply(
        lambda row: tokenizer.batch_decode(
            tokenizer(
                row.text,
                return_tensors="pt",
                max_length=128,
                truncation=True,
            )["input_ids"],
        skip_special_tokens=True
        )[0],
        axis = 1
        )
        print("...texts truncated!")

        print("Etracting entities...")
        entities = []
        for row in data["text"].to_list():
            entity = preprocess_text(row, target_entities)
            entities.append(entity)
        data["entity"] = entities
        print("...entities extracted!")

        print("Perturbating texts...")
        text_input = data["text"].to_list()
        entities = data["entity"].to_list()
        perturbed_texts = []
        for input, entity in zip(tqdm(text_input, total=len(text_input)), entities):
            if entity:
                perturbed_texts.append(replace_entity(input, entity, target_entities, target_entities.keys()[0], target_entities.keys()[1]))
            else:
                perturbed_texts.append(input)
        data['perturbed_text'] = perturbed_texts
        data = data[["text", "perturbed_text", "entity", "labels"]]

        print("...texts perturbated!")

        if not os.path.exists(f"{PROCESSED_PATH}{os.sep}{dimension}"):
            os.mkdir(f"{PROCESSED_PATH}{os.sep}{dimension}")

        data_path = f"{PROCESSED_PATH}{os.sep}{dimension}{os.sep}{path.split('/')[-1]}"
        data.to_csv(data_path, index=False, header=True, encoding="utf-8")
        print(f"{data_path} created")
        print("\n")