## Sample for Manual Error Analysis

In [1]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset, load_from_disk, concatenate_datasets
from sklearn.metrics import accuracy_score
import random
from tqdm import tqdm
import numpy as np
import torch
import os

  from .autonotebook import tqdm as notebook_tqdm


## Set Random Seed for Reproducibility

In [2]:
# Set a seed for random module
random.seed(42)

# Set a seed for numpy module
np.random.seed(42)

# Set a seed for torch module
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Define Parameters

In [3]:
TOPICS = ["cannabis", "kinder", "energie"]
MODEL = "deepset/gelectra-large"#"deepset/gelectra-large"
SAMPLING = "random"
SUFFIX = "_extended" #"", "_holdout", "_extended"
SPLIT = "holdout" # "train", "test", "holdout", "extended"
MAX_CONTENT_LENGTH = 384 # 496, 192, 384
OVERLAP = 64
FEATURES = "url_and_content" # "url", "content", "url_and_content"

## Sample from wrong Predictions

In [4]:
def sample_random_from_dataset(dataset, n=5):
    """
    Samples n random examples from a specified subset of the dataset.
    """
    n = min(n, len(dataset))
    random_indices = random.sample(range(len(dataset)), n)
    sampled_dataset = dataset.select(random_indices)
    return sampled_dataset

In [5]:
from collections import defaultdict
import csv
eval_results = defaultdict(dict)

for topic in TOPICS: # ----------------------------------------------------------------------
    
    # Load the dataset
    #dataset = load_from_disk(f"../../data_ccu/tmp/processed_dataset_{topic}_buffed_chunkified_{SAMPLING}{SUFFIX}_{MAX_CONTENT_LENGTH}_sampling_{MODEL.split('/')[1]}_{FEATURES}_{SPLIT}/")
    dataset = load_from_disk(f"../../data/tmp/processed_dataset_{topic}_buffed_chunkified_{SAMPLING}{SUFFIX}_{MAX_CONTENT_LENGTH}_s_{MODEL.split('/')[1]}_{FEATURES}_{SPLIT}/")
    dataset = dataset[SPLIT]

    # Keep examples where preds are not equal to labels
    dataset_wrong = dataset.filter(lambda x: x['preds'] != x['label'])
    print(f"Number of examples where preds are not equal to labels for topic {topic}: {len(dataset_wrong)}")
    print(f"First example: {dataset_wrong[0]['view_url']}")
    
    dataset_wrong_sample = sample_random_from_dataset(dataset_wrong, n=50)
    
    # Save the dataset as csv file
    dataset_wrong_sample.save_to_disk(f"../../data/tmp/processed_dataset_{topic}_buffed_chunkified_{SAMPLING}{SUFFIX}_{MAX_CONTENT_LENGTH}_sampling_{MODEL.split('/')[1]}_{FEATURES}_{SPLIT}_wrong_sample")
    
    dataset_wrong_sample.to_csv(f"./processed_dataset_{topic}_buffed_chunkified_{SAMPLING}{SUFFIX}_{MAX_CONTENT_LENGTH}_sampling_{MODEL.split('/')[1]}_{FEATURES}_{SPLIT}_wrong_sample.csv", quoting=csv.QUOTE_ALL)

Filter: 100%|██████████| 34209/34209 [00:00<00:00, 49604.76 examples/s]


Number of examples where preds are not equal to labels for topic cannabis: 55
First example: web.de/magazine/panorama/tranq-fleischfressende-droge-new-york-38362760


Saving the dataset (1/1 shards): 100%|██████████| 50/50 [00:00<00:00, 5536.16 examples/s]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 108.18ba/s]
Filter: 100%|██████████| 34046/34046 [00:00<00:00, 50235.53 examples/s]


Number of examples where preds are not equal to labels for topic kinder: 151
First example: www.haufe.de/sozialwesen/leistungen-sozialversicherung/rentenerhoehung-kommt-zum-1-juli_242_405920.html


Saving the dataset (1/1 shards): 100%|██████████| 50/50 [00:00<00:00, 4988.82 examples/s]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 135.51ba/s]
Filter: 100%|██████████| 40361/40361 [00:00<00:00, 50167.10 examples/s]


Number of examples where preds are not equal to labels for topic energie: 813
First example: web.de/magazine/wissen/klima/klima-irland-erwaegt-toetung-zehntausender-kuehe-38337046


Saving the dataset (1/1 shards): 100%|██████████| 50/50 [00:00<00:00, 4828.81 examples/s]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 105.79ba/s]
