In [None]:
from collections import Counter, defaultdict

import numpy as np
from seqeval.metrics import classification_report
import torch
from torch import nn
from transformers import (
    AutoConfig, AutoTokenizer, #TrainingArguments, <- not M1 compatible
    XLMRobertaConfig)
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import (
    RobertaModel, RobertaPreTrainedModel)

from datasets import (
    DatasetDict, get_dataset_config_names, load_dataset)
import pandas as pd

In [2]:
xtreme_subsets = get_dataset_config_names('xtreme')
print(f'XTREME has {len(xtreme_subsets)} configs')

XTREME has 183 configs


In [3]:
panx_subsets = [s for s in xtreme_subsets if s.startswith('PAN')]
print(len(panx_subsets))
panx_subsets[:3]

40


['PAN-X.af', 'PAN-X.ar', 'PAN-X.bg']

In [4]:
# German
load_dataset('xtreme', name='PAN-X.de')

Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    validation: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 20000
    })
})

In [5]:
langs = ['de', 'fr', 'it', 'en']
fracs = [0.629, 0.229, 0.084, 0.059]
# Return a DatasetDict if key doesn't exist
panx_ch = defaultdict(DatasetDict) 
for lang, frac in zip(langs, fracs):
    # Load monolingual corpus
    ds = load_dataset('xtreme', name=f'PAN-X.{lang}')
    # Shuffle and downsample ea split according to prop. spoken
    for split in ds:
        panx_ch[lang][split] = (
            ds[split]
            .shuffle(seed=0)
            .select(range(int(frac * ds[split].num_rows))))

Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-7318edec81f76aa6.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-cbd29dccd93ef58f.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-4433310f7a3b2793.arrow
Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-4a1996403248b4e2.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-5d4f9e5aefa05972.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-6789784a489dc7d6.arrow
Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-845df155c04c1192.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-4038e5f0ccb7a363.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-e220bc62f3b2de61.arrow
Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-2292d48c0b6f8502.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-56d73ebf7717cb83.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-5117c26f1eb0d215.arrow


In [6]:
pd.DataFrame(
    {lang: [panx_ch[lang]['train'].num_rows] for lang in langs},
    index=['n_training'])

Unnamed: 0,de,fr,it,en
n_training,12580,4580,1680,1180


In [7]:
element = panx_ch['de']['train'][0]
for k, v in element.items():
    print(f'{k}: {v}')

tokens: ['2.000', 'Einwohnern', 'an', 'der', 'Danziger', 'Bucht', 'in', 'der', 'polnischen', 'Woiwodschaft', 'Pommern', '.']
ner_tags: [0, 0, 0, 0, 5, 6, 0, 0, 5, 5, 6, 0]
langs: ['de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de']


In [8]:
for k, v in panx_ch['de']['train'].features.items():
    print(f'{k}: {v}')

tokens: Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)
ner_tags: Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], names_file=None, id=None), length=-1, id=None)
langs: Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)


In [9]:
tags = panx_ch['de']['train'].features['ner_tags'].feature
tags

ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], names_file=None, id=None)

In [10]:
def create_tag_names(batch):
    return {'ner_tags_str': [tags.int2str(idx) 
                             for idx in batch['ner_tags']]}

In [11]:
panx_de = panx_ch['de'].map(create_tag_names)

Loading cached processed dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-b62b878f6a8a1f1b.arrow
Loading cached processed dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-58ddfb338a44ff7d.arrow
Loading cached processed dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-b059df060587602c.arrow


