In [49]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")
sys.path.append("./")
from ner.train import *
from chem_sentencepiece.chem_sentencepiece import ChemSentencePiece
from config import config_dics
import matplotlib.pyplot as plt
from collections import Counter
%matplotlib inline

seed_num = 42
random.seed(seed_num)
np.random.seed(seed_num)
torch.manual_seed(seed_num)
torch.cuda.manual_seed(seed_num)
torch.backends.cudnn.deterministic = True

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
args_config = "SW2k.4k.2000k.NoLM"
config_dic = config_dics["SW2k.4k.2000k.NoLM"]
pprint.pprint(config_dic)

{'cache_dir': './Repository/Cache/',
 'char_emb_dim': 30,
 'char_hidden_dim': 100,
 'glove_path': './Repository/GloVe/gv.d50',
 'gpu': True,
 'grad_clip': 0,
 'lm_batch_size': 10,
 'lm_dropout': 0.5,
 'lm_epoch': 0,
 'lm_input_path': './Repository/LargeCorpus/large_corpus_2000k.txt',
 'lm_lr': 1,
 'lm_model_dir': './Repository/LanguageModel/',
 'ner_batch_size': 10,
 'ner_dropout': 0.5,
 'ner_epoch': 100,
 'ner_input_dir': './Repository/Chemdner/',
 'ner_lr': 0.01,
 'ner_model_dir': './Repository/NERModel/',
 'number_normalize': True,
 'sp_path': {'SW2k': './Repository/SentencePiece/sp2000.model',
             'SW4k': './Repository/SentencePiece/sp4000.model'},
 'sw_emb_dim': 50,
 'sw_hidden_dim': 100,
 'vocab_dir': './Repository/Vocabulary/',
 'weight_decay': 1e-05,
 'word_emb_dim': 50,
 'word_hidden_dim': 200}


In [20]:
# load sentence piece
sps: dict = {}
for sp_key, sp_path in config_dic["sp_path"].items():
    sp = ChemSentencePiece()
    sp.load(sp_path)
    sps[sp_key] = sp

In [29]:
# load train data
print("=========== Load train data ===========")
train_word_documents, train_label_documents = load_seq_data(os.path.join(config_dic.get("ner_input_dir"), "train.bioes"), config_dic.get("number_normalize"))
valid_word_documents, valid_label_documents = load_seq_data(os.path.join(config_dic.get("ner_input_dir"), "valid.bioes"), config_dic.get("number_normalize"))
test_word_documents, test_label_documents = load_seq_data(os.path.join(config_dic.get("ner_input_dir"), "test.bioes"), config_dic.get("number_normalize"))

train_char_documents = [[[char for char in word] for word in document] for document in train_word_documents] # Document数 x 文字数
valid_char_documents = [[[char for char in word] for word in document] for document in valid_word_documents]
test_char_documents = [[[char for char in word] for word in document] for document in test_word_documents]

train_sw_documents_dicts = {}
valid_sw_documents_dicts = {}
test_sw_documents_dicts = {}
for sp_key, sp in sps.items():
    train_sw_documents_dicts[sp_key] = get_sw_documents(train_word_documents, sp)
    valid_sw_documents_dicts[sp_key] = get_sw_documents(valid_word_documents, sp)
    test_sw_documents_dicts[sp_key] = get_sw_documents(test_word_documents, sp)
    

  0%|          | 0/903738 [00:00<?, ?it/s]



100%|██████████| 903738/903738 [00:02<00:00, 313719.86it/s]
100%|██████████| 898421/898421 [00:02<00:00, 405074.53it/s]
100%|██████████| 776037/776037 [00:01<00:00, 403814.45it/s]


