# RAG for more complex microscope codebase

Here we developed a RAG agent to interface with [sidpy](https://pycroscopy.github.io/sidpy/), a dask-based python framework for microscopy data.  

The user interacts with the agent to perform basic operations to convert their data structures into sidpy data.


# Imports and Functions

In [2]:
import os
import json
from pathlib import Path
from typing import List

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser



## Function for persist_dir

In [3]:

from __future__ import annotations

import re
from pathlib import Path
from typing import Optional

def get_persist_dir(
    base_dir: str | Path,
    basename: str,
    *,
    new_persist: bool = False,
    create: bool = True,
) -> Path:
    """
    Find (or create) a persist directory under base_dir.

    Looks for directories named: f"{basename}_{run_id}" where run_id is an int.
    - If new_persist=False: returns the latest existing folder if found, else basename_0
    - If new_persist=True : returns a new folder with run_id = (latest + 1) or 0 if none exist
    """
    base_path = Path(base_dir).expanduser().resolve()
    base_path.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(basename)}_(\d+)$")

    max_run_id: Optional[int] = None
    for p in base_path.iterdir():
        if not p.is_dir():
            continue
        m = pattern.match(p.name)
        if not m:
            continue
        run_id = int(m.group(1))
        if max_run_id is None or run_id > max_run_id:
            max_run_id = run_id

    if max_run_id is None:
        next_id = 0
    else:
        next_id = (max_run_id + 1) if new_persist else max_run_id

    persist_path = base_path / f"{basename}_{next_id}"

    if create:
        persist_path.mkdir(parents=True, exist_ok=True)

    return persist_path


## RAG chain function

In [None]:
import os


os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"

