In [5]:
import json
class obj:
    # constructor
    def __init__(self, dict1):
        self.__dict__.update(dict1)
args = {
    "candidates": 20,
    "device": 'mps',
    "dev_dir":'datasets/development/processed_dev',
    "train_dir":'datasets/development/processed_dev',
    "dictionary_path": 'datasets/development/dev_dictionary.txt',
    "max_length": 25,
    "model_name_or_path": 'dmis-lab/biobert-base-cased-v1.1',
    "batch_size": 16,
    "epochs": 1,
    "loss_fn": "mse",
    "contextualized": False,
    "similarity_type": 'binary'
}
args = json.loads(json.dumps(args), object_hook=obj)
vars(args)

from importlib import reload

import numpy as np
import time
import torch
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer
)

# Local modules
from src.candidateDataset import CandidateDataset
from src.rerankNet import RerankNet
from src.umls import Umls
import src.utils as utils

In [6]:
# Initialize
start = time.time()
LOGGER = utils.init_logging()
LOGGER.info(args)
utils.init_seed(42)
bert = AutoModel.from_pretrained(args.model_name_or_path).to(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

# Set loss function
if args.loss_fn=='nll':
    loss_fn = utils.marginal_nll
elif args.loss_fn=='mse':
    loss_fn = utils.mse_loss
elif args.loss_fn=='mse5':
    loss_fn = utils.mse5_loss
else:
    raise Exception(f"Invalid loss function {args.loss_fn}")
    
# Build model
model = RerankNet(encoder=bert, tokenizer=tokenizer, device=args.device)

# Load UMLS data
umls = Umls('umls/processed')
LOGGER.info("UMLS data loaded")

# Load dictionary
dictionary = utils.load_dictionary(args.dictionary_path)
LOGGER.info("Dictionary loaded")

# Load training data
train_mentions = utils.load_mentions(args.train_dir)
train_set = CandidateDataset(train_mentions, dictionary, model.tokenizer, args.max_length, args.candidates, args.similarity_type, umls) 
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)

# Load dev data for validation
dev_mentions = utils.load_mentions(args.dev_dir)
LOGGER.info("Mentions loaded")

07/15/2022 01:31:42 PM: [ <__main__.obj object at 0x160e298b0> ]
Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
07/15/2022 01:31:47 PM: [ UMLS data loaded 

In [9]:
############## Candidate Generation ##############
epoch = 0
train_candidate_idxs = utils.get_topk_candidates(
        dict_names=list(dictionary[:,0]), 
        mentions=train_mentions, 
        tokenizer=model.tokenizer, 
        encoder=model.encoder, 
        max_length=args.max_length, 
        device=args.device, 
        topk=args.candidates)
                        
# Add candidates to training dataset
train_set.set_candidate_idxs(train_candidate_idxs)
LOGGER.info('Epoch {}: max possible acc@1 = {}'.format(epoch,train_set.max_acc1()))

###################### Train ######################
# Train encoder to properly rank candidates
train_loss = 0
train_steps = 0
model.train()
for i, data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Training epoch {epoch}'):
        model.optimizer.zero_grad()
        batch_x, batch_y = data
        batch_pred = model(batch_x)
        loss = loss_fn(batch_pred, batch_y.to(args.device))
        loss.backward()
        model.optimizer.step()
        train_loss += loss.item()
        train_steps += 1

train_loss = train_loss / (train_steps + 1e-9)
LOGGER.info('Epoch {}: loss/train_per_epoch={}/{}'.format(epoch,train_loss,epoch))

#################### Evaluate ####################
# Get candidates on dev dataset
dev_candidate_idxs = utils.get_topk_candidates(
        dict_names=list(dictionary[:,0]), 
        mentions=dev_mentions, 
        tokenizer=model.tokenizer, 
        encoder=model.encoder, 
        max_length=args.max_length, 
        device=args.device, 
        topk=5) # Only need top five candidates to evaluate performance

# Log performance on dev after each epoch
results = utils.evaluate(dev_mentions, dictionary[dev_candidate_idxs], umls)
if 'acc1' in results: LOGGER.info("Epoch {}: acc@1={}".format(epoch,results['acc1']))
if 'acc5' in results: LOGGER.info("Epoch {}: acc@5={}".format(epoch,results['acc5']))
if 'umls_similarity' in results: LOGGER.info("Epoch {}: umls_similarity={}".format(epoch,results['umls_similarity']))

Bulk embedding...: 100%|██████████| 1/1 [00:03<00:00,  3.39s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:00<00:00,  1.86it/s]
07/15/2022 01:33:17 PM: [ Epoch 0: max possible acc@1 = 0.875 ]
Training epoch 0: 100%|██████████| 1/1 [00:08<00:00,  8.34s/it]
07/15/2022 01:33:25 PM: [ Epoch 0: loss/train_per_epoch=0.0/0 ]
Bulk embedding...: 100%|██████████| 1/1 [00:07<00:00,  7.14s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
07/15/2022 01:33:34 PM: [ Epoch 0: acc@1=0.875 ]
07/15/2022 01:33:34 PM: [ Epoch 0: acc@5=0.875 ]
07/15/2022 01:33:34 PM: [ Epoch 0: umls_similarity=0.9166666666666666 ]


In [397]:
import src.utils as utils
reload(utils)

  # Check all annotations were fixed


<module 'src.utils' from '/Users/evan/code/thesis/src/utils.py'>

In [8]:
def mse_loss(score, target):
    "Calculates MSE loss between max similarity of the candidates and similarity of top prediction"
    # Find similarity of the top prediction
    pred_ixs = score.argmax(dim=1)
    predicted_similarity = torch.gather(target, 1, pred_ixs.unsqueeze_(dim=1)).squeeze().requires_grad_()

    # Find max similarity for each mention of the available candidates
    expected_similarity = torch.max(target, dim=1).values
    return torch.nn.functional.mse_loss(expected_similarity, predicted_similarity)

mse_loss(batch_pred, batch_y.to(args.device))

tensor(0.1250, device='mps:0', grad_fn=<MseLossBackward0>)