## Generate Test Split

In [1]:
from datasets import load_from_disk, Dataset, ClassLabel, Value, Features, concatenate_datasets, DatasetDict
from transformers import AutoTokenizer
import pandas as pd 
import numpy as np
import torch
from collections import Counter
import random
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import logging
from transformers import logging as transformers_logging

# Set the logging level to error for transformers, which will suppress warnings
transformers_logging.set_verbosity_error()


**Load Examples:**

In [3]:
topic = "cannabis" #"energie" #"kinder" "cannabis"

In [4]:
#dataset = load_from_disk(f"../../data/tmp/processed_dataset_{topic}_buffed")
dataset = load_from_disk(f"../../data/tmp/processed_dataset_{topic}_buffed_filtered")
print(dataset)
print(dataset[1])

Dataset({
    features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label', 'token_count'],
    num_rows: 3904
})
{'_id': '64a0946b749484eec84dbbf1', 'batch_id': 16, 'domain': 't-online.de', 'view_url': 'email.t-online.de/em', 'lang': 'de', 'text': 'Wetter DAX Telefonverzeichnisse Lotto Telekom Services Telekom Hilfe & Service Frag Magenta Kundencenter Freemail MagentaCloud Tarife & Produkte PUR-Abo Login Suchen E-Mail Login Politik Deutschland Ausland Corona-Krise Tagesanbruch Ukraine Regional Berlin Hamburg München Köln Frankfurt Alle Städte Sport Bundesliga 2. Bundesliga Zweikampf der Woche Fußball Champions League FC Bayern Newsticker Formel 1 Was macht …? Mehr Sport Liveticker Ergebnisse Anzeigen Sportwetten Wirtschaft & Finanzen Aktuelles Börse Immobilien Die Anleger Ratgeber Versicherungen Publikumspreis Anzeigen Immobilien-Teilverkauf Ver

In [5]:
# Count the occurrences of each label
label_counts = Counter(dataset['label'])
print("Class frequencies:", label_counts)

# Find the minimum count
min_count = min(label_counts.values())
print("Minimum class frequency:", min_count)

Class frequencies: Counter({0: 3676, 1: 228})
Minimum class frequency: 228


## Generate Test Split

In [6]:
print("Number of all annotated samples:", len(dataset))

dataset_pos = dataset.filter(lambda example: example['label'] > 0)
print("Number of positive annotated samples:", len(dataset_pos))

dataset_buff = dataset.filter(lambda example: example['category'] == "buff")
print("Number of manually annotated samples:", len(dataset_buff))

dataset_not_buff = dataset.filter(lambda example: example['category'] != "buff")
print("Number of regular annotated samples:", len(dataset_not_buff))

Number of all annotated samples: 3904


Filter: 100%|██████████| 3904/3904 [00:00<00:00, 31695.44 examples/s]


Number of positive annotated samples: 228


Filter: 100%|██████████| 3904/3904 [00:00<00:00, 39104.65 examples/s]


Number of manually annotated samples: 150


Filter: 100%|██████████| 3904/3904 [00:00<00:00, 39980.67 examples/s]

Number of regular annotated samples: 3754





In [7]:
def train_test_split_balanced(dataset: Dataset, n: int, label_column='label', random_state=None):
    """Randomly sample n/2 datapoints from each class for the test set and return the train and test splits. """
    
    # If n is odd, increment by 1 to make it even
    if n % 2 != 0:
        n += 1
    
    if random_state is not None:
        random.seed(random_state)
    
    # Aggregate indices by class
    class_indices = {label: [i for i, example in enumerate(dataset) if example[label_column] == label] 
                     for label in set(dataset[label_column])}
    
    # Ensure there are enough samples in each class
    for label, indices in class_indices.items():
        if len(indices) < n // 2:
            raise ValueError(f"Not enough samples in class {label} to sample {n // 2} examples.")
    
    # Randomly sample n/2 indices from each class for the test set
    test_indices = []
    for indices in class_indices.values():
        test_indices.extend(random.sample(indices, n // 2))
    
    # Determine train indices by finding the difference between all indices and the test ones
    all_indices = set(range(len(dataset)))
    test_set = set(test_indices)
    train_indices = list(all_indices - test_set)
    
    # Select the train and test indices to create new datasets
    train_dataset = dataset.select(train_indices)
    test_dataset = dataset.select(test_indices)
    
    return DatasetDict({'train': train_dataset, 'test': test_dataset})

In [8]:
# Double the positive samples to balance the classes and 10% of the dataset for testing
test_size_int = round(len(dataset_pos) * 2 * 0.1)
print("Test size:", test_size_int)

Test size: 46


In [9]:
split_datasets = train_test_split_balanced(dataset_not_buff, n=test_size_int, label_column='label')
split_datasets['train'] = concatenate_datasets([split_datasets['train'], dataset_buff])
print("Size of training set:", len(split_datasets['train']))
print("Size of testing set:", len(split_datasets['test']))

Size of training set: 3858
Size of testing set: 46


In [10]:
# Count the occurrences of each label
label_counts = Counter(split_datasets["test"]['label'])
print("Class frequencies:", label_counts)

Class frequencies: Counter({0: 23, 1: 23})


In [11]:
split_datasets 

DatasetDict({
    train: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label', 'token_count'],
        num_rows: 3858
    })
    test: Dataset({
        features: ['_id', 'batch_id', 'domain', 'view_url', 'lang', 'text', 'text_length', 'word_count', 'topic', 'category', 'good_for_training', 'good_for_augmentation', 'annotation_type', 'is_topic', 'label', 'token_count'],
        num_rows: 46
    })
})

**Save Splits:**

In [12]:
split_datasets.save_to_disk(f"../../data/tmp/processed_dataset_{topic}_buffed_split")

Saving the dataset (1/1 shards): 100%|██████████| 3858/3858 [00:00<00:00, 34304.25 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 46/46 [00:00<00:00, 5837.76 examples/s]
