In [41]:
import numpy as np
from datasets import load_dataset
import torch
import collections
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EvalPrediction
)

In [42]:
# Load the Reuters 21578 dataset
dataset = load_dataset("reuters21578", "ModApte")

In [43]:
dataset

DatasetDict({
    test: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title'],
        num_rows: 3299
    })
    train: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title'],
        num_rows: 9603
    })
    unused: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title'],
        num_rows: 722
    })
})

In [44]:
train_labels = []
for sample in dataset['train']:
    train_labels.append(sample['topics'])

In [45]:
label_counter = collections.Counter(label for labels in train_labels for label in labels)

In [46]:
labels_with_one_doc = [label for label, count in label_counter.items() if count == 1]

In [47]:
labels_with_one_doc

['lin-oil',
 'rye',
 'red-bean',
 'groundnut-oil',
 'citruspulp',
 'rape-meal',
 'corn-oil',
 'peseta',
 'cotton-oil',
 'ringgit',
 'castorseed',
 'castor-oil',
 'lit',
 'rupiah',
 'skr',
 'nkr',
 'dkr',
 'sun-meal',
 'lin-meal',
 'cruzado']

In [48]:
[label for label in dataset['train']['topics']]

[['cocoa'],
 ['grain', 'wheat', 'corn', 'barley', 'oat', 'sorghum'],
 ['veg-oil',
  'linseed',
  'lin-oil',
  'soy-oil',
  'sun-oil',
  'soybean',
  'oilseed',
  'corn',
  'sunseed',
  'grain',
  'sorghum',
  'wheat'],
 [],
 ['earn'],
 ['acq'],
 ['earn'],
 ['earn', 'acq'],
 ['earn'],
 ['earn'],
 ['earn'],
 ['wheat', 'grain'],
 [],
 ['copper'],
 ['earn'],
 ['earn'],
 [],
 ['earn'],
 ['housing'],
 ['money-supply'],
 [],
 ['earn'],
 ['earn'],
 ['earn'],
 ['earn'],
 ['earn'],
 ['coffee'],
 ['acq', 'ship'],
 ['acq'],
 ['sugar'],
 ['trade'],
 ['reserves'],
 ['ship'],
 ['earn'],
 ['earn'],
 ['earn'],
 ['grain', 'corn'],
 ['money-supply'],
 ['ship'],
 [],
 [],
 ['earn'],
 ['earn'],
 ['earn'],
 ['acq'],
 ['veg-oil', 'soybean', 'oilseed', 'meal-feed', 'soy-meal'],
 ['earn'],
 ['earn'],
 ['coffee'],
 ['money-supply'],
 ['money-supply'],
 ['money-supply'],
 ['money-supply'],
 ['money-supply'],
 ['earn'],
 ['earn'],
 [],
 ['earn'],
 ['earn'],
 ['earn'],
 ['earn'],
 ['money-supply'],
 ['earn'],
 ['m