# AbstainQA  Confidence-Gated RAG System

Designed a retrieval-based QA system that deterministically abstains when document evidence is insufficient. Implemented confidence estimation from embedding similarity, enforced LLM grounding via system constraints, and validated safe failure behavior under empty and low-relevance retrieval.

In [1]:
!pip install -q tiktoken
!pip install -q langchain-huggingface sentence-transformers

In [None]:
!pip install chromadb

In [None]:
import google.generativeai as genai
import os
import json

In [4]:

import re
import json
import hashlib
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Generator
import tiktoken

@dataclass
class Chunk:
    """Represents a single document chunk with metadata."""
    id: str
    content: str
    source_file: str
    section_path: list[str]
    token_count: int
    has_table: bool

    def to_dict(self) -> dict:
        return asdict(self)

class MarkdownChunker:
    """Chunks markdown documents based on structural elements."""
    HEADER_PATTERN = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE)
    TABLE_PATTERN = re.compile(r'^\|.+\|$\n^\|[-:\s|]+\|$(?:\n^\|.+\|$)+', re.MULTILINE)

    def __init__(self, max_tokens: int = 1000, model: str = "cl100k_base"):
        self.max_tokens = max_tokens
        self.tokenizer = tiktoken.get_encoding(model)

    def count_tokens(self, text: str) -> int:
        return len(self.tokenizer.encode(text))

    def generate_chunk_id(self, source: str, content: str, index: int) -> str:
        hash_input = f"{source}:{index}:{content[:100]}"
        return hashlib.sha256(hash_input.encode()).hexdigest()[:16]

    def extract_tables(self, text: str) -> list[tuple[int, int, str]]:
        tables = []
        for match in self.TABLE_PATTERN.finditer(text):
            tables.append((match.start(), match.end(), match.group()))
        return tables

    def parse_sections(self, content: str) -> list[dict]:
        lines = content.split('\n')
        sections = []
        current_section = {'level': 0, 'title': '', 'content_lines': [], 'has_table': False}
        i = 0
        while i < len(lines):
            line = lines[i]
            header_match = self.HEADER_PATTERN.match(line)
            if header_match:
                if current_section['content_lines'] or current_section['title']:
                    section_content = '\n'.join(current_section['content_lines'])
                    current_section['has_table'] = bool(self.TABLE_PATTERN.search(section_content))
                    sections.append({
                        'level': current_section['level'],
                        'title': current_section['title'],
                        'content': section_content.strip(),
                        'has_table': current_section['has_table']
                    })
                level = len(header_match.group(1))
                title = header_match.group(2).strip()
                current_section = {'level': level, 'title': title, 'content_lines': [], 'has_table': False}
            else:
                current_section['content_lines'].append(line)
            i += 1
        if current_section['content_lines'] or current_section['title']:
            section_content = '\n'.join(current_section['content_lines'])
            current_section['has_table'] = bool(self.TABLE_PATTERN.search(section_content))
            sections.append({
                'level': current_section['level'],
                'title': current_section['title'],
                'content': section_content.strip(),
                'has_table': current_section['has_table']
            })
        return sections

    def build_section_path(self, sections: list[dict], current_idx: int) -> list[str]:
        current = sections[current_idx]
        path = []
        if current['title']: path.append(current['title'])
        current_level = current['level']
        for i in range(current_idx - 1, -1, -1):
            section = sections[i]
            if section['level'] < current_level and section['title']:
                path.insert(0, section['title'])
                current_level = section['level']
                if current_level == 1: break
        return path

    def _build_header_prefix(self, section_path: list[str]) -> str:
        if not section_path: return ""
        lines = []
        for i, title in enumerate(section_path):
            lines.append('#' * (i + 1) + ' ' + title)
        return '\n'.join(lines) + '\n\n'

    def split_large_section(self, text: str, section_path: list[str], source_file: str, start_index: int) -> Generator[Chunk, None, None]:
        # (Included from your snippet)
        header_prefix = self._build_header_prefix(section_path)
        header_tokens = self.count_tokens(header_prefix)
        available_tokens = self.max_tokens - header_tokens
        paragraphs = re.split(r'\n\n+', text)
        current_chunk_parts = []
        current_tokens = 0
        chunk_index = start_index

        for para in paragraphs:
            para = para.strip()
            if not para: continue
            para_tokens = self.count_tokens(para)
            is_table = bool(self.TABLE_PATTERN.search(para))

            if is_table and para_tokens > available_tokens:
                if current_chunk_parts:
                    content = header_prefix + '\n\n'.join(current_chunk_parts)
                    yield Chunk(self.generate_chunk_id(source_file, content, chunk_index), content, source_file, section_path, self.count_tokens(content), False)
                    chunk_index += 1; current_chunk_parts = []; current_tokens = 0
                content = header_prefix + para
                yield Chunk(self.generate_chunk_id(source_file, content, chunk_index), content, source_file, section_path, self.count_tokens(content), True)
                chunk_index += 1
                continue

            if current_tokens + para_tokens > available_tokens:
                if current_chunk_parts:
                    chunk_text = '\n\n'.join(current_chunk_parts)
                    content = header_prefix + chunk_text
                    yield Chunk(self.generate_chunk_id(source_file, content, chunk_index), content, source_file, section_path, self.count_tokens(content), bool(self.TABLE_PATTERN.search(chunk_text)))
                    chunk_index += 1; current_chunk_parts = []; current_tokens = 0

            current_chunk_parts.append(para)
            current_tokens += para_tokens

        if current_chunk_parts:
            chunk_text = '\n\n'.join(current_chunk_parts)
            content = header_prefix + chunk_text
            yield Chunk(self.generate_chunk_id(source_file, content, chunk_index), content, source_file, section_path, self.count_tokens(content), bool(self.TABLE_PATTERN.search(chunk_text)))

    def chunk_document(self, content: str, source_file: str) -> Generator[Chunk, None, None]:
        sections = self.parse_sections(content)
        chunk_index = 0
        for i, section in enumerate(sections):
            section_path = self.build_section_path(sections, i)
            header = ('#' * section['level'] + ' ' + section['title']) if section['title'] else ''
            full_text = (header + '\n\n' + section['content']).strip() if header else section['content'].strip()
            if not full_text: continue

            token_count = self.count_tokens(full_text)
            if token_count <= self.max_tokens:
                yield Chunk(self.generate_chunk_id(source_file, full_text, chunk_index), full_text, source_file, section_path, token_count, section['has_table'])
                chunk_index += 1
            else:
                for chunk in self.split_large_section(section['content'], section_path, source_file, chunk_index):
                    yield chunk
                    chunk_index += 1


