# Named Entities Recognition

Source: https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb

In [None]:
! pip install datasets transformers seqeval



In [None]:
import transformers

print(transformers.__version__)

4.12.3


# Fine-tuning a model on a token classification task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a token classification task, which is the task of predicting a label for each token.

![Widget inference representing the NER task](https://raw.githubusercontent.com/huggingface/notebooks/65136bbdb4b1700bef98ab7eb15ebea17ad52855/examples/images/token_classification.png)

The most common token classification tasks are:

- NER (Named-entity recognition) Classify the entities in the text (person, organization, location...).
- POS (Part-of-speech tagging) Grammatically classify the tokens (noun, verb, adjective...)
- Chunk (Chunking) Grammatically classify the tokens and group them into "chunks" that go together

We will see how to easily load a dataset for these kinds of tasks and use the `Trainer` API to fine-tune a model on it.

This notebook is built to run on any token classification task, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a token classification head and a fast tokenizer (check on [this table](https://huggingface.co/transformers/index.html#bigtable) if this is the case). 

In [None]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [None]:
from datasets import load_dataset, load_metric

For our example here, we'll use the [CONLL 2003 dataset](https://www.aclweb.org/anthology/W03-0419.pdf). The notebook should work with any token classification dataset provided by the 🤗 Datasets library. If you're using your own dataset defined from a JSON or csv file (see the [Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files) on how to load them), it might need some adjustments in the names of the columns used.

In [None]:
datasets = load_dataset("conll2003")

Downloading:   0%|          | 0.00/2.60k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.78k [00:00<?, ?B/s]

Downloading and preparing dataset conll2003/conll2003 (download: 4.63 MiB, generated: 9.78 MiB, post-processed: Unknown size, total: 14.41 MiB) to /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6...


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/650k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/163k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/146k [00:00<?, ?B/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset conll2003 downloaded and prepared to /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

The `datasets` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set.

In [None]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

We can see the training, validation and test sets all have a column for the tokens (the input texts split into words) and one column of labels for each kind of task we introduced before.

To access an actual element, you need to select a split first, then give an index:

In [None]:
datasets["train"][0]

{'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'id': '0',
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}

The labels are already coded as integer ids to be easily usable by our model, but the correspondence with the actual categories is stored in the `features` of the dataset:

In [None]:
datasets["train"].features[f"ner_tags"]

Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], names_file=None, id=None), length=-1, id=None)

So for the NER tags, 0 corresponds to 'O', 1 to 'B-PER' etc... On top of the 'O' (which means no special entity), there are four labels for NER here, each prefixed with 'B-' (for beginning) or 'I-' (for intermediate), that indicate if the token is the first one for the current group with the label or not:
- 'PER' for person
- 'ORG' for organization
- 'LOC' for location
- 'MISC' for miscellaneous

Since the labels are lists of `ClassLabel`, the actual names of the labels are nested in the `feature` attribute of the object above:

In [None]:
label_list = datasets["train"].features[f"{task}_tags"].feature.names
label_list

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset (automatically decoding the labels in passing).

In [None]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(datasets["train"])

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags
0,12877,"["", They, want, to, discuss, in, public, ,, at, their, protest, meetings, ,, "", Filipovic, said, ., ""]","["", PRP, VBP, TO, VB, IN, JJ, ,, IN, PRP$, NN, NNS, ,, "", NNP, VBD, ., ""]","[O, B-NP, B-VP, I-VP, I-VP, B-PP, B-ADJP, O, B-PP, B-NP, I-NP, I-NP, O, O, B-NP, B-VP, O, O]","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, B-PER, O, O, O]"
1,10737,"[Cannes, 0, Monaco, 2, (, Henry, 26th, ,, 71st, ), .]","[NNP, CD, NNP, CD, (, NNP, JJ, ,, CD, ), .]","[B-NP, I-NP, I-NP, I-NP, O, B-NP, I-NP, I-NP, I-NP, O, O]","[B-ORG, O, B-ORG, O, O, B-PER, O, O, O, O, O]"
2,5662,"[SOCCER, -, PSV, BEAT, GRONINGEN, 4-1, TO, PULL, AWAY, FROM, AJAX, .]","[NN, :, NNP, NN, NNP, CD, TO, NNP, NNP, NNP, NNP, .]","[B-NP, O, B-NP, B-INTJ, B-NP, I-NP, B-VP, B-NP, I-NP, I-NP, I-NP, O]","[O, O, B-ORG, O, B-ORG, O, O, O, O, O, B-ORG, O]"
3,3774,"[Obilic, 3, 3, 0, 0, 8, 1, 9]","[JJ, CD, CD, CD, CD, CD, CD, CD]","[B-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP]","[B-ORG, O, O, O, O, O, O, O]"
4,1536,"[Bonn, says, Moscow, has, promised, to, observe, ceasefire, .]","[NNP, VBZ, NNP, VBZ, VBN, TO, VB, NN, .]","[B-NP, B-VP, B-NP, B-VP, I-VP, I-VP, I-VP, B-NP, O]","[B-LOC, O, B-LOC, O, O, O, O, O, O]"
5,2264,"[STUTTGART, ,, Germany, 1996-08-23]","[NNP, ,, NNP, CD]","[B-NP, O, B-NP, I-NP]","[B-LOC, O, B-LOC, O]"
6,3959,"[Guingamp, 2, (, Wreh, 15th, ,, 42nd, ), Monaco, 1, (, Scifo, 35th, ), .]","[NN, CD, (, NNP, JJ, ,, NNP, ), NNP, CD, (, NNP, JJ, ), .]","[B-NP, I-NP, O, B-NP, I-NP, I-NP, I-NP, O, B-NP, I-NP, O, B-NP, I-NP, O, O]","[B-ORG, O, O, B-PER, O, O, O, O, B-ORG, O, O, B-PER, O, O, O]"
7,3525,"[London, 21, 11, 1, 9, 555, 462, 23]","[NNP, CD, CD, CD, CD, CD, CD, CD]","[B-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP]","[B-ORG, O, O, O, O, O, O, O]"
8,13861,"[Jobless, figures, are, registered, unemployed, at, labour, ministry, .]","[NN, NNS, VBP, VBN, JJ, IN, NN, NN, .]","[B-NP, I-NP, B-VP, I-VP, B-ADJP, B-PP, B-NP, I-NP, O]","[O, O, O, O, O, O, O, O, O]"
9,5256,"[Played, Sunday, :]","[NNP, NNP, :]","[B-NP, I-NP, O]","[O, O, O]"


## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

The following assertion ensures that our tokenizer is a fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. Those fast tokenizers are available for almost all models, and we will need some of the special features they have for our preprocessing.

In [None]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

You can check which type of models have a fast tokenizer available and which don't on the [big table of models](https://huggingface.co/transformers/index.html#bigtable).

You can directly call this tokenizer on one sentence:

In [None]:
tokenizer("Hello, this is one sentence!")

{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 999, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

If, as is the case here, your inputs have already been split into words, you should pass the list of words to your tokenzier with the argument `is_split_into_words=True`:

In [None]:
tokenizer(["Hello", ",", "this", "is", "one", "sentence", "split", "into", "words", "."], is_split_into_words=True)

{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 3975, 2046, 2616, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Note that transformers are often pretrained with subword tokenizers, meaning that even if your inputs have been split into words already, each of those words could be split again by the tokenizer. Let's look at an example of that:

In [None]:
example = datasets["train"][4]
print(example["tokens"])

['Germany', "'s", 'representative', 'to', 'the', 'European', 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.']


In [None]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print(tokens)

['[CLS]', 'germany', "'", 's', 'representative', 'to', 'the', 'european', 'union', "'", 's', 'veterinary', 'committee', 'werner', 'z', '##wing', '##mann', 'said', 'on', 'wednesday', 'consumers', 'should', 'buy', 'sheep', '##me', '##at', 'from', 'countries', 'other', 'than', 'britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.', '[SEP]']


Here the words "Zwingmann" and "sheepmeat" have been split in three subtokens.

This means that we need to do some processing on our labels as the input ids returned by the tokenizer are longer than the lists of labels our dataset contain, first because some special tokens might be added (we can see a `[CLS]` and a `[SEP]` above) and then because of those possible splits of words in multiple tokens:

In [None]:
len(example[f"{task}_tags"]), len(tokenized_input["input_ids"])

(31, 39)

Thankfully, the tokenizer returns outputs that have a `word_ids` method which can help us.

In [None]:
print(tokenized_input.word_ids())

[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, None]


As we can see, it returns a list with the same number of elements as our processed input ids, mapping special tokens to `None` and all other tokens to their respective word. This way, we can align the labels with the processed input ids.

In [None]:
word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"{task}_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))

39 39


Here we set the labels of all special tokens to -100 (the index that is ignored by PyTorch) and the labels of all other tokens to the label of the word they come from. Another strategy is to set the label only on the first token obtained from a given word, and give a label of -100 to the other subtokens from the same word. We propose the two strategies here, just change the value of the following flag:

In [None]:
label_all_tokens = True

We're now ready to write the function that will preprocess our samples. We feed them to the `tokenizer` with the argument `truncation=True` (to truncate texts that are bigger than the maximum size allowed by the model) and `is_split_into_words=True` (as seen above). Then we align the labels with the token ids using the strategy we picked:

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [None]:
tokenize_and_align_labels(datasets['train'][:5])

{'input_ids': [[101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102], [101, 2848, 13934, 102], [101, 9371, 2727, 1011, 5511, 1011, 2570, 102], [101, 1996, 2647, 3222, 2056, 2006, 9432, 2009, 18335, 2007, 2446, 6040, 2000, 10390, 2000, 18454, 2078, 2329, 12559, 2127, 6529, 5646, 3251, 5506, 11190, 4295, 2064, 2022, 11860, 2000, 8351, 1012, 102], [101, 2762, 1005, 1055, 4387, 2000, 1996, 2647, 2586, 1005, 1055, 15651, 2837, 14121, 1062, 9328, 5804, 2056, 2006, 9317, 10390, 2323, 4965, 8351, 4168, 4017, 2013, 3032, 2060, 2084, 3725, 2127, 1996, 4045, 6040, 2001, 24509, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100], [-100, 1, 2, -100], [-100, 5, 0, 

To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about token classification, we use the `AutoModelForTokenClassification` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us. The only thing we have to specify is the number of labels for our problem (which we can get from the features, as seen before):

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6",
    "7": "LABEL_7",
    "8": "LABEL_8"
  },
  "initializer_range": 0.02,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4,
    "LABEL_5": 5,
    "LABEL_6": 6,
    "LABEL_7": 7,
    "LABEL_8": 8
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 

The warning is telling us we are throwing away some weights (the `vocab_transform` and `vocab_layer_norm` layers) and randomly initializing some other (the `pre_classifier` and `classifier` layers). This is absolutely normal in this case, because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

To instantiate a `Trainer`, we will need to define three more things. The most important is the [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [None]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the notebook and customize the number of epochs for training, as well as the weight decay.

Then we will need a data collator that will batch our processed examples together while applying padding to make them all the same size (each pad will be padded to the length of its longest example). There is a data collator for this task in the Transformers library, that not only pads the inputs, but also the labels:

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

The last thing to define for our `Trainer` is how to compute the metrics from the predictions. Here we will load the [`seqeval`](https://github.com/chakki-works/seqeval) metric (which is commonly used to evaluate results on the CONLL dataset) via the Datasets library.

In [None]:
metric = load_metric("seqeval")

This metric takes list of labels for the predictions and references:

In [None]:
labels = [label_list[i] for i in example[f"{task}_tags"]]
metric.compute(predictions=[labels], references=[labels])

{'LOC': {'f1': 1.0, 'number': 2, 'precision': 1.0, 'recall': 1.0},
 'ORG': {'f1': 1.0, 'number': 1, 'precision': 1.0, 'recall': 1.0},
 'PER': {'f1': 1.0, 'number': 1, 'precision': 1.0, 'recall': 1.0},
 'overall_accuracy': 1.0,
 'overall_f1': 1.0,
 'overall_precision': 1.0,
 'overall_recall': 1.0}

So we will need to do a bit of post-processing on our predictions:
- select the predicted index (with the maximum logit) for each token
- convert it to its string label
- ignore everywhere we set a label of -100

The following function does all this post-processing on the result of `Trainer.evaluate` (which is a namedtuple containing predictions and labels) before applying the metric:

In [None]:
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

Note that we drop the precision/recall/f1 computed for each category and only focus on the overall precision/recall/f1/accuracy.

Then we just need to pass all of this along with our datasets to the `Trainer`:

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

We can now finetune our model by just calling the `train` method:

In [None]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: chunk_tags, id, tokens, pos_tags, ner_tags.
***** Running training *****
  Num examples = 14041
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 2634


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.2356,0.07055,0.911853,0.920013,0.915915,0.98092
2,0.0516,0.062011,0.918502,0.93299,0.92569,0.98297
3,0.0308,0.062068,0.926729,0.936682,0.931679,0.983796


Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-500
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned-ner/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: chunk_tags, id, tokens, pos_tags, ner_tags.
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 16
Saving model checkpoint to distilbert-base-uncased-finetuned-ner/checkpoint-1000
Configuration saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/config.json
Model weights saved in distilbert-base-uncased-finetuned-ner/checkpoint-1000/pytorch_model.

TrainOutput(global_step=2634, training_loss=0.08496604765918822, metrics={'train_runtime': 525.6263, 'train_samples_per_second': 80.139, 'train_steps_per_second': 5.011, 'total_flos': 509926772226690.0, 'train_loss': 0.08496604765918822, 'epoch': 3.0})

The `evaluate` method allows you to evaluate again on the evaluation dataset or on another dataset:

In [None]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: chunk_tags, id, tokens, pos_tags, ner_tags.
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 16


{'epoch': 3.0,
 'eval_accuracy': 0.9837958917819754,
 'eval_f1': 0.9316790920218092,
 'eval_loss': 0.062067680060863495,
 'eval_precision': 0.9267293857221914,
 'eval_recall': 0.9366819554760041,
 'eval_runtime': 13.3257,
 'eval_samples_per_second': 243.89,
 'eval_steps_per_second': 15.309}

To get the precision/recall/f1 computed for each category now that we have finished training, we can apply the same function as before on the result of the `predict` method:

In [None]:
predictions, labels, _ = trainer.predict(tokenized_datasets["validation"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

The following columns in the test set  don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: chunk_tags, id, tokens, pos_tags, ner_tags.
***** Running Prediction *****
  Num examples = 3250
  Batch size = 16


{'LOC': {'f1': 0.9533914361500568,
  'number': 2618,
  'precision': 0.9458646616541353,
  'recall': 0.961038961038961},
 'MISC': {'f1': 0.8279181708784596,
  'number': 1231,
  'precision': 0.8177496038034865,
  'recall': 0.8383428107229894},
 'ORG': {'f1': 0.9018657620547613,
  'number': 2056,
  'precision': 0.8985997102848865,
  'recall': 0.9051556420233463},
 'PER': {'f1': 0.9756418696510861,
  'number': 3034,
  'precision': 0.9743589743589743,
  'recall': 0.976928147659855},
 'overall_accuracy': 0.9837958917819754,
 'overall_f1': 0.9316790920218092,
 'overall_precision': 0.9267293857221914,
 'overall_recall': 0.9366819554760041}

In [None]:
sequence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO"

In [None]:
inputs = tokenizer(sequence, return_tensors="pt")
tokens = inputs.tokens()


In [None]:
inputs.to('cuda')

{'input_ids': tensor([[  101, 17662,  2227,  4297,  1012,  2003,  1037,  2194,  2241,  1999,
          2047,  2259,  2103,  1012,  2049,  4075,  2024,  1999, 12873,  2080,
           102]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [None]:
import torch

In [None]:
outputs = model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)

In [None]:
for token, prediction in zip(tokens, predictions[0].detach().cpu().numpy()):
     print((token, label_list[prediction]))

('[CLS]', 'O')
('hugging', 'B-ORG')
('face', 'I-ORG')
('inc', 'I-ORG')
('.', 'I-ORG')
('is', 'O')
('a', 'O')
('company', 'O')
('based', 'O')
('in', 'O')
('new', 'B-LOC')
('york', 'I-LOC')
('city', 'I-LOC')
('.', 'O')
('its', 'O')
('headquarters', 'O')
('are', 'O')
('in', 'O')
('dumb', 'B-LOC')
('##o', 'B-LOC')
('[SEP]', 'B-LOC')


# Transformers Pipelines

## NER

In [None]:
from transformers import pipeline


In [None]:
ner_pipe = pipeline("ner")

No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmphbi_20bw


Downloading:   0%|          | 0.00/998 [00:00<?, ?B/s]

storing https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/dbfd3b2eb181c7b63cbbd8e7773a5e64941849440953d50ecad5ef346ad8286a.8f943745c8dd5e96d7b60c9b9e1be5711aff8aff42413b74288e076022e6e2bf
creating metadata file for /root/.cache/huggingface/transformers/dbfd3b2eb181c7b63cbbd8e7773a5e64941849440953d50ecad5ef346ad8286a.8f943745c8dd5e96d7b60c9b9e1be5711aff8aff42413b74288e076022e6e2bf
loading configuration file https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/dbfd3b2eb181c7b63cbbd8e7773a5e64941849440953d50ecad5ef346ad8286a.8f943745c8dd5e96d7b60c9b9e1be5711aff8aff42413b74288e076022e6e2bf
Model config BertConfig {
  "_num_labels": 9,
  "architectures": [
    "BertForTokenClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_ac

Downloading:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

storing https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/ac151adcc285472b4550ba8ce6a1c9b7fc0bc9da20170d00a97a4ec43c2847fd.c98827377d113c9ea90545b952aac740c66289834c4fc805b96030c77febb678
creating metadata file for /root/.cache/huggingface/transformers/ac151adcc285472b4550ba8ce6a1c9b7fc0bc9da20170d00a97a4ec43c2847fd.c98827377d113c9ea90545b952aac740c66289834c4fc805b96030c77febb678
loading weights file https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/ac151adcc285472b4550ba8ce6a1c9b7fc0bc9da20170d00a97a4ec43c2847fd.c98827377d113c9ea90545b952aac740c66289834c4fc805b96030c77febb678
All model checkpoint weights were used when initializing BertForTokenClassification.

All the weights of BertForTokenClassification were initialized from the model checkpoint at dbmdz/bert-large-cased-finetun

Downloading:   0%|          | 0.00/60.0 [00:00<?, ?B/s]

storing https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/tokenizer_config.json in cache at /root/.cache/huggingface/transformers/38daf04bf1f6dd2d989d4dd897e83d53e1563fcd2ff4e618dbcb5468c31ffa37.c70618325b9fc2d2d041e439766d360b48a086a8841cc2896322f6b8aefc0225
creating metadata file for /root/.cache/huggingface/transformers/38daf04bf1f6dd2d989d4dd897e83d53e1563fcd2ff4e618dbcb5468c31ffa37.c70618325b9fc2d2d041e439766d360b48a086a8841cc2896322f6b8aefc0225
loading configuration file https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/dbfd3b2eb181c7b63cbbd8e7773a5e64941849440953d50ecad5ef346ad8286a.8f943745c8dd5e96d7b60c9b9e1be5711aff8aff42413b74288e076022e6e2bf
Model config BertConfig {
  "_num_labels": 9,
  "architectures": [
    "BertForTokenClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

storing https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt in cache at /root/.cache/huggingface/transformers/b0a39d1a6ecfd7d86442cba576f2e932ff3c3e3d8d96f9d5a65fd1eb65634305.437aa611e89f6fc6675a049d2b5545390adbc617e7d655286421c191d2be2791
creating metadata file for /root/.cache/huggingface/transformers/b0a39d1a6ecfd7d86442cba576f2e932ff3c3e3d8d96f9d5a65fd1eb65634305.437aa611e89f6fc6675a049d2b5545390adbc617e7d655286421c191d2be2791
loading file https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/b0a39d1a6ecfd7d86442cba576f2e932ff3c3e3d8d96f9d5a65fd1eb65634305.437aa611e89f6fc6675a049d2b5545390adbc617e7d655286421c191d2be2791
loading file https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/tokenizer.json from cache at None
loading file https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/ad

In [None]:
for entity in ner_pipe(sequence):
     print(entity)

{'entity': 'I-ORG', 'score': 0.9994397, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2}
{'entity': 'I-ORG', 'score': 0.9872635, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7}
{'entity': 'I-ORG', 'score': 0.99727476, 'index': 3, 'word': 'Face', 'start': 8, 'end': 12}
{'entity': 'I-ORG', 'score': 0.99941236, 'index': 4, 'word': 'Inc', 'start': 13, 'end': 16}
{'entity': 'I-LOC', 'score': 0.9990946, 'index': 11, 'word': 'New', 'start': 40, 'end': 43}
{'entity': 'I-LOC', 'score': 0.99878746, 'index': 12, 'word': 'York', 'start': 44, 'end': 48}
{'entity': 'I-LOC', 'score': 0.99917126, 'index': 13, 'word': 'City', 'start': 49, 'end': 53}
{'entity': 'I-LOC', 'score': 0.97184104, 'index': 19, 'word': 'D', 'start': 79, 'end': 80}
{'entity': 'I-LOC', 'score': 0.89539707, 'index': 20, 'word': '##UM', 'start': 80, 'end': 82}
{'entity': 'I-LOC', 'score': 0.6929157, 'index': 21, 'word': '##BO', 'start': 82, 'end': 84}


In [None]:
pipe = pipeline("ner", model="xlm-roberta-large-finetuned-conll03-english", device=0)

https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpcsmnb5xr


Downloading:   0%|          | 0.00/852 [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/cd7e3332fb70fd94304b0fcad297caff7b8d9c97c5f76b5a4ac6bbcc14379fe1.ed0120fc465ef220e4bd136ae002fa78741a9545246ccb78502333b8dba60ee3
creating metadata file for /root/.cache/huggingface/transformers/cd7e3332fb70fd94304b0fcad297caff7b8d9c97c5f76b5a4ac6bbcc14379fe1.ed0120fc465ef220e4bd136ae002fa78741a9545246ccb78502333b8dba60ee3
loading configuration file https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/cd7e3332fb70fd94304b0fcad297caff7b8d9c97c5f76b5a4ac6bbcc14379fe1.ed0120fc465ef220e4bd136ae002fa78741a9545246ccb78502333b8dba60ee3
Model config XLMRobertaConfig {
  "_num_labels": 8,
  "architectures": [
    "XLMRobertaForTokenClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id":

Downloading:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/2b449384f083d8866edacbf92e4011e3fcf3026d45f2b9a4cb8650fec7a525c7.09ff0f236572ba82656162f6d6e7ec75e1af5babbf6a088165855208ad7a2c6d
creating metadata file for /root/.cache/huggingface/transformers/2b449384f083d8866edacbf92e4011e3fcf3026d45f2b9a4cb8650fec7a525c7.09ff0f236572ba82656162f6d6e7ec75e1af5babbf6a088165855208ad7a2c6d
loading weights file https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/2b449384f083d8866edacbf92e4011e3fcf3026d45f2b9a4cb8650fec7a525c7.09ff0f236572ba82656162f6d6e7ec75e1af5babbf6a088165855208ad7a2c6d
All model checkpoint weights were used when initializing XLMRobertaForTokenClassification.

All the weights of XLMRobertaForTokenClassification were initialized from the model checkpoint at xlm-roberta-large-finetuned-

Downloading:   0%|          | 0.00/4.83M [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model in cache at /root/.cache/huggingface/transformers/c5604a1be93150d0d1b8dfa45818ac04c1261ea33b5aa73e9f62b07171cafd93.00628a9eeb8baf4080d44a0abe9fe8057893de20c7cb6e6423cddbf452f7d4d8
creating metadata file for /root/.cache/huggingface/transformers/c5604a1be93150d0d1b8dfa45818ac04c1261ea33b5aa73e9f62b07171cafd93.00628a9eeb8baf4080d44a0abe9fe8057893de20c7cb6e6423cddbf452f7d4d8
https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpro9ahl45


Downloading:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/b32bffe3a0720df8952ddee5d7b1f9dd5c71ac8f0e0a69ba69432d8ab0be410c.1b58f47e7fc4532adbdc01d216b9bb2fb0657db965b423d3c9c974934c7c50e3
creating metadata file for /root/.cache/huggingface/transformers/b32bffe3a0720df8952ddee5d7b1f9dd5c71ac8f0e0a69ba69432d8ab0be410c.1b58f47e7fc4532adbdc01d216b9bb2fb0657db965b423d3c9c974934c7c50e3
loading file https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model from cache at /root/.cache/huggingface/transformers/c5604a1be93150d0d1b8dfa45818ac04c1261ea33b5aa73e9f62b07171cafd93.00628a9eeb8baf4080d44a0abe9fe8057893de20c7cb6e6423cddbf452f7d4d8
loading file https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/b32bffe3a0720df8952ddee5d7b1f9dd5c71ac8f0e0a69ba69432d8ab0be410

In [None]:
for entity in pipe(sequence):
     print(entity)

{'entity': 'I-ORG', 'score': 0.9999916, 'index': 1, 'word': '▁Hu', 'start': 0, 'end': 2}
{'entity': 'I-ORG', 'score': 0.99998784, 'index': 2, 'word': 'gging', 'start': 2, 'end': 7}
{'entity': 'I-ORG', 'score': 0.99998534, 'index': 3, 'word': '▁Face', 'start': 8, 'end': 12}
{'entity': 'I-ORG', 'score': 0.99999505, 'index': 4, 'word': '▁Inc', 'start': 13, 'end': 16}
{'entity': 'I-ORG', 'score': 0.99702644, 'index': 5, 'word': '.', 'start': 16, 'end': 17}
{'entity': 'I-LOC', 'score': 0.9999951, 'index': 11, 'word': '▁New', 'start': 40, 'end': 43}
{'entity': 'I-LOC', 'score': 0.99999464, 'index': 12, 'word': '▁York', 'start': 44, 'end': 48}
{'entity': 'I-LOC', 'score': 0.99999464, 'index': 13, 'word': '▁City', 'start': 49, 'end': 53}
{'entity': 'I-LOC', 'score': 0.9998963, 'index': 22, 'word': '▁D', 'start': 79, 'end': 80}
{'entity': 'I-LOC', 'score': 0.9976182, 'index': 23, 'word': 'UM', 'start': 80, 'end': 82}
{'entity': 'I-LOC', 'score': 0.999495, 'index': 24, 'word': 'BO', 'start': 82,

## Zero-shot classification

In [None]:
!pip install sentencepiece




In [None]:
classifier = pipeline("zero-shot-classification")
# classifier = pipeline("zero-shot-classification", device=0) # to utilize GPU

No model was supplied, defaulted to facebook/bart-large-mnli (https://huggingface.co/facebook/bart-large-mnli)
https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpg_rha12v


Downloading:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/980f2be6bd282c5079e99199d7554cfd13000433ed0fdc527e7def799e5738fe.4fdc7ce6768977d347b32986aff152e26fcebbda34ef89ac9b114971d0342e09
creating metadata file for /root/.cache/huggingface/transformers/980f2be6bd282c5079e99199d7554cfd13000433ed0fdc527e7def799e5738fe.4fdc7ce6768977d347b32986aff152e26fcebbda34ef89ac9b114971d0342e09
loading configuration file https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/980f2be6bd282c5079e99199d7554cfd13000433ed0fdc527e7def799e5738fe.4fdc7ce6768977d347b32986aff152e26fcebbda34ef89ac9b114971d0342e09
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": [
    "BartForSequenceClassification"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif

Downloading:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-mnli/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/35014754ae1fcb956d44903df02e4f69d0917cab0901ace5ac7f4a4a998346fe.a30bb5d685bb3c6e9376ab4480f1b252d9796d438d1c84a9b2deb0275c5b2151
creating metadata file for /root/.cache/huggingface/transformers/35014754ae1fcb956d44903df02e4f69d0917cab0901ace5ac7f4a4a998346fe.a30bb5d685bb3c6e9376ab4480f1b252d9796d438d1c84a9b2deb0275c5b2151
loading weights file https://huggingface.co/facebook/bart-large-mnli/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/35014754ae1fcb956d44903df02e4f69d0917cab0901ace5ac7f4a4a998346fe.a30bb5d685bb3c6e9376ab4480f1b252d9796d438d1c84a9b2deb0275c5b2151
All model checkpoint weights were used when initializing BartForSequenceClassification.

All the weights of BartForSequenceClassification were initialized from the model checkpoint at facebook/bart-large-mnli.
If your task is similar to the task the model 

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer_config.json in cache at /root/.cache/huggingface/transformers/569800088d6f014777e6d5d8cb61b2b8bb3d18a508a1d8af041aae6bbc6f3dfe.67d01b18f2079bd75eac0b2f2e7235768c7f26bd728e7a855a1c5acae01a91a8
creating metadata file for /root/.cache/huggingface/transformers/569800088d6f014777e6d5d8cb61b2b8bb3d18a508a1d8af041aae6bbc6f3dfe.67d01b18f2079bd75eac0b2f2e7235768c7f26bd728e7a855a1c5acae01a91a8
loading configuration file https://huggingface.co/facebook/bart-large-mnli/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/980f2be6bd282c5079e99199d7554cfd13000433ed0fdc527e7def799e5738fe.4fdc7ce6768977d347b32986aff152e26fcebbda34ef89ac9b114971d0342e09
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": [
    "BartForSequenceClassification"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,


Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json in cache at /root/.cache/huggingface/transformers/b4f8395edd321fd7cd8a87bca767b1135680a41d8931516dd1a447294633b9db.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05
creating metadata file for /root/.cache/huggingface/transformers/b4f8395edd321fd7cd8a87bca767b1135680a41d8931516dd1a447294633b9db.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05
https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp4ypd0aod


Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt in cache at /root/.cache/huggingface/transformers/19c09c9654551e163f858f3c99c226a8d0026acc4935528df3b09179204efe4c.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
creating metadata file for /root/.cache/huggingface/transformers/19c09c9654551e163f858f3c99c226a8d0026acc4935528df3b09179204efe4c.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpnz1yzwy4


Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/540455855ce0a3c13893c5d090d142de9481365bd32dc5457c957e5d13444d23.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730
creating metadata file for /root/.cache/huggingface/transformers/540455855ce0a3c13893c5d090d142de9481365bd32dc5457c957e5d13444d23.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730
loading file https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json from cache at /root/.cache/huggingface/transformers/b4f8395edd321fd7cd8a87bca767b1135680a41d8931516dd1a447294633b9db.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05
loading file https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt from cache at /root/.cache/huggingface/transformers/19c09c9654551e163f858f3c99c226a8d0026acc4935528df3b09179204efe4c.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading

С помощью этой модели мы можем классифицировать тексты как принадлежащие или не принадлежащие одному из нужных нам классов.

По умолчанию предполагается, что текст принадлежит ровно к одному из объявленных классов, а модель возвращает вероятности каждого из классов, которые в сумме равны 1.


In [None]:
sequence = "Who are you voting for in 2020?"
candidate_labels = ["politics", "public health", "economics"]

classifier(sequence, candidate_labels)

{'labels': ['politics', 'economics', 'public health'],
 'scores': [0.972518801689148, 0.014584142714738846, 0.012897025793790817],
 'sequence': 'Who are you voting for in 2020?'}

Модель также может решать задачу многоклассоовой классификации, определяя для каждого класса, принадлежит ли к нему текст или нет. 
В этом случае модель возвращает вероятности каждого из классов независимо. Нам понадобится добавить парамерт ```multi_class=True``` в вызов функции. 

In [None]:
sequence = "Who are you voting for in 2020?"
candidate_labels = ["politics", "public health", "economics", "elections"]

classifier(sequence, candidate_labels, multi_class=True)

The `multi_class` argument has been deprecated and renamed to `multi_label`. `multi_class` will be removed in a future version of Transformers.


{'labels': ['politics', 'elections', 'public health', 'economics'],
 'scores': [0.972069501876831,
  0.967610776424408,
  0.03248710557818413,
  0.0061644683592021465],
 'sequence': 'Who are you voting for in 2020?'}

Мы можем решать не только задачу классификации по темам, но и задачу классификации по тональности: можно классифицировать отзывы на негативный и позитивный классы.

In [None]:
sequence = "I hated this movie. The acting sucked."
candidate_labels = ["positive", "negative"]

classifier(sequence, candidate_labels)

{'labels': ['negative', 'positive'],
 'scores': [0.9916268587112427, 0.00837317667901516],
 'sequence': 'I hated this movie. The acting sucked.'}


Модель, лежащая в основе данного классификатора, была обучена на задаче Natural Language Inference (NLI). Задача состояла в следующем: по двум входящим текстам требовалось определить, является ли один из них продолжением другого.

Такую модель можно использвать в задаче zero-shot классификации, если свести классификацию к задаче NLI. Для этого модели на вход подаются два текста:
- текст, который нужно классифицировать (*предпосылка*)
- текст шаблона, в который вставлено название нужного класса (*гипотеза*)


Если с точки зрения NLI-модели текст гипотезы продолжает текст предпосылки, то мы можем заключить, что классифицируемый текст относится к соответствующему классу. Более подробно этот подход описан [здесь](https://joeddav.github.io/blog/2020/05/29/ZSL.html).

По умолчанию заданные метки классов вставляются в шаблон `This example is {class_name}.` 

Во многих случаях такой подход работает достаточно хорошо, но в некоторых задачах качество классификации можно улучшить, используя более подходящий шаблон. 


In [None]:
sequences = [
    "I hated this movie. The acting sucked.",
    "This movie didn't quite live up to my high expectations, but overall I still really enjoyed it."
]
candidate_labels = ["positive", "negative"]

classifier(sequences, candidate_labels)

Disabling tokenizer parallelism, we're using DataLoader multithreading already
  cpuset_checked))


[{'labels': ['negative', 'positive'],
  'scores': [0.9916268587112427, 0.00837317667901516],
  'sequence': 'I hated this movie. The acting sucked.'},
 {'labels': ['negative', 'positive'],
  'scores': [0.8148515224456787, 0.1851484626531601],
  'sequence': "This movie didn't quite live up to my high expectations, but overall I still really enjoyed it."}]

Второй пример выше несколько сложнее первого, поэтому модель может предсказывать негативный класс. 

Попробуем повысить качество классификации, используя более подходящий шаблон для данной задачи. Вместо шаблона по умолчанию `This example is {}.`, мы будем использовать `The sentiment of this review is {}.` (здесь `{}` будет заменено на название класса)

In [None]:
sequences = [
    "I hated this movie. The acting sucked.",
    "This movie didn't quite live up to my high expectations, but overall I still really enjoyed it."
]
candidate_labels = ["positive", "negative"]
hypothesis_template = "The sentiment of this review is {}."

classifier(sequences, candidate_labels, hypothesis_template=hypothesis_template)

  cpuset_checked))


[{'labels': ['negative', 'positive'],
  'scores': [0.9890093207359314, 0.010990706272423267],
  'sequence': 'I hated this movie. The acting sucked.'},
 {'labels': ['positive', 'negative'],
  'scores': [0.9581229090690613, 0.04187712445855141],
  'sequence': "This movie didn't quite live up to my high expectations, but overall I still really enjoyed it."}]

С помощью мульти-язычной модели на базе модели XLM RoBERTa этот подход можно применить не только для английского языка, но и для ряда других.

In [None]:
classifier = pipeline("zero-shot-classification", model='joeddav/xlm-roberta-large-xnli')

Some weights of the model checkpoint at joeddav/xlm-roberta-large-xnli were not used when initializing XLMRobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
sequence = "За кого вы голосуете в 2020 году?" # translation: "Who are you voting for in 2020?"
candidate_labels = ["Europe", "public health", "politics"]

classifier(sequence, candidate_labels)

{'labels': ['politics', 'Europe', 'public health'],
 'scores': [0.9048486948013306, 0.05722154304385185, 0.03792975842952728],
 'sequence': 'За кого вы голосуете в 2020 году?'}

In [None]:
sequence = "За кого вы голосуете в 2020 году?" # translation: "Who are you voting for in 2020?"
candidate_labels = ["Europe", "santé publique", "politique"]

classifier(sequence, candidate_labels)

{'labels': ['politique', 'Europe', 'santé publique'],
 'scores': [0.9726154804229736, 0.017128439620137215, 0.010256052017211914],
 'sequence': 'За кого вы голосуете в 2020 году?'}

Как было отмечено выше, по умолчанию в модели используется шаблон на английском языке, `This text is {}.`
В случае, если мы работаем с другим языком, имеет смысл поменять данный шаблон на аналогичный, написанный на нужном нам языке.

In [None]:
sequence = "За кого вы голосуете в 2020 году?"
candidate_labels = ["Europe", "public health", "politics"]
hypothesis_template = "Этот пример относится к {}."

classifier(sequence, candidate_labels, hypothesis_template=hypothesis_template)

{'labels': ['politics', 'Europe', 'public health'],
 'scores': [0.9780572056770325, 0.016864290460944176, 0.005078556947410107],
 'sequence': 'За кого вы голосуете в 2020 году?'}

В основе была взята модель XLM RoBERTa, обученная на 85 языках, fine-tuning модели проходил на мульти-язычном датасете XNLI, включающем 15 языков (Arabic, Bulgarian, Chinese, English, French, German, Greek, Hindi, Russian, Spanish, Swahili, Thai, Turkish, Urdu, Vietnamese).

# Relation Extraction

Библиотека [Ask2Transformers](https://github.com/osainz59/Ask2Transformers) предоставляет классификаторы для извлечения отношений на основе NLI с использованием трансформерных моделей от HuggingFace.

In [None]:
!pip install a2t


Collecting a2t
  Downloading a2t-0.2.0-py3-none-any.whl (56 kB)
[K     |████████████████████████████████| 56 kB 2.5 MB/s 
Installing collected packages: a2t
Successfully installed a2t-0.2.0


![](https://raw.githubusercontent.com/osainz59/Ask2Transformers/master/imgs/RE_NLI_white_bg.svg)

Задача классификации отношений в этой модели определяется следующим образом:


*   в качестве предпосылки задается текст с упоминанием двух сущностей
*   в качестве гипотезы берется шаблон описания отношения.

Модель NLI определяет, есть ли связь между предпосылкой и гипотезой. 


In [None]:
from a2t.relation_classification import NLIRelationClassifierWithMappingHead

Создадим классификатор для определения заданных типов отношений.

In [None]:
# Define the set of relations we want to classify (including no_relation in the first position!)
relations = [
        'no_relation',
        'per:city_of_death',
        'org:founded_by'
    ]

# Define the verbalizations (a descriptive template) of each positive relation
relation_verbalizations = {
        'per:city_of_death': [
            "{subj} died in {obj}"
        ],
        'org:founded_by': [
            "{subj} was founded by {obj}",
            "{obj} founded {subj}"
        ]
    }

# Define the posible entity type combinations for each relation
valid_conditions = {
        'per:city_of_death': [
            "PERSON:CITY",
            "PERSON:LOCATION"
        ],
        'org:founded_by': [
            "ORGANIZATION:PERSON"
        ]
    }

# Define the classifier instance, by default the threshold is set to 0.95
clf = NLIRelationClassifierWithMappingHead(
        labels=relations, 
        template_mapping=relation_verbalizations,
        valid_conditions=valid_conditions,
        negative_threshold=0.9
    )

Downloading:   0%|          | 0.00/688 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Посмотрим на предсказание отношений для примеров из датасета [TACRED](https://nlp.stanford.edu/projects/tacred/). 

In [None]:
from a2t.relation_classification import REInputFeatures

test_examples = [
        REInputFeatures(subj='Billy Mays', obj='Tampa', pair_type='PERSON:CITY', context='Billy Mays, the bearded, boisterous pitchman who, as the undisputed king of TV yell and sell, became an unlikely pop culture icon, died at his home in Tampa, Fla, on Sunday', label='per:city_of_death'),
        REInputFeatures(subj='Old Lane Partners', obj='Pandit', pair_type='ORGANIZATION:PERSON', context='Pandit worked at the brokerage Morgan Stanley for about 11 years until 2005, when he and some Morgan Stanley colleagues quit and later founded the hedge fund Old Lane Partners.', label='org:founded_by'),
        REInputFeatures(subj='He', obj='University of Maryland in College Park', pair_type='PERSON:ORGANIZATION', context='He received an undergraduate degree from Morgan State University in 1950 and applied for admission to graduate school at the University of Maryland in College Park.', label='no_relation')

    ]

clf.predict(test_examples, return_confidences=True, topk=1)


100%|██████████| 3/3 [00:00<00:00,  5.03it/s]


[('per:city_of_death', 0.9872344136238098),
 ('org:founded_by', 0.9368537068367004),
 ('no_relation', 1.0)]

Можно посмотреть на предсказания предобученного классификатора с большим количеством отношений.

In [None]:
from a2t.relation_classification import TACREDClassifier

clf = TACREDClassifier()

clf.predict(test_examples, return_confidences=True, topk=3)


Downloading:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/952 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.33M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.65G [00:00<?, ?B/s]

100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


[[('per:city_of_death', 0.9863594770431519),
  ('per:cities_of_residence', 0.9615850448608398),
  ('per:city_of_birth', 0.20450375974178314)],
 [('org:founded_by', 0.982134222984314),
  ('org:shareholders', 0.21125152707099915),
  ('org:top_members/employees', 0.021178238093852997)],
 [('no_relation', 1.0),
  ('per:schools_attended', 0.861609160900116),
  ('per:employee_of', 0.2816363573074341)]]

# Syntax Parsing

In [None]:
!pip install natasha

Collecting natasha
  Downloading natasha-1.4.0-py3-none-any.whl (34.4 MB)
[K     |████████████████████████████████| 34.4 MB 30 kB/s 
[?25hCollecting pymorphy2
  Downloading pymorphy2-0.9.1-py3-none-any.whl (55 kB)
[K     |████████████████████████████████| 55 kB 3.3 MB/s 
[?25hCollecting ipymarkup>=0.8.0
  Downloading ipymarkup-0.9.0-py3-none-any.whl (14 kB)
Collecting navec>=0.9.0
  Downloading navec-0.10.0-py3-none-any.whl (23 kB)
Collecting yargy>=0.14.0
  Downloading yargy-0.15.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 110 kB/s 
[?25hCollecting slovnet>=0.3.0
  Downloading slovnet-0.5.0-py3-none-any.whl (49 kB)
[K     |████████████████████████████████| 49 kB 6.1 MB/s 
[?25hCollecting razdel>=0.5.0
  Downloading razdel-0.5.0-py3-none-any.whl (21 kB)
Collecting intervaltree>=3
  Downloading intervaltree-3.1.0.tar.gz (32 kB)
Collecting pymorphy2-dicts-ru<3.0,>=2.4
  Downloading pymorphy2_dicts_ru-2.4.417127.4579844-py2.py3-none-any.whl (8.2 MB)


In [None]:
from natasha import Doc, NewsEmbedding, NewsSyntaxParser, Segmenter

In [None]:
sents = "Сократ родился в 469 году до н. э. в семье скульптора Софрониска из Алопек. В молодости Сократ участвовал в нескольких сражениях. В одной из битв Сократ даже спас от смерти своего молодого сослуживца."

Предварительный (и обязательный!) этап - сплиттинг (разбиение на предложения):

In [None]:
doc = Doc(sents)
doc.segment(segmenter)
doc.sents


[DocSent(stop=75, text='Сократ родился в 469 году до н. э. в семье скульп..., tokens=[...]),
 DocSent(start=76, stop=129, text='В молодости Сократ участвовал в нескольких сражен..., tokens=[...]),
 DocSent(start=130, stop=200, text='В одной из битв Сократ даже спас от смерти своего..., tokens=[...])]

In [None]:
emb = NewsEmbedding()
syntax_parser = NewsSyntaxParser(emb)
doc.parse_syntax(syntax_parser)
display(doc.tokens)

[DocToken(stop=6, text='Сократ', id='1_1', head_id='1_2', rel='nsubj'),
 DocToken(start=7, stop=14, text='родился', id='1_2', head_id='1_0', rel='root'),
 DocToken(start=15, stop=16, text='в', id='1_3', head_id='1_5', rel='case'),
 DocToken(start=17, stop=20, text='469', id='1_4', head_id='1_5', rel='amod'),
 DocToken(start=21, stop=25, text='году', id='1_5', head_id='1_2', rel='obl'),
 DocToken(start=26, stop=28, text='до', id='1_6', head_id='1_7', rel='case'),
 DocToken(start=29, stop=30, text='н', id='1_7', head_id='1_12', rel='nmod'),
 DocToken(start=30, stop=31, text='.', id='1_8', head_id='1_7', rel='punct'),
 DocToken(start=32, stop=33, text='э', id='1_9', head_id='1_9', rel='nsubj'),
 DocToken(start=33, stop=34, text='.', id='1_10', head_id='1_9', rel='punct'),
 DocToken(start=35, stop=36, text='в', id='1_11', head_id='1_12', rel='case'),
 DocToken(start=37, stop=42, text='семье', id='1_12', head_id='1_12', rel='nmod'),
 DocToken(start=43, stop=53, text='скульптора', id='1_13',

In [None]:
doc.sents[0].syntax.print()

      ┌► Сократ     nsubj
┌─┌───└─ родился    
│ │ ┌──► в          case
│ │ │ ┌► 469        amod
│ └►└─└─ году       obl
│     ┌► до         case
│ ┌►┌─└─ н          nmod
│ │ └──► .          punct
│ │   ┌─ э          
│ │   └► .          punct
│ │   ┌► в          case
│ └─┌─└─ семье      
│ ┌─└►┌─ скульптора nmod
│ │   └► Софрониска appos
│ │   ┌► из         case
│ └──►└─ Алопек     nmod
└──────► .          punct


In [None]:
doc.sents[1].syntax.print()

      ┌► В          case
    ┌►└─ молодости  obl
    │ ┌► Сократ     nsubj
┌─┌─└─└─ участвовал 
│ │ ┌──► в          case
│ │ │ ┌► нескольких nummod
│ └►└─└─ сражениях  obl
└──────► .          punct


In [None]:
doc.sents[2].syntax.print()

        ┌► В          case
    ┌►┌─└─ одной      obl
    │ │ ┌► из         case
    │ └►└─ битв       nmod
    │ ┌──► Сократ     nsubj
    │ │ ┌► даже       advmod
┌───└─└─└─ спас       
│   │   ┌► от         case
│ ┌─└──►└─ смерти     obl
│ │   ┌──► своего     det
│ │   │ ┌► молодого   amod
│ └──►└─└─ сослуживца nmod
└────────► .          punct


In [None]:
example = "Я смотрел, как Си-лучи мерцают во тьме близ врат Тангейзера. Все эти мгновения исчезнут во времени, как слёзы под дождём. "
doc = Doc(example)
doc

Doc(text='Я смотрел, как Си-лучи мерцают во тьме близ врат ...)

In [None]:
segmenter = Segmenter()

doc.segment(segmenter)
doc.sents

[DocSent(stop=60, text='Я смотрел, как Си-лучи мерцают во тьме близ врат ..., tokens=[...]),
 DocSent(start=61, stop=121, text='Все эти мгновения исчезнут во времени, как слёзы ..., tokens=[...])]

In [None]:
emb = NewsEmbedding()
syntax_parser = NewsSyntaxParser(emb)

In [None]:
doc.parse_syntax(syntax_parser)

In [None]:
display(doc.tokens)

[DocToken(stop=1, text='Я', id='1_1', head_id='1_2', rel='nsubj'),
 DocToken(start=2, stop=9, text='смотрел', id='1_2', head_id='1_0', rel='root'),
 DocToken(start=9, stop=10, text=',', id='1_3', head_id='1_6', rel='punct'),
 DocToken(start=11, stop=14, text='как', id='1_4', head_id='1_6', rel='mark'),
 DocToken(start=15, stop=22, text='Си-лучи', id='1_5', head_id='1_6', rel='obj'),
 DocToken(start=23, stop=30, text='мерцают', id='1_6', head_id='1_2', rel='ccomp'),
 DocToken(start=31, stop=33, text='во', id='1_7', head_id='1_10', rel='case'),
 DocToken(start=34, stop=38, text='тьме', id='1_8', head_id='1_6', rel='obl'),
 DocToken(start=39, stop=43, text='близ', id='1_9', head_id='1_10', rel='case'),
 DocToken(start=44, stop=48, text='врат', id='1_10', head_id='1_11', rel='amod'),
 DocToken(start=49, stop=59, text='Тангейзера', id='1_11', head_id='1_10', rel='nmod'),
 DocToken(start=59, stop=60, text='.', id='1_12', head_id='1_2', rel='punct'),
 DocToken(start=61, stop=64, text='Все', i

In [None]:
doc.sents[0].syntax.print()

        ┌► Я          nsubj
┌─┌─────└─ смотрел    
│ │ ┌────► ,          punct
│ │ │ ┌──► как        mark
│ │ │ │ ┌► Си-лучи    obj
│ └►└─└─└─ мерцают    ccomp
│ ┌►│      во         case
│ │ └────► тьме       obl
│ │     ┌► близ       case
│ └─┌─┌►└─ врат       amod
│   └►└─── Тангейзера nmod
└────────► .          punct


In [None]:
doc.sents[1].syntax.print()

      ┌──► Все       det
      │ ┌► эти       det
      └─└─ мгновения nsubj
┌───┌─└─── исчезнут  
│   │   ┌► во        case
│ ┌─└──►└─ времени   obl
│ │   ┌──► ,         punct
│ │   │ ┌► как       case
│ └►┌─└─└─ слёзы     acl
│   │   ┌► под       case
│   └──►└─ дождём    nmod
└────────► .         punct


Можно посмотреть на детали разбора:

In [None]:
doc.sents[0].syntax.tokens[3]

SyntaxToken(
    id='1_4',
    text='как',
    head_id='1_6',
    rel='mark'
)

# LUKE model

Source: https://github.com/studio-ousia/luke

LUKE is a transformer-based pretrained model for entities representation.
The model is trained to predict words and entites representations on masked language modeling task. These representations make the model efficient for entities related tasks (NER, relation classification, question answering).

In [None]:
# !pip install git+https://github.com/huggingface/transformers.git

In [None]:
import json
import torch
from tqdm import trange
from transformers import LukeTokenizer, LukeForEntityPairClassification

## Loading the dataset

The TACRED dataset is not publicly available.
Here we use sampled dataset provided in this [repo](https://github.com/yuhaozhang/tacred-relation).

In [None]:
!wget https://github.com/yuhaozhang/tacred-relation/raw/master/dataset/tacred/test.json

--2021-11-11 11:12:26--  https://github.com/yuhaozhang/tacred-relation/raw/master/dataset/tacred/test.json
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/yuhaozhang/tacred-relation/master/dataset/tacred/test.json [following]
--2021-11-11 11:12:26--  https://raw.githubusercontent.com/yuhaozhang/tacred-relation/master/dataset/tacred/test.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30422 (30K) [text/plain]
Saving to: ‘test.json’


2021-11-11 11:12:27 (14.2 MB/s) - ‘test.json’ saved [30422/30422]



In [None]:
def load_examples(dataset_file):
    with open(dataset_file, "r") as f:
        data = json.load(f)

    examples = []
    for i, item in enumerate(data):
        tokens = item["token"]
        token_spans = dict(
            subj=(item["subj_start"], item["subj_end"] + 1),
            obj=(item["obj_start"], item["obj_end"] + 1)
        )

        if token_spans["subj"][0] < token_spans["obj"][0]:
            entity_order = ("subj", "obj")
        else:
            entity_order = ("obj", "subj")

        text = ""
        cur = 0
        char_spans = {}
        for target_entity in entity_order:
            token_span = token_spans[target_entity]
            text += " ".join(tokens[cur : token_span[0]])
            if text:
                text += " "
            char_start = len(text)
            text += " ".join(tokens[token_span[0] : token_span[1]])
            char_end = len(text)
            char_spans[target_entity] = (char_start, char_end)
            text += " "
            cur = token_span[1]
        text += " ".join(tokens[cur:])
        text = text.rstrip()

        examples.append(dict(
            text=text,
            entity_spans=[tuple(char_spans["subj"]), tuple(char_spans["obj"])],
            label=item["relation"]
        ))

    return examples

In [None]:
test_examples = load_examples("test.json")

In [None]:
test_examples[:10]

[{'entity_spans': [(115, 118), (3, 6)],
  'label': 'no_relation',
  'text': 'No one knows how Tamaihia Lynae Moore died , but the foster mother of the Sacramento toddler has been arrested for her murder .'},
 {'entity_spans': [(16, 31), (9, 12)],
  'label': 'no_relation',
  'text': "He named one as Shah Abdul Aziz , a member of a pro-Taliban religious party elected to parliament 's lower house in 2002 ."},
 {'entity_spans': [(40, 57), (6, 14)],
  'label': 'per:title',
  'text': "Youth minister and `` Street General '' Charles Ble Goude , who is under UN sanctions for `` acts of violence by street militias , including beatings , rapes and extrajudicial killings '' , vows to fight for Ivory Coast 's sovereignty ."},
 {'entity_spans': [(20, 26), (105, 113)],
  'label': 'no_relation',
  'text': "Prosecutors believe Graham and two other AIM activists , Theda Clark and Arlo Looking Cloud , stopped at Marshall 's home on South Dakota 's Pine Ridge reservation with Aquash shortly before Graham

In [None]:
len(test_examples)

20

## Loading the fine-tuned model and tokenizer

We construct the model and tokenizer using the [fine-tuned model checkpoint](https://huggingface.co/studio-ousia/luke-large-finetuned-tacred).

In [None]:
# Load the model checkpoint
model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
model.eval()
model.to("cuda")

# Load the tokenizer
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")

Downloading:   0%|          | 0.00/3.22k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-tacred were not used when initializing LukeForEntityPairClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/33.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.98k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

## Evaluation

We classify relations between entity pairs in the test set and measure the performance of the model.

In [None]:
batch_size = 128

num_predicted = 0
num_gold = 0
num_correct = 0

for batch_start_idx in trange(0, len(test_examples), batch_size):
    batch_examples = test_examples[batch_start_idx:batch_start_idx + batch_size]
    texts = [example["text"] for example in batch_examples]
    entity_spans = [example["entity_spans"] for example in batch_examples]
    gold_labels = [example["label"] for example in batch_examples]

    inputs = tokenizer(texts, entity_spans=entity_spans, return_tensors="pt", padding=True)
    inputs = inputs.to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
    predicted_indices = outputs.logits.argmax(-1)
    predicted_labels = [model.config.id2label[index.item()] for index in predicted_indices]
    for predicted_label, gold_label in zip(predicted_labels, gold_labels):
        if predicted_label != "no_relation":
            num_predicted += 1
        if gold_label != "no_relation":
            num_gold += 1
            if predicted_label == gold_label:
                num_correct += 1

precision = num_correct / num_predicted
recall = num_correct / num_gold
f1 = 2 * precision * recall / (precision + recall)

print(f"\n\nprecision: {precision} recall: {recall} f1: {f1}")

100%|██████████| 1/1 [00:01<00:00,  1.17s/it]



precision: 1.0 recall: 1.0 f1: 1.0





## Detecting a relation between a pair of entities

Finally, we detect a relation between a pair of entities in a text using the [fine-tuned model](https://huggingface.co/studio-ousia/luke-large-finetuned-tacred).

In [None]:
text = "Beyoncé lives in Los Angeles."
entity_spans = [(0, 7), (17, 28)]  # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs = inputs.to("cuda")
outputs = model(**inputs)

predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: per:cities_of_residence


In [None]:
text = "Socrates was born in 470 BC in Alopece."

ents = ['Socrates', '470 BC', 'Alopece']
entity_spans_list = []
for ent in ents:
  start = text.find(ent)
  entity_spans_list.append((start, start + len(ent)))


In [None]:
entity_spans_list

[(0, 8), (21, 27), (31, 38)]

In [None]:
entity_spans = [entity_spans_list[0], entity_spans_list[1]]  # character-based entity spans

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs = inputs.to("cuda")
outputs = model(**inputs)

predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: per:date_of_birth


In [None]:
entity_spans = [entity_spans_list[1], entity_spans_list[2]]  # character-based entity spans

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs = inputs.to("cuda")
outputs = model(**inputs)

predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: per:city_of_birth
