In [89]:
import os
from pathlib import Path
import streamlit as st
from dotenv import load_dotenv

load_dotenv()

True

In [90]:
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_community.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [91]:
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")

In [92]:
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN

In [93]:
from langchain_groq import ChatGroq
llm = ChatGroq(model="gemma2-9b-it")

from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

In [94]:
def load_docs(docs_dir: Path):
    """Load .txt and .pdf from docs_dir -> LangChain Documents list."""
    docs = []
    for p in docs_dir.glob("**/*"):
        if p.is_file() and p.stat().st_size > 0:  # Check if file is not empty
            if p.suffix.lower() == ".txt":
                docs += TextLoader(str(p), encoding="utf-8").load()
            elif p.suffix.lower() == ".pdf":
                docs += PyPDFLoader(str(p)).load()
    return docs

In [95]:
try:
    import tiktoken
    def count_tokens(text, model_name="gemma2-9b-it"):
        try:
            enc = tiktoken.encoding_for_model(model_name)
            return len(enc.encode(text))
        except KeyError:
            # Fallback to cl100k_base encoding (used by GPT-3.5/4)
            enc = tiktoken.get_encoding("cl100k_base")
            return len(enc.encode(text))
except Exception:
    def count_tokens(text, model_name=None):
        # fallback heuristic: ~0.75 * words
        return max(1, int(len(text.split()) * 0.75))


In [96]:
from langchain.schema import Document

def build_splits(docs, max_tokens=750, overlap_tokens=120):
    """Token-aware chunking with rich metadata for Sinhala text."""
    splits = []
    for doc in docs:
        text = doc.page_content
        words = text.split()
        start = 0
        while start < len(words):
            end = min(len(words), start + max_tokens * 2)  # safe upper bound
            while end > start:
                candidate = " ".join(words[start:end])
                if count_tokens(candidate) <= max_tokens:
                    break
                end -= 1
            if end == start:
                end = min(len(words), start + max_tokens)
            chunk_text = " ".join(words[start:end])
            new_meta = dict(doc.metadata or {})
            new_meta.update({
                "_chunk_start": start,
                "_chunk_end": end,
                "source": doc.metadata.get("source", getattr(doc, "source", "unknown")),
                "orig_filename": doc.metadata.get("file_name") or doc.metadata.get("source", "unknown")
            })
            splits.append(Document(page_content=chunk_text, metadata=new_meta))
            start = max(start + 1, end - int(overlap_tokens / 0.75))
    return splits

In [97]:
def make_embeddings_and_persist(splits, embeddings, persist_path: Path):
    """Batch embeddings and cache them (persistence)."""
    if persist_path.exists():
        return FAISS.load_local(str(persist_path), embeddings, allow_dangerous_deserialization=True)
    texts = [d.page_content for d in splits]
    batch_size = 32
    vectors = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        vectors.extend(embeddings.embed_documents(batch))
    vs = FAISS.from_documents(splits, embeddings)
    vs.save_local(str(persist_path))
    return vs



In [98]:
from sentence_transformers import CrossEncoder

cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def rerank_with_cross_encoder(query, docs, top_k=3):
    """Re-rank documents using cross-encoder for better relevance."""
    pairs = [[query, d.page_content] for d in docs]
    scores = cross_encoder.predict(pairs)
    ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
    return [d for _, d in ranked[:top_k]]

def ensure_vectorstore(splits, embeddings: Embeddings, persist_path: Path):
    """Create or load FAISS index."""
    return make_embeddings_and_persist(splits, embeddings, persist_path)

In [99]:
docs_dir = Path("docs")
docs = load_docs(docs_dir)
splits = build_splits(docs, max_tokens=750, overlap_tokens=120)  # Token-aware chunking

In [100]:
persist_path = Path("vectorstore")
vectorstore = make_embeddings_and_persist(splits, embeddings, persist_path)

In [101]:
retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 20, "lambda_mult": 0.5})

In [102]:
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate

system = SystemMessagePromptTemplate.from_template(
    "ඔබ සිංහලෙන් පිළිතුරු දෙන සහායකයෙක් වනවා. සපයන ලද මූලාශ්‍ර පමණක් භාවිතා කර පිළිතුරු දෙන්න. "
    "පිළිතුර මූලාශ්‍රවල නොමැති නම්, මම නොදනිමි කියා සහ සොයන්නට හෝ වැඩි විස්තර ඉල්ලන්නට යෝජනා කරන්න. "
    "පිළිතුරු සංක්ෂිප්තව තබා ගන්න."
)

human = HumanMessagePromptTemplate.from_template(
    "පරිශීලක ප්‍රශ්නය: {question}\n\nපහත සන්දර්භය භාවිතා කරන්න: {context}\n\nමූලාශ්‍ර රේඛාගතව උපුටා දක්වන්න: [1], [2] වැනි ආකාරයකින්."
)

prompt = ChatPromptTemplate.from_messages([system, human])


In [103]:

from langchain.chains import RetrievalQA

class EnhancedQAChain:
    """Enhanced QA chain with re-ranking and better formatting."""

    def __init__(self, llm, retriever, cross_encoder):
        self.llm = llm
        self.retriever = retriever
        self.cross_encoder = cross_encoder

    def retrieve_and_rerank(self, query):
        candidates = self.retriever.get_relevant_documents(query)
        reranked = rerank_with_cross_encoder(query, candidates, top_k=3)
        return reranked

    def format_docs(self, docs):
        """Format documents with citations."""
        formatted = []
        for i, doc in enumerate(docs, 1):
            source = doc.metadata.get("orig_filename", "unknown")
            formatted.append(f"[{i}] {doc.page_content} (Source: {source})")
        return "\n\n".join(formatted)

    def invoke(self, query):
        docs = self.retrieve_and_rerank(query)
        context = self.format_docs(docs)

        full_prompt = prompt.format_messages(question=query, context=context)

        response = self.llm.invoke(full_prompt)

        return response.content

qa_chain = EnhancedQAChain(llm, retriever, cross_encoder)

In [107]:
try:
    query = "AI හි ප්‍රධාන වර්ග"
    result = qa_chain.invoke(query)
    print("Answer:", result)
except Exception as e:
    print(f"Error during query: {e}")

Answer: AI හි ප්‍රධාන වර්ග අතර ඒවායේ යෙදුම් වලට අනුව 
1. Narrow AI (Weak AI) 
2. General AI (Strong AI) 
3. Super AI 

ලෙස වේ. [1] 



In [109]:
def precision_at_k(qa_pairs, retriever, cross_encoder, k=3):
    """Measure precision at k for retrieval + re-ranking."""
    hits = 0
    for q, expected_filename in qa_pairs:
        candidates = retriever.get_relevant_documents(q)
        reranked = rerank_with_cross_encoder(q, candidates, top_k=k)
        print(f"Query: {q}")
        print(f"Top-{k} retrieved documents' sources: {[d.metadata.get('orig_filename') for d in reranked]}")
        if any(expected_filename in (d.metadata.get("orig_filename") or "") for d in reranked):
            hits += 1
    return hits / len(qa_pairs) if qa_pairs else 0

eval_queries = [
    ("AI යනු කුමක්ද?", "lesson1.txt"),
    ("කෘතිම බුද්ධිකත්වයේ මූලික කරුණු මොනවාද?", "lesson1.txt"),
]

try:
    precision = precision_at_k(eval_queries, retriever, cross_encoder, k=3)
    print(f"Precision at {3}: {precision:.2%}")
except Exception as e:
    print(f"Evaluation error: {e}")

Query: AI යනු කුමක්ද?
Top-3 retrieved documents' sources: [None]
Query: කෘතිම බුද්ධිකත්වයේ මූලික කරුණු මොනවාද?
Top-3 retrieved documents' sources: [None]
Precision at 3: 0.00%
