In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer
from tqdm import tqdm

import pandas as pd
import random
import torch
import os

In [2]:
model = BartForConditionalGeneration.from_pretrained("facebook/perturber").to(device="cuda")
tokenizer = BartTokenizer.from_pretrained("facebook/perturber")

In [3]:
DIMENSION = "race"

In [None]:
RAW_PATH = "../../data/raw"
PROCESSED_PATH = "../../data/processed"
SEP = "<PERT_SEP>"

In [None]:
target_entities = pd.read_csv("../../heterogeneity/lists_for_perturbations/race_swaps.csv")

In [6]:
target_entities

Unnamed: 0,Race,RaceSwap
0,white,hispanic
1,black,asian


In [111]:
def preprocess_text(text):
    text_entity_to_substitute = (text, None)
    entities = target_entities['Race'].to_list()
    random.shuffle(entities)
    for race in entities:
        if race in text.split(" "):
            substitute = target_entities.loc[target_entities['Race'] == race, 'RaceSwap'].iloc[0]
            text_entity_to_substitute = (text, substitute)
            return text_entity_to_substitute

    return text_entity_to_substitute

In [112]:
data_paths = []
for path in os.walk(RAW_PATH):
    for file in path[2]:
        if file.endswith("test.csv") or file.endswith("train.csv"):
            data_paths.append(f"{path[0]}{os.sep}{file}")

In [113]:
def batches(sents, batch_size):
    for i in range(0, len(sents), batch_size):
        yield sents[i : i + batch_size]

In [115]:
for path in data_paths:

    path2check = f'{PROCESSED_PATH}{os.sep}{DIMENSION}{os.sep}{path.split("/")[-1]}'
    if os.path.exists(path2check):
        print(f"Already exists file: {path2check}. Skip.")
        continue

    data = pd.read_csv(path).dropna()

    print(f"Processing {path}")

    print("Truncating texts...")
    data['text'] = data.apply(
    lambda row: tokenizer.batch_decode(
        tokenizer(
            row.text,
            return_tensors="pt",
            max_length=128,
            truncation=True,
        )["input_ids"],
    skip_special_tokens=True
    )[0],
    axis = 1
    )
    print("...texts truncated!")

    print("preprocessing text and extracting entities...")
    preprocessed_text = []
    entities = []
    for row in data["text"].to_list():
        text, entity = preprocess_text(row)
        preprocessed_text.append(text)
        entities.append(entity)
    data["preprocessed_text"] = preprocessed_text
    data["entity"] = entities
    print("...text preprocessed and entities extracted!")

    data['perturber_input'] = data.apply(
        lambda row: f"{target_entities.loc[target_entities['RaceSwap'] == row.entity, 'Race'].iloc[0]}, {row.entity} {SEP} {row.preprocessed_text}" 
        if row.entity else row.preprocessed_text, axis=1
    )

    print("Perturbating texts...")
    perturber_input = data["perturber_input"].to_list()
    entities = data["entity"].to_list()
    with torch.no_grad():
        perturbed_texts = []
        for input, entity in zip(tqdm(perturber_input, total=len(perturber_input)), entities):
            if entity:
                tokenized_batch = tokenizer(
                    input,
                    return_tensors="pt",
                    truncation=True,
                    padding=True,
                    max_length=128
                )
                outputs = model.generate(
                    tokenized_batch["input_ids"].to(device="cuda"),
                    max_length=128,
                )
                perturbed_texts.extend(tokenizer.batch_decode(outputs, skip_special_tokens=True))
            else:
                perturbed_texts.append(input)

    data['perturbed_text'] = perturbed_texts
    data = data[["text", "perturbed_text", "entity", "labels"]]

    print("...texts perturbated!")

    if not os.path.exists(f"{PROCESSED_PATH}{os.sep}{DIMENSION}"):
        os.mkdir(f"{PROCESSED_PATH}{os.sep}{DIMENSION}")

    data_path = f"{PROCESSED_PATH}{os.sep}{DIMENSION}{os.sep}{path.split('/')[-1]}"
    data.to_csv(data_path, index=False, header=True, encoding="utf-8")
    print(f"{data_path} created")
    print("\n")

Processing ../data/raw/fingerprints copy/fingerprints_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 13200/13200 [05:13<00:00, 42.04it/s]


