In [16]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import NLTKTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import os
import glob
from dotenv import load_dotenv
from langchain_core.language_models.base import BaseLanguageModel
from ragas.llms import LangchainLLMWrapper

load_dotenv()
key = os.getenv("GOOGLE_API_KEY")


In [17]:
# --- Load and chunk data ---
pdf_files = glob.glob("data/*.pdf")
pages = []
for pdf_file in pdf_files:
    loader = PyPDFLoader(pdf_file)
    pages.extend(loader.load_and_split())

In [18]:
text_splitter = NLTKTextSplitter(chunk_size=5000, chunk_overlap=1000)
chunks = text_splitter.split_documents(pages)

# --- First-stage vector DB ---
embedding_model = GoogleGenerativeAIEmbeddings(google_api_key=key, model="models/embedding-001")
db = Chroma.from_documents(chunks, embedding_model, persist_directory="chroma_db_")
db.persist()
db_connection = Chroma(persist_directory="chroma_db_", embedding_function=embedding_model)

first_retriever = db_connection.as_retriever(search_kwargs={"k": 30})

In [19]:
# --- ColBERT Reranker: Use bi-encoder for pairwise re-ranking ---
tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0")
model = AutoModel.from_pretrained("colbert-ir/colbertv2.0")

def colbert_rerank(query, candidate_docs):
    query_inputs = tokenizer(query, return_tensors='pt', truncation=True)
    query_embedding = model(**query_inputs).last_hidden_state.mean(dim=1)

    scored = []
    for doc in candidate_docs:
        inputs = tokenizer(doc.page_content, return_tensors='pt', truncation=True)
        doc_embedding = model(**inputs).last_hidden_state.mean(dim=1)
        sim = cosine_similarity(query_embedding.detach().numpy(), doc_embedding.detach().numpy())[0][0]
        scored.append((sim, doc))

    top_ranked = sorted(scored, key=lambda x: x[0], reverse=True)
    return [doc for _, doc in top_ranked[:10]]  # Return top-10 reranked

# --- Prompt Chain ---
chat_template = ChatPromptTemplate.from_messages([
    SystemMessage(content="""You are a helpful academic assistant.
    Please answer the question using only the provided context. 
    Do not include any explanations or additional information beyond what is asked.
    If the context does not contain enough information, say "I don't know" rather than making up an answer."""),
    HumanMessagePromptTemplate.from_template("""Answer the question based on the given context.
    Context: {context}
    Question: {question}
    Answer: """)
])
output_parser = StrOutputParser()

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# --- RAG Chain with rerank ---
def rag_two_stage(question):
    candidates = first_retriever.invoke(question)
    reranked = colbert_rerank(question, candidates)
    context = format_docs(reranked)
    chain = (
        {"context": lambda _: context, "question": lambda _: question}
        | chat_template
        | ChatGoogleGenerativeAI(google_api_key=key, model="gemini-1.5-flash-latest")
        | output_parser
    )
    answer = chain.invoke(question)

    return {
        "question": question,
        "answer": answer,
        "contexts": [doc.page_content for doc in reranked]
    }

In [22]:
from datasets import Dataset

examples = [
    {"question": "What is the role of aggregate functions in SQL?", "ground_truth": "They perform calculations on sets of values."},
    {"question": "Define relationship in the E-R model.", "ground_truth": "An association among several entities."},
    {"question": "What is the purpose of a canonical cover?", "ground_truth": "A minimal set of functional dependencies equivalent to the original."},
    {"question": "What is the main goal of a DBMS?", "ground_truth": "To provide efficient and convenient access to data."},
    {"question": "List three applications of DBMS.", "ground_truth": "Banking, Airlines, Manufacturing."},
    {"question": "How does UNION differ from INTERSECT in SQL?", "ground_truth": "UNION merges results, INTERSECT finds common rows."},
    {"question": "Define data independence.", "ground_truth": "Ability to modify schema at one level without affecting the next."},
    {"question": "What is a superkey?", "ground_truth": "A set of attributes that uniquely identify an entity."},
    {"question": "What is normalization in databases?", "ground_truth": "The process of structuring a relational database to reduce redundancy."},
    {"question": "What does the SELECT clause do in SQL?", "ground_truth": "Specifies the attributes to retrieve."},
    {"question": "What is a functional dependency?", "ground_truth": "A constraint between two sets of attributes."},
    {"question": "What is data redundancy in file systems?", "ground_truth": "Duplication of information across files."},
    {"question": "What is a candidate key?", "ground_truth": "A minimal superkey."},
    {"question": "What is a derived attribute in E-R model?", "ground_truth": "An attribute whose values can be derived from other attributes."},
   
]

dataset = Dataset.from_list(examples)

In [23]:
from ragas.metrics import (
    faithfulness,
    answer_relevancy,
    context_precision,
    context_recall,
)

results = [rag_two_stage(row["question"]) for row in dataset]

for col in ["answer", "contexts"]:
    if col in dataset.column_names:
        dataset = dataset.remove_columns(col)

dataset = dataset.add_column("answer", [r["answer"] for r in results])
dataset = dataset.add_column("contexts", [r["contexts"] for r in results])

from ragas import evaluate

my_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key="key")
wrapped_llm = LangchainLLMWrapper(my_llm)

score = evaluate(
    dataset,
    metrics=[faithfulness,answer_relevancy,context_precision, context_recall],
)
print(score)

Evaluating: 100%|██████████| 56/56 [00:56<00:00,  1.02s/it]


{'faithfulness': 0.9048, 'answer_relevancy': 0.8318, 'context_precision': 0.7621, 'context_recall': 0.9286}
