## Data Preprocessing

In [1]:
seed = 599

import torch
torch.manual_seed(seed)

import random
random.seed(seed)

import numpy as np
np.random.seed(seed)

In [2]:
from datasets import load_dataset
from pandas import DataFrame as df

raw_datasets = load_dataset("conll2003")

Reusing dataset conll2003 (/root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

In [4]:
example_entry = raw_datasets["train"][4]

In [5]:
ner_feature = raw_datasets["train"].features["ner_tags"]
feature_names = ner_feature.feature.names + ["SPECIAL"]

In [6]:
tokens = example_entry["tokens"]
tag_ids = example_entry["ner_tags"]
tags = map(lambda tag_id:feature_names[tag_id], tag_ids)

df(zip(tokens, tag_ids, tags))

Unnamed: 0,0,1,2
0,Germany,5,B-LOC
1,'s,0,O
2,representative,0,O
3,to,0,O
4,the,0,O
5,European,3,B-ORG
6,Union,4,I-ORG
7,'s,0,O
8,veterinary,0,O
9,committee,0,O


In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [8]:
tokenizer.is_fast

True

In [9]:
inputs = tokenizer(tokens, is_split_into_words=True, truncation=True)
# inputs = tokenizer(tokens)
# df(inputs.tokens())

In [10]:
# inputs.word_ids()

In [11]:
# df(zip(inputs.tokens(), inputs.word_ids()))

In [12]:
from typing import List
def align_labels_with_split_tokens(labels_for_word_id: List[int], word_ids: List[int]) -> List[str]:
    """
    Given the map from id of each original word to the orignal label 
    and the word_ids list indicating how the tokenizer split each word, 
    return labels for the tokenized sentence, with word parts 
    padded as "inside".

    For example, if LAMB (B-ORG) is split into LA and ##MB,
    label LA as B-ORG and ##MB as I-ORG. 
    """
    output = []
    prev_word = None
    for word_id in word_ids:
        if word_id is None: 
            # [CLS] or [SEP]
            new_tag_id = -100

        else:
            original_tag_id = labels_for_word_id[word_id]
            if word_id != prev_word:
                # New word.
                # Use the exact same tag id.
                new_tag_id = original_tag_id
            
            else:
                # Non-leading part of a word that was split.
                # Flip any "B-" (odd label id) into "I-" (by adding 1.)
                if (original_tag_id % 2 == 1):
                    new_tag_id = original_tag_id + 1
                else:
                    new_tag_id = original_tag_id

        output.append(new_tag_id)

        prev_word = word_id

    return output


In [13]:
new_feature_names = dict(enumerate(feature_names))
new_feature_names[-100] = "SPECIAL"

In [14]:
new_labels = align_labels_with_split_tokens(tag_ids, inputs.word_ids())

new_tags = map(lambda tag_id: new_feature_names[tag_id], new_labels)

df(zip(inputs.tokens(), inputs.word_ids(), new_labels, new_tags))

Unnamed: 0,0,1,2,3
0,[CLS],,-100,SPECIAL
1,Germany,0.0,5,B-LOC
2,',1.0,0,O
3,s,1.0,0,O
4,representative,2.0,0,O
5,to,3.0,0,O
6,the,4.0,0,O
7,European,5.0,3,B-ORG
8,Union,6.0,4,I-ORG
9,',7.0,0,O


### Apply to the entire dataset

In [15]:
# Peek at "tokens".
df(raw_datasets["train"]["tokens"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,103,104,105,106,107,108,109,110,111,112
0,EU,rejects,German,call,to,boycott,British,lamb,.,,...,,,,,,,,,,
1,Peter,Blackburn,,,,,,,,,...,,,,,,,,,,
2,BRUSSELS,1996-08-22,,,,,,,,,...,,,,,,,,,,
3,The,European,Commission,said,on,Thursday,it,disagreed,with,German,...,,,,,,,,,,
4,Germany,'s,representative,to,the,European,Union,'s,veterinary,committee,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14036,on,Friday,:,,,,,,,,...,,,,,,,,,,
14037,Division,two,,,,,,,,,...,,,,,,,,,,
14038,Plymouth,2,Preston,1,,,,,,,...,,,,,,,,,,
14039,Division,three,,,,,,,,,...,,,,,,,,,,


In [16]:
def tokenize_and_align(examples):
    """
    Given a list of sentences 
    and a list of token ids, output a dataset of the tokenized
    features and the properly-aligned labels. 
    """

    list_of_sentences_of_tokens = examples["tokens"]
    list_of_sentences_of_tag_ids = examples["ner_tags"]

    tokenized_inputs = tokenizer(
        list_of_sentences_of_tokens, 
        truncation=True,  
        # Truncate to the maximum possible length of the model. 
        is_split_into_words=True
    )
    
    list_of_aligned_tags_ids = []

    for (sentence_index, tag_ids) in enumerate(list_of_sentences_of_tag_ids):
        token_ids = tokenized_inputs.word_ids(sentence_index)
        aligned_tag_ids = align_labels_with_split_tokens(tag_ids, token_ids)
        list_of_aligned_tags_ids.append(aligned_tag_ids)

    tokenized_inputs["labels"] = list_of_aligned_tags_ids
    return tokenized_inputs

In [17]:
tokenized_datasets = raw_datasets.map(
    tokenize_and_align,
    batched=True,
    remove_columns=raw_datasets["train"].column_names
)

Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-3b2f5f9281210785.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-1a6416a06fcc9d44.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6/cache-712ad618bec2be3e.arrow


## Data Collation

In [18]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer, 
    # label_pad_token_id=len(feature_names)-1,
)
batch = data_collator([tokenized_datasets["train"][index] for index in range(2)])

for key, value in batch.items():
    print(key)
    print(value)
    print()

attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])

