# Plan : NER with XLM-RoBERTa

Task: **perform NER for a Switzerland customer**

We will:
1. import XLM-RoBERTa (Cross-Language Model based on RoBERTa)
2. Fine-tune to Named Entity Recognition(NER)

Switzerland = 63% DE + 23% FR + 8% IT + 6% EN

As DE is most spoken, we will use DE for zero-shot cross-lingual transfer to FR, IT and EN.

# Dataset : EXTREME

WikiANN : Wikipedia articles in many languages.
Each article is annotated with LOC, PER, ORG in IOB2 format.

In [1]:
from datasets import load_dataset
import torch
import pandas as pd

from collections import defaultdict
from datasets import DatasetDict

In [2]:
langs = ["de", "fr", "it", "en"] # CH languages
fracs = [0.629, 0.229, 0.084, 0.059]
panx_ch = defaultdict(DatasetDict) # return DatasetDict if key doesnt exist

for lang, frac in zip(langs, fracs):
    
    # DatasetDict. keys={'train', 'validation', 'test'}, values = Dataset
    monolingual_ds = load_dataset("xtreme", name=f'PAN-X.{lang}')

    for split in monolingual_ds:
        # shuffle to avoid bias & downsample according to spoken ratio
        panx_ch[lang][split] = (
            monolingual_ds[split]
            .shuffle(seed=0)
            .select(range(int(frac * monolingual_ds[split].num_rows))))


Found cached dataset xtreme (C:/Users/nikit/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.de\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-e5ddf09f1ae095ec.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.de\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-25e7e2dd003d0fa6.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.de\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-73a95bc0accfea8b.arrow
Found cached dataset xtreme (C:/Users/nikit/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.fr\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-6ff29513007ec78b.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.fr\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-c5c9a4fc19dfd7d6.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.fr\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-9711ab25936b81b7.arrow
Found cached dataset xtreme (C:/Users/nikit/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.it\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-daa9a1770078307c.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.it\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-5e244c05031bab3c.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.it\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-497ee15c12bff58d.arrow
Found cached dataset xtreme (C:/Users/nikit/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.en\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-757845faa9fa6949.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.en\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-305cefc7ffa49fd9.arrow
Loading cached shuffled indices for dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.en\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-e5ec5e6ba7c1237d.arrow


## Inspect one example

In [3]:
for item in panx_ch['fr']['train'][200].items():
    print(item)

('tokens', ['**', 'À', "l'ouest", 'par', 'la', 'province', 'de', 'Banten', '.'])
('ner_tags', [0, 0, 0, 0, 0, 5, 6, 6, 0])
('langs', ['fr', 'fr', 'fr', 'fr', 'fr', 'fr', 'fr', 'fr', 'fr'])


In [4]:
# make tags in human-readable format
tags = panx_ch['fr']['train'].features['ner_tags'].feature
print(f'\n{tags=}')

def create_tag_names(batch):
    return {'ner_tags_str': [tags.int2str(idx) for idx in batch['ner_tags']]}

panx_fr = panx_ch['fr'].map(create_tag_names)
fr_example = panx_fr['train'][200]
pd.DataFrame([fr_example['tokens'], fr_example['ner_tags_str']])

Loading cached processed dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.fr\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-8d1477a4eff16eb5.arrow


Loading cached processed dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.fr\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-0832c54610a26671.arrow
Loading cached processed dataset at C:\Users\nikit\.cache\huggingface\datasets\xtreme\PAN-X.fr\1.0.0\29f5d57a48779f37ccb75cb8708d1095448aad0713b425bdc1ff9a4a128a56e4\cache-af085c7a24eed84e.arrow



tags=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], id=None)


Unnamed: 0,0,1,2,3,4,5,6,7,8
0,**,À,l'ouest,par,la,province,de,Banten,.
1,O,O,O,O,O,B-LOC,I-LOC,I-LOC,O


# Create Custom RoBERTa-based Model for Token Classification

In [5]:
import torch.nn as nn

import transformers
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from transformers.models.roberta.modeling_roberta import RobertaModel

from transformers.modeling_outputs import TokenClassifierOutput

class MyXLMRoBERTaForTokenClassification(RobertaPreTrainedModel):

    def __init__(self, config):
        super().__init__(config) # init RoBERTa with config (class XLMRobertaConfig)
        self.num_labels =  config.num_labels
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        # build classification head
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        # load body pre-trained weights (RoBERTa) & init classifier weights
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):

        encode = self.roberta(input_ids, attention_mask=attention_mask,
                               token_type_ids=token_type_ids, **kwargs)
        
        
        # classification
        output = self.dropout(encode[0]) # [CLS]
        logits = self.classifier(output)

        # Compute loss
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states = encode.hidden_states, attentions=encode.attentions)


## Initialize the model

In [17]:
index2tag = {idx : tag for idx, tag in enumerate(tags.names)}
tag2index = {tag : idx for idx, tag in enumerate(tags.names)}