In [5]:
import torch
from langchain_huggingface import HuggingFaceEmbeddings

def get_embedding_model():
    """
    Initializes the BGE-M3 embedding model.
    """
    # Check for GPU availability in Colab
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading BGE-M3 model on: {device}")

    # Configuration for BGE-M3
    model_name = "BAAI/bge-m3"

    # Model keyword arguments
    model_kwargs = {
        "device": device,
        # Trust remote code is sometimes needed for specific architectures,
        # but usually safe for standard BGE. Keeps it robust.
        "trust_remote_code": True
    }

    # Encoding keyword arguments
    # Normalize embeddings is crucial for Cosine Similarity
    encode_kwargs = {
        "normalize_embeddings": True
    }

    # Initialize the LangChain wrapper
    embedding_model = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )

    return embedding_model

In [6]:
import numpy as np

class GuardrailManager:
    def __init__(self, confidence_threshold=0.35): # Slightly lowered threshold for robustness
        self.CONFIDENCE_THRESHOLD = confidence_threshold

    def calculate_confidence(self, retrieved_chunks, k=3):
        if not retrieved_chunks:
            return 0.0

        distances = np.array([c["distance"] for c in retrieved_chunks])
        sims = 1 / (1 + distances)

        sims_sorted = np.sort(sims)[::-1]
        top1 = sims_sorted[0]
        topk = sims_sorted[:k]
        avg_topk = topk.mean()

        gap = top1 - (sims_sorted[1] if len(sims_sorted) > 1 else 0.0)

        # Evidence aggregation with safety anchor
        confidence = (
            0.6 * top1          # sufficiency
          + 0.25 * avg_topk     # consistency
          + 0.15 * gap          # ambiguity penalty
        )

        return round(float(min(confidence, 1.0)), 2)



    def check_guardrails(self, query, retrieved_chunks, confidence_score):
        """
        Applies strict failure logic.
        """
        # Case 1: Empty Retrieval
        if not retrieved_chunks:
            return False, "Abstain: No relevant documents found."

        # Case 2: Confidence Threshold
        if confidence_score < self.CONFIDENCE_THRESHOLD:
            return False, f"Abstain: Confidence score {confidence_score} is below threshold {self.CONFIDENCE_THRESHOLD}."

        return True, "Passed"



