# Conversational RAG for STM data analysis

Here we developed a RAG agent to interface with custom data analysis code. 

The user interacts with the agent to analyse STM data. The agent then provides new code that can be used for plotting data.

We use the "test_rag_output.ipynb" to the RAG prediction.

# Imports and Functions

In [9]:
import os
import json
from pathlib import Path
from typing import List

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

import uuid

In [2]:

from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import MessagesPlaceholder
from langchain_classic.chains import create_history_aware_retriever

from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

## Function for persist_db dir

In [3]:

from __future__ import annotations

import re
from pathlib import Path
from typing import Optional

def get_persist_dir(
    base_dir: str | Path,
    basename: str,
    *,
    new_persist: bool = False,
    create: bool = True,
) -> Path:
    """
    Find (or create) a persist directory under base_dir.

    Looks for directories named: f"{basename}_{run_id}" where run_id is an int.
    - If new_persist=False: returns the latest existing folder if found, else basename_0
    - If new_persist=True : returns a new folder with run_id = (latest + 1) or 0 if none exist
    """
    base_path = Path(base_dir).expanduser().resolve()
    base_path.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(basename)}_(\d+)$")

    max_run_id: Optional[int] = None
    for p in base_path.iterdir():
        if not p.is_dir():
            continue
        m = pattern.match(p.name)
        if not m:
            continue
        run_id = int(m.group(1))
        if max_run_id is None or run_id > max_run_id:
            max_run_id = run_id

    if max_run_id is None:
        next_id = 0
    else:
        next_id = (max_run_id + 1) if new_persist else max_run_id

    persist_path = base_path / f"{basename}_{next_id}"

    if create:
        persist_path.mkdir(parents=True, exist_ok=True)

    return persist_path


## RAG chain function

