<a href="https://colab.research.google.com/github/mirjampaales/cool-ml-project/blob/main/named_entity_recognition/predict_xlmr_ner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## Environment setup

In [1]:
! git clone https://github.com/mukhal/xlm-roberta-ner.git 

Cloning into 'xlm-roberta-ner'...
remote: Enumerating objects: 312, done.[K
remote: Counting objects: 100% (312/312), done.[K
remote: Compressing objects: 100% (187/187), done.[K
remote: Total 312 (delta 165), reused 245 (delta 118), pack-reused 0[K
Receiving objects: 100% (312/312), 2.89 MiB | 9.71 MiB/s, done.
Resolving deltas: 100% (165/165), done.


In [None]:
! pip install -r xlm-roberta-ner/requirements.txt

Downloading the pretrained XLM-R model

In [5]:
%cd xlm-roberta-ner

! mkdir pretrained_models 
! wget -P pretrained_models https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
! tar xzvf pretrained_models/xlmr.base.tar.gz  --directory pretrained_models/
! rm -r pretrained_models/xlmr.base.tar.gz

/content/xlm-roberta-ner
--2021-12-13 00:29:55--  https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.75.142, 104.22.74.142, 172.67.9.4, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.75.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 512274718 (489M) [application/gzip]
Saving to: ‘pretrained_models/xlmr.base.tar.gz’


2021-12-13 00:30:10 (33.6 MB/s) - ‘pretrained_models/xlmr.base.tar.gz’ saved [512274718/512274718]

xlmr.base/
xlmr.base/dict.txt
xlmr.base/sentencepiece.bpe.model
xlmr.base/model.pt


In [81]:
import torch, os
import pandas as pd
from collections import defaultdict
from torch.utils.data import SequentialSampler, DataLoader

from model.xlmr_for_token_classification import XLMRForTokenClassification
from utils.data_utils import InputExample, convert_examples_to_features, InputFeatures, NerProcessor, create_dataset

## Model loading

This sample assumes a model which finetuned XLM-R base is used. If not, then a different pretrained model should be downloaded and the hidden_size parameter may need to be changed. The path of the finetuned model should be configured.

Multilingually finetuned base model can be found [here](https://drive.google.com/file/d/1vVRnEup8AEoUEp1XehV52zkAr0BEwRry/view?usp=sharing)

In [7]:
device = 'cuda' if (torch.cuda.is_available()) else 'cpu'

processor = NerProcessor()
label_list = processor.get_labels()
num_labels = len(label_list) + 1

In [9]:
model = XLMRForTokenClassification(pretrained_path='pretrained_models/xlmr.base/',
                                    n_labels=num_labels, hidden_size=768,
                                    dropout_p=0, device=device)

loading archive file pretrained_models/xlmr.base/
| dictionary: 250001 types


In [None]:
state_dict = torch.load('/content/drive/MyDrive/Colab Notebooks/Machine Learning (Fall 2021)/model.pt', map_location=torch.device(device))
model.load_state_dict(state_dict)
model.to(device)

## Inference methods

In [60]:
def predict(sentences, max_seq_length = 128):
  examples = [InputExample(text_a=sentence, guid=str(i)) for i, sentence in enumerate(sentences)]

  # Generating features
  label_map = {label: i for i, label in enumerate(label_list, 1)}
  ignored_label = "IGNORE"
  label_map[ignored_label] = 0  # 0 label is to be ignored
  features = []
  for (ex_index, example) in enumerate(examples):
      textlist = example.text_a.split(' ')
      labels = []
      valid = []
      label_mask = []
      token_ids = []

      for i, word in enumerate(textlist):
          tokens = model.encode_word(word.strip())  # word token ids
          token_ids.extend(tokens)  # all sentence token ids
          label_1 = 'O'
          for m in range(len(tokens)):
              if m == 0:  # only label the first BPE token of each work
                  labels.append(label_1)
                  valid.append(1)
                  label_mask.append(1)
              else:
                  labels.append(ignored_label)  # unlabeled BPE token
                  label_mask.append(0)
                  valid.append(0)

      if len(token_ids) >= max_seq_length - 1:  # trim extra tokens
          token_ids = token_ids[0:(max_seq_length - 2)]
          labels = labels[0:(max_seq_length - 2)]
          valid = valid[0:(max_seq_length - 2)]
          label_mask = label_mask[0:(max_seq_length - 2)]

      # adding <s>
      token_ids.insert(0, 0)
      labels.insert(0, ignored_label)
      label_mask.insert(0, 0)
      valid.insert(0, 0)

      # adding </s>
      token_ids.append(2)
      labels.append(ignored_label)
      label_mask.append(0)
      valid.append(0)

      label_ids = []
      for i, _ in enumerate(token_ids):
          label_ids.append(label_map[labels[i]])

      input_mask = [1] * len(token_ids)

      while len(token_ids) < max_seq_length:
          token_ids.append(1)  # token padding idx
          input_mask.append(0)
          label_ids.append(label_map[ignored_label])  # label ignore idx
          valid.append(0)
          label_mask.append(0)

      while len(label_ids) < max_seq_length:
          label_ids.append(label_map[ignored_label])
          label_mask.append(0)

      features.append(
          InputFeatures(input_ids=token_ids,
                        input_mask=input_mask,
                        label_id=label_ids,
                        valid_ids=valid,
                        label_mask=label_mask))
      
  data = create_dataset(features)

  # Predict
  eval_sampler = SequentialSampler(data)
  eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=16)

  model.eval()  # turn of dropout

  y_pred = []

  label_map = {i: label for i, label in enumerate(label_list, 1)}

  for input_ids, label_ids, l_mask, valid_ids in eval_dataloader:

      input_ids = input_ids.to(device)
      label_ids = label_ids.to(device)

      valid_ids = valid_ids.to(device)

      with torch.no_grad():
          logits = model(input_ids, labels=None, labels_mask=None, valid_mask=valid_ids)

      logits = torch.argmax(logits, dim=2)
      logits = logits.detach().cpu().numpy()
      label_ids = label_ids.cpu().numpy()

      for i, cur_label in enumerate(label_ids):
          temp = []

          for j, m in enumerate(cur_label):
              if valid_ids[i][j]:  # if it's a valid label
                  temp.append(label_map[logits[i][j]])

          assert len(temp) == len(temp)
          y_pred.append(temp)
  
  return [example.text_a.split(' ') for example in examples], y_pred

In [98]:
def extract_names(sentences):
  tokens, labels = predict(sentences)
  per = []
  org = []
  loc = []

  for sent_tokens, sent_labels in zip(tokens, labels):
    sentence = ' '.join(sent_tokens)
    sent_per = []
    sent_org = []
    sent_loc = []

    entities = {
        'PER': sent_per,
        'ORG': sent_org,
        'LOC': sent_loc
    }

    temp_entity = []
    temp_label = None
    for token, label in zip(sent_tokens, sent_labels):
      if label.split('-')[0] == 'B':
        temp_entity.append(token)
        temp_label = label.split('-')[1]
      elif label.split('-')[0] == 'I' and label.split('-')[1] == temp_label:
        temp_entity.append(token)
      else:
        if temp_label is not None:
          entities[temp_label].append(' '.join(temp_entity))
          temp_entity = []
          temp_label = None

    if temp_label is not None:
          entities[temp_label].append(' '.join(temp_entity))
    
    per.append(sent_per)
    org.append(sent_org)
    loc.append(sent_loc)
  
  return per, org, loc

## Prediction

Can be applied on any file with one sentence per line.

In [100]:
df = pd.read_csv('../test.txt', sep="\n", header=None, names=['sentence'])

In [101]:
per, org, loc = extract_names(df['sentence'].tolist())

In [102]:
df['per'] = per
df['org'] = org
df['loc'] = loc

In [119]:
df.to_csv('../ner.csv', index=False)

In [120]:
df

Unnamed: 0,sentence,per,org,loc
0,Main entrance from the Petser monastery,[],[],[Petser]
1,View Tallinn from the Old Kopli Road,[],[],[Old Kopli Road]
2,Photo postcard,[],[],[]
3,Tallinn : Aleksander Nevski Cathedral,[],[],[Aleksander Nevski Cathedral]
4,Reval : Strandpforten installation,[],[],[]
...,...,...,...,...
111,"Talvinen näkymä Antinkadulta (=Lönnrotinkatu),...",[],"[Helsingin Suomalaisen Reaalilyseon, Ressun lu...","[Antinkadulta, Vanhan kirkon, Kirkkotoria]"
112,"Eerikinkatu 6, 4, 2. Taustalla Yrjönkatu 25.",[],[],"[Eerikinkatu, Yrjönkatu]"
113,Vappukulkue ylittämässä Pitkääsiltaa matkalla ...,[],[],"[Pitkääsiltaa, Stadionille.]"
114,Pallopeliä pelataan Kalliolinnassa. Kalliolinn...,[],[],[Kalliolinnassa. Kalliolinnantie]
