# RAG for STM data analysis

Here we developed a RAG agent to interface with custom data analysis code. 

The user interacts with the agent to analyse STM data. The agent then provides new code that can be used for plotting data.

We use the "test_rag_output.ipynb" to the RAG prediction.

# Imports and Functions

In [1]:
import os
import json
from pathlib import Path
from typing import List

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser



  from .autonotebook import tqdm as notebook_tqdm


## Function for persist_dir

In [19]:

from __future__ import annotations

import re
from pathlib import Path
from typing import Optional

def get_persist_dir(
    base_dir: str | Path,
    basename: str,
    *,
    new_persist: bool = False,
    create: bool = True,
) -> Path:
    """
    Find (or create) a persist directory under base_dir.

    Looks for directories named: f"{basename}_{run_id}" where run_id is an int.
    - If new_persist=False: returns the latest existing folder if found, else basename_0
    - If new_persist=True : returns a new folder with run_id = (latest + 1) or 0 if none exist
    """
    base_path = Path(base_dir).expanduser().resolve()
    base_path.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(basename)}_(\d+)$")

    max_run_id: Optional[int] = None
    for p in base_path.iterdir():
        if not p.is_dir():
            continue
        m = pattern.match(p.name)
        if not m:
            continue
        run_id = int(m.group(1))
        if max_run_id is None or run_id > max_run_id:
            max_run_id = run_id

    if max_run_id is None:
        next_id = 0
    else:
        next_id = (max_run_id + 1) if new_persist else max_run_id

    persist_path = base_path / f"{basename}_{next_id}"

    if create:
        persist_path.mkdir(parents=True, exist_ok=True)

    return persist_path


## RAG chain function

