In [None]:
# rag_ollama.py
import os
import glob
import json
import faiss
import numpy as np
import ollama
from sentence_transformers import SentenceTransformer
from typing import List, Tuple

# ---------- Config ----------
DOCS_DIR = "docs/"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
INDEX_PATH = "faiss_index.bin"
META_PATH = "faiss_meta.json"
TOP_K = 4


# ---------- Helpers: load & chunk documents ----------
def load_text_files(folder: str) -> List[Tuple[str, str]]:
    files = glob.glob(os.path.join(folder, "*.txt"))
    docs = []
    for p in files:
        name = os.path.basename(p)
        with open(p, "r", encoding="utf-8") as f:
            text = f.read()
        docs.append((name, text))
    return docs


def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
    tokens = text.split()
    chunks = []
    i = 0
    while i < len(tokens):
        chunk = tokens[i:i + chunk_size]
        chunks.append(" ".join(chunk))
        i += chunk_size - overlap
    return chunks


# ---------- Build embeddings + FAISS index ----------
def build_faiss_index(docs_dir: str):
    model = SentenceTransformer(EMBEDDING_MODEL)
    docs = load_text_files(docs_dir)
    embeddings = []
    metadata = []
    for doc_id, text in docs:
        chunks = chunk_text(text, chunk_size=200, overlap=40)
        for i, chunk in enumerate(chunks):
            emb = model.encode(chunk, convert_to_numpy=True)
            embeddings.append(emb)
            metadata.append({"doc_id": doc_id, "chunk_id": i, "text": chunk})
    if not embeddings:
        raise ValueError("No document chunks found. Put .txt files in docs/")

    embeddings = np.vstack(embeddings).astype("float32")
    dim = embeddings.shape[1]

    index = faiss.IndexFlatIP(dim)
    faiss.normalize_L2(embeddings)
    index.add(embeddings)

    faiss.write_index(index, INDEX_PATH)
    with open(META_PATH, "w", encoding="utf-8") as f:
        json.dump(metadata, f, ensure_ascii=False, indent=2)
    print(f"Built index with {index.ntotal} vectors, saved to {INDEX_PATH}")


# ---------- Load index + retrieve ----------
def load_index():
    index = faiss.read_index(INDEX_PATH)
    with open(META_PATH, "r", encoding="utf-8") as f:
        metadata = json.load(f)
    return index, metadata


def retrieve(query: str, index, metadata, model: SentenceTransformer, k: int = TOP_K):
    q_emb = model.encode(query, convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb.reshape(1, -1))
    scores, ids = index.search(q_emb.reshape(1, -1), k)
    results = []
    for score, idx in zip(scores[0], ids[0]):
        if idx < 0:
            continue
        meta = metadata[idx]
        results.append({"score": float(score), "meta": meta})
    return results


# ---------- Ollama generation ----------

def generate_with_ollama(query: str, retrieved_chunks: list):
    context = "\n\n---\n\n".join([c["meta"]["text"] for c in retrieved_chunks])
    prompt = (
        f"You are a helpful assistant. Use the following context to answer the question. "
        f"If the answer is not in the context, say 'I don't know'.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n\nAnswer:"
    )

    response = ollama.generate(
        model="llama3.1:8b",  # your local model
        prompt=prompt
    )
    return response["response"].strip()


# ---------- Simple interactive function ----------
def answer_query(query: str):
    index, metadata = load_index()
    embed_model = SentenceTransformer(EMBEDDING_MODEL)
    retrieved = retrieve(query, index, metadata, embed_model, k=TOP_K)
    print("Retrieved chunks (score, doc_id, chunk_id):")
    for r in retrieved:
        print(r["score"], r["meta"]["doc_id"], r["meta"]["chunk_id"])
    return generate_with_ollama(query, retrieved)


# ---------- CLI ----------
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--build", action="store_true", help="Build index from docs/")
    parser.add_argument("--query", type=str, help="Query text")
    args = parser.parse_args()

    if args.build:
        build_faiss_index(DOCS_DIR)
    elif args.query:
        resp = answer_query(args.query)
        print("\n=== Answer ===\n")
        print(resp)
    else:
        print("Run with --build to build index, then --query 'your question' to ask.")