In [4]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")
sys.path.append("./")
from ner.train import *
from chem_sentencepiece.chem_sentencepiece import ChemSentencePiece
from chem_sentencepiece.char_tokenizer import CharTokenizer
from config import config_dics
import matplotlib.pyplot as plt
from collections import Counter
%matplotlib inline

seed_num = 42
random.seed(seed_num)
np.random.seed(seed_num)
torch.manual_seed(seed_num)
torch.cuda.manual_seed(seed_num)
torch.backends.cudnn.deterministic = True

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
!python3.7 lm/train.py --config BL.contextual.attention

2019/07/02 22:14:00
{'cache_dir': './Repository/Cache/',
 'char_emb_dim': 30,
 'char_hidden_dim': 100,
 'glove_path': './Repository/GloVe/gv.d50',
 'gpu': True,
 'grad_clip': 0,
 'lm_batch_size': 10,
 'lm_dropout': 0.5,
 'lm_epoch': 0,
 'lm_input_path': './Repository/LargeCorpus/large_corpus_2000k.txt',
 'lm_lr': 5,
 'lm_model_dir': './Repository/LanguageModel/',
 'ner_batch_size': 10,
 'ner_dropout': 0.5,
 'ner_epoch': 150,
 'ner_input_dir': './Repository/Chemdner/',
 'ner_lr': 0.01,
 'ner_model_dir': './Repository/NERModel/',
 'number_normalize': True,
 'sp_path': {},
 'sw_emb_dim': 50,
 'sw_hidden_dim': 50,
 'use_modality_attention': True,
 'vocab_dir': './Repository/Vocabulary/',
 'weight_decay': 1e-05,
 'word_emb_dim': 50,
 'word_hidden_dim': 200}
