# Overview
This notebook explores how to use OpenAI's <a href="https://platform.openai.com/docs/assistants/overview">assistants</a> model to perform RAG using documents imported into a vector store. I wanted to put this in a notebook since the API documentation is incomplete and there are gaps in the APIs themselves (e.g., there is no way to enumerate threads).

Here's an outline of the steps:
1. Create a vector store and upload documents
2. Create an assistant that uses the vector store
3. Get a conversation thread
4. Execture a chat on that thread using the assistant

In each step, we first check if the objects have already been created before doing so again. This added some complexity for chat threads as OpenAI has no APIs to enumerate threads.

In [None]:
# Imports needed for the rest of the script
from sys import modules
if "openai" not in modules:
    print ("Installing OpenAI")
    %pip install openai
import os
from pathlib import Path
import json
from openai import OpenAI

In [None]:
# Some global configuration variables
CLIENT = OpenAI()                                   # The OpenAI client used in the rest of the notebook
ASSISTANT_NAME = "RAG assistant v 0.0.1"            # Used to identify the assistant, as well as other objects related to it
DIR = "documents"                                   # Directory where the RAG documents are stored
SUPPORTED_EXTENSIONS = [".pdf", ".txt", ".docx"]    # Supported file extensions

In [None]:
# Step 1: Create a vector store and upload documents

def get_vector_store(client, assistant_name):
  """Create a vector store if it doesn't already exist"""
  vector_store_name = assistant_name + " vector store"

  # If it was previously created, return it
  vector_stores = client.beta.vector_stores.list()
  for vector_store in vector_stores:
    if vector_store.name == vector_store_name:
      print ("Using existing vector store")
      return vector_store
  
  # Else create it
  print ("Creating new vector store")
  return client.beta.vector_stores.create(name=vector_store_name)

def upload_document(client, vector_store, filepath):
  """Upload a document to the vector store if it isn't already uploaded"""

  # Check if the file is already uploaded
  for existing_file in client.files.list(purpose="assistants"):
    if existing_file.filename == Path(filepath).name:
      print (f"{filepath} already uploaded")
      return

  # Else upload it
  print (f"Uploading {filepath} and linking to the vector store")
  file = client.files.create(file=Path(filepath), purpose="assistants")
  client.beta.vector_stores.files.create_and_poll(file_id=file.id, vector_store_id=vector_store.id)

def upload_documents (client, vector_store, dir, supported_extensions):
  """Upload all documents in a directory to the vector store"""
  filepaths = [os.path.join(dir, filename) for filename in os.listdir(dir) if Path(filename).suffix in supported_extensions]
  for filepath in filepaths:
    upload_document(client, vector_store, Path(filepath))

vector_store = get_vector_store(CLIENT, ASSISTANT_NAME)
upload_documents(CLIENT, vector_store, DIR, SUPPORTED_EXTENSIONS)

In [None]:
# Step 2: Create an assistant that uses the vector store
def get_assistant(client, assistant_name, vector_store):
  """Create the RAG assistant if it doesn't already exist"""

  # If it was previously created, return it
  assistants = client.beta.assistants.list()
  for assistant in assistants:
    if assistant.name == assistant_name:
      print ("Using existing assistant")
      return assistant

  # Else create it
  print ("Creating new assistant")
  return client.beta.assistants.create(
    name=assistant_name,
    instructions="You are an assistant that will help answer questions about documents in a document library. " + \
      "You will always include references to both the document and section you used to answer a given query.",
    tools=[{"type": "file_search"}],
    tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
    model="gpt-4o",
  )


assistant = get_assistant(CLIENT, ASSISTANT_NAME, vector_store)

In [None]:
# Step 3: Create a thread

# OpenAI currently has no way to enumerate threads, so this is a hack to add local persistence
# It also has no user-friendly way to identify a thread; only thread ID. So we create a wrapper class that stores name in metadata

