## Train Multiclass Classifier: BERT

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForMaskedLM, TrainingArguments, Trainer
from datasets import Dataset, load_dataset, load_from_disk, concatenate_datasets
from sklearn.metrics import accuracy_score
from collections import Counter
import random
import numpy as np
import torch
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
topics = ["cannabis", "energie", "kinder"]

## Load Dataset

**Map class-names to class-ids:**

In [3]:
id_to_class = {0: "other",1: "cannabis", 2: "energie", 3: "kinder"}
class_to_id = {"other": 0, "cannabis": 1, "energie": 2, "kinder": 3}

**Load dataset for each topic:**

In [4]:
dataset_cannabis = load_from_disk(f"../data/tmp/processed_dataset_cannabis_buffed_filtered")
dataset_energie = load_from_disk(f"../data/tmp/processed_dataset_energie_buffed_filtered")
dataset_kinder = load_from_disk(f"../data/tmp/processed_dataset_kinder_buffed_filtered")

**Update dataset schema and class-label mappings:**

In [5]:
from datasets import ClassLabel, Features, Value, DatasetDict, concatenate_datasets, load_dataset
import datasets

# Define the new class label feature
class_labels = ClassLabel(names=["other", "cannabis", "energie", "kinder"])

# Define the new features, including all existing ones plus the updated 'label'
new_features = datasets.Features({
    '_id': datasets.Value('string'),
    'batch_id': datasets.Value('int64'),
    'domain': datasets.Value('string'),
    'view_url': datasets.Value('string'),
    'lang': datasets.Value('string'),
    'text': datasets.Value('string'),
    'text_length': datasets.Value('int64'),
    'word_count': datasets.Value('int64'),
    'token_count': datasets.Value('int64'),
    'topic': datasets.Value('string'),
    'category': datasets.Value('string'),
    'good_for_training': datasets.Value('string'),
    'good_for_augmentation': datasets.Value('string'),
    'annotation_type': datasets.Value('string'),
    'is_topic': datasets.Value('int64'),
    'label': class_labels # Updated ClassLabel feature
})


In [6]:
dataset_cannabis = dataset_cannabis.map(lambda e: {'label': class_to_id['cannabis'] if e["label"] == 1 else class_to_id['other']}, features=new_features)

dataset_energie = dataset_energie.map(lambda e: {'label': class_to_id['energie'] if e["label"] == 1 else class_to_id['other']}, features=new_features)

dataset_kinder = dataset_kinder.map(lambda e: {'label': class_to_id['kinder'] if e["label"] == 1 else class_to_id['other']}, features=new_features)

**Merge the three datasets:**

In [7]:
# Concatenate all datasets
dataset_all_topics = concatenate_datasets([dataset_cannabis, dataset_energie, dataset_kinder])

In [8]:
dataset_all_topics

Dataset({
    features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label'],
    num_rows: 281303
})

In [9]:
dataset_all_topics[0]

{'_id': '999999',
 'batch_id': 16,
 'domain': '',
 'view_url': 'mingle.respondi.de/',
 'lang': '',
 'text': '',
 'text_length': 0,
 'word_count': 0,
 'token_count': 2,
 'topic': 'cannabis',
 'category': 'other',
 'good_for_training': 'False',
 'good_for_augmentation': 'True',
 'annotation_type': 'domain_discarded',
 'is_topic': 0,
 'label': 0}

**Filter out annotations not good for training:**

In [10]:
dataset_all_topics_test = dataset_all_topics.filter(lambda example: example['good_for_training'] == "False")

dataset_all_topics = dataset_all_topics.filter(lambda example: example['good_for_training'] == "True")

**Separate negative and positive examples:**

In [11]:
# Filter positive and negative examples
dataset_all_topics_pos = dataset_all_topics.filter(lambda example: example['label'] > 0, num_proc=16)
dataset_all_topics_neg = dataset_all_topics.filter(lambda example: example['label'] == 0, num_proc=16)

