In [1]:
import torch
import matplotlib.pyplot as plt

In [2]:
from transformers import BertTokenizer, get_linear_schedule_with_warmup, BertConfig, BertForMaskedLM
model_path = 'dmis-lab/biobert-base-cased-v1.2'
tokenizer = BertTokenizer.from_pretrained(model_path)
bert_lm = BertForMaskedLM.from_pretrained(model_path)
bert_lm.to('cuda:5')

In [None]:
from torch.utils.data import TensorDataset, DataLoader
import copy
state_dict_bak = copy.deepcopy(bert_lm.state_dict())

In [None]:
bert_lm.load_state_dict(state_dict_bak)

<All keys matched successfully>

In [None]:
from tqdm import tqdm

synonym_list, dictionaries = torch.load('/tmp/synonym_list')
names = sorted(dictionaries['ncbi-disease'].values())

inputs = tokenizer(['[MASK] is identical with '+n for n in names], max_length=10, padding='max_length', return_tensors='pt', truncation=True)
dataset = TensorDataset(inputs.input_ids, inputs.attention_mask, torch.arange(len(inputs.input_ids)))
dataloader = DataLoader(dataset, batch_size=16)

@torch.no_grad()
def get_name_emb():
    bert_lm.eval()
    cls_token = []
    for batch in dataloader:
        input_ids, attention_mask, labels = (i.to('cuda:5') for i in batch)
        last_hidden_state = bert_lm.bert(input_ids, attention_mask).last_hidden_state
        cls_token.append(last_hidden_state[:, 0].clone().detach())
    cls_token = torch.cat(cls_token, dim=0)
    return cls_token


mention2id = {v:i for i, v in enumerate(names)}

ent_syn_pairs = []


for cui in synonym_list['ncbi-disease']:
    if cui not in dictionaries['ncbi-disease']:
        continue
    ent = mention2id[dictionaries['ncbi-disease'][cui]]
    for syn in synonym_list['ncbi-disease'][cui]:
        ent_syn_pairs.append((ent, syn))

inputs = tokenizer(['[MASK] is identical with ' + syn for ent, syn in ent_syn_pairs], max_length=10, padding='max_length', return_tensors='pt', truncation=True)
testdataset = TensorDataset(inputs.input_ids, inputs.attention_mask, torch.LongTensor([ent for ent, syn in ent_syn_pairs]))
testdataloader = DataLoader(testdataset, batch_size=16)

def test(name_emb):
    scores = []
    labels = []
    with torch.no_grad():
        bert_lm.eval()
        for batch in testdataloader:
            input_ids, attention_mask, label = (i.to('cuda:5') for i in batch)
            cls_output = bert_lm.bert(input_ids, attention_mask).last_hidden_state[:, 1]
            score = cls_output.matmul(name_emb.T)
            scores.append(score)
            labels.append(label)
    scores = torch.cat(scores, dim=0)
    labels = torch.cat(labels, dim=0)

    acc1 = (scores.topk(1, dim=1)[1] == labels.unsqueeze(1)).any(dim=1).float().mean()
    acc10 = (scores.topk(10, dim=1)[1] == labels.unsqueeze(1)).any(dim=1).float().mean()

    return acc1, acc10, scores, labels