In [12]:
de_ex = panx_de['train'][1]
pd.DataFrame([de_ex['tokens'], de_ex['ner_tags_str']], 
             index=['tokens', 'tags'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
tokens,Sie,geht,hinter,Walluf,nahtlos,in,die,Bundesautobahn,66,über,.
tags,O,O,O,B-ORG,O,O,O,B-ORG,I-ORG,O,O


In [13]:
split2freqs = defaultdict(Counter)
for split, dataset in panx_de.items():
    for row in dataset['ner_tags_str']:
        for tag in row:
            if tag.startswith('B'):
                tag_type = tag.split('-')[1]
                split2freqs[split][tag_type] += 1
pd.DataFrame.from_dict(split2freqs, orient='index')

Unnamed: 0,ORG,LOC,PER
validation,2683,3172,2893
test,2573,3180,3071
train,5366,6186,5810


In [14]:
bert_mod_name = 'bert-base-cased'
xlmr_mod_name = 'xlm-roberta-base'
bert_tokenizer = AutoTokenizer.from_pretrained(bert_mod_name)
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_mod_name)

In [15]:
text = 'Jack Sparrow loves New York!'
bert_tokens = bert_tokenizer(text).tokens()
bert_tokens

['[CLS]', 'Jack', 'Spa', '##rrow', 'loves', 'New', 'York', '!', '[SEP]']

In [16]:
xlmr_tokens = xlmr_tokenizer(text).tokens()
xlmr_tokens

['<s>', '▁Jack', '▁Spar', 'row', '▁love', 's', '▁New', '▁York', '!', '</s>']

In [17]:
u'\u2581'

'▁'

In [18]:
''.join(xlmr_tokens).replace(u'\u2581', ' ')

'<s> Jack Sparrow loves New York!</s>'

In [19]:
class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
    config_class = XLMRobertaConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(
            config.hidden_size, config.num_labels)
        self.init_weights()
        
    def forward(
            self, input_ids=None, attention_mask=None,
            token_type_ids=None, labels=None, **kwargs):
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs)
        sequence_output = self.dropout(outputs[0])
        logits = self.classifier(sequence_output)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.num_labels), labels.view(-1))
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions)

In [20]:
idx2tag = {idx: tag for idx, tag in enumerate(tags.names)}
tag2idx = {tag: idx for idx, tag in idx2tag.items()}

In [21]:
xlmr_config = AutoConfig.from_pretrained(
    xlmr_mod_name,
    num_labels=tags.num_classes,
    id2label=idx2tag,
    label2id=tag2idx)

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
xlmr_model = (XLMRobertaForTokenClassification
              .from_pretrained(xlmr_mod_name, config=xlmr_config)
              .to(device))

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classif