input_ids
tensor([[  101,  7270, 22961,  1528,  1840,  1106, 21423,  1418,  2495, 12913,
           119,   102],
        [  101,  1943, 14428,   102,     0,     0,     0,     0,     0,     0,
             0,     0]])

labels
tensor([[-100,    3,    0,    7,    0,    0,    0,    7,    0,    0,    0, -100],
        [-100,    1,    2, -100, -100, -100, -100, -100, -100, -100, -100, -100]])

token_type_ids
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])



In [19]:
from datasets import load_metric
metric = load_metric("seqeval")

In [20]:
labels = raw_datasets["train"][0]["ner_tags"]
labels
predictions = labels.copy()
predictions[2] = 0
predictions[6] = 0

labels = list(map(lambda tag_id:feature_names[tag_id], labels))
predictions = list(map(lambda tag_id:feature_names[tag_id], predictions))

In [21]:
df(zip(labels, predictions))

Unnamed: 0,0,1
0,B-ORG,B-ORG
1,O,O
2,B-MISC,O
3,O,O
4,O,O
5,O,O
6,B-MISC,O
7,O,O
8,O,O


In [22]:
metric.compute(predictions=[predictions], references=[labels])

  _warn_prf(average, modifier, msg_start, len(result))


{'MISC': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 2},
 'ORG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 1.0,
 'overall_recall': 0.3333333333333333,
 'overall_f1': 0.5,
 'overall_accuracy': 0.7777777777777778}

In [23]:
label_names = ner_feature.feature.names
label_names

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

In [24]:
# id2label = {str(index): label for index, label in enumerate(label_names)}
# label2id = {v: k for k, v in id2label.items()}

## Model

In [25]:
import torch.nn as nn
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

In [26]:
import torch
from torch import nn
from transformers import BertModel, BertConfig

include_transformer = False

