## Data Preprocessing

In [1]:
seed = 599

import torch
torch.manual_seed(seed)

import random
random.seed(seed)

import numpy as np
np.random.seed(seed)

In [2]:
from datasets import load_dataset
from pandas import DataFrame as df

raw_datasets = load_dataset("conll2003")

Reusing dataset conll2003 (/root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [4]:
example_entry = raw_datasets["train"][4]

In [5]:
ner_feature = raw_datasets["train"].features["ner_tags"]
feature_names = ner_feature.feature.names + ["SPECIAL"]

In [6]:
tokens = example_entry["tokens"]
tag_ids = example_entry["ner_tags"]
tags = map(lambda tag_id:feature_names[tag_id], tag_ids)

df(zip(tokens, tag_ids, tags))

Unnamed: 0,0,1,2
0,Germany,5,B-LOC
1,'s,0,O
2,representative,0,O
3,to,0,O
4,the,0,O
5,European,3,B-ORG
6,Union,4,I-ORG
7,'s,0,O
8,veterinary,0,O
9,committee,0,O


In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [8]:
tokenizer.is_fast

True

In [9]:
inputs = tokenizer(tokens, is_split_into_words=True, truncation=True)
# inputs = tokenizer(tokens)
# df(inputs.tokens())

In [10]:
# inputs.word_ids()

In [11]:
# df(zip(inputs.tokens(), inputs.word_ids()))

In [12]:
from typing import List
def align_labels_with_split_tokens(labels_for_word_id: List[int], word_ids: List[int]) -> List[str]:
    """
    Given the map from id of each original word to the orignal label 
    and the word_ids list indicating how the tokenizer split each word, 
    return labels for the tokenized sentence, with word parts 
    padded as "inside".

    For example, if LAMB (B-ORG) is split into LA and ##MB,
    label LA as B-ORG and ##MB as I-ORG. 
    """
    output = []
    prev_word = None
    for word_id in word_ids:
        if word_id is None: 
            # [CLS] or [SEP]
            new_tag_id = -100

        else:
            original_tag_id = labels_for_word_id[word_id]
            if word_id != prev_word:
                # New word.
                # Use the exact same tag id.
                new_tag_id = original_tag_id
            
            else:
                # Non-leading part of a word that was split.
                # Flip any "B-" (odd label id) into "I-" (by adding 1.)
                if (original_tag_id % 2 == 1):
                    new_tag_id = original_tag_id + 1
                else:
                    new_tag_id = original_tag_id

        output.append(new_tag_id)

        prev_word = word_id

    return output


In [13]:
new_feature_names = dict(enumerate(feature_names))
new_feature_names[-100] = "SPECIAL"

In [14]:
new_labels = align_labels_with_split_tokens(tag_ids, inputs.word_ids())

new_tags = map(lambda tag_id: new_feature_names[tag_id], new_labels)

df(zip(inputs.tokens(), inputs.word_ids(), new_labels, new_tags))

Unnamed: 0,0,1,2,3
0,[CLS],,-100,SPECIAL
1,Germany,0.0,5,B-LOC
2,',1.0,0,O
3,s,1.0,0,O
4,representative,2.0,0,O
5,to,3.0,0,O
6,the,4.0,0,O
7,European,5.0,3,B-ORG
8,Union,6.0,4,I-ORG
9,',7.0,0,O


### Apply to the entire dataset

In [15]:
# Peek at "tokens".
df(raw_datasets["train"]["tokens"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,103,104,105,106,107,108,109,110,111,112
0,EU,rejects,German,call,to,boycott,British,lamb,.,,...,,,,,,,,,,
1,Peter,Blackburn,,,,,,,,,...,,,,,,,,,,
2,BRUSSELS,1996-08-22,,,,,,,,,...,,,,,,,,,,
3,The,European,Commission,said,on,Thursday,it,disagreed,with,German,...,,,,,,,,,,
4,Germany,'s,representative,to,the,European,Union,'s,veterinary,committee,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14036,on,Friday,:,,,,,,,,...,,,,,,,,,,
14037,Division,two,,,,,,,,,...,,,,,,,,,,
14038,Plymouth,2,Preston,1,,,,,,,...,,,,,,,,,,
14039,Division,three,,,,,,,,,...,,,,,,,,,,


In [16]:
def tokenize_and_align(examples):
    """
    Given a list of sentences 
    and a list of token ids, output a dataset of the tokenized
    features and the properly-aligned labels. 
    """

    list_of_sentences_of_tokens = examples["tokens"]
    list_of_sentences_of_tag_ids = examples["ner_tags"]

    tokenized_inputs = tokenizer(
        list_of_sentences_of_tokens, 
        truncation=True,  
        # Truncate to the maximum possible length of the model. 
        is_split_into_words=True
    )
    
    list_of_aligned_tags_ids = []

    for (sentence_index, tag_ids) in enumerate(list_of_sentences_of_tag_ids):
        token_ids = tokenized_inputs.word_ids(sentence_index)
        aligned_tag_ids = align_labels_with_split_tokens(tag_ids, token_ids)
        list_of_aligned_tags_ids.append(aligned_tag_ids)

    tokenized_inputs["labels"] = list_of_aligned_tags_ids
    return tokenized_inputs

In [17]:
tokenized_datasets = raw_datasets.map(
    tokenize_and_align,
    batched=True,
    remove_columns=raw_datasets["train"].column_names
)

Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-3b2f5f9281210785.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-1a6416a06fcc9d44.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-712ad618bec2be3e.arrow


## Data Collation

In [18]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer, 
    # label_pad_token_id=len(feature_names)-1,
)
batch = data_collator([tokenized_datasets["train"][index] for index in range(2)])

for key, value in batch.items():
    print(key)
    print(value)
    print()

attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])

