# Загружаем данные 

In [1]:
from datasets import load_dataset

from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
import evaluate
from transformers import AutoModelForTokenClassification
from transformers import TrainingArguments
from transformers import Trainer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from accelerate import Accelerator
from transformers import get_scheduler
from tqdm.auto import tqdm
import torch
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets = load_dataset("conll2003")

Reusing dataset conll2003 (/home/ivan/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)
100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 315.35it/s]


In [3]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [4]:
raw_datasets['train']['ner_tags'][0]

[3, 0, 7, 0, 0, 0, 7, 0, 0]

In [5]:
raw_datasets['train']['tokens'][0]

['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']

In [6]:
raw_datasets['train'].features['ner_tags']

Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

# Обработка данных 

In [7]:
ner_feature = raw_datasets["train"].features["ner_tags"]
ner_feature

Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

In [8]:
label_names = ner_feature.feature.names
label_names

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

In [9]:
words = raw_datasets['train'][0]['tokens']
labels = raw_datasets["train"][0]["ner_tags"]
line1 = ""
line2 = ""
for word, label in zip(words, labels):
    full_label = label_names[label]
    max_length = max(len(word), len(full_label))
    line1 += word + " " * (max_length - len(word) + 1)
    line2 += full_label + " " * (max_length - len(full_label) + 1)

print(line1)
print(line2) 

EU    rejects German call to boycott British lamb . 
B-ORG O       B-MISC O    O  O       B-MISC  O    O 


In [10]:
words = raw_datasets['train'][4]['tokens']
labels = raw_datasets["train"][4]["ner_tags"]
line1 = ""
line2 = ""
for word, label in zip(words, labels):
    full_label = label_names[label]
    max_length = max(len(word), len(full_label))
    line1 += word + " " * (max_length - len(word) + 1)
    line2 += full_label + " " * (max_length - len(full_label) + 1)

print(line1)
print(line2) 

Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer . 
B-LOC   O  O              O  O   B-ORG    I-ORG O  O          O         B-PER  I-PER     O    O  O         O         O      O   O         O    O         O     O    B-LOC   O     O   O          O      O   O       O 


In [11]:
words = raw_datasets['train'][0]['tokens']
pos_names = raw_datasets['train'].features['pos_tags'].feature.names
labels = raw_datasets["train"][0]["pos_tags"]
line1 = ""
line2 = ""
for word, label in zip(words, labels):
    full_label = pos_names[label]
    max_length = max(len(word), len(full_label))
    line1 += word + " " * (max_length - len(word) + 1)
    line2 += full_label + " " * (max_length - len(full_label) + 1)

print(line1)
print(line2) 

EU  rejects German call to boycott British lamb . 
NNP VBZ     JJ     NN   TO VB      JJ      NN   . 


In [12]:
model_checkpoint = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


In [13]:
inputs = tokenizer(raw_datasets["train"][4]["tokens"], is_split_into_words=True)
inputs.tokens()

['[CLS]',
 'germany',
 "'",
 's',
 'representative',
 'to',
 'the',
 'european',
 'union',
 "'",
 's',
 'veterinary',
 'committee',
 'werner',
 'z',
 '##wing',
 '##mann',
 'said',
 'on',
 'wednesday',
 'consumers',
 'should',
 'buy',
 'sheep',
 '##me',
 '##at',
 'from',
 'countries',
 'other',
 'than',
 'britain',
 'until',
 'the',
 'scientific',
 'advice',
 'was',
 'clearer',
 '.',
 '[SEP]']

In [14]:
inputs.word_ids()

[None,
 0,
 1,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 7,
 8,
 9,
 10,
 11,
 11,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 18,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 None]

In [15]:
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

In [16]:
labels = raw_datasets["train"][4]["ner_tags"]
word_ids = inputs.word_ids()
print(labels)
print(align_labels_with_tokens(labels, word_ids))

[5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0]
[-100, 5, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, -100]