...texts perturbated!
../data/processed/race/fingerprints_train.csv created


Processing ../data/raw/fingerprints copy/fingerprints_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 3300/3300 [01:25<00:00, 38.39it/s]


...texts perturbated!
../data/processed/race/fingerprints_test.csv created


Processing ../data/raw/clef22 copy/clef22_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 720/720 [00:14<00:00, 48.16it/s]


...texts perturbated!
../data/processed/race/clef22_train.csv created


Processing ../data/raw/clef22 copy/clef22_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 180/180 [00:04<00:00, 38.60it/s]


...texts perturbated!
../data/processed/race/clef22_test.csv created


Processing ../data/raw/unifiedm2/twittercovidq2_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 52/52 [00:00<00:00, 413859.22it/s]


...texts perturbated!
../data/processed/race/twittercovidq2_test.csv created


Processing ../data/raw/unifiedm2/basil_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 6367/6367 [00:14<00:00, 451.82it/s] 


...texts perturbated!
../data/processed/race/basil_train.csv created


Processing ../data/raw/unifiedm2/clickbait_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 3812/3812 [00:02<00:00, 1421.97it/s]


...texts perturbated!
../data/processed/race/clickbait_test.csv created


Processing ../data/raw/unifiedm2/basil_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 1592/1592 [00:03<00:00, 527.35it/s]


...texts perturbated!
../data/processed/race/basil_test.csv created


Processing ../data/raw/unifiedm2/politifact_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 162/162 [00:15<00:00, 10.40it/s]


...texts perturbated!
../data/processed/race/politifact_train.csv created


Processing ../data/raw/unifiedm2/webis_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 1283/1283 [02:06<00:00, 10.11it/s]


...texts perturbated!
../data/processed/race/webis_train.csv created


Processing ../data/raw/unifiedm2/buzzfeed_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 136/136 [00:00<00:00, 974.14it/s]


...texts perturbated!
../data/processed/race/buzzfeed_train.csv created


Processing ../data/raw/unifiedm2/twittercovidq2_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 208/208 [00:00<00:00, 496.63it/s]


...texts perturbated!
../data/processed/race/twittercovidq2_train.csv created


Processing ../data/raw/unifiedm2/propaganda_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 1294/1294 [00:04<00:00, 285.92it/s]


...texts perturbated!
../data/processed/race/propaganda_train.csv created


Processing ../data/raw/unifiedm2/buzzfeed_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 34/34 [00:00<00:00, 620027.55it/s]


...texts perturbated!
../data/processed/race/buzzfeed_test.csv created


Processing ../data/raw/unifiedm2/propaganda_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 324/324 [00:01<00:00, 194.83it/s]


...texts perturbated!
../data/processed/race/propaganda_test.csv created


Processing ../data/raw/unifiedm2/pheme_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 1364/1364 [00:04<00:00, 285.48it/s]


...texts perturbated!
../data/processed/race/pheme_train.csv created


Processing ../data/raw/unifiedm2/pheme_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 341/341 [00:01<00:00, 269.91it/s]


...texts perturbated!
../data/processed/race/pheme_test.csv created


Processing ../data/raw/unifiedm2/webis_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 321/321 [00:42<00:00,  7.50it/s]


...texts perturbated!
../data/processed/race/webis_test.csv created


Processing ../data/raw/unifiedm2/clickbait_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 15244/15244 [00:13<00:00, 1169.81it/s]


...texts perturbated!
../data/processed/race/clickbait_train.csv created


Processing ../data/raw/unifiedm2/politifact_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 40/40 [00:01<00:00, 32.59it/s]


...texts perturbated!
../data/processed/race/politifact_test.csv created


Processing ../data/raw/shadesoftruth copy/shadesoftruth_test.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 1600/1600 [00:51<00:00, 30.79it/s]


...texts perturbated!
../data/processed/race/shadesoftruth_test.csv created


Processing ../data/raw/shadesoftruth copy/shadesoftruth_train.csv
Truncating texts...
...texts truncated!
preprocessing text and extracting entities...
...text preprocessed and entities extracted!
Perturbating texts...


100%|██████████| 6400/6400 [03:18<00:00, 32.30it/s]

...texts perturbated!
../data/processed/race/shadesoftruth_train.csv created





