In [1]:
import tensorflow as tf
from transformers import DistilBertTokenizerFast, TFDistilBertModel
import warnings
warnings.filterwarnings('ignore')
from transformers import logging
logging.set_verbosity_error()

In [2]:
# Corpus
s = ['Hello world', 'hello World']

In [3]:
checkpoint = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)

In [4]:
n_tokens = 10

In [5]:
tokens = tokenizer(s, 
                   max_length=n_tokens, 
                   padding="max_length", 
                   truncation=True, 
                   return_attention_mask=True, 
                   return_tensors='tf')

In [6]:
tokens

{'input_ids': <tf.Tensor: shape=(2, 10), dtype=int32, numpy=
array([[ 101, 7592, 2088,  102,    0,    0,    0,    0,    0,    0],
       [ 101, 7592, 2088,  102,    0,    0,    0,    0,    0,    0]])>, 'attention_mask': <tf.Tensor: shape=(2, 10), dtype=int32, numpy=
array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])>}

In [7]:
model = TFDistilBertModel.from_pretrained(checkpoint, output_hidden_states=True)

In [8]:
# Get Model embeddings
model_output = model(input_ids=tokens['input_ids'], attention_mask=tokens['attention_mask'])

In [9]:
# Print model embeddings
model_output

TFBaseModelOutput(last_hidden_state=<tf.Tensor: shape=(2, 10, 768), dtype=float32, numpy=
array([[[-0.16980135, -0.16616774,  0.02564099, ..., -0.12552133,
          0.17291942,  0.4330536 ],
        [-0.42125538,  0.17608328,  0.49991077, ..., -0.1428831 ,
          0.6527468 ,  0.4612236 ],
        [-0.21821067, -0.22477578,  0.40346316, ...,  0.08059662,
          0.19904637,  0.00151423],
        ...,
        [-0.16411218, -0.30850208,  0.1660112 , ...,  0.22800869,
         -0.07948929,  0.25606138],
        [-0.21185686, -0.31004387,  0.15777205, ...,  0.23552594,
         -0.05401385,  0.2552799 ],
        [-0.24953173, -0.25728527,  0.18696974, ...,  0.15722668,
          0.03211151,  0.30800217]],

       [[-0.16980135, -0.16616774,  0.02564099, ..., -0.12552133,
          0.17291942,  0.4330536 ],
        [-0.42125538,  0.17608328,  0.49991077, ..., -0.1428831 ,
          0.6527468 ,  0.4612236 ],
        [-0.21821067, -0.22477578,  0.40346316, ...,  0.08059662,
          0.1

In [10]:
n_hidden_states = 3 # last n hidden states to consider
hidden_states = None
# constraining hidden states between 1 and 6. As DistilBERT has 6 hidden states
n_hidden_states = 6 if n_hidden_states > 6 else n_hidden_states # return 6 if n_hidden_states > 6
n_hidden_states = 1 if n_hidden_states < 1 else n_hidden_states # return 1 if n_hidden_states < 1
use_only_cls_token = True # Use only [CLS] or compute mean of all the token embeddings in sentence 

In [11]:
hidden_states = model_output[1][-n_hidden_states:] # Get last n hidden states

In [12]:
hidden_states

(<tf.Tensor: shape=(2, 10, 768), dtype=float32, numpy=
 array([[[-0.30223027, -0.5468851 , -0.38858387, ..., -0.566273  ,
           0.31462207,  0.61148447],
         [ 0.08165848,  0.35500383,  0.75847495, ...,  0.1939865 ,
           0.4375318 ,  0.15307696],
         [ 0.20826975,  0.03681483,  0.21884404, ..., -0.30036265,
           0.42666844, -0.49590847],
         ...,
         [ 0.04514224, -0.3245637 , -0.18558994, ..., -0.15828153,
          -0.08757351,  0.01658031],
         [-0.02783131, -0.35155678, -0.20163992, ..., -0.15443051,
          -0.11263376,  0.05099101],
         [-0.03329036, -0.32443106, -0.09240969, ..., -0.09645841,
           0.1388489 ,  0.05601536]],
 
        [[-0.30223027, -0.5468851 , -0.38858387, ..., -0.566273  ,
           0.31462207,  0.61148447],
         [ 0.08165848,  0.35500383,  0.75847495, ...,  0.1939865 ,
           0.4375318 ,  0.15307696],
         [ 0.20826975,  0.03681483,  0.21884404, ..., -0.30036265,
           0.42666844, -0.495

In [13]:
len(hidden_states)

3

In [14]:
hidden_states_tf = tf.convert_to_tensor(hidden_states)

In [15]:
hidden_states_tf.shape

TensorShape([3, 2, 10, 768])

In [16]:
# USE ONLY CLS START

In [17]:
cls_emb = hidden_states_tf[:,:,0,:] # Get 0th token i.e. [CLS] embedding

In [18]:
cls_emb.shape

TensorShape([3, 2, 768])

In [19]:
cls_emb[0]

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.30223027, -0.5468851 , -0.38858387, ..., -0.566273  ,
         0.31462207,  0.61148447],
       [-0.30223027, -0.5468851 , -0.38858387, ..., -0.566273  ,
         0.31462207,  0.61148447]], dtype=float32)>

In [20]:
doc2vec = tf.math.reduce_mean(cls_emb, axis=0) # compute mean of embeddings of all hidden states

In [21]:
doc2vec.shape

TensorShape([2, 768])

In [22]:
doc2vec

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.29142508, -0.40222403, -0.09847391, ..., -0.3118422 ,
         0.27015662,  0.5957739 ],
       [-0.29142508, -0.40222403, -0.09847391, ..., -0.3118422 ,
         0.27015662,  0.5957739 ]], dtype=float32)>

In [23]:
# USE ONLY CLS END

In [24]:
# USE ALL TOKENS START

In [25]:
# considering all the token embeddings except [CLS] and computing mean of token embeddings for selected hidden states
hidden_states_doc2vec = tf.map_fn(lambda x: tf.math.reduce_mean(x[:,1:,:], axis=1), hidden_states_tf)

In [26]:
hidden_states_doc2vec.shape

TensorShape([3, 2, 768])

In [27]:
# computing mean of embeddings of all hidden states
doc2vec = tf.math.reduce_mean(hidden_states_doc2vec, axis=0)

In [28]:
doc2vec.shape

TensorShape([2, 768])

In [29]:
doc2vec.numpy()

array([[-0.1435384 , -0.15883817,  0.12035283, ...,  0.02115129,
         0.00880054,  0.10272264],
       [-0.14353843, -0.15883818,  0.12035289, ...,  0.02115124,
         0.00880051,  0.10272262]], dtype=float32)

In [30]:
# USE ALL TOKENS END

In [31]:
class DistilBertDoc2Vec:
    '''
    Returns Doc2Vec embeddings of documents in a corpus.
    
    Params:
    n_hidden_states: int
     Number of hidden states of DistilBERT to consider to compute Doc2Vec
    
    n_tokens: int
     Max tokens in a document
     
    use_only_cls_token: bool
     Boolean specifying to use only [CLS] token embedding or compute mean of all the tokens
     
    return_tf_tensor: bool
     Boolean specifying if Doc2Vec embeddings must be returned as tf tensor or numpy array.
    '''
    def __init__(self, n_hidden_states: int=6, n_tokens: int=512, use_only_cls_token: bool=False,
                return_tf_tensor: bool=True):
        self.n_hidden_states = n_hidden_states
        self.n_tokens = n_tokens
        self.use_only_cls_token = use_only_cls_token
        self.return_tf_tensor = return_tf_tensor
        self.checkpoint = 'distilbert-base-uncased'
    
    @staticmethod
    def __get_valid_input(value: int, bounds: list[int]):
        '''
        Forces the value to be within the bounds.
        
        Params:
        value: int
         Value to be checked for validity.
         
        bounds: list[int]
         Bounds within which the value must be present
        '''
        value = bounds[1] if value > bounds[1] else value
        value = bounds[0] if value < bounds[0] else value
        return value
    
    def __get_tokens(self, corpus: list[str]):
        '''
        Tokenize the corpus.
        '''
        tokenizer = DistilBertTokenizerFast.from_pretrained(self.checkpoint)
        tokens = tokenizer(corpus, 
                           max_length=self.n_tokens, 
                           padding="max_length", 
                           truncation=True, 
                           return_attention_mask=True, 
                           return_tensors='tf')
        return tokens
    
    
    def __get_model_embeddings(self, corpus: list[str]):
        '''
        Return embeddings of the corpus
        '''
        model = TFDistilBertModel.from_pretrained(self.checkpoint, output_hidden_states=True)
        tokens = self.__get_tokens(corpus)
        model_embeddings = model(input_ids=tokens['input_ids'], attention_mask=tokens['attention_mask'])
        return model_embeddings
        
    def get_doc2vec(self, corpus: list[str]):
        '''
        Get Doc2Vec embeddings for the corpus.
        
        Params:
        corpus: list[str]
         List of documents.
         
        Returns:
        doc2vec
         Doc2Vec embeddings of the corpus
        '''
        n_hidden_states = self.__get_valid_input(self.n_hidden_states, [1,6])
        n_tokens = self.__get_valid_input(self.n_tokens, [1,512])
        model_embeddings = self.__get_model_embeddings(corpus)
        hidden_states = model_embeddings[1][-n_hidden_states:]
        hidden_states_tf = tf.convert_to_tensor(hidden_states)
        if self.use_only_cls_token:
            hidden_states_doc2vec = hidden_states_tf[:,:,0,:]
        else:
            hidden_states_doc2vec = tf.map_fn(lambda x: tf.math.reduce_mean(x[:,1:,:], axis=1), hidden_states_tf)
        
        doc2vec = tf.math.reduce_mean(hidden_states_doc2vec, axis=0)
        
        if self.return_tf_tensor:
            return doc2vec
        
        return doc2vec.numpy()

In [32]:
dbw2v = DistilBertDoc2Vec(use_only_cls_token=False)
dbw2v.get_doc2vec(s)

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.17648351, -0.12502472,  0.3354447 , ..., -0.08912349,
        -0.11272036, -0.01180212],
       [-0.17648351, -0.12502472,  0.3354447 , ..., -0.08912349,
        -0.11272036, -0.01180212]], dtype=float32)>

In [33]:
dbw2v = DistilBertDoc2Vec(use_only_cls_token=True)
dbw2v.get_doc2vec(s)

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.14095749, -0.34704125,  0.01186716, ..., -0.2896978 ,
         0.23704112,  0.35186458],
       [-0.14095749, -0.34704125,  0.01186716, ..., -0.2896978 ,
         0.23704112,  0.35186458]], dtype=float32)>

