In [1]:
import json
class obj:
    # constructor
    def __init__(self, dict1):
        self.__dict__.update(dict1)
args = {
    "candidates": 20,
    "device": 'mps',
    "path": 'datasets/development',
    "max_length": 25,
    "batch_size": 16,
    "epochs": 1,
    "loss_fn": "nll",
    "contextualized": False
}
args = json.loads(json.dumps(args), object_hook=obj)
vars(args)

{'candidates': 20,
 'device': 'mps',
 'path': 'datasets/development',
 'max_length': 25,
 'batch_size': 16,
 'epochs': 1,
 'loss_fn': 'nll',
 'contextualized': False}

In [2]:
from importlib import reload

import numpy as np
import time
import torch
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer
)

# Local modules
from src.candidateDataset import CandidateDataset
from src.rerankNet import RerankNet
from src.umls import Umls
import src.utils as utils

#TODO: Add additional loss functions
loss_fn = utils.marginal_nll

# Initialize
LOGGER = utils.init_logging()
utils.init_seed(42)
model_name_or_path = 'dmis-lab/biobert-base-cased-v1.1'
bert = AutoModel.from_pretrained(model_name_or_path).to(args.device)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# Build model
model = RerankNet(bert, 
                  device = args.device, 
                  loss_fn='nll')

# Load UMLS data
umls = Umls('umls/processed')

# Load dictionary
dictionary = utils.load_dictionary(args.path+'/dev_dictionary.txt')
LOGGER.info("Dictionary loaded")

# Load training data
train_mentions = utils.load_mentions(args.path+'/processed_dev')
train_set = CandidateDataset(train_mentions, dictionary, tokenizer, args.max_length, args.candidates) 
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)

dev_mentions = utils.load_mentions(args.path+'/processed_dev')
dev_set = CandidateDataset(dev_mentions, dictionary, tokenizer, args.max_length, args.candidates) 
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=args.batch_size, shuffle=True)
LOGGER.info("Mentions loaded")

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 179/179 [00:00<00:00, 1141003.67it/s]
07/13/2022 04:21:50 PM: [ Dictionary loaded ]
100%|██

In [3]:
# Begin Training loop
LOGGER.info("train!")
start = time.time()
for epoch in range(args.epochs):
    ############## Candidate Generation ##############
    train_candidate_idxs = utils.get_topk_candidates(
            dict_names=list(dictionary[:,0]), 
            mentions=train_mentions, 
            tokenizer=tokenizer, 
            encoder=bert, 
            max_length=args.max_length, 
            device=args.device, 
            topk=args.candidates)
                        
    # Add candidates to training dataset
    train_set.set_candidate_idxs(train_candidate_idxs)
    
    ###################### Train ######################
    # Train encoder to properly rank candidates
    train_loss = 0
    train_steps = 0
    model.train()
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Training epoch {epoch}'):
        model.optimizer.zero_grad()
        batch_x, batch_y = data
        batch_pred = model(batch_x)
        loss = loss_fn(batch_pred, batch_y.to(args.device))
        # loss.backward() #TODO: Not working on the mac
        model.optimizer.step()
        train_loss += loss.item()
        train_steps += 1

    train_loss = train_loss / (train_steps + 1e-9)
    LOGGER.info('Epoch {}: loss/train_per_epoch={}/{}'.format(epoch,train_loss,epoch))
    
    #################### Evaluate ####################
    # Get candidates on dev dataset
    dev_candidate_idxs = utils.get_topk_candidates(
            dict_names=list(dictionary[:,0]), 
            mentions=dev_mentions, 
            tokenizer=tokenizer, 
            encoder=bert, 
            max_length=args.max_length, 
            device=args.device, 
            topk=5) # Only need top five candidates to evaluate performance
    
    # Log performance on dev after each epoch
    results = utils.evaluate(dev_mentions, dictionary[dev_candidate_idxs], umls)
    if 'acc1' in results: LOGGER.info("Epoch {}: acc@1={}".format(epoch,results['acc1']))
    if 'acc5' in results: LOGGER.info("Epoch {}: acc@5={}".format(epoch,results['acc5']))
    if 'umls_similarity' in results: LOGGER.info("Epoch {}: umls_similarity={}".format(epoch,results['umls_similarity']))


07/13/2022 04:21:54 PM: [ train! ]
Bulk embedding...: 100%|██████████| 1/1 [00:01<00:00,  1.75s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s]
Training epoch 0: 100%|██████████| 1/1 [00:02<00:00,  2.48s/it]
07/13/2022 04:21:59 PM: [ Epoch 0: loss/train_per_epoch=4.364671702788648/0 ]
Bulk embedding...: 100%|██████████| 1/1 [00:03<00:00,  3.55s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:00<00:00,  2.47it/s]
07/13/2022 04:22:03 PM: [ Epoch 0: acc@1=0.875 ]
07/13/2022 04:22:03 PM: [ Epoch 0: acc@5=0.875 ]
07/13/2022 04:22:03 PM: [ Epoch 0: umls_similarity=0.9166666666666666 ]


In [60]:
x = batch_x
mention_tokens, candidate_tokens = x
batch_size, candidates, max_length = candidate_tokens.input_ids.shape
        
# Embed mentions
mention_tokens = mention_tokens.to(args.device)
mention_embeds = bert(
            input_ids=mention_tokens['input_ids'].squeeze(1),
            token_type_ids=mention_tokens['token_type_ids'].squeeze(1),
            attention_mask=mention_tokens['attention_mask'].squeeze(1)
)
mention_embeds = mention_embeds[0][:,0].unsqueeze(1) # [CLS] embedding for mentions : [batch_size, 1, hidden]

# Embed candidate names
candidate_tokens = candidate_tokens.to(args.device)
candidate_embeds = bert(
            input_ids=candidate_tokens['input_ids'].reshape(-1, max_length),
            token_type_ids=candidate_tokens['token_type_ids'].reshape(-1, max_length),
            attention_mask=candidate_tokens['attention_mask'].reshape(-1, max_length)
)
candidate_embeds = candidate_embeds[0][:,0].reshape(batch_size, candidates, -1) # [batch_size, topk, hidden]

In [62]:
batch_size, candidates, max_length = candidate_tokens.input_ids.shape
candidate_embeds.shape

torch.Size([8, 20, 768])

In [11]:
import src.utils as utils
reload(utils)

<module 'src.utils' from '/Users/evan/code/thesis/src/utils.py'>