In [1]:
import os
import json
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer
from transformers import BertModel
from sklearn.metrics import classification_report

In [2]:
from dataset import Dataset_NER, ner_collate_fn
from tag_id_converter import Tag_ID_Converter

In [3]:
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
PATH_dir = '../data/en_ner_data'
# PATH_dir = '../0_data/En_NER_POS'
PATH_en_train = os.path.join(PATH_dir, 'prepro_train.json')
PATH_en_test = os.path.join(PATH_dir, 'prepro_test.json')
PATH_en_valid = os.path.join(PATH_dir, 'prepro_valid.json')
# PATH_ko_dev = os.path.join(PATH_dir, 'prepro_valid.json')
# PATH_tag_cnt_dict = os.path.join(PATH_dir, 'prepro_tag_cnt.json')

In [5]:
dataset_train = Dataset_NER(PATH_en_train)
dataset_test = Dataset_NER(PATH_en_test)
dataset_valid = Dataset_NER(PATH_en_valid)

In [6]:
print(
		len(dataset_train),
		len(dataset_test),
		len(dataset_valid)
	)

14987 3684 3466


In [7]:
tag_converter = Tag_ID_Converter(PATH_dir, ['prepro_train_tag_list.json', 'prepro_test_tag_list.json', 'prepro_valid_tag_list.json'])

In [8]:
tag_converter.id_to_tag

{0: '[PAD]',
 1: 'I-ORG',
 2: 'B-LOC',
 3: 'I-LOC',
 4: 'B-PER',
 5: 'O',
 6: 'B-MISC',
 7: 'I-MISC',
 8: 'B-ORG',
 9: 'I-PER'}

In [34]:

tag_converter.tag_to_id

{'[PAD]': 0,
 'I-ORG': 1,
 'B-LOC': 2,
 'I-LOC': 3,
 'B-PER': 4,
 'O': 5,
 'B-MISC': 6,
 'I-MISC': 7,
 'B-ORG': 8,
 'I-PER': 9}

In [10]:
batch_size = 16
partial_collate_fn = partial(ner_collate_fn, tokenizer, tag_converter)

In [11]:
dataloader_train = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=partial_collate_fn
)
dataloader_test = DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=partial_collate_fn
)
dataloader_valid = DataLoader(
    dataset_valid,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=partial_collate_fn
)

In [12]:
class Bert_NER(nn.Module):

    def __init__(self, bert, output):
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(p = 0.1)
        self.lin = nn.Linear(768, output)
        self.softmax = nn.Softmax(2)

    def forward(self,**kargs):
        emb = self.bert(**kargs)
        e = self.dropout(emb['last_hidden_state'])
        w = self.lin(e)
        return w

In [13]:
tag_len = len(tag_converter.tag_to_id)
model = Bert_NER(bert, tag_len)

In [14]:
CELoss = nn.CrossEntropyLoss(ignore_index=0)
optimizer = AdamW(model.parameters(), lr=1.0e-5)

In [15]:
model.cuda(1)
device = model.bert.device
device

device(type='cuda', index=1)

In [16]:
train_epoch = 5
grad_accumulation_steps = 8

In [17]:
for epoch in range(train_epoch):
    model.train()

    for iteration, batch in enumerate(dataloader_train):
        batch_inputs = {k: v.cuda(device) for k, v in list(batch[0].items())}
        batch_labels = batch[1].cuda(device)

        output = model(**batch_inputs)
        loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))

        loss.backward()

        if (iteration + 1) % grad_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            print(f'epoch - {epoch}: update step {(iteration + 1) // grad_accumulation_steps}')
            print(f'{iteration + 1} - loss: {loss.item()}')

    print(f'epoch {epoch} END')
  

epoch - 0: update step 1
10 - loss: 2.2721500396728516
epoch - 0: update step 2
20 - loss: 1.8595995903015137
epoch - 0: update step 3
30 - loss: 1.5840789079666138
epoch - 0: update step 4
40 - loss: 1.284036636352539
epoch - 0: update step 5
epoch - 0: update step 6
50 - loss: 1.2129381895065308
epoch - 0: update step 7
60 - loss: 0.9315145611763
epoch - 0: update step 8
70 - loss: 0.9551402926445007
epoch - 0: update step 9
80 - loss: 0.8640445470809937
epoch - 0: update step 10
epoch - 0: update step 11
90 - loss: 0.747524082660675
epoch - 0: update step 12
100 - loss: 0.9073916077613831
epoch - 0: update step 13
110 - loss: 0.8516948223114014
epoch - 0: update step 14
120 - loss: 1.044482946395874
epoch - 0: update step 15
epoch - 0: update step 16
130 - loss: 0.6896594166755676
epoch - 0: update step 17
140 - loss: 0.7908639311790466
epoch - 0: update step 18
150 - loss: 0.711876392364502
epoch - 0: update step 19
160 - loss: 0.9316486716270447
epoch - 0: update step 20
epoch - 0