In [7]:
import chromadb
from chromadb import EmbeddingFunction, Documents, Embeddings
import uuid
import os
import shutil
import json
import glob
from langchain_core.documents import Document  # The correct new import

# --- 1. UPDATED VECTOR STORE MANAGER ---
class ChromaEmbeddingAdapter(EmbeddingFunction):
    def __init__(self, langchain_embeddings):
        self.ef = langchain_embeddings
    def __call__(self, input: Documents) -> Embeddings:
        return self.ef.embed_documents(input)
    def name(self):
        return "bge-m3"

class VectorStoreManager:
    def __init__(self, embedding_function, collection_name="uoh_policies", persist_path="./policy_db_v3"):
        """
        UPDATED: Now accepts 'persist_path' to fix the error.
        """
        self.client = chromadb.PersistentClient(path=persist_path)
        self.embedding_adapter = ChromaEmbeddingAdapter(embedding_function)
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            embedding_function=self.embedding_adapter
        )
        print(f"Collection '{collection_name}' loaded from '{persist_path}'.")

    def add_chunks(self, chunks):
        if not chunks: return
        ids = [str(uuid.uuid4()) for _ in chunks]
        documents = [chunk.page_content for chunk in chunks]
        metadatas = [chunk.metadata for chunk in chunks]
        self.collection.add(ids=ids, documents=documents, metadatas=metadatas)
        print(f"Added {len(chunks)} chunks.")

    def retrieve_context(self, query, top_k=5):
        results = self.collection.query(
            query_texts=[query], n_results=top_k, include=["documents", "metadatas", "distances"]
        )
        retrieved_chunks = []
        if results['ids'] and results['ids'][0]:
            for i in range(len(results['ids'][0])):
                retrieved_chunks.append({
                    "id": results['ids'][0][i],
                    "text": results['documents'][0][i],
                    "metadata": results['metadatas'][0][i],
                    "distance": results['distances'][0][i]
                })
        return retrieved_chunks

# --- 2. UPDATED RAG PIPELINE ---
class RAGPipeline:
    def __init__(self, db_path="./policy_db_v3"):
        print("Initializing RAG Pipeline with Custom MarkdownChunker...")

        # Initialize your custom chunker
        self.chunker = MarkdownChunker(max_tokens=1000)

        self.embedding_model = get_embedding_model()

        # Now VectorStoreManager will accept the 'persist_path' argument
        self.vector_db = VectorStoreManager(embedding_function=self.embedding_model, persist_path=db_path)
        self.guardrails = GuardrailManager()
        self.llm = LLMGenerator()
        print("Pipeline Ready.")

    def ingest_document(self, markdown_text: str, source_name: str):
        # 1. Generate custom chunks
        custom_chunks = list(self.chunker.chunk_document(markdown_text, source_name))

        # 2. Convert to LangChain Documents for Chroma
        compatible_chunks = []
        for c in custom_chunks:
            meta = {
                "id": c.id,
                "source": c.source_file,
                "has_table": str(c.has_table),
                "token_count": c.token_count,
                "section_path": " > ".join(c.section_path)
            }
            doc = Document(page_content=c.content, metadata=meta)
            compatible_chunks.append(doc)

        print(f"Adding {len(compatible_chunks)} chunks from '{source_name}'...")
        self.vector_db.add_chunks(compatible_chunks)

    def query(self, user_query: str):
        print(f"\nProcessing Query: '{user_query}'")
        retrieved_chunks = self.vector_db.retrieve_context(user_query)
        retrieval_confidence = self.guardrails.calculate_confidence(retrieved_chunks)
        passed, reason = self.guardrails.check_guardrails(user_query, retrieved_chunks, retrieval_confidence)

        if not passed:
            return {
                "answer": "I cannot answer this question based on the provided policy documents.",
                "confidence": 0.0,
                "sources": []
            }

        raw_response = self.llm.generate(user_query, retrieved_chunks)
        try:
            final_json = json.loads(raw_response)
            final_json['confidence'] = min(final_json.get('confidence', 0.0), retrieval_confidence)
            return final_json
        except json.JSONDecodeError:
             return {"answer": "Error parsing model response.", "confidence": 0.0, "sources": []}



