In [None]:
import tensorflow as tf
from transformers import RobertaTokenizerFast, TFRobertaModel
import numpy as np

In [None]:
class BertEmbedding(tf.keras.Model): 
    def __init__(self, max_len = 128, encoder_weights = None, output_attention = False):
        super(BertEmbedding, self).__init__()
        self.max_len = max_len
        self.encoder = TFRobertaModel.from_pretrained('roberta-base', output_attentions=True)
        self.output_attention = output_attention
        
        self.dense1 = tf.keras.layers.Dense(512, activation = 'relu')
        self.dense2 = tf.keras.layers.Dense(512, activation = 'linear')
        
        dummy_input = {
            'input_ids' : np.random.randint(0,100,size = (8, self.max_len)).astype('int32'),
            'attention_mask' : np.random.randint(0,2,size = (8, self.max_len)).astype('int32'),
        }
        
        dummy = self(dummy_input)
        
    
    def call(self, x, training = True):            
        encoded = self.encoder(input_ids = x['input_ids'], 
                               training =  training, 
                               attention_mask = x['attention_mask'],
                               output_attentions = True
                              )

        cls_token = encoded[0][:,0,:]
        embedding = self.dense1(cls_token)
        embedding = self.dense2(embedding)
        
        if self.output_attention:
            return embedding, encoded.attentions
        else:
            return embedding

In [None]:
max_len = 128
path_to_model='models'
model_path = os.path.join(path_to_model, f'siamese_encoder_{max_len}.tf)
model = BertEmbedding(max_len = max_len, output_attention = False)
model.load_weights(model_path)

In [None]:
model.summary()

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

s0 = "abstract 1"
s1 = "abstract 2"
s2 = "abstract 3"
s3 = "abstract 4"
s4 = "abstract 5"

s = [s0, s1, s2, s3, s4]
inputs = tokenizer.batch_encode_plus(s, add_special_tokens=True, padding = 'max_length', max_length=max_len, truncation = True, return_tensors = 'np')
inputs = {elt : inputs[elt].astype('int32') for elt in inputs}

In [None]:
out = model.predict(inputs)

In [None]:
def normalize(x):
    return (x.T/np.sqrt(np.sum(np.square(x), axis = 1))).T

def cosine(x, y):
    xn = normalize(x)
    yn = normalize(y)
    return (xn @ yn.T + 1)/2

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
l = ['s0', 's1', 's2', 's3', 's4']
sns.heatmap(cosine(out, out), xticklabels = l, yticklabels = l)