In [1]:
import pickle
from transformers import BertTokenizer, BertModel
from torch.nn.utils.rnn import pad_sequence
from torch import nn
import torch

In [2]:
class CustomTokenizer:
    def __init__(self, vocabulary):
        self.vocabulary = vocabulary
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
    def tokenize(self, sentences : list):
        encoded = []
        for sentence in sentences:
            tokens = self.tokenizer.tokenize(sentence)
            encoded_sentence = []
            for token in tokens:
                if token in self.vocabulary:
                    encoded_sentence.append(self.vocabulary[token])
                else:
                    encoded_sentence.append(self.vocabulary["<UNK>"])
            encoded.append(torch.tensor(encoded_sentence))
        encoded =  torch.tensor(pad_sequence(encoded, batch_first=True, padding_value=0))
        return encoded

In [3]:
class BertBatchEmbedding:
    def __init__(self):
        self.model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True).eval().cuda()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def transform(self, sentences):
        padded_sequence = self.tokenizer.batch_encode_plus(sentences, return_tensors="pt", pad_to_max_length=True)
        out = self.model(padded_sequence['input_ids'].cuda(), padded_sequence["attention_mask"].cuda())
        hidden_states = out[2]
        token_embeddings = torch.stack(hidden_states, dim=0)
        return token_embeddings.permute(1,2,0,3)
    
    def parameters(self):
        return self.model.parameters()

In [9]:
class CustomBertEmbedding:
    def __init__(self):
        with open('datasets/embeddings/weights.pickle', 'rb') as handle:
            weights = torch.tensor(pickle.load(handle))
            
        with open('datasets/embeddings/vocab.pickle', 'rb') as handle:
            vocabulary = pickle.load(handle)
            
        self.model = nn.Embedding.from_pretrained(weights, padding_idx = 0, freeze = True)
        self.tokenizer = CustomTokenizer(vocabulary)
        
    def embeddings(self, text):
        tokenized_text = self.tokenizer.tokenize(text)
        embedded_text = self.model(tokenized_text)
        embedded_text = self.encode(embedded_text)
        return embedded_text
    
    def encode(self, embedded_text):
        shape = embedded_text.shape
        embeddings = torch.mean(embedded_text[:,:,-3072:].reshape(shape[0],shape[1],int(3072/768), 768), dim=2)
        return embeddings
    
    def parameters(self):
        return self.model.parameters()

In [10]:
sentences = ["We have a problem.", "What is happpening", "He was lying"]

In [11]:
custom = CustomBertEmbedding()

In [12]:
custom.embeddings(sentences).shape



torch.Size([3, 6, 768])