In [1]:
import json
class obj:
    # constructor
    def __init__(self, dict1):
        self.__dict__.update(dict1)
args = {
    "candidates": 20,
    "device": 'mps',
    "dev_dir":'datasets/development/processed_dev',
    "train_dir":'datasets/development/processed_dev',
    "dictionary_path": 'datasets/development/dev_dictionary.txt',
    "max_length": 25,
    "model_name_or_path": 'dmis-lab/biobert-base-cased-v1.1',
    "batch_size": 16,
    "epochs": 1,
    "loss_fn": "mse5",
    "contextualized": False,
    "similarity_type": 'binary'
}
args = json.loads(json.dumps(args), object_hook=obj)
vars(args)

from importlib import reload

import numpy as np
import time
import torch
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer
)

# Local modules
from src.candidateDataset import CandidateDataset
from src.rerankNet import RerankNet
from src.umls import Umls
import src.utils as utils

  assert(sum([len(ixs)!=2 for ixs in token_ixs]), f"Offsets not lining up for mention in {file}")


In [2]:
LOGGER = utils.init_logging()
LOGGER.info(args)
utils.init_seed(42)
bert = AutoModel.from_pretrained(args.model_name_or_path).to(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

# Set loss function
if args.loss_fn=='nll':
    loss_fn = utils.marginal_nll
elif args.loss_fn=='mse':
    loss_fn = utils.mse_loss
elif args.loss_fn=='mse5':
    loss_fn = utils.mse5_loss
else:
    raise Exception(f"Invalid loss function {args.loss_fn}")
    
# Build model
model = RerankNet(bert, device = args.device)

# Load UMLS data
umls = Umls('umls/processed')
LOGGER.info("UMLS data loaded")

# Load dictionary
dictionary = utils.load_dictionary(args.dictionary_path)
LOGGER.info("Dictionary loaded")

# Load training data
train_mentions = utils.load_mentions(args.train_dir)
train_set = CandidateDataset(train_mentions, dictionary, tokenizer, args.max_length, args.candidates, args.similarity_type, umls) 
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)

# Load dev data for validation
dev_mentions = utils.load_mentions(args.dev_dir)
LOGGER.info("Mentions loaded")

07/15/2022 11:01:51 AM: [ <__main__.obj object at 0x1101c8970> ]
Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
07/15/2022 11:01:55 AM: [ UMLS data loaded 

In [3]:
# Training loop
for epoch in range(args.epochs):
        ############## Candidate Generation ##############
        train_candidate_idxs = utils.get_topk_candidates(
                dict_names=list(dictionary[:,0]), 
                mentions=train_mentions, 
                tokenizer=tokenizer, 
                encoder=bert, 
                max_length=args.max_length, 
                device=args.device, 
                topk=args.candidates)
                                
        # Add candidates to training dataset
        train_set.set_candidate_idxs(train_candidate_idxs)

        ###################### Train ######################
        # Train encoder to properly rank candidates
        train_loss = 0
        train_steps = 0
        model.train()
        for i, data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Training epoch {epoch}'):
                model.optimizer.zero_grad()
                batch_x, batch_y = data
                batch_pred = model(batch_x)
                loss = loss_fn(batch_pred, batch_y.to(args.device))
                # loss.backward() #TODO: Not working on the mac
                model.optimizer.step()
                train_loss += loss.item()
                train_steps += 1

        train_loss = train_loss / (train_steps + 1e-9)
        LOGGER.info('Epoch {}: loss/train_per_epoch={}/{}'.format(epoch,train_loss,epoch))

        #################### Evaluate ####################
        # Get candidates on dev dataset
        dev_candidate_idxs = utils.get_topk_candidates(
                dict_names=list(dictionary[:,0]), 
                mentions=dev_mentions, 
                tokenizer=tokenizer, 
                encoder=bert, 
                max_length=args.max_length, 
                device=args.device, 
                topk=5) # Only need top five candidates to evaluate performance

        # Log performance on dev after each epoch
        results = utils.evaluate(dev_mentions, dictionary[dev_candidate_idxs], umls)
        if 'acc1' in results: LOGGER.info("Epoch {}: acc@1={}".format(epoch,results['acc1']))
        if 'acc5' in results: LOGGER.info("Epoch {}: acc@5={}".format(epoch,results['acc5']))
        if 'umls_similarity' in results: LOGGER.info("Epoch {}: umls_similarity={}".format(epoch,results['umls_similarity']))

Bulk embedding...: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:00<00:00,  3.75it/s]
Training epoch 0: 100%|██████████| 1/1 [00:02<00:00,  2.11s/it]
07/15/2022 11:02:05 AM: [ Epoch 0: loss/train_per_epoch=3.2499999967499997/0 ]
Bulk embedding...: 100%|██████████| 1/1 [00:03<00:00,  3.04s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:00<00:00,  3.24it/s]
07/15/2022 11:02:09 AM: [ Epoch 0: acc@1=0.875 ]
07/15/2022 11:02:09 AM: [ Epoch 0: acc@5=0.875 ]
07/15/2022 11:02:09 AM: [ Epoch 0: umls_similarity=0.9166666666666666 ]


In [397]:
import src.utils as utils
reload(utils)

  # Check all annotations were fixed


<module 'src.utils' from '/Users/evan/code/thesis/src/utils.py'>

In [5]:
np.array([1,2],np.float32)

array([1., 2.], dtype=float32)

In [131]:
def mse_loss(score, target):
    "Calculates MSE loss between max similarity of the candidates and similarity of top prediction"
    # Find similarity of the top prediction
    pred_ixs = score.argmax(dim=1)
    predicted_similarity = torch.gather(target, 1, pred_ixs.unsqueeze_(dim=1)).squeeze()

    # Find max similarity for each mention of the available candidates
    expected_similarity = torch.max(target, dim=1).values
    return torch.nn.functional.mse_loss(expected_similarity, predicted_similarity)
    
mse_loss(batch_pred, batch_y)

tensor(0.2500)

In [125]:
torch.nn.functional.mse_loss(expected, predicted)

tensor(0.2500)

In [40]:
predicted_ixs = utils.retrieve_candidates(batch_pred.detach().cpu(), topk=5)
predicted = torch.stack([batch_y[i][predicted_ixs[i]] for i in range(batch_pred.shape[0])]).requires_grad_()
predicted

tensor([[0., 0., 1., 1., 0.],
        [0., 1., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0.]], requires_grad=True)

In [41]:
torch.nn.functional.binary_cross_entropy(predicted, expected)

tensor(17.5000, grad_fn=<BinaryCrossEntropyBackward0>)

In [62]:
input, torch.sigmoid(input), target

(tensor([[ 0.1844, -0.1064],
         [ 0.1777,  1.3448],
         [-0.6941, -1.8600]], requires_grad=True),
 tensor([[0.5460, 0.4734],
         [0.5443, 0.7933],
         [0.3331, 0.1347]], grad_fn=<SigmoidBackward0>),
 tensor([[0.7057, 0.1659],
         [0.3779, 0.1089],
         [0.0501, 0.6333]]))