In [21]:
from collections import Counter, defaultdict

from torch import nn
from transformers import AutoTokenizer, XLMRobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import (
    RobertaModel, RobertaPreTrainedModel)

from datasets import (
    DatasetDict, get_dataset_config_names, load_dataset)
import pandas as pd

In [2]:
xtreme_subsets = get_dataset_config_names('xtreme')
print(f'XTREME has {len(xtreme_subsets)} configs')

XTREME has 183 configs


In [3]:
panx_subsets = [s for s in xtreme_subsets if s.startswith('PAN')]
print(len(panx_subsets))
panx_subsets[:3]

40


['PAN-X.af', 'PAN-X.ar', 'PAN-X.bg']

In [4]:
# German
load_dataset('xtreme', name='PAN-X.de')

Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    validation: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 20000
    })
})

In [5]:
langs = ['de', 'fr', 'it', 'en']
fracs = [0.629, 0.229, 0.084, 0.059]
# Return a DatasetDict if key doesn't exist
panx_ch = defaultdict(DatasetDict) 
for lang, frac in zip(langs, fracs):
    # Load monolingual corpus
    ds = load_dataset('xtreme', name=f'PAN-X.{lang}')
    # Shuffle and downsample ea split according to prop. spoken
    for split in ds:
        panx_ch[lang][split] = (
            ds[split]
            .shuffle(seed=0)
            .select(range(int(frac * ds[split].num_rows))))

Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-7318edec81f76aa6.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-cbd29dccd93ef58f.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-4433310f7a3b2793.arrow
Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-4a1996403248b4e2.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-5d4f9e5aefa05972.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-6789784a489dc7d6.arrow
Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-845df155c04c1192.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-4038e5f0ccb7a363.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-e220bc62f3b2de61.arrow
Reusing dataset xtreme (/Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-2292d48c0b6f8502.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-56d73ebf7717cb83.arrow
Loading cached shuffled indices for dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-5117c26f1eb0d215.arrow


In [6]:
pd.DataFrame(
    {lang: [panx_ch[lang]['train'].num_rows] for lang in langs},
    index=['n_training'])

Unnamed: 0,de,fr,it,en
n_training,12580,4580,1680,1180


In [7]:
element = panx_ch['de']['train'][0]
for k, v in element.items():
    print(f'{k}: {v}')

tokens: ['2.000', 'Einwohnern', 'an', 'der', 'Danziger', 'Bucht', 'in', 'der', 'polnischen', 'Woiwodschaft', 'Pommern', '.']
ner_tags: [0, 0, 0, 0, 5, 6, 0, 0, 5, 5, 6, 0]
langs: ['de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de']


In [8]:
for k, v in panx_ch['de']['train'].features.items():
    print(f'{k}: {v}')

tokens: Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)
ner_tags: Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], names_file=None, id=None), length=-1, id=None)
langs: Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)


In [9]:
tags = panx_ch['de']['train'].features['ner_tags'].feature
tags

ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], names_file=None, id=None)

In [10]:
def create_tag_names(batch):
    return {'ner_tags_str': [tags.int2str(idx) 
                             for idx in batch['ner_tags']]}

In [11]:
panx_de = panx_ch['de'].map(create_tag_names)

Loading cached processed dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-b62b878f6a8a1f1b.arrow
Loading cached processed dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-58ddfb338a44ff7d.arrow
Loading cached processed dataset at /Users/damiansp/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/2fc6b63c5326cc0d1f73060649612889b3a7ed8a6605c91cecdbd228a7158b17/cache-b059df060587602c.arrow


In [12]:
de_ex = panx_de['train'][1]
pd.DataFrame([de_ex['tokens'], de_ex['ner_tags_str']], 
             index=['tokens', 'tags'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
tokens,Sie,geht,hinter,Walluf,nahtlos,in,die,Bundesautobahn,66,über,.
tags,O,O,O,B-ORG,O,O,O,B-ORG,I-ORG,O,O


In [13]:
split2freqs = defaultdict(Counter)
for split, dataset in panx_de.items():
    for row in dataset['ner_tags_str']:
        for tag in row:
            if tag.startswith('B'):
                tag_type = tag.split('-')[1]
                split2freqs[split][tag_type] += 1
pd.DataFrame.from_dict(split2freqs, orient='index')

Unnamed: 0,ORG,LOC,PER
validation,2683,3172,2893
test,2573,3180,3071
train,5366,6186,5810


In [14]:
bert_mod_name = 'bert-base-cased'
xlmr_mod_name = 'xlm-roberta-base'
bert_tokenizer = AutoTokenizer.from_pretrained(bert_mod_name)
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_mod_name)

Downloading:   0%|          | 0.00/615 [00:00<?, ?B/s]

In [15]:
text = 'Jack Sparrow loves New York!'
bert_tokens = bert_tokenizer(text).tokens()
bert_tokens

['[CLS]', 'Jack', 'Spa', '##rrow', 'loves', 'New', 'York', '!', '[SEP]']

In [16]:
xlmr_tokens = xlmr_tokenizer(text).tokens()
xlmr_tokens

['<s>', '▁Jack', '▁Spar', 'row', '▁love', 's', '▁New', '▁York', '!', '</s>']

In [17]:
u'\u2581'

'▁'

In [18]:
''.join(xlmr_tokens).replace(u'\u2581', ' ')

'<s> Jack Sparrow loves New York!</s>'

In [22]:
class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
    config_class = XLMRobertaConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(
            config.hidden_size, config.num_labels)
        self.init_weights()
        
    def forward(
            self, input_ids=None, attention_mask=None,
            token_type_ids=None, labels=None, **kwargs):
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs)
        sequence_output = self.dropout(outputs[0])
        logits = self.classifier(sequence_output)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.num_labels), labels.view(-1))
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions)