# Load data (from previous notebook)

In [None]:
sentences = open("sentences.txt").read().split("@@@")

In [None]:
len(sentences)

In [None]:
import numpy as np
with open("sentences.npy", "rb") as f:
    sembeddings = np.load(f)

In [None]:
with open("sentences2.npy", "rb") as f:
    sembeddings2 = np.load(f)

# Retrieval

In [None]:
import numpy as np
import pandas as pd
def search(query, text, corpus_embeddings, bi_encoder, cross_encoder, top_k=100):
    # code query to restrict search space
    question_embedding = bi_encoder.encode(query)
    
    # Determine similarity (vectors are normalized)
    sim = np.dot(corpus_embeddings, question_embedding)
    
    # Get most similar top_k by sorting
    hits = [ { "text": text[i], "score": sim[i] } 
                     for i in sim.argsort()[::-1][0:top_k] ]

    # Consider only top hits for re-rankin
    cross_input = [[query, hit["text"]] for hit in hits]
    # cross-encode (this takes most time)
    cross_scores = cross_encoder.predict(cross_input)

    # Integrate cross-scores in original hits (this would be easier with pandas)
    for i in range(len(cross_scores)):
        hits[i]["cross-score"] = cross_scores[i]

    # nre-sort by cross-score, descending!
    hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
    
    # Return top-20 results of re-ranker as dataframe
    return pd.DataFrame(hits[0:20])

In [None]:
# bi-encoder is needed
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

In [None]:
model2 = SentenceTransformer('multi-qa-mpnet-base-dot-v1')

In [None]:
# cross encoder
from sentence_transformers import CrossEncoder, util
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [None]:
pd.set_option('display.max_colwidth', 0)

In [None]:
search("The climate crisis is worse in poorer countries", sentences, sembeddings, model, cross_encoder)

In [None]:
search("The climate crisis is worse in poorer countries", sentences, sembeddings2, model2, cross_encoder)

In [None]:
search("Which countries are impacted most by the climate crisis?", sentences, sembeddings, model, cross_encoder)

In [None]:
scross_encoder = CrossEncoder("cross-encoder/qnli-electra-base")

In [None]:
search("Which countries are impacted most by the climate crisis?", sentences, sembeddings, model, scross_encoder)

In [None]:
search("Welche Länder sind am meisten von der Klimakrise betroffen?", sentences, sembeddings, model, scross_encoder)