**Get examples which are negative for all topics:**

In [12]:
# Collect all view_url values from dataset_all_topics_pos
pos_view_urls = set(dataset_all_topics_pos['view_url'])

# Filter dataset_all_topics_neg to exclude any rows present in dataset_all_topics_pos
dataset_all_topics_neg = dataset_all_topics_neg.filter(lambda example: example['view_url'] not in pos_view_urls, num_proc=16)


In [13]:
dataset_all_topics_neg

Dataset({
    features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label'],
    num_rows: 14489
})

**Deduplicate negative examples:**

In [14]:
seen_urls = set()

dataset_all_topics_neg = dataset_all_topics_neg.filter(lambda example: example['view_url'] not in seen_urls and not seen_urls.add(example['view_url']), num_proc=16)

## Oversample Positive Exmaples

In [15]:
MODEL_NAME = "FacebookAI/xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).eval()

Some weights of the model checkpoint at FacebookAI/xlm-roberta-base were not used when initializing XLMRobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Using 2 GPUs!


In [17]:
def randomly_replace_tokens(text, tokenizer, model, mask_probability=0.15):
    """Elegantly replace tokens one by one, each with full context."""

    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', add_special_tokens=True)
    input_ids = inputs.input_ids.clone()
    attention_mask = inputs.attention_mask

    # Identify non-special tokens for potential masking
    non_special_token_indices = [i for i, token_id in enumerate(input_ids[0])
                                 if token_id not in (tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id)]
    
    # Randomly select tokens for masking
    num_tokens_to_mask = int(len(non_special_token_indices) * mask_probability)
    tokens_to_mask = np.random.choice(non_special_token_indices, size=num_tokens_to_mask, replace=False)

    for i in tokens_to_mask:
        original_token_id = input_ids[0, i].item()  # Save the original token ID
        masked_input_ids = input_ids.detach().clone()
        masked_input_ids[0, i] = tokenizer.mask_token_id  # Mask the token

        with torch.no_grad():
            outputs = model(masked_input_ids, attention_mask=attention_mask)

        predictions = outputs.logits[0, i]
        predictions[original_token_id] = -float('Inf')  # Invalidate the original token
        best_pred_idx = predictions.argmax(dim=-1).item()
        input_ids[0, i] = best_pred_idx  # Replace with the best prediction

    replaced_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return replaced_text

In [18]:
# Example usage
text = "Das hier ist ein Test."
replaced_text = randomly_replace_tokens(text, tokenizer, model, mask_probability=0.35)
print("Original text:", text)
print("Replaced text:", replaced_text)

Original text: Das hier ist ein Test.
Replaced text: Und hier ist ein Testbericht


In [19]:
from datasets import concatenate_datasets, Dataset

def oversample(dataset, label_column, tokenizer, model, mask_probability=0.35):
    """Oversample minority classes in a Hugging Face Dataset of text data.
    """
    # Calculate class distributions and find the maximum count
    class_counts = Counter(dataset['label'])
    max_class_count = round(max(class_counts.values()) * 1.5) #TODO: Change back
    
    # Initialize a list to hold original and synthetic samples
    oversampled_datasets = []
    
    # Process each class
    for label, count in class_counts.items():
        
        class_indices = [i for i, x in enumerate(dataset[label_column]) if x == label]
        class_dataset = dataset.select(class_indices)
        
        # Calculate how many new samples are needed
        num_new_samples = max_class_count - count
        
        if num_new_samples > 0:
            
            # Randomly select samples to be duplicated and modified
            indices_to_augment = np.random.choice(class_indices, size=num_new_samples, replace=True)
            samples_to_augment = dataset.select(indices_to_augment)
            
            # Apply the randomly_replace_tokens function to generate new samples
            augmented_dataset = samples_to_augment.map(
                lambda example: {"text": randomly_replace_tokens(example['text'], tokenizer, model, mask_probability=mask_probability)},
                batched=False
            )
            
            # Append augmented dataset to the list
            oversampled_datasets.append(augmented_dataset)
        
        # Always include the original dataset for this class
        oversampled_datasets.append(class_dataset)
    
    # Concatenate all datasets into a new Dataset
    return concatenate_datasets(oversampled_datasets)


