In [3]:
from sentence_transformers import SentenceTransformer, models
from torch import nn

In [4]:
# We limit that layer to a maximal sequence length of 256, texts longer than that will be truncated.
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)

# Performs pooling (max or mean) on the token embeddings.
# :param word_embedding_dimension: Dimensions for the word embeddings
# :param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings
pooling_model = models.Pooling(word_embedding_dimension=word_embedding_model.get_word_embedding_dimension(),
                               )

# Here, we add on top of the pooling layer a fully connected dense layer with Tanh activation, which performs a down-project to 256 dimensions. Hence, embeddings by this model will only have 256 instead of 768 dimensions.
# Feed-forward function with  activiation function.
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=256,
                           activation_function=nn.Tanh())


model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 256, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
)