class CustomModel(nn.Module):
    def __init__(self, bert_config, bilstm_dimension=256):
        super(CustomModel, self).__init__()
        
        self.bert = BertModel(bert_config)
        
        # While pre-trained BERT includes both an embedder and a
        # context-aware pre-trained transformer,
        # include only the embedder if include_transformer = False.
        bert_layers = list(self.bert.children())
        self.bert_embedder = bert_layers[0]

        # Freeze BERT weights in both the embeddings and the transformer.
        for parameter in self.bert.parameters():
            parameter.requires_grad = False

        self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)

        self.bilstm = nn.LSTM(bert_config.hidden_size, bilstm_dimension, num_layers=1, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(bilstm_dimension * 2, bert_config.num_labels)

    def get_trainable_parameters(self):
        return list(self.bilstm.parameters()) + list(self.linear.parameters())

    def forward(self, **batch):
        if include_transformer:
            bert_embedding = self.bert(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                token_type_ids=batch["token_type_ids"]
            )[0]
        else:
            bert_embedding = self.bert_embedder(
                input_ids=batch["input_ids"],
                # attention_mask=batch["attention_mask"],
                # The embedder is supposed to ignore paddings.
                token_type_ids=batch["token_type_ids"]
            )
            
        bilstm_embedding = self.bilstm(bert_embedding)[0]
        bilstm_embedding = self.dropout(bilstm_embedding)

        output = self.linear(bilstm_embedding)  # (batch_size, padded_sentence_length, num_categories)

        return output

bert_config = BertConfig.from_pretrained("bert-base-cased", num_labels=len(feature_names))
model = CustomModel(bert_config)

In [27]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=32,
)

eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    collate_fn=data_collator,
    batch_size=32,
)

In [28]:
from torch.optim import AdamW

# optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8, weight_decay=0.01)
optimizer = AdamW(model.parameters())

In [29]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [30]:
from transformers import get_scheduler

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
print(num_training_steps)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

4390


In [31]:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

In [32]:
def postprocess(predictions, labels):
    predictions = predictions #.cpu().clone().numpy()
    labels = labels #.cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    truncated_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    # truncated_predictions = [
    #     [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
    #     for prediction, label in zip(predictions, labels)
    # ]
    truncated_predictions = []

    for sentence_predictions, sentence_labels in zip(predictions, labels):
        sentence_paired = []
        for prediction, label in zip(sentence_predictions, sentence_labels):
            prediction_index = torch.argmax(prediction).cpu()
            if label != -100:
                sentence_paired.append(label_names[prediction_index])
        
        truncated_predictions.append(sentence_paired)

    return truncated_labels, truncated_predictions

In [33]:
import pprint

In [34]:
from tqdm.auto import tqdm
import torch

progress_bar = tqdm(range(num_training_steps))
# progress_bar = tqdm(range(102))


for epoch in range(num_train_epochs):
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = loss_fct(outputs.reshape(-1, len(feature_names)), batch["labels"].flatten())
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    model.eval()

    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            predictions = model(**batch)

        labels = batch["labels"]

        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=len(feature_names)-1)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=len(feature_names)-1)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_labels, prediction_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=prediction_labels, references=true_labels)

    #print(metric)
    #print(prediction_labels)
    results = metric.compute()
    print(
        f"Epoch {epoch}: ")
    pprint.pprint(results)


accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)

  0%|          | 0/4390 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 0: 
{'LOC': {'f1': 0.5354126055880442,
         'number': 1837,
         'precision': 0.6639806607574537,
         'recall': 0.44855743059335873},
 'MISC': {'f1': 0.129783693843594,
          'number': 922,
          'precision': 0.2785714285714286,
          'recall': 0.08459869848156182},
 'ORG': {'f1': 0.30693556320708315,
         'number': 1341,
         'precision': 0.4508670520231214,
         'recall': 0.232662192393736},
 'PER': {'f1': 0.3926470588235294,
         'number': 1842,
         'precision': 0.3579088471849866,
         'recall': 0.43485342019543977},
 'overall_accuracy': 0.8642196974156708,
 'overall_f1': 0.38776099297604155,
 'overall_precision': 0.4527072567962256,
 'overall_recall': 0.33911141029956243}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 1: 
{'LOC': {'f1': 0.6254150316933292,
         'number': 1837,
         'precision': 0.7018970189701897,
         'recall': 0.5639629831246598},
 'MISC': {'f1': 0.40654997176736307,
          'number': 922,
          'precision': 0.42402826855123676,
          'recall': 0.39045553145336226},
 'ORG': {'f1': 0.38625204582651396,
         'number': 1341,
         'precision': 0.42792384406165007,
         'recall': 0.35197613721103654},
 'PER': {'f1': 0.4615028901734104,
         'number': 1842,
         'precision': 0.4019331453886428,
         'recall': 0.5418023887079262},
 'overall_accuracy': 0.8939630305527757,
 'overall_f1': 0.4835906521555724,
 'overall_precision': 0.484858737946202,
 'overall_recall': 0.4823291820935712}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 2: 
{'LOC': {'f1': 0.6773144040224786,
         'number': 1837,
         'precision': 0.741580310880829,
         'recall': 0.623298856831791},
 'MISC': {'f1': 0.485895221646517,
          'number': 922,
          'precision': 0.5177914110429448,
          'recall': 0.45770065075921906},
 'ORG': {'f1': 0.4443458980044346,
         'number': 1341,
         'precision': 0.5481400437636762,
         'recall': 0.37360178970917224},
 'PER': {'f1': 0.4774614472123369,
         'number': 1842,
         'precision': 0.5261437908496732,
         'recall': 0.4370249728555918},
 'overall_accuracy': 0.8993936539706835,
 'overall_f1': 0.5347603536528618,
 'overall_precision': 0.5981678117843015,
 'overall_recall': 0.48350723662066647}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 3: 
{'LOC': {'f1': 0.7300872502110892,
         'number': 1837,
         'precision': 0.7558275058275058,
         'recall': 0.7060424605334785},
 'MISC': {'f1': 0.5430140329469189,
          'number': 922,
          'precision': 0.6206415620641562,
          'recall': 0.482646420824295},
 'ORG': {'f1': 0.48530519969856817,
         'number': 1341,
         'precision': 0.49047981721249045,
         'recall': 0.4802386278896346},
 'PER': {'f1': 0.549532710280374,
         'number': 1842,
         'precision': 0.5407251707829742,
         'recall': 0.5586319218241043},
 'overall_accuracy': 0.9161123211867899,
 'overall_f1': 0.5892502803899576,
 'overall_precision': 0.604531775535493,
 'overall_recall': 0.5747223157186132}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 4: 
{'LOC': {'f1': 0.7424501424501425,
         'number': 1837,
         'precision': 0.7788404064554693,
         'recall': 0.7093086554164398},
 'MISC': {'f1': 0.5610519724483406,
          'number': 922,
          'precision': 0.6637037037037037,
          'recall': 0.48590021691973967},
 'ORG': {'f1': 0.5211038961038962,
         'number': 1341,
         'precision': 0.5716829919857525,
         'recall': 0.47874720357941836},
 'PER': {'f1': 0.5850599781897492,
         'number': 1842,
         'precision': 0.5876232201533407,
         'recall': 0.5825190010857764},
 'overall_accuracy': 0.9210278448225113,
 'overall_f1': 0.6167808523890025,
 'overall_precision': 0.6543326411176137,
 'overall_recall': 0.5833052844160216}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 5: 
{'LOC': {'f1': 0.7567720090293453,
         'number': 1837,
         'precision': 0.7855887521968365,
         'recall': 0.7299945563418617},
 'MISC': {'f1': 0.6055344546934346,
          'number': 922,
          'precision': 0.6058631921824105,
          'recall': 0.6052060737527115},
 'ORG': {'f1': 0.5534591194968553,
         'number': 1341,
         'precision': 0.5852036575228595,
         'recall': 0.5249813571961223},
 'PER': {'f1': 0.6159026159026159,
         'number': 1842,
         'precision': 0.5889053987122338,
         'recall': 0.6454940282301845},
 'overall_accuracy': 0.9293430270206628,
 'overall_f1': 0.6431478968792401,
 'overall_precision': 0.6482051282051282,
 'overall_recall': 0.6381689666778863}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 6: 
{'LOC': {'f1': 0.7589212748569871,
         'number': 1837,
         'precision': 0.7595419847328244,
         'recall': 0.7583015786608601},
 'MISC': {'f1': 0.6203389830508474,
          'number': 922,
          'precision': 0.6474056603773585,
          'recall': 0.5954446854663774},
 'ORG': {'f1': 0.5480203841630733,
         'number': 1341,
         'precision': 0.5776859504132231,
         'recall': 0.5212527964205816},
 'PER': {'f1': 0.6097035040431267,
         'number': 1842,
         'precision': 0.6054603854389722,
         'recall': 0.6140065146579805},
 'overall_accuracy': 0.9306234179078119,
 'overall_f1': 0.6446761237395316,
 'overall_precision': 0.6548611111111111,
 'overall_recall': 0.6348030966004712}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 7: 
{'LOC': {'f1': 0.7718587145338738,
         'number': 1837,
         'precision': 0.8243661100803958,
         'recall': 0.7256396298312466},
 'MISC': {'f1': 0.6504249291784703,
          'number': 922,
          'precision': 0.6809015421115066,
          'recall': 0.6225596529284165},
 'ORG': {'f1': 0.5416523825848475,
         'number': 1341,
         'precision': 0.501269035532995,
         'recall': 0.5891126025354213},
 'PER': {'f1': 0.6409163558870538,
         'number': 1842,
         'precision': 0.62918410041841,
         'recall': 0.6530944625407166},
 'overall_accuracy': 0.9322128686642727,
 'overall_f1': 0.6560134566862911,
 'overall_precision': 0.6556825823806322,
 'overall_recall': 0.6563446650959273}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 8: 
{'LOC': {'f1': 0.7746362887729893,
         'number': 1837,
         'precision': 0.7812846068660022,
         'recall': 0.7681001633097442},
 'MISC': {'f1': 0.6605922551252849,
          'number': 922,
          'precision': 0.6954436450839329,
          'recall': 0.6290672451193059},
 'ORG': {'f1': 0.5583573487031701,
         'number': 1341,
         'precision': 0.5400696864111498,
         'recall': 0.5779269202087994},
 'PER': {'f1': 0.6515353805073432,
         'number': 1842,
         'precision': 0.6410930110352075,
         'recall': 0.6623235613463626},
 'overall_accuracy': 0.9351415788544181,
 'overall_f1': 0.6687919463087248,
 'overall_precision': 0.6667781866845098,
 'overall_recall': 0.6708179064288119}


  0%|          | 0/102 [00:00<?, ?it/s]

