In [None]:
!pip install pypdf
!pip install sentence-transformers faiss-gpu-cu12



In [None]:
!pip install -U -q sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import torch
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pypdf import PdfReader
import faiss
import numpy as np
from google.colab import userdata
from sentence_transformers.cross_encoder import CrossEncoder
import google.generativeai as genai

In [None]:
def extract_text_from_pdf(pdf_path):
    reader = PdfReader(pdf_path)
    text = ""
    for page in reader.pages:
        text += page.extract_text()
    return text

In [None]:
def chunk_text(text, chunk_size=512, overlap=50):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=overlap,
        length_function=len,
        is_separator_regex=False,
    )
    chunks = text_splitter.split_text(text)
    return chunks

In [None]:
def build_rag_system(pdf_path, fine_tuned_model_path):
    text = extract_text_from_pdf(pdf_path)
    chunks = chunk_text(text)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Load the fine-tuned model
    embedding_model = SentenceTransformer(fine_tuned_model_path, token=userdata.get('HF_TOKEN')).to(device=device)
    # Load the reranker model
    reranker = CrossEncoder('BAAI/bge-reranker-v2-m3', device=device)
    # Initialize the Gemini API
    GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
    genai.configure(api_key=GOOGLE_API_KEY)
    gemini_model = genai.GenerativeModel('gemini-2.5-pro')
    # Create embeddings for the chunks
    chunk_embeddings = embedding_model.encode(chunks, convert_to_tensor=True)
    # Build a FAISS index
    index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
    index.add(chunk_embeddings.cpu().detach().numpy())

    def retrieve(query, top_k=10):
        query_embedding = embedding_model.encode([query], convert_to_tensor=True)
        scores, top_k_indices = index.search(query_embedding.cpu().detach().numpy(), k=top_k)
        # Retrieve the top_k chunks
        initial_retrieved_chunks_with_indices = [(chunks[i], i) for i in top_k_indices[0]]
        # Prepare for reranking: create pairs of (query, chunk)
        rerank_pairs = [[query, chunk] for chunk, _ in initial_retrieved_chunks_with_indices]
        # Rerank the retrieved chunks
        rerank_scores = reranker.predict(rerank_pairs)
        # Sort chunks based on reranker scores and get the original indices
        reranked_sorted_indices = np.argsort(rerank_scores)[::-1]
        reranked_original_indices = [initial_retrieved_chunks_with_indices[i][1] for i in reranked_sorted_indices]
        # Return the top 1 reranked chunk using original indices
        return [chunks[i] for i in reranked_original_indices[:1]]

    def generate_answer(query, context):
        prompt = f"""
        You are a helpful assistant that answers questions based on the provided context.
        Context:
        {"".join(context)}
        Question:
        {query}
        Answer:
        """
        response = gemini_model.generate_content(prompt)
        return response.text

    return retrieve, generate_answer

In [None]:
if __name__ == '__main__':
    pdf_path = "/content/2501.00309v2.pdf"

    fine_tuned_model_path = 'google/embeddinggemma-300m'

    retriever, generator = build_rag_system(pdf_path, fine_tuned_model_path)

    user_query = "What is Neo4j?"
    retrieved_context = retriever(user_query)
    final_answer = generator(user_query, retrieved_context)

    print("\n--- Final Answer ---")
    print(final_answer)


--- Final Answer ---
Based on the provided context, Neo4j is a graph database platform that offers a comprehensive set of tools for storing, visualizing, managing, and querying graph data. It also includes an LLM Graph Builder for extracting graphs with LLMs and provides GraphRAG demos.
