In [5]:
from datasets import load_dataset
from transformers import AutoTokenizer

In [4]:
dataset = load_dataset('knowledgator/events_classification_biotech')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [7]:
classes = [i[0] for i in dataset['train']["all_labels"]]
class2id = {class_: id for id, class_ in enumerate(classes)}
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'target organization', 'all_labels', 'all_labels_concat', 'label 1', 'label 2', 'label 3', 'label 4', 'label 5'],
        num_rows: 2759
    })
    test: Dataset({
        features: ['title', 'content', 'target organization', 'all_labels', 'all_labels_concat', 'label 1', 'label 2', 'label 3', 'label 4', 'label 5'],
        num_rows: 381
    })
})

In [9]:
def preprocess_function(example):
    text = f"{example['title']}.\n{example['content']}"
    labels = [0.0 for _ in range(len(classes))]
    label_id = class2id[example["all_labels"][0]]
    labels[label_id] = 1.0

    example = tokenizer(text, truncation=True, max_length=512, padding="max_length")
    example["labels"] = labels
    return example

In [10]:
tokenized_dataset = dataset.map(preprocess_function)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

Map:   0%|          | 0/2759 [00:00<?, ? examples/s]

Map:   0%|          | 0/381 [00:00<?, ? examples/s]

In [11]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'target organization', 'all_labels', 'all_labels_concat', 'label 1', 'label 2', 'label 3', 'label 4', 'label 5', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2759
    })
    test: Dataset({
        features: ['title', 'content', 'target organization', 'all_labels', 'all_labels_concat', 'label 1', 'label 2', 'label 3', 'label 4', 'label 5', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 381
    })
})