In [8]:
class LLMGenerator:
    def __init__(self, model_name="gemini-1.5-flash-001"): # Updated to canonical ID
        # 1. Configure API (Ensure key is set)
        os.environ["GOOGLE_API_KEY"] = "AIzaSyDE94dIBM4Qm2BD9jzI4wO1nyI-AOvKVQs"
        genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
        # 2. List models to verify availability (Safety check)
        try:
            self.model = genai.GenerativeModel(
                model_name=model_name,
                generation_config={"response_mime_type": "application/json"}
            )
            print(f"✅ LLM initialized with model: {model_name}")
        except Exception as e:
            print(f"⚠️ Error init model '{model_name}': {e}")
            print("Falling back to 'gemini-pro'...")
            self.model = genai.GenerativeModel("gemini-pro")

    def generate(self, query: str, context_chunks: list) -> str:
        # Format context
        context_str = ""
        for i, chunk in enumerate(context_chunks):
            # Handle potential missing keys safely
            text = chunk.get('text', '')
            cid = chunk.get('id', 'unknown')
            context_str += f"Source {i+1} (ID: {cid}):\n{text}\n\n"

        prompt = f"""
        You are a strict policy assistant. Answer based ONLY on the sources.

        ## USER QUERY:
        "{query}"

        ## SOURCES:
        {context_str}

        ## INSTRUCTIONS:
        1. Answer strictly using the sources.
        2. If the info is missing, return "confidence": 0.0.
        3. Output VALID JSON.

        ## OUTPUT JSON:
        {{
            "answer": "...",
            "confidence": <float>,
            "sources": ["source_id"]
        }}
        """

        try:
            response = self.model.generate_content(prompt)
            return response.text
        except Exception as e:
            return json.dumps({
                "answer": f"LLM Generation Error: {str(e)}",
                "confidence": 0.0,
                "sources": []
            })

In [None]:
rag_system = RAGPipeline()



In [10]:
import glob
md_files = glob.glob("./policies/*.md")
if md_files:
    print(f"Re-ingesting {len(md_files)} files...")
    for file_path in md_files:
        try:
            filename = os.path.basename(file_path)
            with open(file_path, "r", encoding="utf-8") as f:
                rag_system.ingest_document(f.read(), source_name=filename)
        except Exception as e:
            print(f"Skipping {file_path}: {e}")
    print("✅ Database rebuilt with Cosine Similarity!")
else:
    print("❌ No files found in ./policies")

Re-ingesting 11 files...
Adding 11 chunks from 'Consultancy-Policy-_-Guidelines_compressed.md'...
Added 11 chunks.
Adding 37 chunks from 'University-of-Hyderabad-Act_compressed.md'...
Added 37 chunks.
Adding 5 chunks from 'Policy-for-Student-Assistanceship-and-Scholarship_compressed.md'...
Added 5 chunks.
Adding 21 chunks from 'Information-Technology-Policy_compressed.md'...
Added 21 chunks.
Adding 10 chunks from 'Reservation-Policy-Student-Admissions-_compressed.md'...
Added 10 chunks.
Adding 13 chunks from 'POLICY.md'...
Added 13 chunks.
Adding 3 chunks from 'Reservation-Policy-for-Appointment-of-Teachers-and-Staff_compressed.md'...
Added 3 chunks.
Adding 22 chunks from 'admission_policy.md'...
Added 22 chunks.
Adding 18 chunks from 'Research-Policy-_compressed.md'...
Added 18 chunks.
Adding 9 chunks from 'ForeignStudentAdmissionPolicy_compressed.md'...
Added 9 chunks.
Adding 19 chunks from 'Teaching-and-Evaluation-Regulations_compressed.md'...
Added 19 chunks.
✅ Database rebuilt wit

In [11]:
rag_system.llm = LLMGenerator(model_name="gemini-3-flash-preview")
rag_system.guardrails = GuardrailManager(confidence_threshold=0.35)

✅ LLM initialized with model: gemini-3-flash-preview


In [12]:
print("-" * 40)
response = rag_system.query("How to secure myself")
print(json.dumps(response, indent=2))
print("-" * 40)

----------------------------------------

Processing Query: 'How to secure myself'
{
  "answer": "According to the university IT policies, you can secure yourself and your systems by following these guidelines:\n\n1. **System Responsibility**: You are responsible for the security and integrity of your own systems. If a computer is 'hacked into,' it is recommended that the system be shut down immediately to limit potential damage and prevent the attack from spreading. You must take reasonable steps to ensure the machine is not compromised before network privileges are restored (Source 2).\n2. **Data Protection**: Perform backups of critical data in compliance with the requirements of the IT Act of India (Source 2).\n3. **Identity and Communication**: Ensure all electronic communications accurately identify you as the sender; the use of anonymous or masquerading mail forwarders is prohibited (Source 2, 5). Do not use official university email addresses to register for personal social net