# RAG for STM data analysis

Here we developed a RAG agent to interface with custom data analysis code. 

The user interacts with the agent to analyse STM data. The agent then provides new code that can be used for plotting data.

We use the "test_rag_output.ipynb" to the RAG prediction.

# Imports

In [1]:
import os
import json
from pathlib import Path
from typing import List

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser



  from .autonotebook import tqdm as notebook_tqdm


## Function for persist_dir

In [2]:

from __future__ import annotations

import re
from pathlib import Path
from typing import Optional

def get_persist_dir(
    base_dir: str | Path,
    basename: str,
    *,
    new_persist: bool = False,
    create: bool = True,
) -> Path:
    """
    Find (or create) a persist directory under base_dir.

    Looks for directories named: f"{basename}_{run_id}" where run_id is an int.
    - If new_persist=False: returns the latest existing folder if found, else basename_0
    - If new_persist=True : returns a new folder with run_id = (latest + 1) or 0 if none exist
    """
    base_path = Path(base_dir).expanduser().resolve()
    base_path.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(basename)}_(\d+)$")

    max_run_id: Optional[int] = None
    for p in base_path.iterdir():
        if not p.is_dir():
            continue
        m = pattern.match(p.name)
        if not m:
            continue
        run_id = int(m.group(1))
        if max_run_id is None or run_id > max_run_id:
            max_run_id = run_id

    if max_run_id is None:
        next_id = 0
    else:
        next_id = (max_run_id + 1) if new_persist else max_run_id

    persist_path = base_path / f"{basename}_{next_id}"

    if create:
        persist_path.mkdir(parents=True, exist_ok=True)

    return persist_path


# 2D image data

In [10]:
new_persist = True # Set to True to create a new persist directory
folder_path = "./stm_data_code_sample/2D_image_data"

# -----------------------------
# 1) Environment variables
# -----------------------------
# Make sure OPENAI_API_KEY is set in your environment before running
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or ""
print("OPENAI_API_KEY starts with:", (os.environ["OPENAI_API_KEY"][:5] + "...") if os.environ["OPENAI_API_KEY"] else "Not Set")

# Disable LangSmith tracing (prevents 401 errors if you don't use LangSmith)
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_API_KEY"] = ""
os.environ["LANGCHAIN_PROJECT"] = "rag-code-search"


# -----------------------------
# 2) Text splitting / chunking (tuned for code)
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size= 1200,
    chunk_overlap=200,
    length_function=len,
    separators=["\n\n", "\n", " ", ""],
)


# -----------------------------
# 3) Load .py and .ipynb from a folder
# -----------------------------
def load_py_file(path: Path) -> Document:
    text = path.read_text(encoding="utf-8", errors="ignore")
    return Document(page_content=text, metadata={"source": str(path), "type": "py"})

def load_ipynb_file(path: Path) -> Document:
    nb = json.loads(path.read_text(encoding="utf-8", errors="ignore"))

    parts = []
    for cell in nb.get("cells", []):
        cell_type = cell.get("cell_type", "")
        src = cell.get("source", [])
        if isinstance(src, list):
            src = "".join(src)
        src = (src or "").strip()
        if not src:
            continue

        # Keep both markdown + code, but label them
        if cell_type == "code":
            parts.append("# --- notebook code cell ---\n" + src)
        elif cell_type == "markdown":
            parts.append("# --- notebook markdown cell ---\n" + src)

    text = "\n\n".join(parts)
    return Document(page_content=text, metadata={"source": str(path), "type": "ipynb"})

def load_code_documents(folder_path: str) -> List[Document]:
    folder = Path(folder_path)
    documents: List[Document] = []

    for path in folder.rglob("*"):
        if path.is_dir():
            continue

        suffix = path.suffix.lower()
        if suffix == ".py":
            documents.append(load_py_file(path))
        elif suffix == ".ipynb":
            documents.append(load_ipynb_file(path))
        else:
            continue

    return documents



documents = load_code_documents(folder_path)
print(f"Loaded {len(documents)} code documents from {folder_path!r} (.py + .ipynb).")

splits = text_splitter.split_documents(documents)
print(f"Split into {len(splits)} chunks.")


# -----------------------------
# 4) Embedding + Vector store (persisted)
# -----------------------------
#PERSIST_DIR = Path("./chroma_db_code_1")
PERSIST_DIR = get_persist_dir("./persists/", "chroma_db_code", new_persist=new_persist)




COLLECTION_NAME = "code_collection"

embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

# IMPORTANT: if you run from_documents every time, you rebuild the DB.
# The logic below loads existing DB if present, otherwise builds it.
db_file = Path(PERSIST_DIR) / "chroma.sqlite3"
if db_file.exists():
    vectorstore = Chroma(
        collection_name=COLLECTION_NAME,
        persist_directory=PERSIST_DIR,
        embedding_function=embedding_function,
    )
    print(f"Loaded existing vector DB from {PERSIST_DIR!r}.")