In [31]:
# load vocabulary
print("=========== Build vocabulary ===========")
if os.path.exists(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.word.dic")):
    word_dic = Dictionary.load(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.word.dic"))
    char_dic = Dictionary.load(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.char.dic"))
    sw_dicts = {}
    for sp_key, sp in sps.items():
        sw_dicts[sp_key] = Dictionary.load(os.path.join(config_dic.get("vocab_dir"), f"{args_config}.sw{sp_key}.dic"))
else:
    special_token_dict = {PADDING: 0, UNKNOWN: 1, START: 2, END: 3}
    word_dic = Dictionary()
    word_dic.token2id = special_token_dict
    char_dic = Dictionary()
    char_dic.token2id = special_token_dict
    sw_dicts = {}
    for sp_key, sp in sps.items():
        _dic = Dictionary()
        _dic.token2id = special_token_dict
        sw_dicts[sp_key] = _dic
label_dic = Dictionary(train_label_documents)
label_dic.patch_with_special_tokens({PADDING: 0})
label_dic.id2token = {_id: label for label, _id in label_dic.token2id.items()}

# add vocabulary
word_dic.add_documents(train_word_documents)
char_dic.add_documents(list(chain.from_iterable(train_char_documents)))
for sp_key, train_sw_documents in train_sw_documents_dicts.items():
    sw_dicts[sp_key].add_documents(train_sw_documents)



In [5]:
# lmで学習した後で、Trainデータで未知語になってしまった単語
# train_oov_counter = Counter()
# for document in train_word_documents:
#     for word in document:
#         if not word_dic.token2id.get(word):
#             train_oov_counter[word] += 1
# train_oov_counter.most_common(100)

In [32]:
# load GloVe
if config_dic.get("glove_path"):
    print("============== Load Pretrain Word Embeddings ================")
    word2vec = load_pretrain_embeddings(config_dic.get("glove_path"), emb_dim=config_dic.get("word_emb_dim"))
    pretrain_embeddings = build_pretrain_embeddings(word2vec, word_dic, emb_dim=config_dic.get("word_emb_dim"))
else:
    pretrain_embeddings = None

PerfectMatch: 35278, CaseMatch: 423, NotMatch: 6543.


In [44]:
seq_model = SeqModel(config_dic, len(word_dic.token2id), len(char_dic.token2id), [len(sw_dic.token2id) for sw_dic in sw_dicts.values()], len(label_dic.token2id), pretrain_embeddings)
#seq_model.load_expanded_state_dict(torch.load("./Repository/LanguageModel/BaseLine.LM.lm_lr5.0.model"))
optimizer = torch.optim.SGD(seq_model.parameters(), lr=config_dic.get("ner_lr"), weight_decay=config_dic.get("weight_decay"), momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.999)

print(seq_model)
print(optimizer)

42244
42244
SeqModel(
  (word_lstm): WordLSTM(
    (char_rep): CharRep(
      (dropout): Dropout(p=0.5)
      (char_embeddings): Embedding(42244, 30)
      (char_lstm): LSTM(30, 50, batch_first=True, bidirectional=True)
    )
    (subword_rep_list): ModuleList(
      (0): SubwordRep(
        (dropout): Dropout(p=0.5)
        (embeddings): Embedding(42244, 50)
        (lstm): LSTM(50, 50, batch_first=True, bidirectional=True)
      )
      (1): SubwordRep(
        (dropout): Dropout(p=0.5)
        (embeddings): Embedding(42244, 50)
        (lstm): LSTM(50, 50, batch_first=True, bidirectional=True)
      )
    )
    (dropout): Dropout(p=0.5)
    (word_embedding): Embedding(42244, 50)
    (lstm): LSTM(350, 100, batch_first=True, bidirectional=True)
  )
  (hidden2tag): Linear(in_features=200, out_features=8, bias=True)
  (crf): CRF()
)
SGD (
Parameter Group 0
    dampening: 0
    initial_lr: 0.01
    lr: 0.01
    momentum: 0.9
    nesterov: False
    weight_decay: 1e-05
)


# 学習

In [None]:
## start training
epoch = config_dic.get("ner_epoch")
for epoch_i in range(epoch):
    print("Epoch: %s/%s" %(epoch_i, epoch))
    lr_scheduler.step()
    print(f"Learning Rate: {lr_scheduler.get_lr()}")
    # shuffle
    random_ids = list(range(len(train_word_documents)))
    random.shuffle(random_ids)
    
    #####################  Batch Initialize ############################
    total_loss, batch_ave_loss, right_token, total_token = 0, 0, 0, 0
    batch_size = config_dic.get("ner_batch_size")
    batch_steps = len(train_word_documents) // batch_size + 1
    seq_model.train()
    seq_model.zero_grad()
    optimizer.zero_grad()

    # # LangageModelを使用した時に重みを固定する。
    # if epoch_i < 1 and config_dic.get("lm_model_dir"):
    #     print("=========== seq_model.word_lstm.eval() ============")
    #     seq_model.word_lstm.eval()
    #     for param in seq_model.word_lstm.parameters():
    #         param.requires_grad = False

    for batch_i in range(batch_steps):
        start_time = time.time()
        batch_ids = random_ids[batch_i * batch_size: (batch_i + 1) * batch_size]
        batch_word_documents = [train_word_documents[i] for i in batch_ids]
        batch_label_documents = [train_label_documents[i] for i in batch_ids]
        word_features = get_word_features(batch_word_documents, word_dic, config_dic.get("gpu"))
        char_features = get_char_features(batch_word_documents, char_dic, config_dic.get("gpu"))
        sw_features_list = []
        for sp_key, sp in sps.items():
            sw_features_list.append(get_sw_features(batch_word_documents, sw_dicts[sp_key], sp, config_dic.get("gpu")))
        label_features = get_label_features(batch_label_documents, label_dic, config_dic.get("gpu"))
        loss, train_tag_seq = seq_model.neg_log_likelihood_loss(word_features, char_features, sw_features_list, label_features)
        batch_ave_loss += loss.data
        total_loss += loss.data
        loss.backward()

        optimizer.step()
        seq_model.zero_grad()

        rt, tt = predict_check(train_tag_seq, label_features.get("label_ids"), word_features.get("masks"))
        right_token += rt
        total_token += tt
        if batch_i != 0 and batch_i % 50 == 0:
            if batch_ave_loss > 1e8 or str(loss) == "nan":
                print("Error: Loss Explosion (>1e8)! EXIT...")
                exit(1)
            sys.stdout.flush()
            print(f"""Batch: {batch_i}; Time(sec/batch): {time.time() - start_time:.4f}; Loss: {batch_ave_loss:.4f} Right: {right_token}, Total: {total_token}, Accuracy: {right_token / total_token:.4f}""") 
            batch_ave_loss = 0
    print(f"Total Loss: {total_loss}")

    ################ valid predict check #####################
    print("============== Valid Evaluate ===========")
    true_seqs, pred_seqs = [], []
    right_token, total_token = 0, 0
    batch_steps = len(valid_word_documents) // batch_size + 1
    random_ids = list(range(len(valid_word_documents)))
    seq_model.eval()
    for batch_i in range(batch_steps):
        batch_ids = random_ids[batch_i * batch_size: (batch_i + 1) * batch_size]
        batch_word_documents = [valid_word_documents[i] for i in batch_ids]
        batch_label_documents = [valid_label_documents[i] for i in batch_ids]

        valid_word_features = get_word_features(batch_word_documents, word_dic, config_dic.get("gpu"))
        valid_char_features = get_char_features(batch_word_documents, char_dic, config_dic.get("gpu"))
        valid_sw_features_list = []
        for sp_key, sp in sps.items():
            valid_sw_features_list.append(get_sw_features(batch_word_documents, sw_dicts[sp_key], sp, config_dic.get("gpu")))
        valid_label_features = get_label_features(batch_label_documents, label_dic, config_dic.get("gpu"))
        valid_tag_seq = seq_model.forward(valid_word_features, valid_char_features, valid_sw_features_list)
        masks = valid_word_features.get("masks")
        rt, tt = predict_check(valid_tag_seq, valid_label_features.get("label_ids"), masks)
        right_token += rt
        total_token += tt
        ################ evaluate by precision, recall and fscore ###################
        true_seqs.extend([label_dic.id2token.get(int(label_id), label_dic.token2id["O"]) for label_id in valid_label_features.get("label_ids").masked_select(masks)])
        pred_seqs.extend([label_dic.id2token.get(int(label_id), label_dic.token2id["O"]) for label_id in valid_tag_seq.masked_select(masks)])
    precision, recall, fscore = evaluate(true_seqs, pred_seqs)
    print(f"Right: {right_token}, Total: {total_token}, Accuracy: {right_token / total_token:.4f}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, Fscore: {fscore:.4f}")

Epoch: 0/100
Learning Rate: [0.01]
Batch: 50; Time(sec/batch): 0.2404; Loss: 7656.4614 Right: 13369.0, Total: 14661.0, Accuracy: 0.9119
Batch: 100; Time(sec/batch): 0.2860; Loss: 7417.3423 Right: 25938.0, Total: 28618.0, Accuracy: 0.9064
Batch: 150; Time(sec/batch): 0.3285; Loss: 8437.2373 Right: 39160.0, Total: 43400.0, Accuracy: 0.9023
Batch: 200; Time(sec/batch): 0.2794; Loss: 21255.3750 Right: 51977.0, Total: 58102.0, Accuracy: 0.8946
Batch: 250; Time(sec/batch): 0.2916; Loss: 9184.3809 Right: 64878.0, Total: 72275.0, Accuracy: 0.8977
Batch: 300; Time(sec/batch): 0.2828; Loss: 8582.0781 Right: 77660.0, Total: 86333.0, Accuracy: 0.8995
Batch: 350; Time(sec/batch): 0.2454; Loss: 7880.2236 Right: 90572.0, Total: 100531.0, Accuracy: 0.9009


# Predict!!!

In [23]:
word_dic.id2token = {v: k for k, v in word_dic.token2id.items()}
char_dic.id2token = {v: k for k, v in char_dic.token2id.items()}

In [108]:
################ valid predict check #####################
print("============== Predict Check==========")
batch_size = 10
true_seqs, pred_seqs, word_seqs, char_seqs = [], [], [], []
right_token, total_token = 0, 0
batch_steps = len(valid_word_documents) // batch_size + 1
random_ids = list(range(len(valid_word_documents)))
seq_model.eval()
for batch_i in tqdm(range(batch_steps)):
    batch_ids = random_ids[batch_i * batch_size: (batch_i + 1) * batch_size]
    batch_word_documents = [valid_word_documents[i] for i in batch_ids]
    batch_label_documents = [valid_label_documents[i] for i in batch_ids]

    valid_word_features = get_word_features(batch_word_documents, word_dic, config_dic.get("gpu"))
    valid_char_features = get_char_features(batch_word_documents, char_dic, config_dic.get("gpu"))
    if config_dic.get("sp_path"):
        valid_sw_features = get_sw_features(batch_word_documents, sw_dic, sp, config_dic.get("gpu"))
    else:
        valid_sw_features = None
    valid_label_features = get_label_features(batch_label_documents, label_dic, config_dic.get("gpu"))
    valid_tag_seq = seq_model.forward(valid_word_features, valid_char_features, valid_sw_features)
    rt, tt = predict_check(valid_tag_seq, valid_label_features.get("label_ids"), valid_word_features.get("masks"))
    right_token += rt
    total_token += tt
    ################ evaluate by precision, recall and fscore ###################
    masks = valid_word_features.get("masks")
    word_seqs.extend([word_dic.id2token.get(int(word_id)) for word_id in valid_word_features.get("word_ids").masked_select(masks)])
    
#     char_ids = valid_char_features.get("char_ids")
#     char_masks = masks.unsqueeze(-1).expand(char_ids.shape)
#     chars = []
#     for i, char_id in enumerate(char_ids.masked_select(char_masks)):
#         if i != 0 and i % char_ids.shape[-1] == 0:
#             char_seqs.append(chars)
#             chars = []
#         if not char_id == char_dic.token2id[PADDING]:
#             chars.append(char_dic.id2token.get(int(char_id)))
#     char_seqs.append(chars)
    
    true_seqs.extend([label_dic.id2token[int(label_id)] for label_id in valid_label_features.get("label_ids").masked_select(masks)])
    pred_seqs.extend([label_dic.id2token[int(label_id)] for label_id in valid_tag_seq.masked_select(masks)])
    
precision, recall, fscore = evaluate(true_seqs, pred_seqs)
print(f"Right: {right_token}, Total: {total_token}, Accuracy: {right_token / total_token:.4f}")
print(f"Precision: {precision}, Recall: {recall}, Fscore: {fscore}")

  0%|          | 2/3081 [00:00<03:24, 15.07it/s]



100%|██████████| 3081/3081 [02:59<00:00, 17.15it/s]


True: 29526, Correct: 24403, Pred: 31515
Right: 849610.0, Total: 867613.0, Accuracy: 0.9792
Precision: 0.7743296842773283, Recall: 0.8264919054392739, Fscore: 0.7995609508363233


# Evaluate by entities

In [116]:
true_entities = []
true_entity = ""
for i, (word, label) in enumerate(zip(word_seqs, true_seqs)):
    if label == "O":
        pass
    elif label == "B-CHEM" and not true_entity:
        true_entity += word
    elif label == "I-CHEM" and true_entity:
        true_entity += (" " + word)
    elif label == "E-CHEM" and true_entity:
        true_entity += (" " + word)
        true_entities.append(true_entity)
        true_entity = ""
    elif label == "S-CHEM":
        true_entities.append(word)
        true_entity = ""
    else:
        # ありえないやつも考えなくてはいけない。
        print("Warning!! The combination of the labels is not compatible.")
        true_entity = ""
    pre_label = label
true_entity_counter = Counter(true_entities)

pred_entities = []
pred_entity = ""
for i, (word, label) in enumerate(zip(word_seqs, pred_seqs)):
    if label == "O":
        pass
    elif label == "B-CHEM" and not pred_entity:
        pred_entity += word
    elif label == "I-CHEM" and pred_entity:
        pred_entity += (" " + word)
    elif label == "E-CHEM" and pred_entity:
        pred_entity += (" " + word)
        pred_entities.append(pred_entity)
        pred_entity = ""
    elif label == "S-CHEM":
        pred_entities.append(word)
        pred_entity = ""
    else:
        # ありえないやつも考えなくてはいけない。
        print("Warning!! The combination of the labels is not compatible.")
        pred_entity = ""
    pre_label = label
pred_entity_counter = Counter(pred_entities)

In [141]:
results = {"entity": [], "true_count": [], "pred_count": [], "tp": [], "fp": [], "fn": []}
for entity in set(true_entities) | set(pred_entities):
    pred_count = pred_entity_counter.get(entity, 0)
    true_count = true_entity_counter.get(entity, 0)
    results["entity"].append(entity)
    results["true_count"].append(true_count)
    results["pred_count"].append(pred_count)
    results["tp"].append(min(true_count, pred_count))
    results["fp"].append(max(true_count - pred_count, 0))
    results["fn"].append(max(pred_count - true_count, 0))

In [142]:
import pandas as pd
df = pd.DataFrame(results).sort_values("true_count", ascending=False)

In [149]:
df.tail(10000)

Unnamed: 0,entity,true_count,pred_count,tp,fp,fn
4054,Steroid,4,3,3,1,0
8591,diallyl disulfide,4,4,4,0,0
3350,polyesters,4,0,0,4,0
4275,quinpirole,4,4,4,0,0
8192,telmisartan,4,4,4,0,0
4895,NaB,4,4,4,0,0
949,carbonate,4,4,4,0,0
10444,ZnS,4,4,4,0,0
4059,Cort,4,1,1,3,0
2394,MAs,4,0,0,4,0