In [20]:
# Count the occurrences of each label
label_counts = Counter(dataset_all_topics_pos['label'])
print("Class frequencies before:", label_counts)

Class frequencies before: Counter({2: 248, 1: 224, 3: 212})


In [21]:
dataset_all_topics_pos = oversample(dataset_all_topics_pos, 'label', tokenizer, model)

Map:   5%|▌         | 8/148 [01:29<26:06, 11.19s/ examples]


KeyboardInterrupt: 

In [None]:
# Count the occurrences of each label
label_counts = Counter(dataset_all_topics_pos['label'])
print("Class frequencies after:", label_counts)

Class frequencies after: Counter({1: 372, 2: 372, 3: 372})


## Sample Negative Examples

In [22]:
print("Number of distinct negative domains", len(set(dataset_all_topics_neg["domain"])))

# Count the occurrences of each label
label_counts = Counter(dataset_all_topics_pos['label'])
print("Class frequencies:", label_counts)

# Find the minimum count
min_count = min(label_counts.values())
print("Minimum class frequency:", min_count)

Number of distinct negative domains 1922
Class frequencies: Counter({2: 248, 1: 224, 3: 212})
Minimum class frequency: 212


In [23]:
import random
from datasets import Dataset

def sample_random_from_dataset(dataset, n=5):
    """Samples n random examples from the dataset and returns both the sampled and unsampled parts.
    """

    # Select the sampled dataset
    n = min(n, len(dataset))
    random_indices = random.sample(range(len(dataset)), n)
    sampled_dataset = dataset.select(random_indices)
    
    # Determine indices not included in the random sample to create the remainder dataset
    all_indices = set(range(len(dataset)))
    unsampled_indices = list(all_indices - set(random_indices))
    remaining_dataset = dataset.select(unsampled_indices)
    
    return sampled_dataset, remaining_dataset

In [24]:
print(len(dataset_all_topics_pos))

684


In [25]:
dataset_all_topics_neg, dataset_all_topics_neg_test = sample_random_from_dataset(dataset_all_topics_neg, n = len(dataset_all_topics_pos))

In [26]:
dataset_all_topics = concatenate_datasets([dataset_all_topics_pos, dataset_all_topics_neg])
dataset_all_topics_holdout = concatenate_datasets([dataset_all_topics_test, dataset_all_topics_neg_test])

In [27]:
seen_urls = set()

dataset_all_topics_holdout = dataset_all_topics_holdout.filter(lambda example: example['view_url'] not in seen_urls and not seen_urls.add(example['view_url']), num_proc=16)

Filter (num_proc=16): 100%|██████████| 279935/279935 [00:02<00:00, 135818.13 examples/s]


## Sample Example for Validation Set

In [28]:
datasets = dataset_all_topics.train_test_split(test_size=0.05, shuffle=True)
datasets["valid"] = datasets["test"]
datasets["test"] = dataset_all_topics_holdout

In [29]:
datasets.save_to_disk(f"../data/tmp/processed_dataset_all_topics_multiclass")