input_ids
tensor([[  101,  7270, 22961,  1528,  1840,  1106, 21423,  1418,  2495, 12913,
           119,   102],
        [  101,  1943, 14428,   102,     0,     0,     0,     0,     0,     0,
             0,     0]])

labels
tensor([[-100,    3,    0,    7,    0,    0,    0,    7,    0,    0,    0, -100],
        [-100,    1,    2, -100, -100, -100, -100, -100, -100, -100, -100, -100]])

token_type_ids
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])



In [19]:
from datasets import load_metric
metric = load_metric("seqeval")

In [69]:
labels = raw_datasets["train"][0]["ner_tags"]
labels
predictions = labels.copy()
predictions[2] = 0
predictions[6] = 0

labels = list(map(lambda tag_id:feature_names[tag_id], labels))
predictions = list(map(lambda tag_id:feature_names[tag_id], predictions))

In [70]:
df(zip(labels, predictions))

Unnamed: 0,0,1
0,B-ORG,B-ORG
1,O,O
2,B-MISC,O
3,O,O
4,O,O
5,O,O
6,B-MISC,O
7,O,O
8,O,O


In [71]:
metric.compute(predictions=[predictions], references=[labels])

  _warn_prf(average, modifier, msg_start, len(result))


{'MISC': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 2},
 'ORG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 1.0,
 'overall_recall': 0.3333333333333333,
 'overall_f1': 0.5,
 'overall_accuracy': 0.7777777777777778}

In [23]:
label_names = ner_feature.feature.names
label_names

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

In [24]:
# id2label = {str(index): label for index, label in enumerate(label_names)}
# label2id = {v: k for k, v in id2label.items()}

## Model

In [133]:
import torch.nn as nn
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

In [134]:
import torch
from torch import nn
from transformers import BertModel, BertConfig

include_transformer = True

class CustomModel(nn.Module):
    def __init__(self, bert_config, bilstm_dimension=256):
        super(CustomModel, self).__init__()
        
        self.bert = BertModel(bert_config)
        
        # While pre-trained BERT includes both an embedder and a
        # context-aware pre-trained transformer,
        # include only the embedder if include_transformer = False.
        bert_layers = list(self.bert.children())
        self.bert_embedder = bert_layers[0]

        # Freeze BERT weights in both the embeddings and the transformer.
        for parameter in self.bert.parameters():
            parameter.requires_grad = False

        self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)

        self.bilstm = nn.LSTM(bert_config.hidden_size, bilstm_dimension, num_layers=1, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(bilstm_dimension * 2, bert_config.num_labels)

    def get_trainable_parameters(self):
        return list(self.bilstm.parameters()) + list(self.linear.parameters())

    def forward(self, **batch):
        if include_transformer:
            bert_embedding = self.bert(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                token_type_ids=batch["token_type_ids"]
            )[0]
        else:
            bert_embedding = self.bert_embedder(
                input_ids=batch["input_ids"],
                # attention_mask=batch["attention_mask"],
                # The embedder is supposed to ignore paddings.
                token_type_ids=batch["token_type_ids"]
            )
            
        bilstm_embedding = self.bilstm(bert_embedding)[0]
        bilstm_embedding = self.dropout(bilstm_embedding)

        output = self.linear(bilstm_embedding)  # (batch_size, padded_sentence_length, num_categories)

        return output

bert_config = BertConfig.from_pretrained("bert-base-cased", num_labels=len(feature_names))
model = CustomModel(bert_config)

In [135]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=32,
)

eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    collate_fn=data_collator,
    batch_size=32,
)

In [136]:
from torch.optim import AdamW

# optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8, weight_decay=0.01)
optimizer = AdamW(model.parameters())

In [137]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [138]:
from transformers import get_scheduler

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
print(num_training_steps)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

4390


In [139]:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

In [140]:
def postprocess(predictions, labels):
    predictions = predictions #.cpu().clone().numpy()
    labels = labels #.cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    truncated_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    # truncated_predictions = [
    #     [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
    #     for prediction, label in zip(predictions, labels)
    # ]
    truncated_predictions = []

    for sentence_predictions, sentence_labels in zip(predictions, labels):
        sentence_paired = []
        for prediction, label in zip(sentence_predictions, sentence_labels):
            prediction_index = torch.argmax(prediction).cpu()
            if label != -100:
                sentence_paired.append(label_names[prediction_index])
        
        truncated_predictions.append(sentence_paired)

    return truncated_labels, truncated_predictions

In [141]:
import pprint

In [142]:
from tqdm.auto import tqdm
import torch

progress_bar = tqdm(range(num_training_steps))
# progress_bar = tqdm(range(102))


for epoch in range(num_train_epochs):
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = loss_fct(outputs.reshape(-1, len(feature_names)), batch["labels"].flatten())
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    model.eval()

    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            predictions = model(**batch)

        labels = batch["labels"]

        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=len(feature_names)-1)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=len(feature_names)-1)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_labels, prediction_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=prediction_labels, references=true_labels)

    #print(metric)
    #print(prediction_labels)
    results = metric.compute()
    print(
        f"Epoch {epoch}: ")
    pprint.pprint(results)


accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)

  0%|          | 0/4390 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 0: 
{'LOC': {'f1': 0.2589285714285714,
         'number': 3005,
         'precision': 0.613947696139477,
         'recall': 0.16405990016638936},
 'MISC': {'f1': 0.0, 'number': 1416, 'precision': 0.0, 'recall': 0.0},
 'ORG': {'f1': 0.1258301293254107,
         'number': 2334,
         'precision': 0.3415559772296015,
         'recall': 0.07712082262210797},
 'PER': {'f1': 0.13798449612403102,
         'number': 3101,
         'precision': 0.3472041612483745,
         'recall': 0.08610125765881974},
 'overall_accuracy': 0.7970354488025729,
 'overall_f1': 0.15725637808448348,
 'overall_precision': 0.4478323010957599,
 'overall_recall': 0.09537337662337662}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 1: 
{'LOC': {'f1': 0.4501260352898811,
         'number': 1837,
         'precision': 0.6648936170212766,
         'recall': 0.3402286336418073},
 'MISC': {'f1': 0.008298755186721992,
          'number': 922,
          'precision': 0.09523809523809523,
          'recall': 0.004338394793926247},
 'ORG': {'f1': 0.2427017744705209,
         'number': 1341,
         'precision': 0.5221674876847291,
         'recall': 0.1580909768829232},
 'PER': {'f1': 0.2564734895191122,
         'number': 1842,
         'precision': 0.5279187817258884,
         'recall': 0.16938110749185667},
 'overall_accuracy': 0.8225702007417437,
 'overall_f1': 0.2911248579724782,
 'overall_precision': 0.5826174835775644,
 'overall_recall': 0.19404240996297542}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 2: 
{'LOC': {'f1': 0.4952029520295203,
         'number': 1837,
         'precision': 0.7686139747995419,
         'recall': 0.3652694610778443},
 'MISC': {'f1': 0.025974025974025976,
          'number': 922,
          'precision': 0.16455696202531644,
          'recall': 0.014099783080260303},
 'ORG': {'f1': 0.26997245179063356,
         'number': 1341,
         'precision': 0.5168776371308017,
         'recall': 0.18269947800149142},
 'PER': {'f1': 0.2963702963702964,
         'number': 1842,
         'precision': 0.3832902670111972,
         'recall': 0.24158523344191096},
 'overall_accuracy': 0.8381850827103079,
 'overall_f1': 0.3221948645796694,
 'overall_precision': 0.5311171240819482,
 'overall_recall': 0.23123527431841132}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 3: 
{'LOC': {'f1': 0.5351192420777523,
         'number': 1837,
         'precision': 0.6691176470588235,
         'recall': 0.4458356015242243},
 'MISC': {'f1': 0.12600969305331178,
          'number': 922,
          'precision': 0.2468354430379747,
          'recall': 0.08459869848156182},
 'ORG': {'f1': 0.34912718204488774,
         'number': 1341,
         'precision': 0.5271084337349398,
         'recall': 0.2609992542878449},
 'PER': {'f1': 0.39215686274509803,
         'number': 1842,
         'precision': 0.41144901610017887,
         'recall': 0.3745928338762215},
 'overall_accuracy': 0.8605992818037322,
 'overall_f1': 0.39438053547795987,
 'overall_precision': 0.49909817057459416,
 'overall_recall': 0.3259845169976439}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 4: 
{'LOC': {'f1': 0.5372062663185377,
         'number': 1837,
         'precision': 0.6707416462917686,
         'recall': 0.44801306477953184},
 'MISC': {'f1': 0.1733442354865086,
          'number': 922,
          'precision': 0.3521594684385382,
          'recall': 0.11496746203904555},
 'ORG': {'f1': 0.34273858921161826,
         'number': 1341,
         'precision': 0.3863423760523854,
         'recall': 0.307979120059657},
 'PER': {'f1': 0.3877887788778878,
         'number': 1842,
         'precision': 0.39297658862876256,
         'recall': 0.38273615635179153},
 'overall_accuracy': 0.8666921763701654,
 'overall_f1': 0.39620632923642696,
 'overall_precision': 0.4661808244135732,
 'overall_recall': 0.34449680242342645}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 5: 
{'LOC': {'f1': 0.5671242048878473,
         'number': 1837,
         'precision': 0.7365217391304347,
         'recall': 0.46107784431137727},
 'MISC': {'f1': 0.20806794055201697,
          'number': 922,
          'precision': 0.29938900203665986,
          'recall': 0.1594360086767896},
 'ORG': {'f1': 0.38657718120805373,
         'number': 1341,
         'precision': 0.48322147651006714,
         'recall': 0.3221476510067114},
 'PER': {'f1': 0.4146489104116223,
         'number': 1842,
         'precision': 0.4685362517099863,
         'recall': 0.3718783930510315},
 'overall_accuracy': 0.8707099546712191,
 'overall_f1': 0.42479122648153733,
 'overall_precision': 0.5281461095821867,
 'overall_recall': 0.3552675866711545}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 6: 
{'LOC': {'f1': 0.5872340425531916,
         'number': 1837,
         'precision': 0.7364532019704434,
         'recall': 0.48829613500272184},
 'MISC': {'f1': 0.2427921092564492,
          'number': 922,
          'precision': 0.40404040404040403,
          'recall': 0.1735357917570499},
 'ORG': {'f1': 0.3793395427603726,
         'number': 1341,
         'precision': 0.4387855044074437,
         'recall': 0.3340790454884415},
 'PER': {'f1': 0.47179744296816245,
         'number': 1842,
         'precision': 0.4382859804378202,
         'recall': 0.51085776330076},
 'overall_accuracy': 0.8807028904456349,
 'overall_f1': 0.45617306975009325,
 'overall_precision': 0.5115014638226684,
 'overall_recall': 0.4116459104678559}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 7: 
{'LOC': {'f1': 0.5841772151898734,
         'number': 1837,
         'precision': 0.6976568405139834,
         'recall': 0.502449646162221},
 'MISC': {'f1': 0.28062678062678065,
          'number': 922,
          'precision': 0.4087136929460581,
          'recall': 0.21366594360086769},
 'ORG': {'f1': 0.3940552016985138,
         'number': 1341,
         'precision': 0.45759368836291914,
         'recall': 0.3460104399701715},
 'PER': {'f1': 0.44833524684270953,
         'number': 1842,
         'precision': 0.4756394640682095,
         'recall': 0.4239956568946797},
 'overall_accuracy': 0.8816742214634721,
 'overall_f1': 0.4546765356147265,
 'overall_precision': 0.5301501905402376,
 'overall_recall': 0.39801413665432517}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 8: 
{'LOC': {'f1': 0.5936102236421725,
         'number': 1837,
         'precision': 0.7184841453982985,
         'recall': 0.5057158410451824},
 'MISC': {'f1': 0.2599431818181818,
          'number': 922,
          'precision': 0.3765432098765432,
          'recall': 0.1984815618221258},
 'ORG': {'f1': 0.40115321252059305,
         'number': 1341,
         'precision': 0.44802207911683534,
         'recall': 0.36316181953765847},
 'PER': {'f1': 0.47512576858580213,
         'number': 1842,
         'precision': 0.48963133640552997,
         'recall': 0.46145494028230183},
 'overall_accuracy': 0.8849855772060988,
 'overall_f1': 0.46452959028831564,
 'overall_precision': 0.5321599304650152,
 'overall_recall': 0.4121507909794682}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 9: 
{'LOC': {'f1': 0.5984943538268507,
         'number': 1837,
         'precision': 0.7061435973353072,
         'recall': 0.5193249863908547},
 'MISC': {'f1': 0.29715061058344644,
          'number': 922,
          'precision': 0.3967391304347826,
          'recall': 0.23752711496746204},
 'ORG': {'f1': 0.4140261934938741,
         'number': 1341,
         'precision': 0.4775828460038986,
         'recall': 0.36539895600298283},
 'PER': {'f1': 0.4813186813186813,
         'number': 1842,
         'precision': 0.4872080088987764,
         'recall': 0.4755700325732899},
 'overall_accuracy': 0.8870312591982104,
 'overall_f1': 0.47595838410347735,
 'overall_precision': 0.5371271419504972,
 'overall_recall': 0.42729720632783574}


