<a href="https://colab.research.google.com/github/nishikaz/PlayGround/blob/master/GoogleColab/NERwithBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NER with BERT

参考：[https://medium.com/@yingbiao/ner-with-bert-in-action-936ff275bc73](https://medium.com/@yingbiao/ner-with-bert-in-action-936ff275bc73)

In [7]:
!mkdir data && wget -P data/ https://raw.githubusercontent.com/billpku/NLP_In_Action/master/data/ner_dataset.csv

--2021-01-20 02:52:01--  https://raw.githubusercontent.com/billpku/NLP_In_Action/master/data/ner_dataset.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 14159575 (14M) [text/plain]
Saving to: ‘data/ner_dataset.csv’


2021-01-20 02:52:02 (34.1 MB/s) - ‘data/ner_dataset.csv’ saved [14159575/14159575]



## 1. Load data

In [22]:
import pandas as pd

data_path = 'data/'
data_file_address = 'data/ner_dataset.csv'

df_data = pd.read_csv(data_file_address, sep=',', encoding='latin1').fillna(method='ffill')

In [39]:
df_data.head(n=20)

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,Sentence: 1,of,IN,O
2,Sentence: 1,demonstrators,NNS,O
3,Sentence: 1,have,VBP,O
4,Sentence: 1,marched,VBN,O
5,Sentence: 1,through,IN,O
6,Sentence: 1,London,NNP,B-geo
7,Sentence: 1,to,TO,O
8,Sentence: 1,protest,VB,O
9,Sentence: 1,the,DT,O


In [38]:
df_data.Tag.unique()

array(['O', 'B-geo', 'B-gpe', 'B-per', 'I-geo', 'B-org', 'I-org', 'B-tim',
       'B-art', 'I-art', 'I-per', 'I-gpe', 'I-tim', 'B-nat', 'B-eve',
       'I-eve', 'I-nat'], dtype=object)

In [29]:
df_data.Tag.value_counts()

O        887908
B-geo     37644
B-tim     20333
B-org     20143
I-per     17251
B-per     16990
I-org     16784
B-gpe     15870
I-geo      7414
I-tim      6528
B-art       402
B-eve       308
I-art       297
I-eve       253
B-nat       201
I-gpe       198
I-nat        51
Name: Tag, dtype: int64

In [31]:
class SentenceGetter(object):
  def __init__(self, data):
    self.n_sent = 1
    self.data = data
    self.empty = False
    agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(),
                                                                               s["POS"].values.tolist(),
                                                                               s["Tag"].values.tolist())]
    self.grouped = self.data.groupby("Sentence #").apply(agg_func)
    self.sentences = [s for s in self.grouped]
  
  def get_next(self):
    try:
      s = self.grouped['Sentence: {}'.format(self.n_sent)]
      self.n_sent += 1
      return s
    except:
      return None

In [37]:
getter = SentenceGetter(df_data)

sentences = [[s[0] for s in sent] for sent in getter.sentences]
labels = [[s[2] for s in sent] for sent in getter.sentences]

In [40]:
address = 0
for sentence, label in zip(sentences[address], labels[address]):
  print(label, '\t', sentence)

O 	 Thousands
O 	 of
O 	 demonstrators
O 	 have
O 	 marched
O 	 through
B-geo 	 London
O 	 to
O 	 protest
O 	 the
O 	 war
O 	 in
B-geo 	 Iraq
O 	 and
O 	 demand
O 	 the
O 	 withdrawal
O 	 of
B-gpe 	 British
O 	 troops
O 	 from
O 	 that
O 	 country
O 	 .


## Set data into training embedding

In [42]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/cd/40/866cbfac4601e0f74c7303d533a9c5d4a53858bd402e08e3e294dd271f25/transformers-4.2.1-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 5.6MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 18.2MB/s 
[?25hCollecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 25.5MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=b8cd66af86cd2

In [58]:
import transformers
import torch

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

tokenized_sentences = tokenizer(sentences, return_tensors='pt', truncation=True, padding=True, is_split_into_words=True)

In [72]:
label2int, int2label = {}, {}
for i, each_label in enumerate(df_data.Tag.unique()):
  label2int[each_label] = i
  int2label[i] = each_label

In [59]:
class NERDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

dataset = NERDataset(tokenized_sentences, labels)

In [69]:
torch.tensor(dataset.labels[0])

ValueError: ignored