Epoch 9: 
{'LOC': {'f1': 0.7804063456721403,
         'number': 1837,
         'precision': 0.7984054669703873,
         'recall': 0.7632008709853021},
 'MISC': {'f1': 0.6583990980834273,
          'number': 922,
          'precision': 0.6854460093896714,
          'recall': 0.6334056399132321},
 'ORG': {'f1': 0.5742056074766355,
         'number': 1341,
         'precision': 0.5757121439280359,
         'recall': 0.5727069351230425},
 'PER': {'f1': 0.6439595529536987,
         'number': 1842,
         'precision': 0.6315240083507306,
         'recall': 0.6568946796959826},
 'overall_accuracy': 0.9357302643197739,
 'overall_f1': 0.671864406779661,
 'overall_precision': 0.6766814612495732,
 'overall_recall': 0.6671154493436553}


In [38]:
example_labels = []  # (Ground Truth, Predictions)

for batch in tqdm(eval_dataloader):
    with torch.no_grad():
        predictions = model(**batch)

    labels = batch["labels"]

    predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=len(feature_names)-1)
    labels = accelerator.pad_across_processes(labels, dim=1, pad_index=len(feature_names)-1)

    predictions_gathered = accelerator.gather(predictions)
    labels_gathered = accelerator.gather(labels)

    true_labels, prediction_labels = postprocess(predictions_gathered, labels_gathered)
    
    example_labels.extend(zip(true_labels, prediction_labels))


def get_tokenized_sentence(dataset_entry):
    """
    Given a data sentence entry (a sentence),
    return a tokenized list of words from that sentence.
    """
    words = dataset_entry["tokens"]
    tokenized_words = tokenizer(words, is_split_into_words=True, truncation=True).tokens()
    return tokenized_words

tokenized_sentences = list(map(get_tokenized_sentence, raw_datasets["validation"]))

# List of strings, one for each line in the output.
output = []

for sentence, (ground_truth_labels, predicted_labels) in zip(tokenized_sentences, example_labels):
    for word, ground_truth_label, predicted_label in zip(sentence[1:-1], ground_truth_labels, predicted_labels):
        line = f"{word} {ground_truth_label} {predicted_label}\n"
        output.append(line)

    output.append("\n")

with open("eval-nocrf-no-transformer.txt", "w") as output_file:
    output_file.writelines(output)

  0%|          | 0/102 [00:00<?, ?it/s]