# Experimenting with KB-BERT for NER

## Check the basic pipeline

In [2]:
from transformers import pipeline

In [3]:
nlp = pipeline('ner', model='KB/bert-base-swedish-cased-ner', tokenizer='KB/bert-base-swedish-cased-ner')

nlp('Idag släpper KB tre språkmodeller.')

Downloading: 100%|██████████| 992/992 [00:00<00:00, 440kB/s]
Downloading: 100%|██████████| 499M/499M [00:52<00:00, 9.51MB/s]
Downloading: 100%|██████████| 399k/399k [00:00<00:00, 680kB/s]
Downloading: 100%|██████████| 2.00/2.00 [00:00<00:00, 521B/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 50.0kB/s]
Downloading: 100%|██████████| 182/182 [00:00<00:00, 55.7kB/s]
Downloading: 100%|██████████| 3.00/3.00 [00:00<00:00, 749B/s]
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[{'word': 'Idag', 'score': 0.9997400045394897, 'entity': 'TME', 'index': 1},
 {'word': 'KB', 'score': 0.8388986587524414, 'entity': 'ORG', 'index': 3}]

## Utilize the pre-trained model

In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch 

In [43]:
tokenizer = AutoTokenizer.from_pretrained('KB/bert-base-swedish-cased-ner')
model = AutoModelForTokenClassification.from_pretrained('KB/bert-base-swedish-cased-ner')

In [44]:
model.config.id2label

{0: 'O',
 1: 'OBJ',
 2: 'TME',
 3: 'ORG/PRS',
 4: 'OBJ/ORG',
 5: 'PRS/WRK',
 6: 'WRK',
 7: 'LOC',
 8: 'ORG',
 9: 'PER',
 10: 'LOC/PRS',
 11: 'LOC/ORG',
 12: 'MSR',
 13: 'EVN'}

In [45]:
label_list = model.config.id2label

sequence = "H&M investerar i miljövänliga kläder."

# Bit of a hack to get the tokens with the special tokens
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence)))
inputs = tokenizer.encode(sequence, return_tensors="pt")

outputs = model(inputs)[0]
predictions = torch.argmax(outputs, dim=2)

print([(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())])

[('[CLS]', 'O'), ('H', 'ORG'), ('&', 'ORG'), ('M', 'ORG'), ('investerar', 'O'), ('i', 'O'), ('miljövän', 'O'), ('##liga', 'O'), ('kläder', 'O'), ('.', 'O'), ('[SEP]', 'O')]


In [46]:
outputs.shape

torch.Size([1, 11, 14])