Saving the dataset (1/1 shards): 100%|██████████| 1299/1299 [00:00<00:00, 13930.93 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 279399/279399 [00:06<00:00, 43103.42 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 69/69 [00:00<00:00, 3326.21 examples/s]


## Chunkify Examples

In [31]:
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter

In [32]:
MODEL_NAME = "FacebookAI/xlm-roberta-base"
MAX_CONTENT_LENGTH = 384
OVERLAP = 64

# Load a pre-trained tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [33]:
def get_input_length(text):
    """ Tokenize the input text and return the number of tokens """
    return len(tokenizer.encode(text, add_special_tokens=True, truncation=False, padding=False))

print(get_input_length("Hello, my name is John Doe"))

10


In [34]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=MAX_CONTENT_LENGTH,
    chunk_overlap=OVERLAP,
    length_function=get_input_length,
    separators = ["\n\n", "\n", ".", "?", "!", " ", ""]
)

# text_splitter = CharacterTextSplitter(
#     separator = ".", # Split text by sentences
#     chunk_size=MAX_CONTENT_LENGTH,
#     chunk_overlap=OVERLAP,
#     length_function=get_input_length,
#     is_separator_regex=False,
# )

In [35]:
test_text = datasets['train'][1]['text']
print(test_text[:100])

Ausnahmen und Förderungen: Gilt Habecks Gesetz für mein Haus? So kommen Sie an den Klimabonus Ukrain


In [36]:
texts = text_splitter.split_text(test_text)
print("Number of Chunks:", len(texts)) 
print("First Chunk:",texts[0])
print("Length of text:", get_input_length(texts[0]))


Number of Chunks: 6
First Chunk: Ausnahmen und Förderungen: Gilt Habecks Gesetz für mein Haus? So kommen Sie an den Klimabonus Ukraine-Krieg Politik Panorama Eintracht Frankfurt Meinung Kategorien Politik Krieg in Israel Ukraine-Krieg US-Wahl Panorama Eintracht Frankfurt Meinung Kommentare Gastbeiträge Kolumnen Wissen Wirtschaft Frax Gastwirtschaft Rhein-Main Landespolitik Darmstadt Wiesbaden Offenbach Main-Kinzig-Kreis Main-Taunus-Kreis Hochtaunus Kreis Groß-Gerau Hessen Sport Fußball Sport A-Z Kultur TV & Kino Gesellschaft Times mager Musik Literatur Theater Kunst Verbraucher Ratgeber Gesundheit Geld Karriere Auto Buchtipps Wohnen Reise Einfach Tasty Zukunft Anzeigen Stellenmarkt Trauer Webkiosk Abo & Service Abo kündigen Thema Produktempfehlung Über uns FR-Jobs Altenhilfe Projekte Schlappekicker Startseite Verbraucher Ausnahmen und Förderungen: Gilt Habecks Gesetz für mein Haus? So kommen Sie an den Klimabonus Stand: 26.04.2023, 05:08 Uhr Von: Moritz Bletzinger Kommentare Drucken He

In [37]:
def expandRow(row, col_name):
    """
    Generate prompts based on text chunks from the input row.
    """
    rows = []

    # Split the text into chunks
    text_chunks = text_splitter.split_text(row.get(col_name, ""))

    # Generate prompts for each text chunk
    for chunk_id, text_chunk in enumerate(text_chunks):
        new_row = {
            **row, 'chunk_id': chunk_id, 'text': text_chunk
        }
        rows.append(new_row)

    return rows

In [38]:
from multiprocessing import Pool


def processDataset(dataset, num_processes, func, params=()):
    """Process a list of articles in parallel using a multiprocessing Pool."""

    # Creates a list of arguments for each call to func
    # Uses starmap to pass multiple arguments to func
    with Pool(processes=num_processes) as pool:
        args = [(row,) + params for row in dataset]
        dataset = list(pool.starmap(func, args))

    # Flatten the resulting list of lists
    # and convert it into a Dataset
    dataset = [item for sublist in dataset for item in sublist]
    dataset = Dataset.from_dict(
        {key: [dic[key] for dic in dataset] for key in dataset[0]})

    return dataset

In [39]:
datasets

DatasetDict({
    train: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label'],
        num_rows: 2120
    })
    test: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label'],
        num_rows: 278986
    })
    valid: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label'],
        num_rows: 112
    })
})

In [40]:
num_processes = 24
params = ("text",)

for split in datasets:
    datasets[split] = processDataset(datasets[split], num_processes, expandRow, params)