In [4]:
def get_rag_chain(folder_path: str, persist_basename: str, 
                  new_persist: bool = False, chunk_size: int = 1200, chunk_overlap: int = 200, 
                  model: str = "gpt-5", temperature: float = 0.0, api_key: Optional[str] = None):
    

    # -----------------------------
    # 1) Environment variables
    # -----------------------------
    # Make sure OPENAI_API_KEY is set in your environment before running
    if api_key is None:
        os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or ""
    else:
        os.environ["OPENAI_API_KEY"] = api_key

    print("OPENAI_API_KEY starts with:", (os.environ["OPENAI_API_KEY"][:5] + "...") if os.environ["OPENAI_API_KEY"] else "Not Set")

    # Disable LangSmith tracing (prevents 401 errors if you don't use LangSmith)
    os.environ["LANGCHAIN_TRACING_V2"] = "false"
    os.environ["LANGCHAIN_API_KEY"] = ""
    os.environ["LANGCHAIN_PROJECT"] = "rag-code-search"


    # -----------------------------
    # 2) Text splitting / chunking (tuned for code)
    # -----------------------------
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size= chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", " ", ""],
    )


    # -----------------------------
    # 3) Load .py and .ipynb from a folder
    # -----------------------------
    def load_py_file(path: Path) -> Document:
        text = path.read_text(encoding="utf-8", errors="ignore")
        return Document(page_content=text, metadata={"source": str(path), "type": "py"})

    def load_ipynb_file(path: Path) -> Document:
        nb = json.loads(path.read_text(encoding="utf-8", errors="ignore"))

        parts = []
        for cell in nb.get("cells", []):
            cell_type = cell.get("cell_type", "")
            src = cell.get("source", [])
            if isinstance(src, list):
                src = "".join(src)
            src = (src or "").strip()
            if not src:
                continue

            # Keep both markdown + code, but label them
            if cell_type == "code":
                parts.append("# --- notebook code cell ---\n" + src)
            elif cell_type == "markdown":
                parts.append("# --- notebook markdown cell ---\n" + src)

        text = "\n\n".join(parts)
        return Document(page_content=text, metadata={"source": str(path), "type": "ipynb"})

    def load_code_documents(folder_path: str) -> List[Document]:
        folder = Path(folder_path)
        documents: List[Document] = []

        for path in folder.rglob("*"):
            if path.is_dir():
                continue

            suffix = path.suffix.lower()
            if suffix == ".py":
                documents.append(load_py_file(path))
            elif suffix == ".ipynb":
                documents.append(load_ipynb_file(path))
            else:
                continue

        return documents



    documents = load_code_documents(folder_path)
    print(f"Loaded {len(documents)} code documents from {folder_path!r} (.py + .ipynb).")

    splits = text_splitter.split_documents(documents)
    print(f"Split into {len(splits)} chunks.")


    # -----------------------------
    # 4) Embedding + Vector store (persisted)
    # -----------------------------
    #PERSIST_DIR = Path("./chroma_db_code_1")
    PERSIST_DIR = get_persist_dir("./persists/", persist_basename, new_persist=new_persist)


    COLLECTION_NAME = "code_collection"

    embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

    # IMPORTANT: if you run from_documents every time, you rebuild the DB.
    # The logic below loads existing DB if present, otherwise builds it.
    db_file = Path(PERSIST_DIR) / "chroma.sqlite3"
    if db_file.exists():
        vectorstore = Chroma(
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
            embedding_function=embedding_function,
        )
        print(f"Loaded existing vector DB from {PERSIST_DIR!r}.")
    else:
        if not splits:
            raise ValueError("No code chunks found. Check that ./docs contains .py or .ipynb files.")
        vectorstore = Chroma.from_documents(
            documents=splits,
            embedding=embedding_function,
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
        )
        #vectorstore.persist()
        print(f"Created and persisted vector DB at {PERSIST_DIR!r}.")


    # -----------------------------
    # 5) Retriever
    # -----------------------------
    #retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
    retriever = vectorstore.as_retriever(
        search_type="mmr",
        search_kwargs={"k": 6, "fetch_k": 20, "lambda_mult": 0.5}
    )

    # -----------------------------
    # 6) RAG chain that outputs PYTHON CODE ONLY (as a string)
    # -----------------------------
    def docs2str(docs: List[Document]) -> str:
        # Keep source hints so the model can reference repo utilities if they exist
        out = []
        for d in docs:
            src = d.metadata.get("source", "unknown")
            out.append(f"### SOURCE: {src}\n{d.page_content}")
        return "\n\n".join(out)

    template = """You are a coding assistant.

    You must write Python code ONLY (no markdown fences, no explanations).
    Your output must be a single Python script as plain text.

    CONTEXT (repository code snippets):
    {context}

    Task:
    {question}

    Python code:"""

    prompt = ChatPromptTemplate.from_template(template)

    llm = ChatOpenAI(model=model, temperature=temperature)

    rag_chain = (
        {"context": retriever | docs2str, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    return rag_chain











## QA rag chain function

In [None]:

def get_qa_rag_chain(folder_path: str, persist_basename: str, 
                  new_persist: bool = False, chunk_size: int = 1200, chunk_overlap: int = 200, 
                  model: str = "gpt-5", temperature: float = 0.0, api_key: Optional[str] = None):
    

    # -----------------------------
    # 1) Environment variables
    # -----------------------------
    # Make sure OPENAI_API_KEY is set in your environment before running
    if api_key is None:
        os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") or ""
    else:
        os.environ["OPENAI_API_KEY"] = api_key

    print("OPENAI_API_KEY starts with:", (os.environ["OPENAI_API_KEY"][:5] + "...") if os.environ["OPENAI_API_KEY"] else "Not Set")

    # Disable LangSmith tracing (prevents 401 errors if you don't use LangSmith)
    os.environ["LANGCHAIN_TRACING_V2"] = "false"
    os.environ["LANGCHAIN_API_KEY"] = ""
    os.environ["LANGCHAIN_PROJECT"] = "rag-code-search"


    # -----------------------------
    # 2) Text splitting / chunking (tuned for code)
    # -----------------------------
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size= chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", " ", ""],
    )


    # -----------------------------
    # 3) Load .py and .ipynb from a folder
    # -----------------------------
    def load_py_file(path: Path) -> Document:
        text = path.read_text(encoding="utf-8", errors="ignore")
        return Document(page_content=text, metadata={"source": str(path), "type": "py"})

    def load_ipynb_file(path: Path) -> Document:
        nb = json.loads(path.read_text(encoding="utf-8", errors="ignore"))

        parts = []
        for cell in nb.get("cells", []):
            cell_type = cell.get("cell_type", "")
            src = cell.get("source", [])
            if isinstance(src, list):
                src = "".join(src)
            src = (src or "").strip()
            if not src:
                continue

            # Keep both markdown + code, but label them
            if cell_type == "code":
                parts.append("# --- notebook code cell ---\n" + src)
            elif cell_type == "markdown":
                parts.append("# --- notebook markdown cell ---\n" + src)

        text = "\n\n".join(parts)
        return Document(page_content=text, metadata={"source": str(path), "type": "ipynb"})

    def load_code_documents(folder_path: str) -> List[Document]:
        folder = Path(folder_path)
        documents: List[Document] = []

        for path in folder.rglob("*"):
            if path.is_dir():
                continue

            suffix = path.suffix.lower()
            if suffix == ".py":
                documents.append(load_py_file(path))
            elif suffix == ".ipynb":
                documents.append(load_ipynb_file(path))
            else:
                continue

        return documents



    documents = load_code_documents(folder_path)
    print(f"Loaded {len(documents)} code documents from {folder_path!r} (.py + .ipynb).")

    splits = text_splitter.split_documents(documents)
    print(f"Split into {len(splits)} chunks.")


    # -----------------------------
    # 4) Embedding + Vector store (persisted)
    # -----------------------------
    #PERSIST_DIR = Path("./chroma_db_code_1")
    PERSIST_DIR = get_persist_dir("./persists/", persist_basename, new_persist=new_persist)


    COLLECTION_NAME = "code_collection"

    embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

    # IMPORTANT: if you run from_documents every time, you rebuild the DB.
    # The logic below loads existing DB if present, otherwise builds it.
    db_file = Path(PERSIST_DIR) / "chroma.sqlite3"
    if db_file.exists():
        vectorstore = Chroma(
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
            embedding_function=embedding_function,
        )
        print(f"Loaded existing vector DB from {PERSIST_DIR!r}.")
    else:
        if not splits:
            raise ValueError("No code chunks found. Check that ./docs contains .py or .ipynb files.")
        vectorstore = Chroma.from_documents(
            documents=splits,
            embedding=embedding_function,
            collection_name=COLLECTION_NAME,
            persist_directory=PERSIST_DIR,
        )
        #vectorstore.persist()
        print(f"Created and persisted vector DB at {PERSIST_DIR!r}.")


    # -----------------------------
    # 5) Retriever
    # -----------------------------
    #retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
    retriever = vectorstore.as_retriever(
        search_type="mmr",
        search_kwargs={"k": 6, "fetch_k": 20, "lambda_mult": 0.5}
    )

    # -----------------------------
    # 6) RAG chain that outputs PYTHON CODE ONLY (as a string)
    # -----------------------------
    def docs2str(docs: List[Document]) -> str:
        # Keep source hints so the model can reference repo utilities if they exist
        out = []
        for d in docs:
            src = d.metadata.get("source", "unknown")
            out.append(f"### SOURCE: {src}\n{d.page_content}")
        return "\n\n".join(out)

    
    llm = ChatOpenAI(model=model, temperature=temperature)

    # -----------------------------
    # 7 Query prompt (history -> standalone retrieval query)
    # -----------------------------
    contextualize_q_system_prompt = """
    You are a query rewriting assistant for retrieval.

    Given the chat history and the user's latest question:
    - Rewrite the question into a standalone search query that can be used to retrieve relevant code/docs.
    - Resolve references like "it", "that", "the previous one", etc. using the chat history.
    - Preserve exact identifiers (function names, class names, file paths, error messages, config keys).
    - Do NOT answer the question. Only output the rewritten query.
    - If the question is already standalone, return it unchanged.
    """.strip()

    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )


    history_aware_retriever = create_history_aware_retriever(
        llm=llm,
        retriever=retriever,
        prompt=contextualize_q_prompt,
    )


    # -----------------------------
    # 8) Answer prompt (use retrieved docs + chat history)
    # -----------------------------
    qa_system_prompt = """
    You are a helpful assistant.

    Use BOTH:
    - The retrieved context (primary source of truth).
    - The chat history (to understand intent, constraints, and resolve references).

    Rules:
    - Output only the python code as a single script, such the user can run it directly. Do not add ```python``` markers
    - If using explanations or markdown text, comment them.
    - Prefer the retrieved context for factual claims about the code/docs.
    - Use chat history mainly to interpret what the user means and what constraints they set earlier.
    - If the retrieved context does not contain the answer, say so clearly and suggest what to search for next.
    """.strip()

    qa_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", qa_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "Question: {input}\n\nRetrieved context:\n{context}"),
        ]
    )

    question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

    # -----------------------------
    # 9) Full RAG chain
    # -----------------------------
    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)


    return rag_chain