In [23]:
input_ids = xlmr_tokenizer.encode(text, return_tensors='pt')
pd.DataFrame(
    [xlmr_tokens, input_ids[0].numpy()], index=['Tokens', 'Input IDs'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
Tokens,<s>,▁Jack,▁Spar,row,▁love,s,▁New,▁York,!,</s>
Input IDs,0,21763,37456,15555,5161,7,2356,5753,38,2


In [24]:
outputs = xlmr_model(input_ids.to(device)).logits
predictions = torch.argmax(outputs, dim=-1)
print(f'N. tokens in seq: {len(xlmr_tokens)}')
print(f'Shape out: {outputs.shape}')

N. tokens in seq: 10
Shape out: torch.Size([1, 10, 7])


In [25]:
def tag_text(text, tags, model, tokenizer):
    tokens = tokenizer(text).tokens()
    input_ids = (
        xlmr_tokenizer(text, return_tensors='pt').input_ids.to(device))
    outputs = model(inputs)[0]
    preds = torch.argmax(outputs, dim=2)
    preds = [tag.names[p] for p in preds[0].cpu().numpy()]
    return pd.DataFrame([tokens, preds], index=['Tokens', 'Tags'])

In [26]:
words, labels = de_ex['tokens'], de_ex['ner_tags']

In [27]:
tokenized_input = xlmr_tokenizer(
    de_ex['tokens'], is_split_into_words=True)
tokens = xlmr_tokenizer.convert_ids_to_tokens(
    tokenized_input['input_ids'])
pd.DataFrame([tokens], index=['Tokens'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18
Tokens,<s>,▁Sie,▁geht,▁hinter,▁Wall,uf,▁na,ht,los,▁in,▁die,▁Bundes,auto,bahn,▁66,▁über,▁,.,</s>


In [28]:
word_ids = tokenized_input.word_ids()
pd.DataFrame([tokens, word_ids], index=['Tokens', 'Word IDs'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18
Tokens,<s>,▁Sie,▁geht,▁hinter,▁Wall,uf,▁na,ht,los,▁in,▁die,▁Bundes,auto,bahn,▁66,▁über,▁,.,</s>
Word IDs,,0,1,2,3,3,4,4,4,5,6,7,7,7,8,9,10,10,


In [30]:
prev_word_idx = None
label_ids = []

for word_idx in word_ids:
    if word_idx is None or word_idx == prev_word_idx:
        labels.append(-100)
    else:
        label_ids.append(labels[word_idx])
    prev_word_idx = word_idx
    
labels = [idx2tag[lab] if lab != -100 else 'IGN' for lab in label_ids]
index = ['tokens', 'word ids', 'label ids', 'labels']
pd.DataFrame([tokens, word_ids, label_ids, labels], index=index).T

Unnamed: 0,tokens,word ids,label ids,labels
0,<s>,,0.0,O
1,▁Sie,0.0,0.0,O
2,▁geht,1.0,0.0,O
3,▁hinter,2.0,3.0,B-ORG
4,▁Wall,3.0,0.0,O
5,uf,3.0,0.0,O
6,▁na,4.0,0.0,O
7,ht,4.0,3.0,B-ORG
8,los,4.0,4.0,I-ORG
9,▁in,5.0,0.0,O


In [34]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = xlmr_tokenizer(
        examples['tokens'], truncation=True, is_split_into_words=True)
    labels = []
    for i, lab in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        prev_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None or word_idx == prev_word_idx:
                label_ids.append(-100)
            else:
                label_ids.append(lab[word_idx])
            prev_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [35]:
def encode_panx_dataset(corpus):
    return corpus.map(
        tokenize_and_align_labels,
        batched=True,
        remove_columns=['langs', 'ner_tags', 'tokens'])

In [36]:
panx_de_encoded = encode_panx_dataset(panx_ch['de'])

  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

### Eval

In [39]:
y_true = [
    ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'],
    ['B-PER', 'I-PER', 'O']]
y_pred = [
    ['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'],
    ['B-PER', 'I-PER', 'O']]
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

        MISC       0.00      0.00      0.00         1
         PER       1.00      1.00      1.00         1

   micro avg       0.50      0.50      0.50         2
   macro avg       0.50      0.50      0.50         2
weighted avg       0.50      0.50      0.50         2



In [41]:
def align_preds(preds, label_ids):
    preds = np.argmax(preds, axis=2)
    batch_size, seq_len = preds.shape
    labels_list, preds_list = [], []
    for batch in range(batch_size):
        ex_labels, ex_preds = [], []
        for seq in rang(seq_len):
            if lable_ids[batch, seq] != 100:
                ex_labels.append(idx2tag[label_ids[batch][seq]])
                ex_preds.append(idx2tag[preds[batch][seq]])
        labels_list.append(ex_labels)
        preds_list.append(ex_preds)
    return preds_list, labels_list

```python
EPOCHS = 3
BATCH = 24
LOG_STEPS = len(panx_de_encoded['train']) // BATCH
mod_name = f'{xlmr_model_name}-finetuned-panx-de'
training_args = TrainingArguments(
    output_dir=mod_name,
    log_level='error',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH,
    per_device_eval_batch_size=BATCH,
    evaluation_strategy='epoch',
    save_steps=1e6,
    weight_decay=0.01,
    disable_tqdm=False,
    logging_steps=LOG_STEPS,
    push_to_hub=False)
```