In [3]:
import json
class obj:
    # constructor
    def __init__(self, dict1):
        self.__dict__.update(dict1)
args = {
    "candidates": 20,
    "device": 'mps',
    "path": 'datasets/development',
    "max_length": 25,
    "batch_size": 16,
    "epochs": 1,
    "loss_fn": "nll",
    "contextualized": False,
    "similarity_type": 'binary'
}
args = json.loads(json.dumps(args), object_hook=obj)
vars(args)

{'candidates': 20,
 'device': 'mps',
 'path': 'datasets/development',
 'max_length': 25,
 'batch_size': 16,
 'epochs': 1,
 'loss_fn': 'nll',
 'contextualized': False,
 'similarity_type': 'binary'}

In [4]:
from importlib import reload

import numpy as np
import time
import torch
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer
)

# Local modules
from src.candidateDataset import CandidateDataset
from src.rerankNet import RerankNet
from src.umls import Umls
import src.utils as utils

#TODO: Add additional loss functions
loss_fn = utils.marginal_nll

# Initialize
LOGGER = utils.init_logging()
utils.init_seed(42)
model_name_or_path = 'dmis-lab/biobert-base-cased-v1.1'
bert = AutoModel.from_pretrained(model_name_or_path).to(args.device)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# Build model
model = RerankNet(bert, 
                  device = args.device, 
                  loss_fn='nll')

# Load UMLS data
umls = Umls('umls/processed')

# Load dictionary
dictionary = utils.load_dictionary(args.path+'/dev_dictionary.txt')
LOGGER.info("Dictionary loaded")

# Load training data
train_mentions = utils.load_mentions(args.path+'/processed_dev')
train_set = CandidateDataset(train_mentions, dictionary, tokenizer, args.max_length, args.candidates, args.similarity_type, umls) 
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)

dev_mentions = utils.load_mentions(args.path+'/processed_dev')
dev_set = CandidateDataset(dev_mentions, dictionary, tokenizer, args.max_length, args.candidates, args.similarity_type, umls) 
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=args.batch_size, shuffle=True)
LOGGER.info("Mentions loaded")

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 179/179 [00:00<00:00, 1171264.30it/s]
07/14/2022 12:33:07 PM: [ Dictionary loaded ]
100%|██

In [3]:
# Begin Training loop
LOGGER.info("train!")
start = time.time()
for epoch in range(args.epochs):
    ############## Candidate Generation ##############
    train_candidate_idxs = utils.get_topk_candidates(
            dict_names=list(dictionary[:,0]), 
            mentions=train_mentions, 
            tokenizer=tokenizer, 
            encoder=bert, 
            max_length=args.max_length, 
            device=args.device, 
            topk=args.candidates)
                        
    # Add candidates to training dataset
    train_set.set_candidate_idxs(train_candidate_idxs)
    
    ###################### Train ######################
    # Train encoder to properly rank candidates
    train_loss = 0
    train_steps = 0
    model.train()
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Training epoch {epoch}'):
        model.optimizer.zero_grad()
        batch_x, batch_y = data
        batch_pred = model(batch_x)
        loss = loss_fn(batch_pred, batch_y.to(args.device))
        # loss.backward() #TODO: Not working on the mac
        model.optimizer.step()
        train_loss += loss.item()
        train_steps += 1

    train_loss = train_loss / (train_steps + 1e-9)
    LOGGER.info('Epoch {}: loss/train_per_epoch={}/{}'.format(epoch,train_loss,epoch))
    
    #################### Evaluate ####################
    # Get candidates on dev dataset
    dev_candidate_idxs = utils.get_topk_candidates(
            dict_names=list(dictionary[:,0]), 
            mentions=dev_mentions, 
            tokenizer=tokenizer, 
            encoder=bert, 
            max_length=args.max_length, 
            device=args.device, 
            topk=5) # Only need top five candidates to evaluate performance
    
    # Log performance on dev after each epoch
    results = utils.evaluate(dev_mentions, dictionary[dev_candidate_idxs], umls)
    if 'acc1' in results: LOGGER.info("Epoch {}: acc@1={}".format(epoch,results['acc1']))
    if 'acc5' in results: LOGGER.info("Epoch {}: acc@5={}".format(epoch,results['acc5']))
    if 'umls_similarity' in results: LOGGER.info("Epoch {}: umls_similarity={}".format(epoch,results['umls_similarity']))


