In [None]:
%%capture --no-display

!pip install huggingface-hub
!pip install datasets
!pip install transformers
!pip install sentence-transformers

!pip install faiss-gpu
!pip install faiss-cpu

!pip install gradio

In [None]:
import numpy as np
import pandas as pd
import faiss
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from sentence_transformers import util
from datasets import Dataset, load_dataset
from huggingface_hub import notebook_login

import os
import json
import gzip

In [None]:
notebook_login()

In [None]:
# Load bi-encoder and tokenizer
bi_encoder_name = "sentence-transformers/msmarco-distilbert-base-v2"
bi_encoder = AutoModel.from_pretrained(bi_encoder_name)
bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_name, batch_size=2)

# Load cross-encoder and tokenizer
cross_encoder_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_encoder_name)
cross_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_name)

In [None]:
%%capture --no-display

import torch

device = torch.device("cuda")
bi_encoder.to(device)

In [None]:
def load_wikipedia_data(wikipedia_filepath):
    if not os.path.exists(wikipedia_filepath):
        util.http_get(wikipedia_filepath, wikipedia_filepath)

    passages = []
    with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
        for line in fIn:
            data = json.loads(line.strip())
            passages.extend(data['paragraphs'])

    return passages

wikipedia_filepath = 'http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz'
corpus = load_wikipedia_data(wikipedia_filepath)

corpus = pd.DataFrame(corpus, columns =['text'])
ds = Dataset.from_pandas(corpus)

In [None]:
corpus = pd.DataFrame(corpus, columns =['text'])
ds = Dataset.from_pandas(corpus)

In [None]:
def cls_pooling(model_output):
    return model_output.last_hidden_state[:, 0]

In [None]:
def create_embeddings(corpus, model, tokenizer):
    encoded_input = tokenizer(
        corpus, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return cls_pooling(model_output)

In [None]:
embeddings_dataset = ds.map(
    lambda x: {"embeddings": create_embeddings(x["text"], bi_encoder, bi_tokenizer).detach().cpu().numpy()[0]}
)

Map:   0%|          | 0/509663 [00:00<?, ? examples/s]

In [None]:
from datasets import load_dataset
embeddings_dataset.push_to_hub("LukeSajkowski/simplewiki-2020-11-01-embeddings")

In [None]:
#Loading the dataset from the hub
datasets = load_dataset('LukeSajkowski/simplewiki-2020-11-01-embeddings')
embeddings_dataset = datasets['train']

In [None]:
# Create FAISS index for efficient search
embeddings_dataset.add_faiss_index(column="embeddings")

In [None]:
# Re-rank using cross-encoder
def rerank(query, candidates, model, tokenizer):
    inputs = tokenizer([query] * len(candidates), candidates, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        scores = model(**inputs).logits.squeeze().tolist()
    return [x for _, x in sorted(zip(scores, candidates), key=lambda pair: pair[0], reverse=True)]

In [None]:
import gradio as gr

def search_and_rerank(query):
    # Perform bi-encoder search
    query_embedding = create_embeddings(query, bi_encoder, bi_tokenizer).cpu().detach().numpy()
    scores, samples = embeddings_dataset.get_nearest_examples("embeddings", query_embedding, k=25)

    # Convert search results to DataFrame
    samples_df = pd.DataFrame.from_dict(samples)
    samples_df["scores"] = scores
    samples_df.sort_values("scores", ascending=False, inplace=True)
    top_k_candidates = list(samples_df["text"])

    # Re-rank using cross-encoder
    reranked_docs = rerank(query, top_k_candidates, cross_encoder, cross_tokenizer)

    # Return the final reranked results as a formatted string
    results = "Reranked results:\n"
    for idx, doc in enumerate(reranked_docs):
        results += f"{idx + 1}. {doc}\n"
    return results

# Gradio interface
input_query = gr.inputs.Textbox(lines=2, label="Enter your query")
output_result = gr.outputs.Textbox(label="Results")

interface = gr.Interface(fn=search_and_rerank, inputs=input_query, outputs=output_result, title="Search and Rerank")
interface.launch()