In [26]:
def get_rag_chain(folder_path: str, persist_basename: str, 
                  new_persist: bool = False, chunk_size: int = 1200, chunk_overlap: int = 200, 
                  model: str = "gpt-5", temperature: float = 0.0, api_key: Optional[str] = None):
    

    # -----------------------------
    # 1) Environment variables
    # -----------------------------
    # Make sure OPENAI_API_KEY is set in your environment before running
    if api_key is None:
        os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or ""
    else:
        os.environ["OPENAI_API_KEY"] = api_key

    print("OPENAI_API_KEY starts with:", (os.environ["OPENAI_API_KEY"][:5] + "...") if os.environ["OPENAI_API_KEY"] else "Not Set")

    # Disable LangSmith tracing (prevents 401 errors if you don't use LangSmith)
    os.environ["LANGCHAIN_TRACING_V2"] = "false"
    os.environ["LANGCHAIN_API_KEY"] = ""
    os.environ["LANGCHAIN_PROJECT"] = "rag-code-search"


    # -----------------------------
    # 2) Text splitting / chunking (tuned for code)
    # -----------------------------
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size= chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", " ", ""],
    )


    # -----------------------------
    # 3) Load .py and .ipynb from a folder
    # -----------------------------
    def load_py_file(path: Path) -> Document:
        text = path.read_text(encoding="utf-8", errors="ignore")
        return Document(page_content=text, metadata={"source": str(path), "type": "py"})

    def load_ipynb_file(path: Path) -> Document:
        nb = json.loads(path.read_text(encoding="utf-8", errors="ignore"))

        parts = []
        for cell in nb.get("cells", []):
            cell_type = cell.get("cell_type", "")
            src = cell.get("source", [])
            if isinstance(src, list):
                src = "".join(src)
            src = (src or "").strip()
            if not src:
                continue

            # Keep both markdown + code, but label them
            if cell_type == "code":
                parts.append("# --- notebook code cell ---\n" + src)
            elif cell_type == "markdown":
                parts.append("# --- notebook markdown cell ---\n" + src)

        text = "\n\n".join(parts)
        return Document(page_content=text, metadata={"source": str(path), "type": "ipynb"})

    def load_code_documents(folder_path: str) -> List[Document]:
        folder = Path(folder_path)
        documents: List[Document] = []

        for path in folder.rglob("*"):
            if path.is_dir():
                continue

            suffix = path.suffix.lower()
            if suffix == ".py":
                documents.append(load_py_file(path))
            elif suffix == ".ipynb":
                documents.append(load_ipynb_file(path))
            else:
                continue

        return documents



    documents = load_code_documents(folder_path)
    print(f"Loaded {len(documents)} code documents from {folder_path!r} (.py + .ipynb).")

    splits = text_splitter.split_documents(documents)
    print(f"Split into {len(splits)} chunks.")


    # -----------------------------
    # 4) Embedding + Vector store (persisted)
    # -----------------------------
    #PERSIST_DIR = Path("./chroma_db_code_1")
    PERSIST_DIR = get_persist_dir("./persists/", persist_basename, new_persist=new_persist)


    COLLECTION_NAME = "code_collection"

    embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

    # IMPORTANT: if you run from_documents every time, you rebuild the DB.
    # The logic below loads existing DB if present, otherwise builds it.
    db_file = Path(PERSIST_DIR) / "chroma.sqlite3"
    if db_file.exists():
        vectorstore = Chroma(
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
            embedding_function=embedding_function,
        )
        print(f"Loaded existing vector DB from {PERSIST_DIR!r}.")
    else:
        if not splits:
            raise ValueError("No code chunks found. Check that ./docs contains .py or .ipynb files.")
        vectorstore = Chroma.from_documents(
            documents=splits,
            embedding=embedding_function,
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
        )
        #vectorstore.persist()
        print(f"Created and persisted vector DB at {PERSIST_DIR!r}.")


    # -----------------------------
    # 5) Retriever
    # -----------------------------
    #retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
    retriever = vectorstore.as_retriever(
        search_type="mmr",
        search_kwargs={"k": 6, "fetch_k": 20, "lambda_mult": 0.5}
    )

    # -----------------------------
    # 6) RAG chain that outputs PYTHON CODE ONLY (as a string)
    # -----------------------------
    def docs2str(docs: List[Document]) -> str:
        # Keep source hints so the model can reference repo utilities if they exist
        out = []
        for d in docs:
            src = d.metadata.get("source", "unknown")
            out.append(f"### SOURCE: {src}\n{d.page_content}")
        return "\n\n".join(out)

    template = """You are a coding assistant.

    You must write Python code ONLY (no markdown fences, no explanations).
    Your output must be a single Python script as plain text.

    CONTEXT (repository code snippets):
    {context}

    Task:
    {question}

    Python code:"""

    prompt = ChatPromptTemplate.from_template(template)

    llm = ChatOpenAI(model=model, temperature=temperature)

    rag_chain = (
        {"context": retriever | docs2str, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    return rag_chain


## Write code to py-file

In [21]:

def write_generated_code_to_file(
    code_str: str,
    filename: str = "main_rag_test.py",
    encoding: str = "utf-8",
) -> Path:
    """
    Overwrite `filename` with `code_str`. Creates the file if it doesn't exist.
    Returns the Path to the written file.
    """
    path = Path(filename).expanduser().resolve()
    path.parent.mkdir(parents=True, exist_ok=True)

    # Ensure we're writing a string (some chains return dicts / messages)
    if not isinstance(code_str, str):
        code_str = str(code_str)

    # Add a trailing newline for nicer diffs/editors
    if not code_str.endswith("\n"):
        code_str += "\n"

    path.write_text(code_str, encoding=encoding)
    return path

# 2D image data

In [None]:
new_persist = False # Set to True to create a new persist directory. Set to True while creating new db.
folder_path = "./stm_data_code_sample/2D_image_data"

chunk_size = 1200
chunk_overlap = 200


api_key = os.getenv("OPENAI_API_KEY")
persist_basename = "chroma_db_code"

model = "gpt-5"
temperature = 0

task = """
Write code to plot the Z-height image of the file "image.sxm" located in "./stm_data_code_sample/2D_image_data".
- Import necessary libraries
- The script should:
  1) load ./stm_data_code_sample/2D_image_data/image.sxm
  2) import stm_utils from ./stm_data_code_sample/2D_image_data
  2) extract the Z/height channel (or the most appropriate topography channel)
  3) plot it with matplotlib imshow + colorbar
  4) Do not save the image to file, just show it
"""

In [18]:
rag_chain = get_rag_chain(
    folder_path=folder_path,
    persist_basename=persist_basename,
    new_persist=new_persist,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    model=model,
    temperature=temperature,
    api_key=api_key,
)

response_code_str = rag_chain.invoke(task)

print("\n===== GENERATED PYTHON CODE (STRING) =====\n")
print(response_code_str)


OPENAI_API_KEY starts with: sk-pr...
Loaded 3 code documents from './stm_data_code_sample/2D_image_data' (.py + .ipynb).
Split into 6 chunks.
Loaded existing vector DB from WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/persists/chroma_db_code_10').

===== GENERATED PYTHON CODE (STRING) =====

import os
import sys
import numpy as np
import matplotlib.pyplot as plt

def get_script_dir():
    try:
        return os.path.dirname(os.path.abspath(__file__))
    except NameError:
        return os.getcwd()

script_dir = get_script_dir()
data_dir = os.path.join(script_dir, "stm_data_code_sample", "2D_image_data")
sys.path.insert(0, data_dir)

from stm_utils import Sxm_Image

filepath = os.path.join(data_dir, "image.sxm")
im = Sxm_Image(filepath)

channels = im.get_channels()

def channel_score(ch):
    name = ch.lower()
    if name.startswith('z'):
        base = 0
    elif 'topo' in name or 'topography' in name:
        base = 10
    elif 'height' in name:
      

## Test the code
Write the code to "main_rag_test.py". Overwrites if already exists.

Test it by running "python main_rag_test.py" on terminal.

In [22]:
write_generated_code_to_file(
    code_str=response_code_str,
    filename="main_rag_test.py",
)


WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/main_rag_test.py')

# 3D hyperspectral data analysis

In [None]:
new_persist = True # Set to True to create a new persist directory. Set to True while creating new db.
folder_path = "./stm_data_code_sample/3D_hyperspectral_data"

# Chunk size and overlap hyperparameters
chunk_size = 1200
chunk_overlap = 200


api_key = os.getenv("OPENAI_API_KEY")
persist_basename = "chroma_db_code3d"

model = "gpt-5"
temperature = 0  #Vary in the range 0-2.

task = """
Write code to plot the current map at a probe bias of 1.2 V for the cits_data.3ds".
- Import necessary libraries
- The script should:
  1) load ./stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds
  2) extract the Z/height channel (or the most appropriate topography channel)
  3) plot it with matplotlib imshow + colorbar
  4) Do not save the image to file, just show it
"""

In [24]:
rag_chain = get_rag_chain(
    folder_path=folder_path,
    persist_basename=persist_basename,
    new_persist=new_persist,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    model=model,
    temperature=temperature,
    api_key=api_key,
)

response_code_str = rag_chain.invoke(task)

print("\n===== GENERATED PYTHON CODE (STRING) =====\n")
print(response_code_str)

OPENAI_API_KEY starts with: sk-pr...
Loaded 2 code documents from './stm_data_code_sample/3D_hyperspectral_data' (.py + .ipynb).
Split into 7 chunks.
Created and persisted vector DB at WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/persists/chroma_db_code3d_2').

===== GENERATED PYTHON CODE (STRING) =====

import os
import numpy as np
import matplotlib.pyplot as plt

import stmpy

def get_topography(smd):
    # Try common attribute names
    for name in ['topo', 'topography', 'z', 'Z', 'Topo', 'Topography', 'height', 'Height']:
        if hasattr(smd, name):
            topo = getattr(smd, name)
            topo = np.asarray(topo)
            if topo.ndim == 2:
                return topo
            if topo.ndim >= 3:
                return topo[0]
    # Try within dict-like attributes
    for obj in [getattr(smd, 'signals', None), getattr(smd, 'grid', None), getattr(smd, '__dict__', None)]:
        if isinstance(obj, dict):
            for k, v in obj.ite

## Test the code
Write the code to "main_rag_test.py". Overwrites if already exists.

Test it by running "python main_rag_test.py" on terminal.

In [25]:
write_generated_code_to_file(
    code_str=response_code_str,
    filename="main_rag_test.py",
)

WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/main_rag_test.py')