xlmr_model_name = "xlm-roberta-base"
# init xlmr tokenizer
xlmr_tokenizer = transformers.AutoTokenizer.from_pretrained(xlmr_model_name)

# we dont AutoModel.from_pretrained(xmlr_model_name) 
# because we need to overwrite number of output classes and mappings
# thus, we AutoConfig.from_pretrained(..) first then AutoModel.from_pretrained(NewConfig)
xmlr_config = transformers.AutoConfig.from_pretrained(xlmr_model_name,
                                                      num_labels=tags.num_classes,
                                                      id2label=index2tag,
                                                      label2id=tag2index)

device = torch.device("cuda")
# init our model
# MyXLMRoBERTaForTokenClassification.from_pretrained(..) is inherited from RobertaPreTrainedModel
xlmr_model = MyXLMRoBERTaForTokenClassification.from_pretrained(
    xlmr_model_name, config=xmlr_config).to(device)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing MyXLMRoBERTaForTokenClassification: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing MyXLMRoBERTaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MyXLMRoBERTaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MyXLMRoBERTaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias', 'robert

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
Tokens,<s>,▁Jack,▁Spar,row,▁love,s,▁New,▁Your,k,!,</s>
Tags,O,O,O,O,O,O,O,O,O,O,O


## Let's make dummy inference

In [21]:
text = "Jack Sparrow loves New Yourk!"

def tag_text(tex):
    # {'input_ids' : [0, 217, 37 .. 2], 'attention_mask' : [1,1...1]}
    tokenizer_output = xlmr_tokenizer(text, return_tensors="pt")

    # [<s>	▁Jack	▁Spar	row	...	</s>]
    tokens = tokenizer_output.tokens()

    # [0, 217, 37 .. 2]
    input_ids = tokenizer_output.input_ids

    # [ [0.3, 0.1, 0.6] , [0.4, 0.9, 0.2] , ... [-0.5, 0.1, -0.3] ]
    outputs = xlmr_model(input_ids.to(device))[0]
    print(f'{outputs.shape=}')

    # [0, 2, 3, 3, 0, 0, 5, 6, 6, 0, 0]
    predictions = torch.argmax(outputs, dim=2)
    
    # [O, B-PER, I-PER, I-PER, O, O, B-LOC, I-LOC, O, O]
    preds = [tags.names[p] for p in predictions[0].cpu().numpy()]
    
    return pd.DataFrame([tokens, preds], index=["Tokens", "Tags"])

tag_text(text)

outputs.shape=torch.Size([1, 11, 7])


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
Tokens,<s>,▁Jack,▁Spar,row,▁love,s,▁New,▁Your,k,!,</s>
Tags,O,O,O,O,O,O,O,O,O,O,O


## Now, let's prepare Train Batch

### First, let's check how tokenizer tokenize
- tokenizer splits ONE "l'ouest" -> FOUR tokens '▁l', "'", 'ou', 'est'
- thus we don't need to NER 3 last tokens and need to mask them for NER

In [59]:
words, labels = fr_example['tokens'], fr_example['ner_tags']
print(f'len(words) = {len(words)}\n{words=}')
print(f'{labels=}')

# {input_ids : [0, 96, 25, 796, 525 ... ], attention_mask : [1 1 .. 1] }
tokenized_input = xlmr_tokenizer(words, is_split_into_words=True)

tokens = xlmr_tokenizer.convert_ids_to_tokens(tokenized_input.input_ids)
print(f'\nlen(tokens) = {len(tokens)}\n{tokens=}')

# MASK remaining chunks by setting label = -100

# [None 0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 7, 8, 8, None]
word_ids = tokenized_input.word_ids()

previous_word_idx = None
label_ids = []

for word_idx in word_ids:
    if word_idx is None or word_idx == previous_word_idx:
        label_ids.append(-100)
    else:
        label_ids.append(labels[word_idx])
    previous_word_idx = word_idx

# overwrite labels
labels = [index2tag[l] if l!= -100 else "IGN" for l in label_ids]

pd.DataFrame([tokens, labels])

len(words) = 9
words=['**', 'À', "l'ouest", 'par', 'la', 'province', 'de', 'Banten', '.']
labels=[0, 0, 0, 0, 0, 5, 6, 6, 0]

len(tokens) = 16
tokens=['<s>', '▁**', '▁À', '▁l', "'", 'ou', 'est', '▁par', '▁la', '▁province', '▁de', '▁Ban', 'ten', '▁', '.', '</s>']


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
0,<s>,▁**,▁À,▁l,',ou,est,▁par,▁la,▁province,▁de,▁Ban,ten,▁,.,</s>
1,IGN,O,O,O,IGN,IGN,IGN,O,O,B-LOC,I-LOC,I-LOC,IGN,O,IGN,IGN
