In [1]:
import os
import json
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer
from transformers import BertModel
from sklearn.metrics import classification_report

In [2]:
from dataset import Dataset_NER, ner_collate_fn
from tag_id_converter import Tag_ID_Converter

In [3]:
PRETAINED_MODEL_NAME = 'bert-base-multilingual-cased'
tokenizer = BertTokenizer.from_pretrained(PRETAINED_MODEL_NAME)
bert = BertModel.from_pretrained(PRETAINED_MODEL_NAME)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
PATH_dir = '../data/en_ner_data'
# PATH_dir = '../0_data/En_NER_POS'
PATH_en_train = os.path.join(PATH_dir, 'prepro_train.json')
PATH_en_test = os.path.join(PATH_dir, 'prepro_test.json')
PATH_en_valid = os.path.join(PATH_dir, 'prepro_valid.json')
# PATH_ko_dev = os.path.join(PATH_dir, 'prepro_valid.json')
# PATH_tag_cnt_dict = os.path.join(PATH_dir, 'prepro_tag_cnt.json')

In [5]:
dataset_train = Dataset_NER(PATH_en_train)
dataset_test = Dataset_NER(PATH_en_test)
dataset_valid = Dataset_NER(PATH_en_valid)

In [6]:
tag_converter = Tag_ID_Converter(PATH_dir, ['prepro_train_tag_list.json', 'prepro_test_tag_list.json', 'prepro_valid_tag_list.json'])

In [7]:
batch_size = 16
partial_collate_fn = partial(ner_collate_fn, tokenizer, tag_converter)

In [8]:
dataloader_train = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=partial_collate_fn
)
dataloader_test = DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=partial_collate_fn
)
dataloader_valid = DataLoader(
    dataset_valid,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=partial_collate_fn
)

In [9]:
class Bert_NER(nn.Module):

    def __init__(self, bert, output):
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(p = 0.1)
        self.lin = nn.Linear(768, output)
        self.softmax = nn.Softmax(2)

    def forward(self,**kargs):
        emb = self.bert(**kargs)
        e = self.dropout(emb['last_hidden_state'])
        w = self.lin(e)
        return w

In [10]:
tag_len = len(tag_converter.tag_to_id)
model = Bert_NER(bert, tag_len)

In [12]:
CELoss = nn.CrossEntropyLoss(ignore_index=0)
optimizer = AdamW(model.parameters(), lr=1.0e-5)

In [13]:
## loading model 
# loaded_model = torch.load('../model/model.pt')

## loading model parameter
model.load_state_dict(torch.load('../model/model_state_dict.pt'))

<All keys matched successfully>

In [15]:
model.cuda(2)
device = model.bert.device
device

device(type='cuda', index=2)

In [16]:
model.eval()

gold_list = []
pred_list = []

with torch.no_grad():
    for iteration, batch in enumerate(dataloader_test):
        batch_inputs = {k: v.cuda(device) for k, v in list(batch[0].items())}
        batch_labels = batch[1].cuda(device)
        
        output = model(**batch_inputs)
        loss = CELoss(output.view(-1, output.size(-1)), batch_labels.view(-1))
        
        print('loss:', loss.item())
        pred_ids = torch.argmax(output, dim=-1)
        
        for g, p in zip(batch_labels, pred_ids):
            gold_mask = g != tag_converter.pad_id
            
            gold = tag_converter.convert_id_to_tag_list(g[gold_mask].tolist())
            pred = tag_converter.convert_id_to_tag_list(p[gold_mask].tolist())
            gold_list.append(gold)
            pred_list.append(pred)
            
            print(f'GOLD:\t\t{gold}')
            print(f'PREDICTION:\t{pred}')

loss: 9.704825401306152
GOLD:		['O', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC']
GOLD:		['O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-MISC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'O', 'I-LOC', 'I-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC']
GOLD:		['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER']
PREDICTION:	['O', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC']
GOLD:		['B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O']
PREDICTION:	['B-MISC', 'I-ORG', 'I-ORG', 'I-ORG', 'B-LOC', 'B-MISC', 'I-ORG', 'I-ORG', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC', 'B-LOC']
GOLD:		['B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC

In [17]:
gold_list_flat = []
pred_list_flat = []
for g, p in zip(gold_list, pred_list):
    gold_list_flat += g
    pred_list_flat += p

In [18]:
test_tag_list = []

with open ('../data/en_ner_data/prepro_test_tag_list.json', 'r') as f:
	test_tag_list = json.load(f)

test_tag_list

['B-ORG', 'I-PER', 'I-MISC', 'B-MISC', 'I-ORG', 'B-PER', 'O', 'B-LOC', 'I-LOC']

In [19]:
tags = test_tag_list
tags.remove('O')
print(tags)

['B-ORG', 'I-PER', 'I-MISC', 'B-MISC', 'I-ORG', 'B-PER', 'B-LOC', 'I-LOC']


In [20]:

print(classification_report(gold_list_flat, pred_list_flat, digits=5, labels=tags))

              precision    recall  f1-score   support

       B-ORG    0.05321   0.02433   0.03339      1603
       I-PER    0.01180   0.00966   0.01063      3415
      I-MISC    0.00062   0.00118   0.00081       851
      B-MISC    0.01198   0.02878   0.01691       695
       I-ORG    0.06194   0.04025   0.04879      2758
       B-PER    0.00000   0.00000   0.00000      1576
       B-LOC    0.00030   0.00977   0.00058      1638
       I-LOC    0.00460   0.00955   0.00621      1675

   micro avg    0.00357   0.01661   0.00587     14211
   macro avg    0.01806   0.01544   0.01467     14211
weighted avg    0.02206   0.01661   0.01746     14211



In [17]:
def get_chunk_type(tag_name):
    tag_class = tag_name.split('-')[0]
    tag_type = tag_name.split('-')[-1]
    return tag_class, tag_type

In [18]:
def get_chunks(seq):
    default = "O"

    chunks = []
    chunk_type, chunk_start = None, None
    for i, tok in enumerate(seq):
        # End of a chunk 1
        if tok == default and chunk_type is not None:
            # Add a chunk.
            chunk = (chunk_type, chunk_start, i)
            chunks.append(chunk)
            chunk_type, chunk_start = None, None

        # End of a chunk + start of a chunk!
        elif tok != default:
            tok_chunk_class, tok_chunk_type = get_chunk_type(tok)
            if chunk_type is None:
                chunk_type, chunk_start = tok_chunk_type, i
            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
                chunk = (chunk_type, chunk_start, i)
                chunks.append(chunk)
                chunk_type, chunk_start = tok_chunk_type, i
        else:
            pass

    # end condition
    if chunk_type is not None:
        chunk = (chunk_type, chunk_start, len(seq))
        chunks.append(chunk)

    return chunks

In [19]:
def evaluate_ner_F1(total_answers, total_preds):
    num_match = num_preds = num_answers = 0

    for answers, preds in zip(total_answers, total_preds):

        answer_seg_result = set(get_chunks(answers))
        pred_seg_result = set(get_chunks(preds))

        num_match += len(answer_seg_result & pred_seg_result)
        num_answers += len(answer_seg_result)
        num_preds += len(pred_seg_result)

    precision = 100.0 * num_match / num_preds
    recall = 100.0 * num_match / num_answers
    F1 = 2 * precision * recall / (precision + recall)

    return precision, recall, F1

In [20]:
evaluate_ner_F1(gold_list, pred_list)

(0.5633581796783053, 6.513062409288825, 1.0370177795109552)