In [1]:
from transformers import BertTokenizer, BertModel
import torch

BERT_MODEL='bert-base-cased'
# BERT_MODEL='dmis-lab/biobert-base-cased-v1.1'

tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=False)
model = BertModel.from_pretrained(BERT_MODEL, output_hidden_states = True)

In [2]:
def get_hidden_states(model,tokenizer,marked_text):
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    
    segments_ids = [1] * len(tokenized_text)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        hidden_states = outputs[2]
        return hidden_states

In [3]:
def get_embeddings(text):
    marked_text = "[CLS] " + text + " [SEP]"
    
    hidden_states = get_hidden_states(model, tokenizer, marked_text)
    token_embeddings = torch.stack(hidden_states, dim=0)
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    token_embeddings = token_embeddings.permute(1,0,2)
    
    return token_embeddings

In [4]:
input_sentence = "A 58-year-old African-American woman presents to the ER with episodic pressing/burning anterior chest pain that began two days earlier for the first time in her life. The pain started while she was walking, radiates to the back, and is accompanied by nausea, diaphoresis and mild dyspnea, but is not increased on inspiration. The latest episode of pain ended half an hour prior to her arrival. She is known to have hypertension and obesity. She denies smoking, diabetes, hypercholesterolemia, or a family history of heart disease. She currently takes no medications. Physical examination is normal. The EKG shows nonspecific changes."
tokenized_input = tokenizer.tokenize(input_sentence)

In [5]:
first_embeddings = []
last_embeddings = []

with open("embeddings/label.tsv", "w") as labels_file:
    for token in tokenized_input:
        if (not token.startswith("##")):
            marked_text = "[CLS] " + token + " [SEP]"

            print(token, file=labels_file)

            token_embeddings = get_embeddings(marked_text)
        
            first_embedding_layer = token_embeddings[1][1]
            last_embedding_layer = token_embeddings[1][12]

            tsv_rows = ''
            for e in last_embedding_layer:
                tsv_rows += str(e.item()) + '\t'
            last_embeddings.append(tsv_rows)
            
            tsv_rows = ''
            for e in first_embedding_layer:
                tsv_rows += str(e.item()) + '\t'
            first_embeddings.append(tsv_rows)

In [7]:
with open("embeddings/bert/first_embedding.tsv", "w") as f:
    for e in first_embeddings:
        print (e, file=f)
        
with open("embeddings/bert/last_embedding.tsv", "w") as f:
    for e in last_embeddings:
        print (e, file=f)