In [193]:
example_labels = []  # (Ground Truth, Predictions)

for batch in tqdm(eval_dataloader):
    with torch.no_grad():
        predictions = model(**batch)

    labels = batch["labels"]

    predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=len(feature_names)-1)
    labels = accelerator.pad_across_processes(labels, dim=1, pad_index=len(feature_names)-1)

    predictions_gathered = accelerator.gather(predictions)
    labels_gathered = accelerator.gather(labels)

    true_labels, prediction_labels = postprocess(predictions_gathered, labels_gathered)
    
    example_labels.extend(zip(true_labels, prediction_labels))


  0%|          | 0/102 [00:00<?, ?it/s]

In [197]:
df(example_labels[0])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,25,26,27,28,29,30,31,32,33,34
0,O,O,O,O,O,B-ORG,I-ORG,I-ORG,I-ORG,I-ORG,...,O,O,O,O,O,O,O,O,O,O
1,O,O,O,O,O,B-PER,I-PER,I-PER,I-MISC,O,...,O,O,O,O,O,O,O,O,O,O


In [207]:
inputs = tokenizer(raw_datasets["validation"][5]["tokens"], is_split_into_words=True, truncation=True)
# inputs.tokens()


In [269]:
def get_tokenized_sentence(dataset_entry):
    """
    Given a data sentence entry (a sentence),
    return a tokenized list of words from that sentence.
    """
    words = dataset_entry["tokens"]
    tokenized_words = tokenizer(words, is_split_into_words=True, truncation=True).tokens()
    return tokenized_words

tokenized_sentences = list(map(get_tokenized_sentence, raw_datasets["validation"]))

# List of strings, one for each line in the output.
output = []

for sentence, (ground_truth_labels, predicted_labels) in zip(tokenized_sentences, example_labels):
    for word, ground_truth_label, predicted_label in zip(sentence[1:-1], ground_truth_labels, predicted_labels):
        line = f"{word} {ground_truth_label} {predicted_label}\n"
        output.append(line)

    output.append("\n")

with open("eval-nocrf-with-transformer.txt", "w") as output_file:
    output_file.writelines(output)

In [259]:
new_labels = align_labels_with_split_tokens(raw_datasets["validation"][3]["ner_tags"], inputs.word_ids())

0

In [265]:
inputs = tokenizer(raw_datasets["validation"][3]["tokens"], is_split_into_words=True, truncation=True)

df(zip(inputs.tokens(), inputs.word_ids(), tokenized_datasets["validation"][3]["labels"]))

Unnamed: 0,0,1,2
0,[CLS],,-100
1,Their,0.0,0
2,stay,1.0,0
3,on,2.0,0
4,top,3.0,0
5,",",4.0,0
6,though,5.0,0
7,",",6.0,0
8,may,7.0,0
9,be,8.0,0


In [264]:
tokenized_datasets["validation"][3]["labels"]

[-100,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 3,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 0,
 -100]