In [17]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [18]:
tokenized_datasets = raw_datasets.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)

Loading cached processed dataset at /home/ivan/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-7f7b7bdf1cf63085.arrow
Loading cached processed dataset at /home/ivan/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-2397cfa85b91cecd.arrow
  0%|                                                     | 0/4 [00:00<?, ?ba/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 12.44ba/s]


In [19]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [20]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(3)])
batch["labels"]

tensor([[-100,    3,    0,    7,    0,    0,    0,    7,    0,    0, -100],
        [-100,    1,    2, -100, -100, -100, -100, -100, -100, -100, -100],
        [-100,    5,    0,    0,    0,    0,    0, -100, -100, -100, -100]])

# Метрики 

In [21]:
metric = evaluate.load("seqeval")

In [22]:
metric

EvaluationModule(name: "seqeval", module_type: "metric", features: {'predictions': Sequence(feature=Value(dtype='string', id='label'), length=-1, id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='label'), length=-1, id='sequence')}, usage: """
Produces labelling scores along with its sufficient statistics
from a source against one or more references.

Args:
    predictions: List of List of predicted labels (Estimated targets as returned by a tagger)
    references: List of List of reference labels (Ground truth (correct) target values)
    suffix: True if the IOB prefix is after type, False otherwise. default: False
    scheme: Specify target tagging scheme. Should be one of ["IOB1", "IOB2", "IOE1", "IOE2", "IOBES", "BILOU"].
        default: None
    mode: Whether to count correct entity labels with incorrect I/B tags as true positives or not.
        If you want to only count exact matches, pass mode="strict". default: None.
    sample_weight: Array-like of sha

In [23]:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

# Подготовливаем данные для модели

In [24]:
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], collate_fn=data_collator, batch_size=8
)

In [25]:
def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_labels, true_predictions

#  Модель

In [26]:
id2label = {str(i): label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}

In [37]:
num_warmup_steps = 0
num_train_epochs = 3
output_dir = './models'
batch_size = 8

In [28]:
model_checkpoint

'prajjwal1/bert-tiny'

In [29]:
model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    id2label=id2label,
    label2id=label2id,
)

Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from t

In [30]:
model.config.num_labels

9

In [31]:
optimizer = AdamW(model.parameters(), lr=2e-5)

In [32]:
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [33]:
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)


In [42]:
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(**batch)

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]

        # Necessary to pad predictions and labels for being gathered
        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=true_predictions, references=true_labels)

    results = metric.compute()
    print(
        f"epoch {epoch}:",
        {
            key: results[f"overall_{key}"]
            for key in ["precision", "recall", "f1", "accuracy"]
        },
    )

#     # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir + f'/epoch_{epoch}', save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
#         repo.push_to_hub(
#             commit_message=f"Training in progress epoch {epoch}", blocking=False
#         )


  0%|                                                  | 0/5268 [00:00<?, ?it/s][A
  0%|                                          | 5/5268 [00:00<01:57, 44.93it/s][A
  0%|                                         | 10/5268 [00:00<02:09, 40.61it/s][A
  0%|                                         | 15/5268 [00:00<02:08, 40.98it/s][A
  0%|▏                                        | 20/5268 [00:00<02:07, 41.20it/s][A
  0%|▏                                        | 25/5268 [00:00<02:07, 41.10it/s][A
  1%|▏                                        | 30/5268 [00:00<02:07, 40.95it/s][A
  1%|▎                                        | 35/5268 [00:00<02:07, 41.14it/s][A
  1%|▎                                        | 40/5268 [00:00<02:05, 41.50it/s][A
  1%|▎                                        | 45/5268 [00:01<02:06, 41.38it/s][A
  1%|▍                                        | 50/5268 [00:01<02:06, 41.25it/s][A
  1%|▍                                        | 55/5268 [00:01<02:10, 39.83

  9%|███▍                                    | 460/5268 [00:11<02:04, 38.76it/s][A
  9%|███▌                                    | 464/5268 [00:11<02:05, 38.14it/s][A
  9%|███▌                                    | 469/5268 [00:11<02:02, 39.11it/s][A
  9%|███▌                                    | 474/5268 [00:11<01:59, 40.04it/s][A
  9%|███▋                                    | 479/5268 [00:12<01:57, 40.65it/s][A
  9%|███▋                                    | 484/5268 [00:12<01:56, 41.17it/s][A
  9%|███▋                                    | 489/5268 [00:12<01:56, 41.12it/s][A
  9%|███▊                                    | 494/5268 [00:12<01:56, 41.04it/s][A
  9%|███▊                                    | 499/5268 [00:12<01:58, 40.34it/s][A
 10%|███▊                                    | 504/5268 [00:12<02:01, 39.33it/s][A
 10%|███▊                                    | 508/5268 [00:12<02:01, 39.31it/s][A
 10%|███▉                                    | 512/5268 [00:12<02:01, 39.28i

 18%|███████                                 | 922/5268 [00:23<01:56, 37.16it/s][A
 18%|███████                                 | 926/5268 [00:23<01:54, 37.85it/s][A
 18%|███████                                 | 931/5268 [00:23<01:51, 38.74it/s][A
 18%|███████                                 | 935/5268 [00:23<01:53, 38.19it/s][A
 18%|███████▏                                | 939/5268 [00:23<01:59, 36.09it/s][A
 18%|███████▏                                | 943/5268 [00:23<02:00, 35.95it/s][A
 18%|███████▏                                | 948/5268 [00:23<01:53, 38.00it/s][A
 18%|███████▏                                | 952/5268 [00:23<01:52, 38.44it/s][A
 18%|███████▎                                | 957/5268 [00:24<01:49, 39.47it/s][A
 18%|███████▎                                | 962/5268 [00:24<01:47, 40.04it/s][A
 18%|███████▎                                | 967/5268 [00:24<01:47, 40.09it/s][A
 18%|███████▍                                | 972/5268 [00:24<01:46, 40.36i

 26%|██████████                             | 1361/5268 [00:34<01:51, 35.10it/s][A
 26%|██████████                             | 1365/5268 [00:34<01:50, 35.40it/s][A
 26%|██████████▏                            | 1369/5268 [00:34<01:50, 35.13it/s][A
 26%|██████████▏                            | 1373/5268 [00:34<01:50, 35.20it/s][A
 26%|██████████▏                            | 1377/5268 [00:35<01:46, 36.46it/s][A
 26%|██████████▏                            | 1381/5268 [00:35<01:48, 35.79it/s][A
 26%|██████████▎                            | 1385/5268 [00:35<01:49, 35.51it/s][A
 26%|██████████▎                            | 1389/5268 [00:35<01:47, 36.09it/s][A
 26%|██████████▎                            | 1393/5268 [00:35<01:49, 35.36it/s][A
 27%|██████████▎                            | 1397/5268 [00:35<01:50, 34.98it/s][A
 27%|██████████▎                            | 1401/5268 [00:35<01:48, 35.66it/s][A
 27%|██████████▍                            | 1405/5268 [00:35<01:50, 34.89i

epoch 0: {'precision': 0.629585997980478, 'recall': 0.5591928251121077, 'f1': 0.5923052564914503, 'accuracy': 0.919424277566842}



 33%|█████████████                          | 1762/5268 [00:48<10:41,  5.47it/s][A
 34%|█████████████                          | 1766/5268 [00:48<08:04,  7.23it/s][A
 34%|█████████████                          | 1771/5268 [00:48<05:47, 10.06it/s][A
 34%|█████████████▏                         | 1776/5268 [00:48<04:20, 13.39it/s][A
 34%|█████████████▏                         | 1780/5268 [00:48<03:35, 16.22it/s][A
 34%|█████████████▏                         | 1784/5268 [00:48<03:01, 19.15it/s][A
 34%|█████████████▏                         | 1788/5268 [00:48<02:36, 22.18it/s][A
 34%|█████████████▎                         | 1792/5268 [00:48<02:16, 25.38it/s][A
 34%|█████████████▎                         | 1796/5268 [00:48<02:03, 28.16it/s][A
 34%|█████████████▎                         | 1800/5268 [00:49<01:52, 30.78it/s][A
 34%|█████████████▎                         | 1805/5268 [00:49<01:43, 33.41it/s][A
 34%|█████████████▍                         | 1810/5268 [00:49<01:36, 35.75

 42%|████████████████▏                      | 2187/5268 [00:58<01:17, 39.91it/s][A
 42%|████████████████▏                      | 2191/5268 [00:59<01:19, 38.54it/s][A
 42%|████████████████▎                      | 2195/5268 [00:59<01:22, 37.42it/s][A
 42%|████████████████▎                      | 2199/5268 [00:59<01:21, 37.45it/s][A
 42%|████████████████▎                      | 2203/5268 [00:59<01:23, 36.82it/s][A
 42%|████████████████▎                      | 2207/5268 [00:59<01:24, 36.23it/s][A
 42%|████████████████▎                      | 2211/5268 [00:59<01:25, 35.68it/s][A
 42%|████████████████▍                      | 2216/5268 [00:59<01:20, 38.05it/s][A
 42%|████████████████▍                      | 2221/5268 [00:59<01:18, 38.87it/s][A
 42%|████████████████▍                      | 2226/5268 [01:00<01:15, 40.06it/s][A
 42%|████████████████▌                      | 2231/5268 [01:00<01:14, 40.51it/s][A
 42%|████████████████▌                      | 2236/5268 [01:00<01:16, 39.85i

 50%|███████████████████▌                   | 2650/5268 [01:10<01:05, 40.09it/s][A
 50%|███████████████████▋                   | 2655/5268 [01:10<01:05, 40.18it/s][A
 50%|███████████████████▋                   | 2660/5268 [01:10<01:03, 40.76it/s][A
 51%|███████████████████▋                   | 2665/5268 [01:11<01:03, 40.94it/s][A
 51%|███████████████████▊                   | 2670/5268 [01:11<01:02, 41.31it/s][A
 51%|███████████████████▊                   | 2675/5268 [01:11<01:03, 40.93it/s][A
 51%|███████████████████▊                   | 2680/5268 [01:11<01:02, 41.34it/s][A
 51%|███████████████████▉                   | 2685/5268 [01:11<01:02, 41.13it/s][A
 51%|███████████████████▉                   | 2690/5268 [01:11<01:04, 40.14it/s][A
 51%|███████████████████▉                   | 2695/5268 [01:11<01:03, 40.43it/s][A
 51%|███████████████████▉                   | 2700/5268 [01:11<01:02, 41.12it/s][A
 51%|████████████████████                   | 2705/5268 [01:12<01:01, 41.42i

 59%|██████████████████████▉                | 3099/5268 [01:22<00:58, 37.20it/s][A
 59%|██████████████████████▉                | 3103/5268 [01:22<00:59, 36.31it/s][A
 59%|███████████████████████                | 3107/5268 [01:22<01:00, 35.81it/s][A
 59%|███████████████████████                | 3111/5268 [01:22<01:00, 35.51it/s][A
 59%|███████████████████████                | 3115/5268 [01:22<01:03, 33.94it/s][A
 59%|███████████████████████                | 3119/5268 [01:22<01:02, 34.28it/s][A
 59%|███████████████████████                | 3123/5268 [01:22<01:01, 34.69it/s][A
 59%|███████████████████████▏               | 3127/5268 [01:22<01:00, 35.16it/s][A
 59%|███████████████████████▏               | 3131/5268 [01:23<01:01, 34.49it/s][A
 60%|███████████████████████▏               | 3135/5268 [01:23<00:59, 35.57it/s][A
 60%|███████████████████████▏               | 3139/5268 [01:23<01:02, 33.94it/s][A
 60%|███████████████████████▎               | 3143/5268 [01:23<01:01, 34.50i

epoch 1: {'precision': 0.629585997980478, 'recall': 0.5591928251121077, 'f1': 0.5923052564914503, 'accuracy': 0.919424277566842}



 67%|██████████████████████████             | 3521/5268 [01:36<03:40,  7.94it/s][A
 67%|██████████████████████████             | 3525/5268 [01:36<02:51, 10.15it/s][A
 67%|██████████████████████████▏            | 3529/5268 [01:36<02:15, 12.82it/s][A
 67%|██████████████████████████▏            | 3533/5268 [01:36<01:49, 15.91it/s][A
 67%|██████████████████████████▏            | 3537/5268 [01:36<01:29, 19.25it/s][A
 67%|██████████████████████████▏            | 3541/5268 [01:36<01:16, 22.65it/s][A
 67%|██████████████████████████▏            | 3545/5268 [01:36<01:07, 25.49it/s][A
 67%|██████████████████████████▎            | 3550/5268 [01:36<00:58, 29.40it/s][A
 67%|██████████████████████████▎            | 3554/5268 [01:36<00:54, 31.67it/s][A
 68%|██████████████████████████▎            | 3558/5268 [01:36<00:51, 33.22it/s][A
 68%|██████████████████████████▎            | 3562/5268 [01:37<00:48, 34.88it/s][A
 68%|██████████████████████████▍            | 3567/5268 [01:37<00:46, 36.86

 76%|█████████████████████████████▍         | 3978/5268 [01:47<00:32, 40.17it/s][A
 76%|█████████████████████████████▍         | 3983/5268 [01:47<00:31, 40.23it/s][A
 76%|█████████████████████████████▌         | 3988/5268 [01:47<00:32, 39.43it/s][A
 76%|█████████████████████████████▌         | 3993/5268 [01:47<00:31, 40.20it/s][A
 76%|█████████████████████████████▌         | 3998/5268 [01:47<00:31, 40.38it/s][A
 76%|█████████████████████████████▋         | 4003/5268 [01:48<00:30, 41.04it/s][A
 76%|█████████████████████████████▋         | 4008/5268 [01:48<00:31, 40.29it/s][A
 76%|█████████████████████████████▋         | 4013/5268 [01:48<00:30, 40.56it/s][A
 76%|█████████████████████████████▋         | 4018/5268 [01:48<00:30, 40.55it/s][A
 76%|█████████████████████████████▊         | 4023/5268 [01:48<00:30, 40.17it/s][A
 76%|█████████████████████████████▊         | 4028/5268 [01:48<00:31, 38.83it/s][A
 77%|█████████████████████████████▊         | 4033/5268 [01:48<00:31, 39.74i

 84%|████████████████████████████████▉      | 4451/5268 [01:59<00:19, 41.04it/s][A
 85%|████████████████████████████████▉      | 4456/5268 [01:59<00:20, 40.36it/s][A
 85%|█████████████████████████████████      | 4461/5268 [01:59<00:19, 40.54it/s][A
 85%|█████████████████████████████████      | 4466/5268 [01:59<00:20, 38.34it/s][A
 85%|█████████████████████████████████      | 4470/5268 [01:59<00:21, 37.51it/s][A
 85%|█████████████████████████████████▏     | 4475/5268 [01:59<00:20, 38.84it/s][A
 85%|█████████████████████████████████▏     | 4480/5268 [01:59<00:20, 39.40it/s][A
 85%|█████████████████████████████████▏     | 4485/5268 [02:00<00:19, 40.13it/s][A
 85%|█████████████████████████████████▏     | 4490/5268 [02:00<00:19, 40.73it/s][A
 85%|█████████████████████████████████▎     | 4495/5268 [02:00<00:19, 40.57it/s][A
 85%|█████████████████████████████████▎     | 4500/5268 [02:00<00:18, 41.01it/s][A
 86%|█████████████████████████████████▎     | 4505/5268 [02:00<00:18, 40.48i

 94%|████████████████████████████████████▌  | 4934/5268 [02:11<00:08, 41.01it/s][A
 94%|████████████████████████████████████▌  | 4939/5268 [02:11<00:07, 41.59it/s][A
 94%|████████████████████████████████████▌  | 4944/5268 [02:11<00:07, 41.09it/s][A
 94%|████████████████████████████████████▋  | 4949/5268 [02:11<00:07, 41.44it/s][A
 94%|████████████████████████████████████▋  | 4954/5268 [02:11<00:07, 40.98it/s][A
 94%|████████████████████████████████████▋  | 4959/5268 [02:11<00:07, 39.81it/s][A
 94%|████████████████████████████████████▋  | 4964/5268 [02:11<00:07, 40.50it/s][A
 94%|████████████████████████████████████▊  | 4969/5268 [02:11<00:07, 41.30it/s][A
 94%|████████████████████████████████████▊  | 4974/5268 [02:11<00:07, 41.55it/s][A
 95%|████████████████████████████████████▊  | 4979/5268 [02:12<00:06, 41.79it/s][A
 95%|████████████████████████████████████▉  | 4984/5268 [02:12<00:06, 41.47it/s][A
 95%|████████████████████████████████████▉  | 4989/5268 [02:12<00:06, 41.69i

epoch 2: {'precision': 0.629585997980478, 'recall': 0.5591928251121077, 'f1': 0.5923052564914503, 'accuracy': 0.919424277566842}


#  Анализ результатов

Протестируем модель на тестовых данных:

In [43]:
test_dataloader = DataLoader(
    tokenized_datasets["test"], 
    collate_fn=data_collator,
    batch_size=8
)

In [44]:
model = AutoModelForTokenClassification.from_pretrained('./models/epoch_2')

In [46]:
model.eval()

for batch in test_dataloader:
    with torch.no_grad():
        outputs = model(**batch)

    predictions = outputs.logits.argmax(dim=-1)
    labels = batch["labels"]

    # Necessary to pad predictions and labels for being gathered
    predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
    labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

    predictions_gathered = accelerator.gather(predictions)
    labels_gathered = accelerator.gather(labels)

    true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
    metric.add_batch(predictions=true_predictions, references=true_labels)

    results = metric.compute()
    print(
        f"epoch {epoch}:",
        {
            key: results[f"overall_{key}"]
            for key in ["precision", "recall", "f1", "accuracy"]
        },
    )

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


epoch 2: {'precision': 0.6470588235294118, 'recall': 0.4782608695652174, 'f1': 0.55, 'accuracy': 0.9329608938547486}
epoch 2: {'precision': 0.7058823529411765, 'recall': 0.7058823529411765, 'f1': 0.7058823529411765, 'accuracy': 0.9685863874345549}
epoch 2: {'precision': 0.6, 'recall': 0.5, 'f1': 0.5454545454545454, 'accuracy': 0.9152542372881356}
epoch 2: {'precision': 0.47368421052631576, 'recall': 0.45, 'f1': 0.46153846153846156, 'accuracy': 0.9158415841584159}
epoch 2: {'precision': 0.8181818181818182, 'recall': 0.7105263157894737, 'f1': 0.7605633802816901, 'accuracy': 0.9624060150375939}
epoch 2: {'precision': 0.9285714285714286, 'recall': 0.9285714285714286, 'f1': 0.9285714285714286, 'accuracy': 0.9837837837837838}
epoch 2: {'precision': 0.7, 'recall': 0.6363636363636364, 'f1': 0.6666666666666666, 'accuracy': 0.8676470588235294}
epoch 2: {'precision': 0.9285714285714286, 'recall': 0.9285714285714286, 'f1': 0.9285714285714286, 'accuracy': 0.989247311827957}
epoch 2: {'precision': 1

epoch 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0}
epoch 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0}
epoch 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0}
epoch 2: {'precision': 0.625, 'recall': 0.5263157894736842, 'f1': 0.5714285714285714, 'accuracy': 0.8823529411764706}
epoch 2: {'precision': 0.6363636363636364, 'recall': 0.5, 'f1': 0.56, 'accuracy': 0.8901098901098901}
epoch 2: {'precision': 0.72, 'recall': 0.6206896551724138, 'f1': 0.6666666666666666, 'accuracy': 0.901840490797546}
epoch 2: {'precision': 0.4909090909090909, 'recall': 0.4426229508196721, 'f1': 0.46551724137931033, 'accuracy': 0.8009049773755657}
epoch 2: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 0.9210526315789473}
epoch 2: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 1.0}
epoch 2: {'precision': 0.5, 'recall': 0.375, 'f1': 0.42857142857142855, 'accuracy': 0.9347826086956522}
epoch 2: {'precision': 0.5, 'recall': 0.4166666666666667, 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


epoch 2: {'precision': 0.8888888888888888, 'recall': 1.0, 'f1': 0.9411764705882353, 'accuracy': 0.9682539682539683}
epoch 2: {'precision': 0.6666666666666666, 'recall': 0.46153846153846156, 'f1': 0.5454545454545455, 'accuracy': 0.916030534351145}
epoch 2: {'precision': 0.5714285714285714, 'recall': 0.4, 'f1': 0.47058823529411764, 'accuracy': 0.9577464788732394}
epoch 2: {'precision': 0.5, 'recall': 0.2857142857142857, 'f1': 0.36363636363636365, 'accuracy': 0.9146341463414634}
epoch 2: {'precision': 0.6, 'recall': 0.46153846153846156, 'f1': 0.5217391304347826, 'accuracy': 0.9102564102564102}
epoch 2: {'precision': 0.4583333333333333, 'recall': 0.4230769230769231, 'f1': 0.43999999999999995, 'accuracy': 0.8571428571428571}
epoch 2: {'precision': 0.35714285714285715, 'recall': 0.3125, 'f1': 0.3333333333333333, 'accuracy': 0.9184782608695652}
epoch 2: {'precision': 0.5, 'recall': 0.3888888888888889, 'f1': 0.43750000000000006, 'accuracy': 0.8951612903225806}
epoch 2: {'precision': 0.45, 'rec

epoch 2: {'precision': 0.6521739130434783, 'recall': 0.5555555555555556, 'f1': 0.6, 'accuracy': 0.9214876033057852}
epoch 2: {'precision': 0.6875, 'recall': 0.5, 'f1': 0.5789473684210527, 'accuracy': 0.9285714285714286}
epoch 2: {'precision': 0.3333333333333333, 'recall': 0.25, 'f1': 0.28571428571428575, 'accuracy': 0.9096045197740112}
epoch 2: {'precision': 0.5, 'recall': 0.4166666666666667, 'f1': 0.45454545454545453, 'accuracy': 0.9236641221374046}
epoch 2: {'precision': 0.3333333333333333, 'recall': 0.4, 'f1': 0.3636363636363636, 'accuracy': 0.9328859060402684}
epoch 2: {'precision': 0.4666666666666667, 'recall': 0.5384615384615384, 'f1': 0.5, 'accuracy': 0.9259259259259259}
epoch 2: {'precision': 0.5833333333333334, 'recall': 0.4375, 'f1': 0.5, 'accuracy': 0.9042553191489362}
epoch 2: {'precision': 0.8, 'recall': 0.5714285714285714, 'f1': 0.6666666666666666, 'accuracy': 0.9719626168224299}
epoch 2: {'precision': 0.5, 'recall': 0.4, 'f1': 0.4444444444444445, 'accuracy': 0.9595141700

epoch 2: {'precision': 0.7272727272727273, 'recall': 0.6666666666666666, 'f1': 0.6956521739130435, 'accuracy': 0.9455782312925171}
epoch 2: {'precision': 0.5454545454545454, 'recall': 0.46153846153846156, 'f1': 0.4999999999999999, 'accuracy': 0.9205298013245033}
epoch 2: {'precision': 0.5, 'recall': 0.2857142857142857, 'f1': 0.36363636363636365, 'accuracy': 0.9396551724137931}
epoch 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0}
epoch 2: {'precision': 1.0, 'recall': 0.4, 'f1': 0.5714285714285715, 'accuracy': 0.9705882352941176}
epoch 2: {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 0.9424460431654677}
epoch 2: {'precision': 0.2, 'recall': 0.25, 'f1': 0.22222222222222224, 'accuracy': 0.964824120603015}
epoch 2: {'precision': 0.8181818181818182, 'recall': 0.75, 'f1': 0.7826086956521738, 'accuracy': 0.9705882352941176}
epoch 2: {'precision': 0.6086956521739131, 'recall': 0.56, 'f1': 0.5833333333333334, 'accuracy': 0.9333333333333333}
epoch 2: {'precision': 0.5

epoch 2: {'precision': 0.625, 'recall': 0.47619047619047616, 'f1': 0.5405405405405405, 'accuracy': 0.918918918918919}
epoch 2: {'precision': 0.6842105263157895, 'recall': 0.52, 'f1': 0.5909090909090909, 'accuracy': 0.9146341463414634}
epoch 2: {'precision': 0.42857142857142855, 'recall': 0.3333333333333333, 'f1': 0.375, 'accuracy': 0.9097744360902256}
epoch 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0}
epoch 2: {'precision': 0.7857142857142857, 'recall': 0.7857142857142857, 'f1': 0.7857142857142857, 'accuracy': 0.9767441860465116}
epoch 2: {'precision': 0.9333333333333333, 'recall': 1.0, 'f1': 0.9655172413793104, 'accuracy': 0.978021978021978}
epoch 2: {'precision': 0.8181818181818182, 'recall': 0.8181818181818182, 'f1': 0.8181818181818182, 'accuracy': 0.9642857142857143}
epoch 2: {'precision': 0.8571428571428571, 'recall': 0.75, 'f1': 0.7999999999999999, 'accuracy': 0.9862068965517241}
epoch 2: {'precision': 0.7777777777777778, 'recall': 0.6363636363636364, 'f1': 0.

epoch 2: {'precision': 0.7142857142857143, 'recall': 0.5555555555555556, 'f1': 0.6250000000000001, 'accuracy': 0.8846153846153846}
epoch 2: {'precision': 1.0, 'recall': 0.625, 'f1': 0.7692307692307693, 'accuracy': 0.8571428571428571}
epoch 2: {'precision': 0.8571428571428571, 'recall': 0.75, 'f1': 0.7999999999999999, 'accuracy': 0.9038461538461539}
epoch 2: {'precision': 0.625, 'recall': 0.5, 'f1': 0.5555555555555556, 'accuracy': 0.9433962264150944}
epoch 2: {'precision': 0.4375, 'recall': 0.3888888888888889, 'f1': 0.411764705882353, 'accuracy': 0.6551724137931034}
epoch 2: {'precision': 0.5833333333333334, 'recall': 0.5, 'f1': 0.5384615384615384, 'accuracy': 0.84}
epoch 2: {'precision': 0.23809523809523808, 'recall': 0.20833333333333334, 'f1': 0.22222222222222224, 'accuracy': 0.8135593220338984}
epoch 2: {'precision': 0.631578947368421, 'recall': 0.7058823529411765, 'f1': 0.6666666666666667, 'accuracy': 0.9414893617021277}
epoch 2: {'precision': 0.4, 'recall': 0.35294117647058826, 'f1

In [38]:
results = tokenized_datasets['test'].map(postprocess, batched=True, batch_size=batch_size)


  0%|                                                   | 0/432 [00:00<?, ?ba/s][A


TypeError: postprocess() missing 1 required positional argument: 'labels'