Use Cache data. ./Repository/Cache/large_corpus_2000k.txt.word_documents
Use Cache data. ./Repository/Cache/large_corpus_2000k.txt.Char_documents
PerfectMatch: 305739, CaseMatch: 1631, NotMatch: 100521.
[712]
LanguageModel(
  (word_lstm)

In [5]:
args_config = "BL.contextual.attention"
config_dic = config_dics["BL.contextual.attention"]
pprint.pprint(config_dic)

{'cache_dir': './Repository/Cache/',
 'char_emb_dim': 30,
 'char_hidden_dim': 100,
 'glove_path': './Repository/GloVe/gv.d50',
 'gpu': True,
 'grad_clip': 0,
 'lm_batch_size': 10,
 'lm_dropout': 0.5,
 'lm_epoch': 0,
 'lm_input_path': './Repository/LargeCorpus/large_corpus_2000k.txt',
 'lm_lr': 5,
 'lm_model_dir': './Repository/LanguageModel/',
 'ner_batch_size': 10,
 'ner_dropout': 0.5,
 'ner_epoch': 150,
 'ner_input_dir': './Repository/Chemdner/',
 'ner_lr': 0.01,
 'ner_model_dir': './Repository/NERModel/',
 'number_normalize': True,
 'sp_path': {},
 'sw_emb_dim': 50,
 'sw_hidden_dim': 50,
 'use_modality_attention': True,
 'vocab_dir': './Repository/Vocabulary/',
 'weight_decay': 1e-05,
 'word_emb_dim': 50,
 'word_hidden_dim': 200}


In [6]:
# load sentence piece
sps: dict = {"Char": CharTokenizer()}
for sp_key, sp_path in config_dic["sp_path"].items():
    sps[sp_key] = ChemSentencePiece.load(sp_path)

In [7]:
# load train data
print("=========== Load train data ===========")
train_word_documents, train_label_documents = load_seq_data(os.path.join(config_dic.get("ner_input_dir"), "train.bioes"), config_dic.get("number_normalize"))
valid_word_documents, valid_label_documents = load_seq_data(os.path.join(config_dic.get("ner_input_dir"), "valid.bioes"), config_dic.get("number_normalize"))
test_word_documents, test_label_documents = load_seq_data(os.path.join(config_dic.get("ner_input_dir"), "test.bioes"), config_dic.get("number_normalize"))

train_sw_documents_dicts = {}
valid_sw_documents_dicts = {}
test_sw_documents_dicts = {}
for sp_key, sp in sps.items():
    train_sw_documents_dicts[sp_key] = get_sw_documents(train_word_documents, sp)
    valid_sw_documents_dicts[sp_key] = get_sw_documents(valid_word_documents, sp)
    test_sw_documents_dicts[sp_key] = get_sw_documents(test_word_documents, sp)
    

  0%|          | 0/903738 [00:00<?, ?it/s]



100%|██████████| 903738/903738 [00:02<00:00, 367502.53it/s]
100%|██████████| 898421/898421 [00:02<00:00, 404374.38it/s]
100%|██████████| 776037/776037 [00:02<00:00, 385301.83it/s]


In [8]:
# load vocabulary
print("=========== Build vocabulary ===========")
if os.path.exists(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.word.dic")):
    word_dic = Dictionary.load(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.word.dic"))
    sw_dicts = {}
    for sp_key, sp in sps.items():
        sw_dicts[sp_key] = Dictionary.load(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.{sp_key}.dic"))
else:
    special_token_dict = {PADDING: 0, UNKNOWN: 1, START: 2, END: 3}
    word_dic = Dictionary()
    word_dic.token2id = special_token_dict
    sw_dicts = {}
    for sp_key, sp in sps.items():
        _dic = Dictionary()
        _dic.token2id = special_token_dict
        sw_dicts[sp_key] = _dic
label_dic = Dictionary(train_label_documents)
label_dic.patch_with_special_tokens({PADDING: 0})
label_dic.id2token = {_id: label for label, _id in label_dic.token2id.items()}

# add vocabulary
word_dic.add_documents(train_word_documents)
for sp_key, train_sw_documents in train_sw_documents_dicts.items():
    sw_dicts[sp_key].add_documents(train_sw_documents)



In [15]:
# load GloVe
if config_dic.get("glove_path"):
    print("============== Load Pretrain Word Embeddings ================")
    word2vec = load_pretrain_embeddings(config_dic.get("glove_path"), emb_dim=config_dic.get("word_emb_dim"))
    pretrain_embeddings = build_pretrain_embeddings(word2vec, word_dic, emb_dim=config_dic.get("word_emb_dim"))
else:
    pretrain_embeddings = None

PerfectMatch: 306631, CaseMatch: 2023, NotMatch: 102215.


In [16]:
config_dic["ner_lr"] = 0.015

In [18]:
seq_model = SeqModel(config_dic, len(word_dic.token2id), None, [len(sw_dic.token2id) for sw_dic in sw_dicts.values()], len(label_dic.token2id), pretrain_embeddings)
optimizer = torch.optim.SGD(seq_model.parameters(), lr=config_dic.get("ner_lr"), weight_decay=config_dic.get("weight_decay"), momentum=0)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.99)

print(seq_model)
print(optimizer)

SeqModel(
  (word_lstm): WordLSTM(
    (subword_rep_list): ModuleList(
      (0): SubwordRep(
        (dropout): Dropout(p=0.5)
        (embeddings): Embedding(722, 50)
        (lstm): LSTM(50, 25, batch_first=True, bidirectional=True)
      )
    )
    (dropout): Dropout(p=0.5)
    (word_embedding): Embedding(410869, 50)
    (modality_att): ModalityAttention(
      (drop): Dropout(p=0.5)
      (w_m): Linear(in_features=100, out_features=100, bias=True)
      (sigmoid): Sigmoid()
    )
    (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)
  )
  (hidden2tag): Linear(in_features=200, out_features=8, bias=True)
  (crf): CRF()
)
SGD (
Parameter Group 0
    dampening: 0
    initial_lr: 0.015
    lr: 0.015
    momentum: 0
    nesterov: False
    weight_decay: 1e-05
)


# 学習

In [19]:
## start training
epoch = config_dic.get("ner_epoch")
for epoch_i in range(epoch):
    print("Epoch: %s/%s" %(epoch_i, epoch))
    print(f"Learning Rate: {lr_scheduler.get_lr()}")
    # shuffle
    random_ids = list(range(len(train_word_documents)))
    random.shuffle(random_ids)
    
    #####################  Batch Initialize ############################
    total_loss, batch_ave_loss, right_token, total_token = 0, 0, 0, 0
    batch_size = config_dic.get("ner_batch_size")
    batch_steps = len(train_word_documents) // batch_size + 1
    seq_model.train()
    seq_model.zero_grad()
    optimizer.zero_grad()

    for batch_i in range(batch_steps):
        start_time = time.time()
        batch_ids = random_ids[batch_i * batch_size: (batch_i + 1) * batch_size]
        batch_word_documents = [train_word_documents[i] for i in batch_ids]
        batch_label_documents = [train_label_documents[i] for i in batch_ids]
        word_features = get_word_features(batch_word_documents, word_dic, config_dic.get("gpu"))
        #char_features = get_char_features(batch_word_documents, char_dic, config_dic.get("gpu"))
        sw_features_list = []
        for sp_key, sp in sps.items():
            sw_features_list.append(get_sw_features(batch_word_documents, sw_dicts[sp_key], sp, config_dic.get("gpu")))
        label_features = get_label_features(batch_label_documents, label_dic, config_dic.get("gpu"))
        loss, train_tag_seq = seq_model.neg_log_likelihood_loss(word_features, None, sw_features_list, label_features)
        batch_ave_loss += loss.data
        total_loss += loss.data
        loss.backward()

        optimizer.step()
        seq_model.zero_grad()

        rt, tt = predict_check(train_tag_seq, label_features.get("label_ids"), word_features.get("masks"))
        right_token += rt
        total_token += tt
        if batch_i != 0 and batch_i % 50 == 0:
            if batch_ave_loss > 1e8 or str(loss) == "nan":
                print("Error: Loss Explosion (>1e8)! EXIT...")
                exit(1)
            sys.stdout.flush()
            print(f"""Batch: {batch_i}; Time(sec/batch): {time.time() - start_time:.4f}; Loss: {batch_ave_loss:.4f} Right: {right_token}, Total: {total_token}, Accuracy: {right_token / total_token:.4f}""") 
            batch_ave_loss = 0
    print(f"Total Loss: {total_loss}")

    ################ valid predict check #####################
    print("============== Valid Evaluate ===========")
    true_seqs, pred_seqs = [], []
    right_token, total_token = 0, 0
    batch_steps = len(valid_word_documents) // batch_size + 1
    random_ids = list(range(len(valid_word_documents)))
    seq_model.eval()
    for batch_i in range(batch_steps):
        batch_ids = random_ids[batch_i * batch_size: (batch_i + 1) * batch_size]
        batch_word_documents = [valid_word_documents[i] for i in batch_ids]
        batch_label_documents = [valid_label_documents[i] for i in batch_ids]

        valid_word_features = get_word_features(batch_word_documents, word_dic, config_dic.get("gpu"))
        #valid_char_features = get_char_features(batch_word_documents, char_dic, config_dic.get("gpu"))
        valid_sw_features_list = []
        for sp_key, sp in sps.items():
            valid_sw_features_list.append(get_sw_features(batch_word_documents, sw_dicts[sp_key], sp, config_dic.get("gpu")))
        valid_label_features = get_label_features(batch_label_documents, label_dic, config_dic.get("gpu"))
        valid_tag_seq = seq_model.forward(valid_word_features, None, valid_sw_features_list)
        masks = valid_word_features.get("masks")
        rt, tt = predict_check(valid_tag_seq, valid_label_features.get("label_ids"), masks)
        right_token += rt
        total_token += tt
        ################ evaluate by precision, recall and fscore ###################
        true_seqs.extend([label_dic.id2token.get(int(label_id), label_dic.token2id["O"]) for label_id in valid_label_features.get("label_ids").masked_select(masks)])
        pred_seqs.extend([label_dic.id2token.get(int(label_id), label_dic.token2id["O"]) for label_id in valid_tag_seq.masked_select(masks)])
    precision, recall, fscore = evaluate(true_seqs, pred_seqs)
    print(f"Right: {right_token}, Total: {total_token}, Accuracy: {right_token / total_token:.4f}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, Fscore: {fscore:.4f}")

Epoch: 0/150
Learning Rate: [0.015]
Batch: 50; Time(sec/batch): 0.2237; Loss: 22661.3145 Right: 11946.0, Total: 13892.0, Accuracy: 0.8599
Batch: 100; Time(sec/batch): 0.2394; Loss: 7527.8320 Right: 24525.0, Total: 27789.0, Accuracy: 0.8825
Batch: 150; Time(sec/batch): 0.1904; Loss: 5537.0254 Right: 37473.0, Total: 41803.0, Accuracy: 0.8964
Batch: 200; Time(sec/batch): 0.2061; Loss: 4238.6865 Right: 50215.0, Total: 55941.0, Accuracy: 0.8976
Batch: 250; Time(sec/batch): 0.1946; Loss: 3401.1235 Right: 63256.0, Total: 70155.0, Accuracy: 0.9017
Batch: 300; Time(sec/batch): 0.1574; Loss: 2748.5903 Right: 76449.0, Total: 84279.0, Accuracy: 0.9071
Batch: 350; Time(sec/batch): 0.3103; Loss: 5522.6191 Right: 89058.0, Total: 98279.0, Accuracy: 0.9062
Batch: 400; Time(sec/batch): 0.2657; Loss: 2560.0706 Right: 101995.0, Total: 112180.0, Accuracy: 0.9092
Batch: 450; Time(sec/batch): 0.1719; Loss: 2730.5583 Right: 115662.0, Total: 127056.0, Accuracy: 0.9103
Batch: 500; Time(sec/batch): 0.3039; Loss:

KeyboardInterrupt: 

# Predict!!!

In [13]:
word_dic.id2token = {v: k for k, v in word_dic.token2id.items()}
char_dic.id2token = {v: k for k, v in char_dic.token2id.items()}
seq_model.load_state_dict(torch.load("./Repository/NERModel/SW2k.4k.8k.2000k.NoLM.model"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [52]:
word_seqs, original_word_seqs, true_seqs, pred_seqs = [], [], [], []
right_token, total_token = 0, 0
batch_size = config_dic.get("ner_batch_size")
batch_steps = len(valid_word_documents) // batch_size + 1
random_ids = list(range(len(valid_word_documents)))
seq_model.eval()
for batch_i in tqdm(range(batch_steps)):
    batch_ids = random_ids[batch_i * batch_size: (batch_i + 1) * batch_size]
    batch_word_documents = [valid_word_documents[i] for i in batch_ids]
    batch_label_documents = [valid_label_documents[i] for i in batch_ids]

    valid_word_features = get_word_features(batch_word_documents, word_dic, config_dic.get("gpu"))
    valid_char_features = get_char_features(batch_word_documents, char_dic, config_dic.get("gpu"))
    
    valid_sw_features_list = []
    for sp_key, sp in sps.items():
        valid_sw_features_list.append(get_sw_features(batch_word_documents, sw_dicts[sp_key], sp, config_dic.get("gpu")))
    
    valid_label_features = get_label_features(batch_label_documents, label_dic, config_dic.get("gpu"))
    valid_tag_seq = seq_model.forward(valid_word_features, valid_char_features, valid_sw_features_list)
    rt, tt = predict_check(valid_tag_seq, valid_label_features.get("label_ids"), valid_word_features.get("masks"))
    right_token += rt
    total_token += tt
    ################ evaluate by precision, recall and fscore ###################
    masks = valid_word_features.get("masks")
    #word_seqs.extend([word_dic.id2token.get(int(word_id)) for word_id in valid_word_features.get("word_ids").masked_select(masks)])
    original_word_seqs.extend([word for word_document in batch_word_documents for word in word_document])

#     print("=============================")
#     print([word_dic.id2token.get(int(word_id)) for word_id in valid_word_features.get("word_ids").masked_select(masks)])
#     print([word for word_document in batch_word_documents for word in word_document])
    
#     char_ids = valid_char_features.get("char_ids")
#     char_masks = masks.unsqueeze(-1).expand(char_ids.shape)
#     chars = []
#     for i, char_id in enumerate(char_ids.masked_select(char_masks)):
#         if i != 0 and i % char_ids.shape[-1] == 0:
#             char_seqs.append(chars)
#             chars = []
#         if not char_id == char_dic.token2id[PADDING]:
#             chars.append(char_dic.id2token.get(int(char_id)))
#     char_seqs.append(chars)

    true_seqs.extend([label_dic.id2token[int(label_id)] for label_id in valid_label_features.get("label_ids").masked_select(masks)])
    pred_seqs.extend([label_dic.id2token[int(label_id)] for label_id in valid_tag_seq.masked_select(masks)])

precision, recall, fscore = evaluate(true_seqs, pred_seqs)
print(f"Right: {right_token}, Total: {total_token}, Accuracy: {right_token / total_token:.4f}")
print(f"Precision: {precision}, Recall: {recall}, Fscore: {fscore}")






  0%|          | 0/3081 [00:00<?, ?it/s][A[A[A[A[A




  0%|          | 1/3081 [00:00<05:38,  9.10it/s][A[A[A[A[A




  0%|          | 3/3081 [00:00<05:10,  9.90it/s][A[A[A[A[A




  0%|          | 5/3081 [00:00<05:00, 10.25it/s][A[A[A[A[A




  0%|          | 6/3081 [00:00<05:07,  9.98it/s][A[A[A[A[A




  0%|          | 8/3081 [00:00<04:51, 10.54it/s][A[A[A[A[A




  0%|          | 9/3081 [00:00<05:02, 10.16it/s][A[A[A[A[A




  0%|          | 11/3081 [00:00<04:41, 10.91it/s][A[A[A[A[A




  0%|          | 13/3081 [00:01<04:40, 10.93it/s][A[A[A[A[A




  0%|          | 15/3081 [00:01<04:50, 10.56it/s][A[A[A[A[A




  1%|          | 17/3081 [00:01<04:43, 10.81it/s][A[A[A[A[A




  1%|          | 19/3081 [00:01<04:32, 11.24it/s][A[A[A[A[A




  1%|          | 21/3081 [00:01<04:19, 11.78it/s][A[A[A[A[A




  1%|          | 23/3081 [00:02<04:23, 11.60it/s][A[A[A[A[A




  1%|          | 25/3081 [00:02<04:13, 12.05

 15%|█▌        | 465/3081 [00:39<03:46, 11.56it/s][A[A[A[A[A




 15%|█▌        | 467/3081 [00:39<03:42, 11.74it/s][A[A[A[A[A




 15%|█▌        | 469/3081 [00:40<03:37, 12.01it/s][A[A[A[A[A




 15%|█▌        | 471/3081 [00:40<03:35, 12.13it/s][A[A[A[A[A




 15%|█▌        | 473/3081 [00:40<03:35, 12.13it/s][A[A[A[A[A




 15%|█▌        | 475/3081 [00:40<03:37, 11.98it/s][A[A[A[A[A




 15%|█▌        | 477/3081 [00:40<04:10, 10.40it/s][A[A[A[A[A




 16%|█▌        | 479/3081 [00:40<03:58, 10.89it/s][A[A[A[A[A




 16%|█▌        | 481/3081 [00:41<04:06, 10.55it/s][A[A[A[A[A




 16%|█▌        | 483/3081 [00:41<04:01, 10.74it/s][A[A[A[A[A




 16%|█▌        | 485/3081 [00:41<03:58, 10.86it/s][A[A[A[A[A




 16%|█▌        | 487/3081 [00:41<03:44, 11.54it/s][A[A[A[A[A




 16%|█▌        | 489/3081 [00:41<03:36, 12.00it/s][A[A[A[A[A




 16%|█▌        | 491/3081 [00:41<03:36, 11.99it/s][A[A[A[A[A




 16%|█▌        | 493

 30%|███       | 933/3081 [01:20<03:02, 11.75it/s][A[A[A[A[A




 30%|███       | 935/3081 [01:20<02:51, 12.51it/s][A[A[A[A[A




 30%|███       | 937/3081 [01:20<03:01, 11.80it/s][A[A[A[A[A




 30%|███       | 939/3081 [01:21<02:58, 12.02it/s][A[A[A[A[A




 31%|███       | 941/3081 [01:21<02:56, 12.11it/s][A[A[A[A[A




 31%|███       | 943/3081 [01:21<03:00, 11.86it/s][A[A[A[A[A




 31%|███       | 945/3081 [01:21<02:55, 12.16it/s][A[A[A[A[A




 31%|███       | 947/3081 [01:21<02:51, 12.45it/s][A[A[A[A[A




 31%|███       | 949/3081 [01:21<02:46, 12.79it/s][A[A[A[A[A




 31%|███       | 951/3081 [01:22<02:48, 12.65it/s][A[A[A[A[A




 31%|███       | 953/3081 [01:22<02:51, 12.40it/s][A[A[A[A[A




 31%|███       | 955/3081 [01:22<03:02, 11.68it/s][A[A[A[A[A




 31%|███       | 957/3081 [01:22<03:12, 11.05it/s][A[A[A[A[A




 31%|███       | 959/3081 [01:22<03:03, 11.58it/s][A[A[A[A[A




 31%|███       | 961

 45%|████▌     | 1393/3081 [02:00<02:26, 11.49it/s][A[A[A[A[A




 45%|████▌     | 1395/3081 [02:00<02:29, 11.28it/s][A[A[A[A[A




 45%|████▌     | 1397/3081 [02:00<02:25, 11.54it/s][A[A[A[A[A




 45%|████▌     | 1399/3081 [02:00<02:33, 10.98it/s][A[A[A[A[A




 45%|████▌     | 1401/3081 [02:00<02:36, 10.74it/s][A[A[A[A[A




 46%|████▌     | 1403/3081 [02:01<02:33, 10.93it/s][A[A[A[A[A




 46%|████▌     | 1405/3081 [02:01<02:31, 11.05it/s][A[A[A[A[A




 46%|████▌     | 1407/3081 [02:01<02:32, 10.99it/s][A[A[A[A[A




 46%|████▌     | 1409/3081 [02:01<02:33, 10.86it/s][A[A[A[A[A




 46%|████▌     | 1411/3081 [02:01<02:34, 10.79it/s][A[A[A[A[A




 46%|████▌     | 1413/3081 [02:02<02:35, 10.72it/s][A[A[A[A[A




 46%|████▌     | 1415/3081 [02:02<02:33, 10.88it/s][A[A[A[A[A




 46%|████▌     | 1417/3081 [02:02<02:34, 10.80it/s][A[A[A[A[A




 46%|████▌     | 1419/3081 [02:02<02:32, 10.88it/s][A[A[A[A[A




 46%|█

 60%|██████    | 1853/3081 [02:39<01:48, 11.37it/s][A[A[A[A[A




 60%|██████    | 1855/3081 [02:40<01:47, 11.43it/s][A[A[A[A[A




 60%|██████    | 1857/3081 [02:40<01:52, 10.84it/s][A[A[A[A[A




 60%|██████    | 1859/3081 [02:40<01:50, 11.01it/s][A[A[A[A[A




 60%|██████    | 1861/3081 [02:40<01:43, 11.74it/s][A[A[A[A[A




 60%|██████    | 1863/3081 [02:40<01:45, 11.58it/s][A[A[A[A[A




 61%|██████    | 1865/3081 [02:41<01:50, 11.02it/s][A[A[A[A[A




 61%|██████    | 1867/3081 [02:41<01:51, 10.89it/s][A[A[A[A[A




 61%|██████    | 1869/3081 [02:41<01:47, 11.32it/s][A[A[A[A[A




 61%|██████    | 1871/3081 [02:41<01:41, 11.96it/s][A[A[A[A[A




 61%|██████    | 1873/3081 [02:41<01:39, 12.11it/s][A[A[A[A[A




 61%|██████    | 1875/3081 [02:41<01:36, 12.48it/s][A[A[A[A[A




 61%|██████    | 1877/3081 [02:41<01:36, 12.50it/s][A[A[A[A[A




 61%|██████    | 1879/3081 [02:42<01:37, 12.31it/s][A[A[A[A[A




 61%|█

 75%|███████▌  | 2313/3081 [03:19<01:07, 11.39it/s][A[A[A[A[A




 75%|███████▌  | 2315/3081 [03:20<01:06, 11.49it/s][A[A[A[A[A




 75%|███████▌  | 2317/3081 [03:20<01:08, 11.18it/s][A[A[A[A[A




 75%|███████▌  | 2319/3081 [03:20<01:06, 11.42it/s][A[A[A[A[A




 75%|███████▌  | 2321/3081 [03:20<01:04, 11.75it/s][A[A[A[A[A




 75%|███████▌  | 2323/3081 [03:20<01:04, 11.79it/s][A[A[A[A[A




 75%|███████▌  | 2325/3081 [03:20<01:02, 12.08it/s][A[A[A[A[A




 76%|███████▌  | 2327/3081 [03:21<00:58, 12.98it/s][A[A[A[A[A




 76%|███████▌  | 2329/3081 [03:21<00:59, 12.60it/s][A[A[A[A[A




 76%|███████▌  | 2331/3081 [03:21<01:02, 12.01it/s][A[A[A[A[A




 76%|███████▌  | 2333/3081 [03:21<01:03, 11.71it/s][A[A[A[A[A




 76%|███████▌  | 2335/3081 [03:21<01:04, 11.59it/s][A[A[A[A[A




 76%|███████▌  | 2337/3081 [03:21<01:03, 11.81it/s][A[A[A[A[A




 76%|███████▌  | 2339/3081 [03:22<01:04, 11.52it/s][A[A[A[A[A




 76%|█

 90%|█████████ | 2773/3081 [03:59<00:27, 11.06it/s][A[A[A[A[A




 90%|█████████ | 2775/3081 [03:59<00:26, 11.53it/s][A[A[A[A[A




 90%|█████████ | 2777/3081 [04:00<00:26, 11.36it/s][A[A[A[A[A




 90%|█████████ | 2779/3081 [04:00<00:28, 10.68it/s][A[A[A[A[A




 90%|█████████ | 2781/3081 [04:00<00:26, 11.12it/s][A[A[A[A[A




 90%|█████████ | 2783/3081 [04:00<00:26, 11.18it/s][A[A[A[A[A




 90%|█████████ | 2785/3081 [04:00<00:26, 11.26it/s][A[A[A[A[A




 90%|█████████ | 2787/3081 [04:01<00:25, 11.43it/s][A[A[A[A[A




 91%|█████████ | 2789/3081 [04:01<00:25, 11.55it/s][A[A[A[A[A




 91%|█████████ | 2791/3081 [04:01<00:24, 11.88it/s][A[A[A[A[A




 91%|█████████ | 2793/3081 [04:01<00:23, 12.39it/s][A[A[A[A[A




 91%|█████████ | 2795/3081 [04:01<00:22, 12.64it/s][A[A[A[A[A




 91%|█████████ | 2797/3081 [04:01<00:23, 12.22it/s][A[A[A[A[A




 91%|█████████ | 2799/3081 [04:02<00:22, 12.40it/s][A[A[A[A[A




 91%|█

True: 29526, Correct: 25483, Pred: 29187
Right: 856376.0, Total: 867613.0, Accuracy: 0.9870
Precision: 0.8730941857676363, Recall: 0.8630698367540472, Fscore: 0.8680530717217652


# Evaluate by entities

In [55]:
import pandas as pd
true_entities = []
true_entity = ""
for i, (word, label) in enumerate(zip(original_word_seqs, true_seqs)):
    if label == "O":
        pass
    elif label == "B-CHEM" and not true_entity:
        true_entity += word
    elif label == "I-CHEM" and true_entity:
        true_entity += (" " + word)
    elif label == "E-CHEM" and true_entity:
        true_entity += (" " + word)
        true_entities.append(true_entity)
        true_entity = ""
    elif label == "S-CHEM":
        true_entities.append(word)
        true_entity = ""
    else:
        # ありえないやつも考えなくてはいけない。
        print("Warning!! The combination of the labels is not compatible.")
        true_entity = ""
    pre_label = label
true_entity_counter = Counter(true_entities)

pred_entities = []
pred_entity = ""
for i, (word, label) in enumerate(zip(original_word_seqs, pred_seqs)):
    if label == "O":
        pass
    elif label == "B-CHEM" and not pred_entity:
        pred_entity += word
    elif label == "I-CHEM" and pred_entity:
        pred_entity += (" " + word)
    elif label == "E-CHEM" and pred_entity:
        pred_entity += (" " + word)
        pred_entities.append(pred_entity)
        pred_entity = ""
    elif label == "S-CHEM":
        pred_entities.append(word)
        pred_entity = ""
    else:
        # ありえないやつも考えなくてはいけない。
        print("Warning!! The combination of the labels is not compatible.")
        pred_entity = ""
    pre_label = label
pred_entity_counter = Counter(pred_entities)

results = {"entity": [], "true_count": [], "pred_count": [], "tp": [], "fp": [], "fn": []}
for entity in set(true_entities) | set(pred_entities):
    pred_count = pred_entity_counter.get(entity, 0)
    true_count = true_entity_counter.get(entity, 0)
    results["entity"].append(entity)
    results["true_count"].append(true_count)
    results["pred_count"].append(pred_count)
    results["tp"].append(min(true_count, pred_count))
    results["fp"].append(max(true_count - pred_count, 0))
    results["fn"].append(max(pred_count - true_count, 0))

df = pd.DataFrame(results).sort_values("true_count", ascending=False)

In [56]:
all_word_set = word_dic.token2id.keys()
df["Unknown"] = df.entity.map(lambda x: any([w not in all_word_set for w in x.split(" ")]))

In [65]:
for sp_id, sp in sps.items():
    df[sp_id] = df.entity.map(lambda x: [" ".join(sp.tokenize(word)) for word in x.split(" ")])

In [66]:
df

Unnamed: 0,entity,true_count,pred_count,tp,fp,fn,Unknown,SW2k,SW4k,SW8k
675,glucose,371,375,371,0,4,False,[▁glucose],[▁glucose],[▁glucose]
281,oxygen,212,212,212,0,0,False,[▁oxygen],[▁oxygen],[▁oxygen]
1643,graphene,172,179,172,0,7,False,[▁graphene],[▁graphene],[▁graphene]
4213,ethanol,145,145,145,0,0,False,[▁ ethanol],[▁ethanol],[▁ethanol]
2078,glutathione,144,141,141,3,0,False,[▁glutathione],[▁glutathione],[▁glutathione]
3179,cholesterol,139,139,139,0,0,False,[▁cholesterol],[▁cholesterol],[▁cholesterol]
3569,carbon,137,135,135,2,0,False,[▁carbon],[▁carbon],[▁carbon]
5705,Ca ( 0 + ),136,136,136,0,0,False,"[▁Ca, ▁(, ▁0, ▁+, ▁)]","[▁Ca, ▁(, ▁0, ▁+, ▁)]","[▁Ca, ▁(, ▁0, ▁+, ▁)]"
5897,calcium,135,138,135,0,3,False,[▁calcium],[▁calcium],[▁calcium]
5610,iron,133,147,133,0,14,False,[▁iron],[▁iron],[▁iron]