In [28]:
model.eval()

gold_list = []
pred_list = []

with torch.no_grad():
    for iteration, batch in enumerate(dataloader_test):
        batch_inputs = {k: v.cuda(device) for k, v in list(batch[0].items())}
        batch_labels = batch[1].cuda(device)
        
        output = model(**batch_inputs)
        loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))
        
        print('loss:', loss.item())
        pred_ids = torch.argmax(output, dim=-1)
        
        for g, p in zip(batch_labels, pred_ids):
            gold_mask = g != tag_converter.pad_id
            
            gold = tag_converter.convert_id_to_tag_list(g[gold_mask].tolist())
            pred = tag_converter.convert_id_to_tag_list(p[gold_mask].tolist())
            gold_list.append(gold)
            pred_list.append(pred)
            
            print(f'GOLD:\t\t{gold}')
            print(f'PREDICTION:\t{pred}')

loss: 0.02454856038093567
GOLD:		['O', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['O', 'O', 'O', 'O', 'O', 'O', 'O']
GOLD:		['O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
GOLD:		['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER']
PREDICTION:	['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER']
GOLD:		['B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O']
GOLD:		['B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['B-LOC', 'O', 'O', '

In [29]:
gold_list_flat = []
pred_list_flat = []
for g, p in zip(gold_list, pred_list):
    gold_list_flat += g
    pred_list_flat += p

In [37]:
test_tag_list = []

with open ('../data/en_ner_data/prepro_test_tag_list.json', 'r') as f:
	test_tag_list = json.load(f)

test_tag_list

['B-ORG', 'I-PER', 'I-MISC', 'B-MISC', 'I-ORG', 'B-PER', 'O', 'B-LOC', 'I-LOC']

In [38]:
tags = test_tag_list
tags.remove('O')
print(tags)

['B-ORG', 'I-PER', 'I-MISC', 'B-MISC', 'I-ORG', 'B-PER', 'B-LOC', 'I-LOC']


In [39]:

print(classification_report(gold_list_flat, pred_list_flat, digits=5, labels=tags))

              precision    recall  f1-score   support

       B-ORG    0.89110   0.89333   0.89221      1603
       I-PER    0.95146   0.97013   0.96071      3415
      I-MISC    0.58112   0.64395   0.61093       851
      B-MISC    0.78854   0.83165   0.80952       695
       I-ORG    0.90057   0.91298   0.90673      2758
       B-PER    0.94909   0.95812   0.95358      1576
       B-LOC    0.90958   0.92735   0.91838      1638
       I-LOC    0.86161   0.92179   0.89068      1675

   micro avg    0.88696   0.91211   0.89936     14211
   macro avg    0.85413   0.88241   0.86784     14211
weighted avg    0.88895   0.91211   0.90024     14211



In [None]:
def get_chunk_type(tag_name):
    tag_class = tag_name.split('-')[0]
    tag_type = tag_name.split('-')[-1]
    return tag_class, tag_type

In [23]:
def get_chunks(seq):
    default = "O"

    chunks = []
    chunk_type, chunk_start = None, None
    for i, tok in enumerate(seq):
        # End of a chunk 1
        if tok == default and chunk_type is not None:
            # Add a chunk.
            chunk = (chunk_type, chunk_start, i)
            chunks.append(chunk)
            chunk_type, chunk_start = None, None

        # End of a chunk + start of a chunk!
        elif tok != default:
            tok_chunk_class, tok_chunk_type = get_chunk_type(tok)
            if chunk_type is None:
                chunk_type, chunk_start = tok_chunk_type, i
            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
                chunk = (chunk_type, chunk_start, i)
                chunks.append(chunk)
                chunk_type, chunk_start = tok_chunk_type, i
        else:
            pass

    # end condition
    if chunk_type is not None:
        chunk = (chunk_type, chunk_start, len(seq))
        chunks.append(chunk)

    return chunks

In [24]:
def evaluate_ner_F1(total_answers, total_preds):
    num_match = num_preds = num_answers = 0

    for answers, preds in zip(total_answers, total_preds):

        answer_seg_result = set(get_chunks(answers))
        pred_seg_result = set(get_chunks(preds))

        num_match += len(answer_seg_result & pred_seg_result)
        num_answers += len(answer_seg_result)
        num_preds += len(pred_seg_result)

    precision = 100.0 * num_match / num_preds
    recall = 100.0 * num_match / num_answers
    F1 = 2 * precision * recall / (precision + recall)

    return precision, recall, F1

In [25]:
evaluate_ner_F1(gold_list, pred_list)

(86.70045438657812, 90.00362844702467, 88.32116788321167)

In [41]:
torch.save(model, '../model/model.pt')

In [42]:
torch.save(model.state_dict(), '../model/model_state_dict.pt')