In [4]:
def get_rag_chain(folder_path: str, persist_basename: str, 
                  new_persist: bool = False, chunk_size: int = 1200, chunk_overlap: int = 200, 
                  model: str = "gpt-5", temperature: int = 0, api_key: Optional[str] = None):
    
    print("=" * 60)
    print("STARTING get_rag_chain")
    print(f"Parameters: folder_path={folder_path}, persist_basename={persist_basename}")
    print(f"new_persist={new_persist}, chunk_size={chunk_size}, chunk_overlap={chunk_overlap}")
    print(f"model={model}, temperature={temperature}")
    print("=" * 60)

    # -----------------------------
    # 1) Environment variables
    # -----------------------------
    print("\n[1/6] Setting up environment variables...")
    # Make sure OPENAI_API_KEY is set in your environment before running
    if api_key is None:
        os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or ""
    else:
        os.environ["OPENAI_API_KEY"] = api_key

    print("OPENAI_API_KEY starts with:", (os.environ["OPENAI_API_KEY"][:5] + "...") if os.environ["OPENAI_API_KEY"] else "Not Set")

    # Disable LangSmith tracing (prevents 401 errors if you don't use LangSmith)
    os.environ["LANGCHAIN_TRACING_V2"] = "false"
    os.environ["LANGCHAIN_API_KEY"] = ""
    os.environ["LANGCHAIN_PROJECT"] = "rag-code-search"
    print("✓ Environment variables configured")


    # -----------------------------
    # 2) Text splitting / chunking (tuned for code)
    # -----------------------------
    print("\n[2/6] Creating text splitter...")
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size= chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", " ", ""],
    )
    print("✓ Text splitter created")


    # -----------------------------
    # 3) Load .py and .ipynb from a folder
    # -----------------------------
    print("\n[3/6] Loading code documents...")
    
    def load_py_file(path: Path) -> Document:
        text = path.read_text(encoding="utf-8", errors="ignore")
        return Document(page_content=text, metadata={"source": str(path), "type": "py"})

    def load_ipynb_file(path: Path) -> Document:
        nb = json.loads(path.read_text(encoding="utf-8", errors="ignore"))

        parts = []
        for cell in nb.get("cells", []):
            cell_type = cell.get("cell_type", "")
            src = cell.get("source", [])
            if isinstance(src, list):
                src = "".join(src)
            src = (src or "").strip()
            if not src:
                continue

            # Keep both markdown + code, but label them
            if cell_type == "code":
                parts.append("# --- notebook code cell ---\n" + src)
            elif cell_type == "markdown":
                parts.append("# --- notebook markdown cell ---\n" + src)

        text = "\n\n".join(parts)
        return Document(page_content=text, metadata={"source": str(path), "type": "ipynb"})

    def load_code_documents(folder_path: str) -> List[Document]:
        folder = Path(folder_path)
        documents: List[Document] = []

        for path in folder.rglob("*"):
            if path.is_dir():
                continue

            suffix = path.suffix.lower()
            if suffix == ".py":
                documents.append(load_py_file(path))
            elif suffix == ".ipynb":
                documents.append(load_ipynb_file(path))
            else:
                continue

        return documents



    documents = load_code_documents(folder_path)
    print(f"✓ Loaded {len(documents)} code documents from {folder_path!r} (.py + .ipynb).")

    print("  Splitting documents into chunks...")
    splits = text_splitter.split_documents(documents)
    print(f"✓ Split into {len(splits)} chunks.")


    # -----------------------------
    # 4) Embedding + Vector store (persisted)
    # -----------------------------
    print("\n[4/6] Setting up vector store...")
    #PERSIST_DIR = Path("./chroma_db_code_1")
    PERSIST_DIR = get_persist_dir("./persists/", persist_basename, new_persist=new_persist)
    print(f"  Persist directory: {PERSIST_DIR}")

    COLLECTION_NAME = "code_collection"

    print("  Creating embedding function...")
    embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")
    print("✓ Embedding function created")

    # IMPORTANT: if you run from_documents every time, you rebuild the DB.
    # The logic below loads existing DB if present, otherwise builds it.
    db_file = Path(PERSIST_DIR) # / "chroma.sqlite3"
    if db_file.exists():
        print(f"  Loading existing vector DB from {PERSIST_DIR}...")
        print("  ⚠️  WARNING: If on OneDrive, this may cause kernel crash!")
        vectorstore = Chroma(
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
            embedding_function=embedding_function,
        )
        print(f"✓ Loaded existing vector DB from {PERSIST_DIR!r}.")
    else:
        if not splits:
            raise ValueError("No code chunks found. Check that ./docs contains .py or .ipynb files.")
        print(f"  Creating new vector DB (this may take a moment)...")
        print("  ⚠️  WARNING: If on OneDrive, this may cause kernel crash!")
        vectorstore = Chroma.from_documents(
            documents=splits,
            embedding=embedding_function,
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
        )
        #vectorstore.persist()
        print(f"✓ Created and persisted vector DB at {PERSIST_DIR!r}.")


    # -----------------------------
    # 5) Retriever
    # -----------------------------
    print("\n[5/6] Creating retriever...")
    #retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
    retriever = vectorstore.as_retriever(
        search_type="mmr",
        search_kwargs={"k": 6, "fetch_k": 20, "lambda_mult": 0.5}
    )
    print("✓ Retriever created with MMR search (k=6, fetch_k=20)")

    # -----------------------------
    # 6) RAG chain that outputs PYTHON CODE ONLY (as a string)
    # -----------------------------
    print("\n[6/6] Building RAG chain...")
    
    def docs2str(docs: List[Document]) -> str:
        # Keep source hints so the model can reference repo utilities if they exist
        out = []
        for d in docs:
            src = d.metadata.get("source", "unknown")
            out.append(f"### SOURCE: {src}\n{d.page_content}")
        return "\n\n".join(out)

    template = """You are a coding assistant.

    You must write Python code ONLY (no markdown fences, no explanations).
    Your output must be a single Python script as plain text.

    CONTEXT (repository code snippets):
    {context}

    Task:
    {question}

    Python code:"""

    prompt = ChatPromptTemplate.from_template(template)
    print("  Prompt template created")

    print(f"  Initializing LLM ({model})...")
    llm = ChatOpenAI(model=model, temperature=temperature)
    print("✓ LLM initialized")

    print("  Assembling chain...")
    rag_chain = (
        {"context": retriever | docs2str, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    print("✓ RAG chain assembled")

    print("\n" + "=" * 60)
    print("✓✓✓ get_rag_chain completed successfully!")
    print("=" * 60 + "\n")
    
    return rag_chain

## Write code to py-file

In [17]:

def write_generated_code_to_file(
    code_str: str,
    prompt_str: str = "",
    filename: str = "main_rag_test.py",
    encoding: str = "utf-8",
) -> Path:
    """
    Overwrite `filename` with `code_str`. Creates the file if it doesn't exist.
    Returns the Path to the written file.
    """
    path = Path(filename).expanduser().resolve()
    path.parent.mkdir(parents=True, exist_ok=True)

    # Ensure we're writing a string (some chains return dicts / messages)
    if not isinstance(code_str, str):
        code_str = str(code_str)

    # Add a trailing newline for nicer diffs/editors
    if not code_str.endswith("\n"):
        code_str += "\n"

    path.write_text(f"# Prompt: \n \"\"\" {prompt_str} \n \"\"\" \n {code_str}", encoding = encoding)
    
    return path

# learn sidpy

In [None]:
new_persist = False # Set to True to create a new persist directory. Set to True while creating new db.
folder_path = "C:/Users/zwx/Documents/sidpy"

chunk_size = 800
chunk_overlap = 100


api_key = os.getenv("OPENAI_API_KEY")

persist_basename = "chroma_db_sidpy"

model = "gpt-5"
temperature = 0

task = """
Write code to generate a sidpy dataset for a fake 2D image of size 256x256.\n
You can fill the image with random values from 0 to 1.
"""

In [7]:
rag_chain = get_rag_chain(
    folder_path=folder_path,
    persist_basename=persist_basename,
    new_persist=new_persist,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    model=model,
    temperature=temperature,
    api_key=api_key,
)

STARTING get_rag_chain
Parameters: folder_path=C:/Users/zwx/Documents/sidpy, persist_basename=chroma_db_sidpy
new_persist=False, chunk_size=800, chunk_overlap=100
model=gpt-5, temperature=0

[1/6] Setting up environment variables...
OPENAI_API_KEY starts with: sk-pr...
✓ Environment variables configured

[2/6] Creating text splitter...
✓ Text splitter created

[3/6] Loading code documents...
✓ Loaded 101 code documents from 'C:/Users/zwx/Documents/sidpy' (.py + .ipynb).
  Splitting documents into chunks...
✓ Split into 2466 chunks.

[4/6] Setting up vector store...
  Persist directory: C:\Users\zwx\Documents\RAG-for-SPM-data-analysis\persists\chroma_db_sidpy_0
  Creating embedding function...
✓ Embedding function created
  Loading existing vector DB from C:\Users\zwx\Documents\RAG-for-SPM-data-analysis\persists\chroma_db_sidpy_0...
✓ Loaded existing vector DB from WindowsPath('C:/Users/zwx/Documents/RAG-for-SPM-data-analysis/persists/chroma_db_sidpy_0').

[5/6] Creating retriever...
✓ 

In [11]:
rag_chain

{
  context: VectorStoreRetriever(tags=['Chroma', 'OpenAIEmbeddings'], vectorstore=<langchain_chroma.vectorstores.Chroma object at 0x0000018B3636AE10>, search_type='mmr', search_kwargs={'k': 6, 'fetch_k': 20, 'lambda_mult': 0.5})
           | RunnableLambda(docs2str),
  question: RunnablePassthrough()
}
| ChatPromptTemplate(input_variables=['context', 'question'], input_types={}, partial_variables={}, messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['context', 'question'], input_types={}, partial_variables={}, template='You are a coding assistant.\n\n    You must write Python code ONLY (no markdown fences, no explanations).\n    Your output must be a single Python script as plain text.\n\n    CONTEXT (repository code snippets):\n    {context}\n\n    Task:\n    {question}\n\n    Python code:'), additional_kwargs={})])
| ChatOpenAI(profile={'max_input_tokens': 400000, 'max_output_tokens': 128000, 'image_inputs': True, 'audio_inputs': False, 'video_inputs': Fals

In [8]:

print('✓ rag_chain created successfully')

print("\n" + "=" * 60)
print("INVOKING RAG CHAIN")
print("=" * 60)
print(f"Task: {task[:100]}..." if len(task) > 100 else f"Task: {task}")
print("\n[Step 1/3] Retrieving relevant documents from vector store...")
print("  ⚠️  WARNING: ChromaDB query on OneDrive may crash kernel here!")

try:
    print("\n[Step 2/3] Calling LLM to generate code...")
    print(f"  Using model: {model}")
    print("  This may take 10-30 seconds depending on complexity...")
    
    response_code_str = rag_chain.invoke(task)
    
    print("\n[Step 3/3] ✓ Response received successfully!")
    print(f"  Generated code length: {len(response_code_str)} characters")
    
except Exception as e:
    print("\n❌ ERROR during rag_chain.invoke()!")
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {str(e)}")
    import traceback
    print("\nFull traceback:")
    traceback.print_exc()
    raise

print("\n===== GENERATED PYTHON CODE (STRING) =====\n")
print(response_code_str)


✓ rag_chain created successfully

INVOKING RAG CHAIN
Task: 
Write code to generate a sidpy dataset for a fake 2D image of size 256x256.

You can fill the image...

[Step 1/3] Retrieving relevant documents from vector store...

[Step 2/3] Calling LLM to generate code...
  Using model: gpt-5
  This may take 10-30 seconds depending on complexity...

[Step 3/3] ✓ Response received successfully!
  Generated code length: 1796 characters

===== GENERATED PYTHON CODE (STRING) =====

import numpy as np

# Create random 2D image data
np.random.seed(0)
image = np.random.rand(256, 256)

# Import sidpy Dataset and Dimension with fallbacks for different package structures
try:
    from sidpy import Dataset, Dimension
except Exception:
    try:
        from sidpy.sid.dataset import Dataset
        from sidpy.sid.dimension import Dimension
    except Exception as e:
        raise ImportError("sidpy is required to run this script. Please install sidpy.") from e

# Create the sidpy Dataset
if hasattr(Da

In [10]:
print("\n===== WRITING GENERATED CODE TO FILE =====\n")


===== WRITING GENERATED CODE TO FILE =====



## Test the code
Write the code to "main_rag_test.py". Overwrites if already exists.

Test it by running "python main_rag_test.py" on terminal.

In [18]:
write_generated_code_to_file(
    code_str=response_code_str,
    prompt_str=task,
    filename="test_sidpy_rag.py",
)

WindowsPath('C:/Users/zwx/Documents/RAG-for-SPM-data-analysis/test_sidpy_rag.py')