## DB connection functions

In [19]:
import sqlite3
from datetime import datetime

DB_NAME = "rag_chat_data.db"

def get_db_connection():
    conn = sqlite3.connect(DB_NAME)
    conn.row_factory = sqlite3.Row
    return conn

def create_application_logs():
    conn = get_db_connection()
    conn.execute('''CREATE TABLE IF NOT EXISTS application_logs
                    (id INTEGER PRIMARY KEY AUTOINCREMENT,
                     session_id TEXT,
                     user_query TEXT,
                     gpt_response TEXT,
                     model TEXT,
                     created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
    conn.close()


def insert_application_logs(session_id, user_query, gpt_response, model):
    conn = get_db_connection()
    conn.execute('INSERT INTO application_logs (session_id, user_query, gpt_response, model) VALUES (?, ?, ?, ?)',
                 (session_id, user_query, gpt_response, model))
    conn.commit()
    conn.close()


def get_chat_history(session_id):
    conn = get_db_connection()
    cursor = conn.cursor()
    cursor.execute('SELECT user_query, gpt_response FROM application_logs WHERE session_id = ? ORDER BY created_at', (session_id,))
    messages = []
    for row in cursor.fetchall():
        messages.extend([
            {"role": "human", "content": row['user_query']},
            {"role": "ai", "content": row['gpt_response']}
        ])
    conn.close()
    return messages

def get_chat_history(session_id):
    conn = get_db_connection()
    cursor = conn.cursor()
    cursor.execute('SELECT user_query, gpt_response FROM application_logs WHERE session_id = ? ORDER BY created_at', (session_id,))
    messages = []
    for row in cursor.fetchall():
        messages.extend([
        HumanMessage(content=row['user_query']),
        AIMessage(content=row['gpt_response'])
])
    conn.close()
    return messages



# Initialize the database
create_application_logs()

## Write code to py-file

In [20]:

def write_generated_code_to_file(
    code_str: str,
    filename: str = "main_rag_test.py",
    encoding: str = "utf-8",
) -> Path:
    """
    Overwrite `filename` with `code_str`. Creates the file if it doesn't exist.
    Returns the Path to the written file.
    """
    path = Path(filename).expanduser().resolve()
    path.parent.mkdir(parents=True, exist_ok=True)

    # Ensure we're writing a string (some chains return dicts / messages)
    if not isinstance(code_str, str):
        code_str = str(code_str)

    # Add a trailing newline for nicer diffs/editors
    if not code_str.endswith("\n"):
        code_str += "\n"

    path.write_text(code_str, encoding=encoding)
    return path

# Coversational data analysis

Here we analyse the hyperspectral data in "./stm_data_code_sample/3D_hyperspectral_data"

In [21]:
new_persist = False # Set to True to create a new persist directory. Set to True while creating new db.
folder_path = "./stm_data_code_sample/3D_hyperspectral_data"

# Chunk size and overlap hyperparameters
chunk_size = 1200
chunk_overlap = 200


api_key = os.getenv("OPENAI_API_KEY")
persist_basename = "chroma_db_code3d"

model = "gpt-4o-mini"
temperature = 0  #Vary in the range 0-2.

rag_chain = get_qa_rag_chain(
    folder_path=folder_path,
    persist_basename=persist_basename,
    new_persist=new_persist,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    model=model,
    temperature=temperature,
    api_key=api_key,
)



OPENAI_API_KEY starts with: sk-pr...
Loaded 2 code documents from './stm_data_code_sample/3D_hyperspectral_data' (.py + .ipynb).
Split into 7 chunks.
Loaded existing vector DB from WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/persists/chroma_db_code3d_3').


In [23]:

session_id = str(uuid.uuid4())
chat_history = get_chat_history(session_id)
print("chat_history:", chat_history)
task1 = """
Write code to plot the current map at a probe bias of 1.2 V for the cits_data.3ds".
- Import necessary libraries
- The script should:
  1) load ./stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds
  2) extract the Z/height channel (or the most appropriate topography channel)
  3) plot it with matplotlib imshow + colorbar
  4) Do not save the image to file, just show it
"""

answer1 = rag_chain.invoke({"input": task1, "chat_history":chat_history})['answer']
insert_application_logs(session_id, task1, answer1, "gpt-5")
print(f"Human: {task1}")
print(f"\n===== AI GENERATED PYTHON CODE (STRING) =====\n{answer1}\n")

chat_history: []
Human: 
Write code to plot the current map at a probe bias of 1.2 V for the cits_data.3ds".
- Import necessary libraries
- The script should:
  1) load ./stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds
  2) extract the Z/height channel (or the most appropriate topography channel)
  3) plot it with matplotlib imshow + colorbar
  4) Do not save the image to file, just show it


===== AI GENERATED PYTHON CODE (STRING) =====
import numpy as np
import matplotlib.pyplot as plt
from CITS_Class import CITS_Analysis

# Load the CITS data
filepath = './stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds'
data = CITS_Analysis(filepath)

# Set the probe bias
probe_bias = 1.2

# Extract the current map at the specified probe bias
i_2D, V_actual = data.current_map(probe_bias)
print('Nearest Probed bias = ', V_actual)

# Plot the current map
plt.imshow(i_2D, aspect='auto', origin='lower')
plt.colorbar()  # Add a colorbar to the plot
plt.title(f'Current Map at Probe Bias

In [24]:
chat_history = get_chat_history(session_id)
print("chat_history:", chat_history)

chat_history: [HumanMessage(content='\nWrite code to plot the current map at a probe bias of 1.2 V for the cits_data.3ds".\n- Import necessary libraries\n- The script should:\n  1) load ./stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds\n  2) extract the Z/height channel (or the most appropriate topography channel)\n  3) plot it with matplotlib imshow + colorbar\n  4) Do not save the image to file, just show it\n', additional_kwargs={}, response_metadata={}), AIMessage(content="import numpy as np\nimport matplotlib.pyplot as plt\nfrom CITS_Class import CITS_Analysis\n\n# Load the CITS data\nfilepath = './stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds'\ndata = CITS_Analysis(filepath)\n\n# Set the probe bias\nprobe_bias = 1.2\n\n# Extract the current map at the specified probe bias\ni_2D, V_actual = data.current_map(probe_bias)\nprint('Nearest Probed bias = ', V_actual)\n\n# Plot the current map\nplt.imshow(i_2D, aspect='auto', origin='lower')\nplt.colorbar()  # Add a c

In [None]:

task2 = "provide the code, but change the filepath to 'cits_data.3ds'"

answer2 = rag_chain.invoke({"input": task2, "chat_history":chat_history})['answer']
insert_application_logs(session_id, task2, answer2, "gpt-4o-mini")
print(f"Human: {task2}")
print(f"\n===== AI GENERATED PYTHON CODE (STRING) =====\n\n{answer2}\n")

Human: provide the code, but change the filepath to 'cits_data.3ds'

===== AI GENERATED PYTHON CODE (STRING) =====
import numpy as np
import matplotlib.pyplot as plt
from CITS_Class import CITS_Analysis

# Load the CITS data
filepath = 'cits_data.3ds'
data = CITS_Analysis(filepath)

# Set the probe bias
probe_bias = 1.2

# Extract the current map at the specified probe bias
i_2D, V_actual = data.current_map(probe_bias)
print('Nearest Probed bias = ', V_actual)

# Plot the current map
plt.imshow(i_2D, aspect='auto', origin='lower')
plt.colorbar()  # Add a colorbar to the plot
plt.title(f'Current Map at Probe Bias of {probe_bias} V')
plt.xlabel('X Position')
plt.ylabel('Y Position')
plt.show()



In [None]:
task3 = "Great. Now provide the didv_x map at the same probe_bias. Retain the same filepath"
answer3 = rag_chain.invoke({"input": task3, "chat_history":chat_history})['answer']
insert_application_logs(session_id, task3, answer3, "gpt-4o-mini")
print(f"Human: {task3}")
print(f"\n===== AI GENERATED PYTHON CODE (STRING) =====\n\n{answer3}\n")

Human: Great. Now provide the didv_x map at the same probe_bias. Retain the same filepath

===== AI GENERATED PYTHON CODE (STRING) =====

import numpy as np
import matplotlib.pyplot as plt
from CITS_Class import CITS_Analysis

# Load the CITS data
filepath = './stm_data_code_sample/3D_hyperspectral_data/cits_data.3ds'
data = CITS_Analysis(filepath)

# Set the probe bias
probe_bias = 1.2

# Extract the didv_x map at the specified probe bias
didv_x_2D, V_actual = data.didv_x_map(probe_bias)
print('Nearest Probed bias = ', V_actual)

# Plot the didv_x map
plt.imshow(didv_x_2D, aspect='auto', origin='lower')
plt.colorbar()  # Add a colorbar to the plot
plt.title(f'didv_x Map at Probe Bias of {probe_bias} V')
plt.xlabel('X Position')
plt.ylabel('Y Position')
plt.show()



## Test the code
Write the code to "main_rag_test.py". Overwrites if already exists.

Test it by running "python main_rag_test.py" on terminal.

In [None]:
write_generated_code_to_file(
    code_str=answer2,
    filename="main_rag_test.py",
)

WindowsPath('C:/Users/ggn/1_py_scripts/RAG/RAG_data_analysis_hackathon_2025/main_rag_test.py')