## Data Augmentation: Word Replacement

In [None]:
import os
import sys

# Needed to import modules from parent directory
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [1]:
from utils.database import *
from utils.files import *
from tqdm import tqdm
from bson import ObjectId
import pandas as pd 
import numpy as np

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForMaskedLM
from datasets import load_from_disk, Dataset, ClassLabel, Value, Features
from huggingface_hub import InferenceClient
from transformers import BertTokenizer
import matplotlib.pyplot as plt
from utils.preprocessing import *
from utils.accelerators import *
from utils.multithreading import *
from utils.database import *
from utils.model import *
from utils.files import *
from datasets import Dataset
from tqdm import tqdm
import statistics
import hashlib
import random
import time
import math
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
topic = "cannabis" #"energie" #"kinder" "cannabis"

## Get Predictions

### Load Model

In [3]:
MODEL_NAME = "deepset/gbert-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).eval()

Some weights of the model checkpoint at deepset/gbert-large were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Using 2 GPUs!


### Load Dataset

In [5]:
#dataset = load_from_disk(f"../../data/tmp/processed_dataset_buff_{topic}_split_chunkified")
dataset = load_from_disk(f"../../data/tmp/processed_dataset_{topic}_buffed_chunkified_random")

dataset

DatasetDict({
    train: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 2651
    })
    test: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 295
    })
})

## Generate new Training Examples

### Test on an Example

In [6]:
def randomly_replace_tokens(text, tokenizer, model, mask_probability=0.15):
    """ Randomly mask input tokens and predict the missing ones with a model. """

    # Tokenize the input text and prepare it for the model: Convert the text to input IDs, 
    # generate attention masks (to ignore padding in the attention mechanism), and ensure 
    # all inputs are of the same length by padding shorter texts and truncating longer ones.
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', add_special_tokens=True)
    input_ids = inputs.input_ids.clone()
    attention_mask = inputs.attention_mask
    replaced_input_ids = input_ids.clone()

    # Generate a random array of the same shape as input_ids. This will be used to decide
    # which tokens to mask based on the mask_probability. Tokens corresponding to 'True' in
    # this array will be considered for masking.
    rand = torch.rand(input_ids.shape)
    mask_arr = (rand < mask_probability) * (input_ids != tokenizer.cls_token_id) * \
               (input_ids != tokenizer.sep_token_id) * (input_ids != tokenizer.pad_token_id)

    # Replace selected tokens with the mask token ID in input_ids.
    selection = mask_arr.nonzero(as_tuple=False)[:, 1].tolist()
    input_ids[0, selection] = tokenizer.mask_token_id

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    
    predictions = outputs.logits
    
    print("Selection: ", selection)
    # For each token position that was masked, find the token with the highest score (most likely token)
    # and replace the masked token with it in replaced_input_ids.
    for i in selection:
        # Get all predictions for this token and sort predictions by likelihood
        all_predictions = predictions[0, i]
        sorted_predictions = torch.argsort(all_predictions, descending=True) 
        
        for pred_i in sorted_predictions:
            # If the predicted token is different from the original, use it
            
            print("Pred_i: ", pred_i)
            print("Replaced_input_ids: ", replaced_input_ids[0, i])
            if pred_i != replaced_input_ids[0, i]:
                replaced_input_ids[0, i] = pred_i
                break  # Exit the loop once a different token is found

    # Decode the replaced tokens back to a string, skipping special tokens
    replaced_text = tokenizer.decode(replaced_input_ids[0], skip_special_tokens=True)
    return replaced_text

In [7]:
def randomly_replace_tokens(text, tokenizer, model, mask_probability=0.15):
    """Elegantly replace tokens one by one, each with full context."""

    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', add_special_tokens=True)
    input_ids = inputs.input_ids.clone()
    attention_mask = inputs.attention_mask

    # Identify non-special tokens for potential masking
    non_special_token_indices = [i for i, token_id in enumerate(input_ids[0])
                                 if token_id not in (tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id)]
    
    # Randomly select tokens for masking
    num_tokens_to_mask = int(len(non_special_token_indices) * mask_probability)
    tokens_to_mask = np.random.choice(non_special_token_indices, size=num_tokens_to_mask, replace=False)

    for i in tokens_to_mask:
        original_token_id = input_ids[0, i].item()  # Save the original token ID
        masked_input_ids = input_ids.detach().clone()
        masked_input_ids[0, i] = tokenizer.mask_token_id  # Mask the token

        with torch.no_grad():
            outputs = model(masked_input_ids, attention_mask=attention_mask)

        predictions = outputs.logits[0, i]
        predictions[original_token_id] = -float('Inf')  # Invalidate the original token
        best_pred_idx = predictions.argmax(dim=-1).item()
        input_ids[0, i] = best_pred_idx  # Replace with the best prediction

    replaced_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return replaced_text