else:
    if not splits:
        raise ValueError("No code chunks found. Check that ./docs contains .py or .ipynb files.")
    vectorstore = Chroma.from_documents(
        documents=splits,
        embedding=embedding_function,
        collection_name=COLLECTION_NAME,
        persist_directory=PERSIST_DIR,
    )
    #vectorstore.persist()
    print(f"Created and persisted vector DB at {PERSIST_DIR!r}.")


# -----------------------------
# 5) Retriever
# -----------------------------
#retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
retriever = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={"k": 6, "fetch_k": 20, "lambda_mult": 0.5}
)

# -----------------------------
# 6) RAG chain that outputs PYTHON CODE ONLY (as a string)
# -----------------------------
def docs2str(docs: List[Document]) -> str:
    # Keep source hints so the model can reference repo utilities if they exist
    out = []
    for d in docs:
        src = d.metadata.get("source", "unknown")
        out.append(f"### SOURCE: {src}\n{d.page_content}")
    return "\n\n".join(out)

template = """You are a coding assistant.

You must write Python code ONLY (no markdown fences, no explanations).
Your output must be a single Python script as plain text.

Task:
- Write code to plot the Z-height image of the file "image.sxm" located in "./docs".
- If a library is needed, include a short install hint as a Python comment (e.g. # pip install ...).
- Prefer using any existing loader/utility patterns found in the CONTEXT below.
- The script should:
  1) load ./docs/image.sxm
  2) extract the Z/height channel (or the most appropriate topography channel)
  3) plot it with matplotlib imshow + colorbar
  4) save the figure to ./z_height.png and also show it

CONTEXT (repository code snippets):
{context}

Question: {question}

Python code:"""

prompt = ChatPromptTemplate.from_template(template)

llm = ChatOpenAI(model="gpt-5", temperature= 0)

rag_chain = (
    {"context": retriever | docs2str, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)



OPENAI_API_KEY starts with: sk-pr...
Loaded 3 code documents from './stm_data_code_sample/2D_image_data' (.py + .ipynb).
Split into 6 chunks.
Created and persisted vector DB at WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/persists/chroma_db_code_7').


## RAG Query

In [12]:
# -----------------------------
# 7) Ask the question
# -----------------------------


question = "use the file at './stm_data_code_sample/2D_image_data/image.sxm' to plot the Z-height image of the image.sxm file. Don't save the figure"
response_code_str = rag_chain.invoke(question)

print("\n===== GENERATED PYTHON CODE (STRING) =====\n")
print(response_code_str)


===== GENERATED PYTHON CODE (STRING) =====

# pip install stmpy scikit-learn scipy matplotlib
import os
import sys
import matplotlib.pyplot as plt

# Ensure we can import stm_utils from the repository structure
repo_rel_utils_dir = os.path.join('.', 'stm_data_code_sample', '2D_image_data')
if repo_rel_utils_dir not in sys.path:
    sys.path.insert(0, repo_rel_utils_dir)

try:
    from stm_utils import Sxm_Image
except Exception as e:
    raise ImportError("Failed to import Sxm_Image from stm_utils. Ensure the path is correct and dependencies are installed. "
                      "Try: pip install stmpy scikit-learn scipy matplotlib") from e

def select_height_channel(channels):
    # Preference order and common aliases for topography/height channels
    preferred = [
        'Z_Fwd', 'Z_Bkd', 'Z_Forward', 'Z_Backward',
        'Height', 'Topography', 'Topo', 'Z', 'ZSensor'
    ]
    ch_lower_map = {ch.lower(): ch for ch in channels}

    # Exact preferred names (case-insensitive)
   

# 3D hyperspectral data analysis

In [15]:
new_persist = True # Set to True to create a new persist directory
folder_path = "./stm_data_code_sample/3D_hyperspectral_data"

# -----------------------------
# 1) Environment variables
# -----------------------------
# Make sure OPENAI_API_KEY is set in your environment before running
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or ""
print("OPENAI_API_KEY starts with:", (os.environ["OPENAI_API_KEY"][:5] + "...") if os.environ["OPENAI_API_KEY"] else "Not Set")

# Disable LangSmith tracing (prevents 401 errors if you don't use LangSmith)
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_API_KEY"] = ""
os.environ["LANGCHAIN_PROJECT"] = "rag-code-search"


# -----------------------------
# 2) Text splitting / chunking (tuned for code)
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size= 800,
    chunk_overlap=200,
    length_function=len,
    separators=["\n\n", "\n", " ", ""],
)


# -----------------------------
# 3) Load .py and .ipynb from a folder
# -----------------------------
def load_py_file(path: Path) -> Document:
    text = path.read_text(encoding="utf-8", errors="ignore")
    return Document(page_content=text, metadata={"source": str(path), "type": "py"})

def load_ipynb_file(path: Path) -> Document:
    nb = json.loads(path.read_text(encoding="utf-8", errors="ignore"))

    parts = []
    for cell in nb.get("cells", []):
        cell_type = cell.get("cell_type", "")
        src = cell.get("source", [])
        if isinstance(src, list):
            src = "".join(src)
        src = (src or "").strip()
        if not src:
            continue

        # Keep both markdown + code, but label them
        if cell_type == "code":
            parts.append("# --- notebook code cell ---\n" + src)
        elif cell_type == "markdown":
            parts.append("# --- notebook markdown cell ---\n" + src)

    text = "\n\n".join(parts)
    return Document(page_content=text, metadata={"source": str(path), "type": "ipynb"})

