Load and process the data



In [2]:
! pip install transformers
import torch
import json
from transformers import BertTokenizer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

batch_size = 32
learning_rate = 2e-5
max_len = 180
batch_size = 32

epochs = 4

# dev = xm.xla_device()


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.get_device_name(0))

class DataGetter:
    def __init__(self, path):
        self.path = path
        self.sentence = []
        self.label = []
        self.tokenized_sentence = []
        self.numeric_label = []
        self.tag2idx = {}

    def bio_converter(self):
        # Converting a normal font style file to BIO sequence labeling style
        with open(self.path, 'r') as fp:
            for line in fp.readlines():
                sentence = json.loads(line)  # read the json file line by line
                labels = sentence['label']
                text = sentence['text']
                text_list = list(text)
                target = []  # the 'target' list stores the output
                for i in range(len(text)):  # filling the output list with 'O' tag, which is the default tag
                    target.append('O')
                for key in labels:
                    entity_type = key
                    entity_dict = labels[key]  # e.g. {'叶老桂': [[9, 11]]}
                    for entity_name in entity_dict:
                        # searching for the label indexes in the target,
                        # replacing them with appropriate tag, the start tag is marked as 'B',
                        # 'I' tag is marked until the end of the label
                        entity_start_index = entity_dict[entity_name][0][0]
                        entity_end_index = entity_dict[entity_name][0][1]
                        entity_length = entity_end_index - entity_start_index + 1
                        target[entity_start_index] = 'B-' + str(entity_type)
                        if entity_length != 1:
                            for i in range(entity_start_index + 1, entity_end_index + 1):
                                target[i] = 'I-' + str(entity_type)
                self.label.append(target)
                self.sentence.append(text_list)
        fp.close()
        print("数据保存完毕")
        self.convert_labels_to_id()
        return self.label, self.sentence

    def convert_labels_to_id(self):
        self.count_labels()
        self.numeric_label = [[self.tag2idx.get(l) for l in lab] for lab in self.label]

    def count_labels(self):
      # 通过遍历整个label数组得到该文件中一共有多少种label组合。其中Label组合是指形如i-geo 或者 b-per这类的BIO标注与tag的组合
        tag_values = []
        for label_list in self.label:
            for labels in label_list:
                tag_values.append(labels)
        tag_values = list(set(tag_values))
        self.tag2idx = {t: i for i, t in enumerate(tag_values)}

    def get_tag2idx(self):
      # 得到储存有所有tag与index的字典tag2idx
      return self.tag2idx

    def get_numeric_labels(self):
      return self.numeric_label
    

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 4.6MB/s 
[?25hCollecting tokenizers==0.8.1.rc1
[?25l  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 23.0MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 42.0MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl 

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
train_data_loader = DataGetter('/content/train.json')
train_labels, train_sentences = train_data_loader.bio_converter()
tag2idx = train_data_loader.get_tag2idx()
train_labels = train_data_loader.get_numeric_labels()
print(tag2idx) # "O"'s  id = 3

数据保存完毕
{'B-movie': 0, 'I-movie': 1, 'I-scene': 2, 'I-book': 3, 'B-position': 4, 'B-book': 5, 'B-game': 6, 'I-organization': 7, 'B-scene': 8, 'I-company': 9, 'I-game': 10, 'B-government': 11, 'B-company': 12, 'O': 13, 'I-name': 14, 'B-organization': 15, 'I-position': 16, 'I-address': 17, 'I-government': 18, 'B-address': 19, 'B-name': 20}


Custom our data set to meet Bert's requirement

In [7]:
class CustomData:
    def __init__(self, tokenizer, sentences, labels, max_len):
        self.tokenizer = tokenizer
        self.sentences = sentences
        self.labels = labels
        self.max_len = max_len
        self.len = len(sentences)

    def __getitem__(self, index):
      # 当bert解压数据集时会调用函数getitem。在这里我们override getitem函数使得当bert解压data loader时返回三个tensor字典
        sentence = self.sentences[index]
        label = self.labels[index]
        inputs = tokenizer.encode_plus(
            sentence,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        label.extend([13] * max_len)  # 当我们对数据进行padding时我们把多余的项用标记为O的无关项填充
        label = label[:max_len]

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'tags': torch.tensor(label, dtype=torch.long)
        }

    def __len__(self):
        return self.len

In [8]:
training_set = CustomData(tokenizer, train_sentences, train_labels, max_len)
train_params = {'batch_size': batch_size,
                'shuffle': True,
                'num_workers': 0
                }
#将training的参数以字典形式储存
training_loader = DataLoader(training_set, **train_params) # 打包成tensor的dataloader，其中包括data和训练用的参数

Set up the bert model


In [9]:
from transformers import BertForTokenClassification, AdamW, BertModel, BertConfig
import torch

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        super().__init__()
        self.l1 = BertForTokenClassification.from_pretrained(
            'bert-base-chinese',
            num_labels=len(tag2idx),
        )

    def forward(self, ids, masks, labels):
        output = self.l1(ids, masks, labels=labels)
        return output

model = BERTClass()
model.to(dev)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=624.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411577189.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

BERTClass(
  (l1): BertForTokenClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
       

Set up for fine tune

In [15]:
from transformers import get_linear_schedule_with_warmup

optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate) # set up the optimizer

max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(training_loader) * epochs
print('total steps: %d'%total_steps)
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

total steps: 1344


Fit the model for NER

In [11]:
def train(epoch):
    model.train()
    for step, data in enumerate(training_loader, 0):
        ids = data['ids'].to(dev, dtype=torch.long)
        mask = data['mask'].to(dev, dtype=torch.long)
        targets = data['tags'].to(dev, dtype=torch.long)

        loss = model(ids, mask, labels=targets)[0]

        if step%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        optimizer.step()
        scheduler.step()

In [12]:
for epoch in range(4):
  train(epoch)

Epoch: 0, Loss:  3.041417360305786
Epoch: 1, Loss:  0.15226371586322784
Epoch: 2, Loss:  0.14114871621131897
Epoch: 3, Loss:  0.11926872283220291


Validate the model