In [34]:
dbw2v = DistilBertDoc2Vec(use_only_cls_token=True, n_hidden_states=0)
dbw2v.get_doc2vec(s)

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.16980127, -0.16616768,  0.02564128, ..., -0.12552126,
         0.17291929,  0.43305337],
       [-0.16980127, -0.16616768,  0.02564128, ..., -0.12552126,
         0.17291929,  0.43305337]], dtype=float32)>

In [35]:
dbw2v = DistilBertDoc2Vec(use_only_cls_token=True, n_hidden_states=1)
dbw2v.get_doc2vec(s)

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.16980127, -0.16616768,  0.02564128, ..., -0.12552126,
         0.17291929,  0.43305337],
       [-0.16980127, -0.16616768,  0.02564128, ..., -0.12552126,
         0.17291929,  0.43305337]], dtype=float32)>

In [36]:
dbw2v = DistilBertDoc2Vec(use_only_cls_token=True, n_hidden_states=100)
dbw2v.get_doc2vec(s)

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.14095749, -0.34704125,  0.01186716, ..., -0.2896978 ,
         0.23704112,  0.35186458],
       [-0.14095749, -0.34704125,  0.01186716, ..., -0.2896978 ,
         0.23704112,  0.35186458]], dtype=float32)>

In [37]:
dbw2v = DistilBertDoc2Vec(use_only_cls_token=True, n_hidden_states=6)
dbw2v.get_doc2vec(s)

<tf.Tensor: shape=(2, 768), dtype=float32, numpy=
array([[-0.14095749, -0.34704125,  0.01186716, ..., -0.2896978 ,
         0.23704112,  0.35186458],
       [-0.14095749, -0.34704125,  0.01186716, ..., -0.2896978 ,
         0.23704112,  0.35186458]], dtype=float32)>