def load_code_documents(folder_path: str) -> List[Document]:
    folder = Path(folder_path)
    documents: List[Document] = []

    for path in folder.rglob("*"):
        if path.is_dir():
            continue

        suffix = path.suffix.lower()
        if suffix == ".py":
            documents.append(load_py_file(path))
        elif suffix == ".ipynb":
            documents.append(load_ipynb_file(path))
        else:
            continue

    return documents



documents = load_code_documents(folder_path)
print(f"Loaded {len(documents)} code documents from {folder_path!r} (.py + .ipynb).")

splits = text_splitter.split_documents(documents)
print(f"Split into {len(splits)} chunks.")


# -----------------------------
# 4) Embedding + Vector store (persisted)
# -----------------------------
#PERSIST_DIR = Path("./chroma_db_code_1")
PERSIST_DIR = get_persist_dir("./persists/", "chroma_db_code3d", new_persist=new_persist)




COLLECTION_NAME = "code_collection"

embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

# IMPORTANT: if you run from_documents every time, you rebuild the DB.
# The logic below loads existing DB if present, otherwise builds it.
db_file = Path(PERSIST_DIR) / "chroma.sqlite3"
if db_file.exists():
    vectorstore = Chroma(
        collection_name=COLLECTION_NAME,
        persist_directory=PERSIST_DIR,
        embedding_function=embedding_function,
    )
    print(f"Loaded existing vector DB from {PERSIST_DIR!r}.")
else:
    if not splits:
        raise ValueError("No code chunks found. Check that ./docs contains .py or .ipynb files.")
    vectorstore = Chroma.from_documents(
        documents=splits,
        embedding=embedding_function,
        collection_name=COLLECTION_NAME,
        persist_directory=PERSIST_DIR,
    )
    #vectorstore.persist()
    print(f"Created and persisted vector DB at {PERSIST_DIR!r}.")


# -----------------------------
# 5) Retriever
# -----------------------------
#retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
retriever = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={"k": 6, "fetch_k": 20, "lambda_mult": 0.5}
)

# -----------------------------
# 6) RAG chain that outputs PYTHON CODE ONLY (as a string)
# -----------------------------
def docs2str(docs: List[Document]) -> str:
    # Keep source hints so the model can reference repo utilities if they exist
    out = []
    for d in docs:
        src = d.metadata.get("source", "unknown")
        out.append(f"### SOURCE: {src}\n{d.page_content}")
    return "\n\n".join(out)

template = """You are a coding assistant.

You must write Python code ONLY (no markdown fences, no explanations).
Your output must be a single Python script as plain text.

Task:
- Write code to plot the Z-height image of the file "image.sxm" located in "./docs".
- If a library is needed, include a short install hint as a Python comment (e.g. # pip install ...).
- Prefer using any existing loader/utility patterns found in the CONTEXT below.
- The script should:
  1) load ./docs/image.sxm
  2) extract the Z/height channel (or the most appropriate topography channel)
  3) plot it with matplotlib imshow + colorbar
  4) save the figure to ./z_height.png and also show it

CONTEXT (repository code snippets):
{context}

Question: {question}

Python code:"""

prompt = ChatPromptTemplate.from_template(template)

llm = ChatOpenAI(model="gpt-5", temperature= 0)

rag_chain = (
    {"context": retriever | docs2str, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)


OPENAI_API_KEY starts with: sk-pr...
Loaded 2 code documents from './stm_data_code_sample/3D_hyperspectral_data' (.py + .ipynb).
Split into 10 chunks.
Created and persisted vector DB at WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/persists/chroma_db_code3d_1').


In [16]:
# -----------------------------
# 7) Ask the question
# -----------------------------


question = "use the file at './stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds' to plot the current map image at a bias of 1.2 V. Don't save the figure"
response_code_str = rag_chain.invoke(question)

print("\n===== GENERATED PYTHON CODE (STRING) =====\n")
print(response_code_str)


===== GENERATED PYTHON CODE (STRING) =====

import os
import sys
import numpy as np
import matplotlib.pyplot as plt

# pip install hoffmanstmpy

try:
    import stmpy
except ImportError as e:
    raise ImportError("stmpy is required. Install via: pip install hoffmanstmpy") from e

def load_cits_class_module(module_path):
    import importlib.util
    spec = importlib.util.spec_from_file_location("CITS_Class", module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot load module from {module_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    # Ensure stmpy is available to the module if it relies on a global reference
    setattr(module, "stmpy", stmpy)
    # Ensure numpy if needed
    if not hasattr(module, "np"):
        setattr(module, "np", np)
    return module

def main():
    repo_root = os.path.abspath(".")
    data_path = os.path.join(repo_root, "stm_data_code_sample", "3D_hyperspectral_data", "ci