def train(names, epochs, lr):
    optimizer = torch.optim.Adam(bert_lm.parameters(), lr=lr)
    crit = torch.nn.CrossEntropyLoss(reduction='mean')
    for epoch in range(epochs):
        name_emb = get_name_emb()

        acc1, acc10, _, _ = test(name_emb)

        print('acc1, acc10', acc1, acc10)
        pbar = tqdm(dataloader)
        for batch in pbar:
            bert_lm.train()    
            optimizer.zero_grad()
            input_ids, attention_mask, labels = (i.to('cuda:5') for i in batch)
            last_hidden_state = bert_lm.bert(input_ids, attention_mask).last_hidden_state
            cls_token = last_hidden_state[:, 1]
            sim = cls_token.matmul(name_emb.T)
            loss = crit(sim, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix_str('%.2e'%(loss))

        

In [None]:

train(names, epochs=20, lr=1e-5)

  0%|          | 2/770 [00:00<00:47, 16.32it/s, 9.42e+00]

acc1, acc10 tensor(0., device='cuda:5') tensor(0., device='cuda:5')


100%|██████████| 770/770 [00:45<00:00, 17.06it/s, 9.42e+00]
  0%|          | 2/770 [00:00<00:45, 16.89it/s, 9.42e+00]

acc1, acc10 tensor(0., device='cuda:5') tensor(0.0023, device='cuda:5')


100%|██████████| 770/770 [00:46<00:00, 16.44it/s, 9.42e+00]
  0%|          | 2/770 [00:00<00:39, 19.47it/s, 9.42e+00]

acc1, acc10 tensor(0., device='cuda:5') tensor(0.0023, device='cuda:5')


100%|██████████| 770/770 [00:47<00:00, 16.17it/s, 9.42e+00]
  0%|          | 2/770 [00:00<00:46, 16.56it/s, 9.42e+00]

acc1, acc10 tensor(0., device='cuda:5') tensor(0., device='cuda:5')


  8%|▊         | 62/770 [00:03<00:44, 15.91it/s, 9.42e+00]

In [10]:

train(names, epochs=20, lr=1e-4)

  0%|          | 1/770 [00:00<01:22,  9.36it/s, 9.50e+00]

acc1, acc10 tensor(0.0199, device='cuda:5') tensor(0.0736, device='cuda:5')


100%|██████████| 770/770 [00:53<00:00, 14.49it/s, 1.61e-02]
  0%|          | 2/770 [00:00<00:46, 16.61it/s, 0.00e+00]

acc1, acc10 tensor(0.1799, device='cuda:5') tensor(0.3575, device='cuda:5')


100%|██████████| 770/770 [00:49<00:00, 15.48it/s, 6.64e-01]
  0%|          | 2/770 [00:00<00:47, 16.21it/s, 1.19e-07]

acc1, acc10 tensor(0.1741, device='cuda:5') tensor(0.3341, device='cuda:5')


100%|██████████| 770/770 [00:50<00:00, 15.37it/s, 1.37e-06]
  0%|          | 2/770 [00:00<00:46, 16.42it/s, 6.79e-05]

acc1, acc10 tensor(0.1647, device='cuda:5') tensor(0.3259, device='cuda:5')


100%|██████████| 770/770 [00:48<00:00, 15.93it/s, 5.96e-08]
  0%|          | 2/770 [00:00<00:47, 16.20it/s, 1.28e-04]

acc1, acc10 tensor(0.1659, device='cuda:5') tensor(0.3306, device='cuda:5')


100%|██████████| 770/770 [00:48<00:00, 15.74it/s, 1.72e-06]
  0%|          | 2/770 [00:00<00:45, 16.93it/s, 5.20e-05]

acc1, acc10 tensor(0.1600, device='cuda:5') tensor(0.3143, device='cuda:5')


100%|██████████| 770/770 [00:51<00:00, 14.96it/s, 8.10e-04]
  0%|          | 2/770 [00:00<00:47, 16.32it/s, 4.95e-04]

acc1, acc10 tensor(0.1636, device='cuda:5') tensor(0.3131, device='cuda:5')


100%|██████████| 770/770 [00:49<00:00, 15.50it/s, 1.27e-04]
  0%|          | 2/770 [00:00<00:46, 16.37it/s, 2.84e-06]

acc1, acc10 tensor(0.1530, device='cuda:5') tensor(0.3061, device='cuda:5')


 10%|▉         | 74/770 [00:04<00:45, 15.42it/s, 5.35e-03]

<All keys matched successfully>

(tensor(0.1764, device='cuda:5'), tensor(0.3692, device='cuda:5'))