## Toy example of sentence embedding

In [None]:
from sentence_transformers import SentenceTransformer, util
import torch

In [2]:
embedder = SentenceTransformer('paraphrase-MiniLM-L6-v2')

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.69k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/516 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [3]:
embedder

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

In [4]:
# Corpus with example sentences
corpus = ['A man is eating food.',
          'A man is eating a piece of bread.',
          'The girl is carrying a baby.',
          'A man is riding a horse.',
          'A woman is playing violin.',
          'Two men pushed carts through the woods.',
          'A man is riding a white horse on an enclosed ground.',
          'A monkey is playing drums.',
          'A cheetah is running behind its prey.'
          ]

In [6]:
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
queries = ['A man is eating pasta.', 'Someone in a gorilla costume is playing a set of drums.', 'A cheetah chases prey on across a field.']

In [None]:
top_k = min(5, len(corpus))

for query in queries:
    query_embedding = embedder.encode(query, convert_to_tensor=True)
    
    # we use cosine-similarity and torch.topk to find the highest 5 scores
    cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0]
    top_results = torch.topk(cos_scores, k = top_k)
    
    print("\n\n======================\n\n")
    print("Query:", query)
    print("\nTop 5 most similar sentences in corpus:")
    
    for score, idx in zip(top_results[0], top_results[1]):
        print(corpus[idx], "score:{.4f}".format(score))
        
    """
    # Alternatively, we can also use util.semantic_search to perform cosine similarty + topk
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)
    hits = hits[0]      #Get the hits for the first query
    for hit in hits:
        print(corpus[hit['corpus_id']], "(Score: {:.4f})".format(hit['score']))
    """

## Re-Ranker based on Cross-Encoder

- in the example application, a cross-encoder is used for reranking
<br>
- use cross-encoder to score against retreived documents against a query

In [7]:
!pip install -U rank_bm25

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting sentence-transformers
  Using cached https://files.pythonhosted.org/packages/3b/fd/8a81047bbd9fa134a3f27e12937d2a487bd49d353a038916a5d7ed4e5543/sentence-transformers-2.0.0.tar.gz
Collecting rank_bm25
  Downloading https://files.pythonhosted.org/packages/16/5a/23ed3132063a0684ea66fb410260c71c4ffda3b99f8f1c021d1e245401b5/rank_bm25-0.2.1-py3-none-any.whl
Collecting transformers<5.0.0,>=4.6.0 (from sentence-transformers)
  Using cached https://files.pythonhosted.org/packages/f9/35/75ade98da1e5b36a7a055fa3257d9b61769914a416b8844d6e4b237219e4/transformers-4.9.2-py3-none-any.whl
Collecting torch>=1.6.0 (from sentence-transformers)
  Using cached https://files.pythonhosted.org/packages/51/cd/3fa53975000c3

In [10]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util

cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/541 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
# re-ranking
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)

# sort results
for idx in range(len(cross_scores)):
    hits[idx]['cross-score'] = cross_scores[idx]
    
# output top-5 hits
hits = sorted(hits, key=lambda x:x['score'], reverse=True)
for hit in hits[0:5]:
    print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))