07/14/2022 10:58:18 AM: [ train! ]
Bulk embedding...: 100%|██████████| 1/1 [00:04<00:00,  4.31s/it]
Bulk embedding...: 100%|██████████| 1/1 [00:01<00:00,  1.73s/it]
Training epoch 0:   0%|          | 0/1 [00:10<?, ?it/s]


NotImplementedError: The operator 'aten::index.Tensor' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [397]:
import src.utils as utils
reload(utils)

  # Check all annotations were fixed


<module 'src.utils' from '/Users/evan/code/thesis/src/utils.py'>

In [398]:
doc_dir='datasets/n2c2/processed_test'
test_mentions = utils.load_mentions(doc_dir)
utils.bulk_embed_contextualized(test_mentions, bert, tokenizer, doc_dir, max_length=300, device='mps', show_progress=True)

100%|██████████| 40/40 [00:00<00:00, 4524.48it/s]
Bulk embedding contextualized...:   5%|▌         | 2/40 [00:26<09:01, 14.24s/it]

Bad annotation in 0198
154 tensor([21947])


Bulk embedding contextualized...: 100%|██████████| 40/40 [29:05<00:00, 43.63s/it]   


array([[-0.15507695,  0.06025336, -0.22768542, ..., -0.16558737,
         0.19927882,  0.16621587],
       [-0.01320879,  0.08289725, -0.12982935, ...,  0.49201033,
        -0.26241493, -0.01202229],
       [-0.15307908, -0.1627849 , -0.13408834, ...,  0.20175664,
        -0.19037403, -0.09205182],
       ...,
       [ 0.07421985, -0.19934276,  0.7217341 , ...,  0.38563707,
        -0.3801183 ,  0.42218462],
       [ 0.01618759, -0.21299045, -0.23948702, ...,  0.46053863,
         0.15869279,  0.48101687],
       [ 0.08568915,  0.19834511, -0.04924138, ...,  0.17086448,
        -0.00849999, -0.03946758]], dtype=float32)

In [399]:
doc_dir='datasets/n2c2/processed_dev'
dev_mentions = utils.load_mentions(doc_dir)
utils.bulk_embed_contextualized(dev_mentions, bert, tokenizer, doc_dir, max_length=300, device='mps', show_progress=True)

100%|██████████| 10/10 [00:00<00:00, 776.22it/s]
Bulk embedding contextualized...: 100%|██████████| 10/10 [04:04<00:00, 24.46s/it]


array([[ 0.09808334, -0.19211553,  0.4358952 , ...,  0.3686808 ,
         0.25733262, -0.30773503],
       [-0.1085677 ,  0.24839644,  0.2609251 , ...,  0.20432799,
         0.38009298, -0.00730523],
       [-0.2050771 ,  0.46604535,  0.47189686, ...,  0.05825787,
         0.06764286,  0.06848229],
       ...,
       [ 0.29949102, -0.07483034,  0.14035282, ...,  0.14551571,
         0.7869169 ,  0.2301968 ],
       [-0.04362535,  0.40494886, -0.18296073, ...,  0.1948201 ,
         0.37076458, -0.00378272],
       [ 0.04414786, -0.03230633,  0.33434194, ...,  0.01382919,
        -0.04408884, -0.10813165]], dtype=float32)

In [400]:
doc_dir='datasets/n2c2/processed_train'
train_mentions = utils.load_mentions(doc_dir)
utils.bulk_embed_contextualized(train_mentions, bert, tokenizer, doc_dir, max_length=300, device='mps', show_progress=True)

100%|██████████| 50/50 [00:00<00:00, 2799.75it/s]
Bulk embedding contextualized...:   6%|▌         | 3/50 [03:03<58:47, 75.05s/it]

