In [1]:
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util
import torch
from torch import nn, Tensor, device

model_name = 'bert-base-nli-stsb-mean-tokens'
sbert_hugginface = SentenceTransformer(model_name)

In [3]:
torch.save(sbert_hugginface, "../pretrained_bert/bert-base-nli-stsb-mean-tokens.pt")

In [25]:
class SBERT(nn.Module):
    """entity-aware BERT
        also supports regular SBERT without entity

        esbert_model can be either entity_transformer or regular transformer
    """
    
    def __init__(self, bert_model, device="cuda"):
        super(SBERT, self).__init__()
        self.bert_model = bert_model.to(device)
            
    def forward(self, features):
        
        batch_to_device(features, device)
        bert_features = self.bert_model(features)
        cls_embeddings = bert_features['cls_token_embeddings']
        token_embeddings = bert_features['token_embeddings']
        attention_mask = bert_features['attention_mask']
        
        output_vectors = []
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1) # tokens not weighted
        output_vectors.append(sum_embeddings / sum_mask)
        output_vector = torch.cat(output_vectors, 1)
        # print(output_vector.shape)

        features.update({"sentence_embedding": output_vector})
        return features

class EntitySBert(SBERT):
    "inherits SBERT"
    def __init__(self, esbert_model, device="cuda"):
        super(EntitySBert, self).__init__(esbert_model, device)


In [5]:
sbert_hugginface = torch.load("../pretrained_bert/bert-base-nli-stsb-mean-tokens.pt")

In [29]:
sbert_ours = SBERT(sbert_hugginface[0]) # sbert, without entity

In [30]:
torch.save(sbert_ours, "../pretrained_bert/SBERT-base-nli-stsb-mean-tokens.pt")

In [31]:
torch.load("../pretrained_bert/SBERT-base-nli-stsb-mean-tokens.pt") 

SBERT(
  (bert_model): Transformer(
    (auto_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
            

# save fine-tuned SBERT

In [33]:
sbert_hugginface = SentenceTransformer("../output/exp_sbert_ep2_mgn2.0_btch32_norm1.0_max_seq_256")

In [34]:
sbert_ours = SBERT(sbert_hugginface[0]) # sbert, without entity

In [35]:
torch.save(sbert_ours, "../output/exp_sbert_ep2_mgn2.0_btch32_norm1.0_max_seq_256/sbert.pt")