In [1]:
import os
import csv
import re
import pandas as pd
import numpy as np
import torch
from collections import OrderedDict

In [2]:
from transformers import BertModel, BertTokenizer
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True,)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
def bert_text_preparation(text, tokenizer):
    """
    Preprocesses text input in a way that BERT can interpret.
    """
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)
    # convert inputs to tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensor = torch.tensor([segments_ids])
    return tokenized_text, tokens_tensor, segments_tensor

In [4]:
def get_bert_embeddings(tokens_tensor, segments_tensor, model):
    """
    Obtains BERT embeddings for tokens.
    """
    # gradient calculation id disabled
    with torch.no_grad():
      # obtain hidden states
        outputs = model(tokens_tensor, segments_tensor)
        hidden_states = outputs[2]
    # concatenate the tensors for all layers
    # use "stack" to create new dimension in tensor
    token_embeddings = torch.stack(hidden_states, dim=0)
    # remove dimension 1, the "batches"
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    # swap dimensions 0 and 1 so we can loop over tokens
    token_embeddings = token_embeddings.permute(1,0,2)
    # intialized list to store embeddings
    token_vecs_sum = []
    # "token_embeddings" is a [Y x 12 x 768] tensor
    # where Y is the number of tokens in the sentence
    # loop over tokens in sentence
    for token in token_embeddings:
    # "token" is a [12 x 768] tensor
    # sum the vectors from the last four layers
        sum_vec = torch.sum(token[-4:], dim=0)
        token_vecs_sum.append(sum_vec)
    return token_vecs_sum

In [5]:
def wrap_embeddings(sentences):
    context_embeddings = []
    context_tokens = []
    for sentence in sentences:
        tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(sentence, tokenizer)
        list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
        # make ordered dictionary to keep track of the position of each   word
        tokens = OrderedDict()
        # loop over tokens in sensitive sentence
        for token in tokenized_text[1:-1]:
            # keep track of position of word and whether it occurs multiple times
            if token in tokens:
                tokens[token] += 1
            else:
                tokens[token] = 1

            # compute the position of the current token
            token_indices = [i for i, t in enumerate(tokenized_text) if t == token]
            current_index = token_indices[tokens[token]-1]
            # get the corresponding embedding
            token_vec = list_token_embeddings[current_index]

            # save values
            context_tokens.append(token)
            context_embeddings.append(token_vec)
        # add delimiter
        context_tokens.append("@@@")
        context_embeddings.append(torch.tensor([0]).repeat(768))
    return(context_embeddings, context_tokens)

In [9]:
txt_files = [f for f in os.listdir('../../output/loco/data/') if re.match(r'sentences', f)]
txt_files[0:1]

['sentences.txt']

In [15]:
for txt_file in txt_files[0:1]:
    df = pd.read_fwf('../../output/loco/data/' + txt_file, header = None, delimiter = "@")
    sentences = df[0].tolist()
    context_embeddings, context_tokens = wrap_embeddings(sentences)
    filepath = os.path.join('../../output/loco/embeddings/')
    name = 'metadata_' + txt_file.replace("txt", "tsv")
    with open(os.path.join(filepath, name), 'w+') as file_metadata:
      for i, token in enumerate(context_tokens):
        file_metadata.write(token + '\n')

    name = 'embeddings_' + txt_file.replace("txt", "tsv")
    with open(os.path.join(filepath, name), 'w+') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for embedding in context_embeddings:
            writer.writerow(embedding.numpy())