class OpenAIThreads:
    _class_initialized = False      # Class initialization flag
    _thread_ids = {}                # Dictionary to store mapping of thread names to IDs
    _threads = {}                   # Dictionary to store mapping of thread IDs to objects
    _file = "openai_threads.json"   # File to store persisted thread IDs and names
    _client = OpenAI()              # OpenAI client

    def _initialize_class():
        """Initialize the class if it hasn't been done yet"""
        if not OpenAIThreads._class_initialized:
            print ("Initializing OpenAIThread class")
            if os.path.exists(OpenAIThreads._file):
                with open(OpenAIThreads._file) as f:
                    OpenAIThreads._thread_ids = json.load(f)
            OpenAIThreads._class_initialized = True

    @classmethod
    def get(cls, name):
        """Either load an existing thread with the given name, or create a new one"""

        # Initialize the class if it hasn't been done yet
        cls._initialize_class()

        # If the thread already exists, load it
        if name in cls._thread_ids:
            print (f"Loading existing thread named '{name}'")
            id = cls._thread_ids[name]
            if id in cls._threads:
                return cls._threads[id]
            thread = cls._client.beta.threads.retrieve(id)
            cls._threads[id] = thread
            return thread
        
        # Else create a new thread
        thread = cls._client.beta.threads.create(metadata={"name": name})
        cls._thread_ids[name] = thread.id
        cls._threads[thread.id] = thread
        with open(cls._file, "w") as f:
            json.dump(cls._thread_ids, f)
        return thread
    
    @classmethod
    def delete(cls, name):
        """Delete a thread with the given name"""

        # Initialize the class if it hasn't been done yet
        cls._initialize_class()
    
        if not name in cls._thread_ids:
            print (f"No thread named '{name}'")
            return
    
        id = cls._thread_ids[name]
        try:
            cls._client.beta.threads.delete(id)
        except Exception as e:
            print (f"Thread '{name}' not found in OpenAI")
        cls._thread_ids.pop(name, None)
        cls._threads.pop(id, None)
        with open(cls._file, "w") as f:
            json.dump(cls._thread_ids, f)

    @classmethod
    def list_names(cls):
        return cls._thread_ids.keys()

def test_threads():
    """Test the threads class"""
    while True:
        print ("Threads:")
        print (OpenAIThreads.list_names())
        thread_name = input("Enter a thread name to add or remove, or press Enter to continue: ")
        if thread_name == "":
            break
        if thread_name in OpenAIThreads.list_names():
            OpenAIThreads.delete(thread_name)
        else:
            OpenAIThreads.get(thread_name)


In [None]:
# Step 4: Start a chat session on the thread

# Get a fresh thread each time, deleting the old one if it exists
OpenAIThreads.delete(ASSISTANT_NAME + " thread")
thread = OpenAIThreads.get(ASSISTANT_NAME + " thread")

# Create a streaming event handler to interact with the assistant
from typing_extensions import override
from openai import AssistantEventHandler
class EventHandler(AssistantEventHandler):
  @override
  def on_text_created(self, text) -> None:
    print(f"\nassistant > ", end="", flush=True)
      
  @override
  def on_text_delta(self, delta, snapshot):
    print(delta.value, end="", flush=True)
      
  #def on_tool_call_created(self, tool_call):
  #  print(f"\nassistant > {tool_call.type}\n", flush=True)
  
  def on_tool_call_delta(self, delta, snapshot):
    if delta.type == 'code_interpreter':
      if delta.code_interpreter.input:
        print(delta.code_interpreter.input, end="", flush=True)
      if delta.code_interpreter.outputs:
        print(f"\n\noutput >", flush=True)
        for output in delta.code_interpreter.outputs:
          if output.type == "logs":
            print(f"\n{output.logs}", flush=True)

# Start a chat session
while True:
    query = input("Enter a query, or press Enter to exit: ").strip()
    if query == "":
        break
    message = CLIENT.beta.threads.messages.create(          # Add their query to the thread
        thread_id=thread.id,
        role="user",
        content=query
    )
    with CLIENT.beta.threads.runs.stream(                   # Run the assistant on the thread
        thread_id=thread.id,
        assistant_id=assistant.id,
        event_handler=EventHandler()
    ) as stream:
        stream.until_done()