Token indices sequence length is longer than the specified maximum sequence length for this model (647 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (619 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (579 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (592 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (766 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for thi

In [41]:
label_counts = Counter(datasets["train"]['label'])
print("Class frequencies:", label_counts)

Class frequencies: Counter({0: 8493, 2: 2652, 1: 2504, 3: 2179})


In [42]:
datasets

DatasetDict({
    train: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label', 'chunk_id'],
        num_rows: 15828
    })
    test: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label', 'chunk_id'],
        num_rows: 816025
    })
    valid: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'token_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label', 'chunk_id'],
        num_rows: 827
    })
})

**Update text length:**

In [43]:
def update_metrics(example):
    example['text_length'] = len(example['text'])
    example['word_count'] = len(example['text'].split())
    example['token_count'] = get_input_length(example['text'])
    return example

datasets = datasets.map(update_metrics)

Map:   0%|          | 0/15828 [00:00<?, ? examples/s]

Map: 100%|██████████| 15828/15828 [00:14<00:00, 1067.84 examples/s]
Map: 100%|██████████| 816025/816025 [11:46<00:00, 1155.16 examples/s]
Map: 100%|██████████| 827/827 [00:00<00:00, 1120.96 examples/s]


In [44]:
datasets["train"][0]

{'_id': 'dummy_id_280',
 'batch_id': 99999,
 'domain': 'www.pv-magazine.de',
 'view_url': 'https://www.pv-magazine.de/2022/12/22/europaeische-kommission-gibt-gruenes-licht-fuer-erneuerbare-energien-gesetz-2023/',
 'lang': 'de',
 'text': 'Europäische Kommission gibt grünes licht for Ernäuerbare -Elien-GeSETz 2023 – eNews magazine  Skip to Top Global Deutschland Spanien Frankreich Italien USA Großbritannien Lateina China Kanada Brasilien Schweiz Deutschland  Abstand  News Alle News Themen Events Marktübersichten Magazin Schauplatz Branchebuch Kontakt Werbung Europäische-Union gibt grünes licht für Erneuerbare-Elien- Gesetz Die europäische Kommission hatte Mittwoch zwei Novelle des Erneubaren-Energien-Gesetzesgenehmig Auch dem Windenergie -auf -See-GeSETz stimmte sie zu. Damit können die Gelder zum Ernährbare -Ausbau wie vor der Bundesregierung bis zum Vertragsbeschluss außer Kraft einsetzen. 22. Januar 2022, Ernst, Berlin Die EU-Union kann im EHEC 2023 einen hilfe zum Nachdenken finden. 

**Extract URL path:**

In [45]:
from urllib.parse import urlparse, urlunparse

def extract_url_path(example):
    view_url = example['view_url']
    if "://" not in view_url:
        view_url = "http://" + view_url  # Assume http if no protocol specified
    parsed_url = urlparse(view_url)
    new_url = urlunparse(('', '', parsed_url.path, parsed_url.params, parsed_url.query, parsed_url.fragment))
    example['url_path'] = new_url.lstrip('/')  # Store the result in a new field
    return example


extract_url_path({"view_url": "https://www.google.com/search?q=python+url+path"})

{'view_url': 'https://www.google.com/search?q=python+url+path',
 'url_path': 'search?q=python+url+path'}

In [46]:
datasets = datasets.map(extract_url_path) 

Map: 100%|██████████| 15828/15828 [00:03<00:00, 5015.11 examples/s]
Map: 100%|██████████| 816025/816025 [02:39<00:00, 5130.75 examples/s]
Map: 100%|██████████| 827/827 [00:00<00:00, 5625.42 examples/s]


## Save Dataset to Disk

In [47]:
datasets.save_to_disk(f"../data/tmp/processed_dataset_multiclass_chunkified_{MAX_CONTENT_LENGTH}")

Saving the dataset (1/1 shards): 100%|██████████| 15828/15828 [00:00<00:00, 406205.87 examples/s]
Saving the dataset (3/3 shards): 100%|██████████| 816025/816025 [00:01<00:00, 530009.87 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 827/827 [00:00<00:00, 117367.85 examples/s]