In [391]:

def bulk_embed_contextualized(mentions, file, encoder, tokenizer, doc_dir, max_length, device, show_progress=True):
    embeddings = []
    print(file)
    # Tokenize entire document
    with open(f'{doc_dir}/{file}.txt') as f:
        doc = f.readlines()
        doc_tokens = tokenizer(doc, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", return_offsets_mapping=True)
        
    # Remove offset_mapping from tokenization for formatting prior to encoding
    offsets = doc_tokens.pop('offset_mapping')
    
    # Find the offset for the end of each sentence
    sentence_lengths = [len(line) for line in doc]
    
    # Update offsets tensor to be document-level token offsets instead of sentence-level
    sentence_offsets = np.zeros(len(sentence_lengths))
    for i,l in enumerate(sentence_lengths[:-1], start=1):
        sentence_offsets[i] = l + sentence_offsets[i-1]
    offsets = torch.IntTensor(offsets.numpy() + sentence_offsets[:,None,None])
    
    # Reshape to remove sentence dimension from tensors
    offsets = offsets.reshape(-1,2)
    
    # Get character-level mention offsets from annotation file
    file_mask = mentions[:,3]==file
    mention_offsets = mentions[:,2][file_mask]
    mention_offsets = torch.IntTensor([list(map(int,l.split('|'))) for l in mention_offsets])
    
    # Create a padding_mask to ignore padding in tensors
    # Padding offsets are formated [###, ###] where ### are equal numbers
    padding_mask = (offsets[:,0]!=offsets[:,1]).unsqueeze(1)

    # Find the indexes corresponding to mentions in the tokens.input_ids (results of BERT tokenization)
    # offsets==offset finds all offsets ([start,end]) matching the start OR end of a given offset
    # padding_mask ignores indexes for padding
    token_ixs = [((offsets==offset) & padding_mask).nonzero(as_tuple=True)[0] for offset in mention_offsets]
    
    # If offsets do not match tokens, we have to go in and fix them. Unfortunately, this does happen
    if sum([len(ixs)!=2 for ixs in token_ixs]):
        LOGGER.info(f"Bad annotation in {file}")
        for i, ixs in enumerate(token_ixs):
            if len(ixs)!=2:
                print(i, ixs)
                char_start, char_end = mention_offsets[i][0].item(),mention_offsets[i][1].item()
                char_start = max(offsets[:,0][offsets[:,0]<=char_start])
                char_end = min(offsets[:,1][offsets[:,1]>=char_end])
                offset = torch.stack([char_start, char_end])
                token_ixs[i] = ((offsets==offset) & padding_mask).nonzero(as_tuple=True)[0]
    
    # Check all annotations were fixed
    assert(sum([len(ixs)!=2 for ixs in token_ixs]), f"Offsets not lining up for mention in {file}")

    with torch.no_grad():
        # Encode each sentence within doc
        doc_tokens = doc_tokens.to(device)
        outputs = encoder(**doc_tokens)
        
        # Flatten sentence dimension
        embedding = outputs[0].view(-1,768)

        # Average embeddings for all tokens in each mention; torch.Size([|mentions|, 768])
        doc_embeds = torch.stack([embedding[s.item():e.item()+1].mean(0) for s,e in token_ixs])
        
        # Append doc embeddings to output
        doc_embeds = doc_embeds.cpu().detach().numpy()
        embeddings.append(doc_embeds)
        
        # Print mention tokens to verify correct indexes
        # input_ids = doc_tokens['input_ids'].reshape(-1)
        # for s,e in token_ixs[:5]:
        #     s = s.item()
        #     e = e.item()+1
        #     print(s,e, tokenizer.convert_ids_to_tokens(input_ids[s:e]))
        
    # Concatenate embeddings from all mentions
    return np.concatenate(embeddings, axis=0)

  assert(sum([len(ixs)!=2 for ixs in token_ixs]), f"Offsets not lining up for mention in {file}")
