In [4]:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-distilroberta-v1') #max input length 512 output dimension 768
# model = SentenceTransformer('multi-qa-mpnet-base-dot-v1') #max input length 512 output dimension 768
# model = SentenceTransformer('multi-qa-distilbert-cos-v1') #max input length 512 output dimension 768
# model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') #max input length 512 output dimension 384


In [151]:
class SentenceBert:
    "SentenceBert class to encode and rank sentences using SentenceBert model."
    def __init__(self, model_name="all-distilroberta-v1"):
        self.model = SentenceTransformer(model_name)
    
    def encode(self, sentence):
        return self.model.encode(sentence)
    
    def cos_sim(self, embedding1, embedding2):
        return util.cos_sim(embedding1, embedding2)
    
    def rank_sentences(self, sentence, sentences):
        embedding = self.encode(sentence)
        embeddings = self.encode(sentences)
        return self.cos_sim(embedding, embeddings)
    
    def top_k_sentences(self, sentence, sentences, k=5):
        scores = self.rank_sentences(sentence, sentences)
        sorted_scores, sorted_indices = scores.sort(descending=True)
        print(sorted_scores, sorted_indices)
        top_k_indices = sorted_indices[0][:k]
        top_k_sentences = [(sentences[i], scores[0][i].item()) for i in top_k_indices]
        return top_k_sentences

In [152]:
sentence_bert = SentenceBert()

In [153]:
res = sentence_bert.encode(["seppo", "teppo", "matti", "teppo", "seppoilua"])

In [154]:
test_sentences = ["The quick brown fox jumps over the lazy dog."
,"Despite the rain, the match continued without any delay."
,"Artificial intelligence and machine learning are transforming industries."
,"The cake was decorated with fresh roses and tasted just as sweet."
,"Quantum computing holds the potential to revolutionize technology."
,"The stock market experienced a significant downturn last week."
,"Renewable energy sources are becoming more cost-effective and widespread."
,"The archaeologist discovered ancient ruins beneath the city's streets."
,"The novel's intricate plot twists left readers both puzzled and intrigued."
,"Advancements in medical research are leading to groundbreaking treatments."]

In [155]:
query = "energy crisis needs actions that are novel and rely on data"

In [156]:
sentence_bert.top_k_sentences(query, test_sentences, k=2)

tensor([[ 0.4707,  0.2198,  0.1863,  0.1602,  0.1415,  0.0593,  0.0189, -0.0234,
         -0.0499, -0.0736]]) tensor([[6, 2, 5, 4, 9, 8, 7, 0, 1, 3]])


[('Renewable energy sources are becoming more cost-effective and widespread.',
  0.47070589661598206),
 ('Artificial intelligence and machine learning are transforming industries.',
  0.2198319286108017)]