In [8]:
# Example usage
text = "Das hier ist ein Test."
replaced_text = randomly_replace_tokens(text, tokenizer, model, mask_probability=0.35)
print("Original text:", text)
print("Replaced text:", replaced_text)

Original text: Das hier ist ein Test.
Replaced text: Das hier ist ein Spiel!


## Iterate over Training Dataset

In [9]:
from sklearn.model_selection import train_test_split

In [10]:
# Filter positive examples and sample 20 percent of the positive exampl
positive_examples = dataset['train'].filter(lambda example: example['label'] == 1)

# Select the first 20% of the shuffled positive examples as your random sample
positive_examples_shuffled = positive_examples.shuffle(seed=42)
num_samples = int(len(positive_examples_shuffled) * 0.01) 
sampled_examples = positive_examples_shuffled.select(range(num_samples))

# Generate new data points for the sampled positive examples
dataset[f'positive_sampled'] = sampled_examples
dataset

DatasetDict({
    train: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 2651
    })
    test: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 295
    })
    positive_sampled: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 13
    })
})

In [11]:
def generate_new_data_points(text, n_examples=1):
    """Generates n new data points from the original text."""
    new_texts = [randomly_replace_tokens(text, tokenizer, model, 0.35) for _ in range(n_examples)]
    return new_texts

In [12]:
for n in [1, 2, 3, 4, 5]:
    print(f"Generating {n} new examples for each original example...")
    
    # Placeholder for the expanded dataset
    expanded_examples = []

    # Iterate over each example in the sampled examples to generate new data points
    for example in tqdm(sampled_examples):
        new_texts = generate_new_data_points(example['text'], n)
        for new_text in new_texts:
            new_example = example.copy()
            new_example['text'] = new_text
            expanded_examples.append(new_example)
    
    # Convert the list of new examples to a Dataset
    expanded_dataset = Dataset.from_pandas(pd.DataFrame(expanded_examples))
    dataset[f'expanded_{n}'] = expanded_dataset

    print(f"Completed generating {n} new examples for each original example.")


Generating 1 new examples for each original example...


100%|██████████| 13/13 [01:16<00:00,  5.90s/it]


Completed generating 1 new examples for each original example.
Generating 2 new examples for each original example...


100%|██████████| 13/13 [02:34<00:00, 11.86s/it]


Completed generating 2 new examples for each original example.
Generating 3 new examples for each original example...


100%|██████████| 13/13 [03:49<00:00, 17.65s/it]


Completed generating 3 new examples for each original example.
Generating 4 new examples for each original example...


100%|██████████| 13/13 [05:04<00:00, 23.43s/it]


Completed generating 4 new examples for each original example.
Generating 5 new examples for each original example...


100%|██████████| 13/13 [06:22<00:00, 29.42s/it]

Completed generating 5 new examples for each original example.





## Save Generated Trainig Examples

In [13]:
dataset

DatasetDict({
    train: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 2651
    })
    test: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 295
    })
    positive_sampled: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 13
    })
    expanded_1: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 13
    })
    expanded_2: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'is_topic', 'label', 'chunk_id'],
        num_rows: 26
    })
    expanded_3: Dataset

In [14]:
# Save the expanded dataset
dataset.save_to_disk(f"../../data/tmp/augmented_dataset_{topic}_word_replacement")

Saving the dataset (1/1 shards): 100%|██████████| 2651/2651 [00:00<00:00, 130095.12 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 295/295 [00:00<00:00, 48139.12 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 13/13 [00:00<00:00, 2363.91 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 13/13 [00:00<00:00, 2308.76 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 26/26 [00:00<00:00, 4932.91 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 39/39 [00:00<00:00, 7137.22 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 52/52 [00:00<00:00, 9279.83 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 65/65 [00:00<00:00, 12175.86 examples/s]
