# Load data (from previous notebook)

In [None]:
import os
import urllib.request

# File name and URL
file_name = "sentences.txt"
url = "https://github.com/datanizing/m3-llm-workshop/raw/main/sentences.txt"

# Check if the file exists, if not, download it
if not os.path.isfile(file_name):
    print(f"{file_name} does not exist. Downloading...")
    urllib.request.urlretrieve(url, file_name)
    print(f"Downloaded {file_name}.")
else:
    print(f"{file_name} already exists.")

In [None]:
sentences = open("sentences.txt", encoding="utf-8").read().split("@@@")

In [None]:
len(sentences)

In [None]:
import numpy as np
with open("sentences-saev2.npy", "rb") as f:
    sembeddings = np.load(f)

In [None]:
with open("sentences-mpnet.npy", "rb") as f:
    sembeddings2 = np.load(f)

# Retrieval

In [None]:
import numpy as np
import pandas as pd
def search(query, text, corpus_embeddings, bi_encoder, cross_encoder, prompt_name="query", top_k=100):
    # code query to restrict search space
    query_embedding = bi_encoder.encode(query, prompt_name=prompt_name)
    
    # Determine similarity (vectors are normalized)
    sim = model.similarity(query_embedding, corpus_embeddings)[0].numpy() 
    
    # Get most similar top_k by sorting
    hits = [ { "text": text[i], "score": sim[i] } 
                     for i in sim.argsort()[::-1][0:top_k] ]

    # Consider only top hits for re-rankin
    cross_input = [[query, hit["text"]] for hit in hits]
    # cross-encode (this takes most time)
    cross_scores = cross_encoder.predict(cross_input)

    # Integrate cross-scores in original hits (this would be easier with pandas)
    for i in range(len(cross_scores)):
        hits[i]["cross-score"] = cross_scores[i]

    # nre-sort by cross-score, descending!
    hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
    
    # Return top-20 results of re-ranker as dataframe
    return pd.DataFrame(hits[0:20])

In [None]:
# bi-encoder is needed
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("Snowflake/snowflake-arctic-embed-l-v2.0")

In [None]:
model2 = SentenceTransformer('all-mpnet-base-v2')

In [None]:
# cross encoder
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [None]:
pd.set_option('display.max_colwidth', 0)

In [None]:
search("Is the climate crisis worse for poorer countries?", 
       sentences, sembeddings, model, cross_encoder, prompt_name="query").style.background_gradient(cmap='coolwarm')

In [None]:
search("Is the climate crisis worse for poorer countries?", 
       sentences, sembeddings2, model2, cross_encoder, prompt_name=None).style.background_gradient(cmap='coolwarm')

In [None]:
search("Which countries are impacted most by the climate crisis?", 
       sentences, sembeddings, model, cross_encoder).style.background_gradient(cmap='coolwarm')

In [None]:
search("Which countries are impacted most by the climate crisis?", 
       sentences, sembeddings2, model2, cross_encoder, prompt_name=None).style.background_gradient(cmap='coolwarm')

In [None]:
# eine von vielen weiteren Alternativen:
cross_encoder = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual",
    model_kwargs={"torch_dtype": "auto"},
    trust_remote_code=True)

In [None]:
search("Sind arme Länder durch die Klimakrise stärker betroffen?", 
       sentences, sembeddings, model, cross_encoder).style.background_gradient(cmap='coolwarm')

In [None]:
search("Welche Länder sind durch die Klimakrise am meisten betroffen?", 
       sentences, sembeddings, model, cross_encoder).style.background_gradient(cmap='coolwarm')