<a href="https://colab.research.google.com/github/kderekmackenzie/royalty-qna/blob/main/royalty_qna.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ---- 1. Install Dependencies ----
!pip install -q openai chromadb PyPDF2 tiktoken

# ---- 2. Upload PDF File ----
from google.colab import files
uploaded = files.upload()

# Use the uploaded file name as PDF path
PDF_PATH = next(iter(uploaded.keys()))

# ---- 3. Set Your OpenAI API Key (Securely) ----
import os
import getpass
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key:")

# ---- 4. OpenAI v1 SDK Setup ----
from openai import OpenAI
client = OpenAI()

# ---- 5. Extract Text from PDF ----
import PyPDF2

def extract_text_from_pdf(pdf_path):
    text = ""
    with open(pdf_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + "\n"
    return text

# ---- 6. Chunk the Text into Embeddable Segments ----
import tiktoken

def chunk_text(text, max_tokens=500):
    tokenizer = tiktoken.get_encoding("cl100k_base")
    words = text.split()
    chunks, current_chunk = [], []

    for word in words:
        current_chunk.append(word)
        if len(tokenizer.encode(" ".join(current_chunk))) >= max_tokens:
            chunks.append(" ".join(current_chunk))
            current_chunk = []

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

# ---- 7. Embed and Store in ChromaDB ----
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

def embed_and_store_chunks(chunks, collection_name="royalty_contract_chunks"):
    chroma_client = chromadb.Client()

    # Clean up any existing collection
    try:
        chroma_client.delete_collection(name=collection_name)
    except:
        pass

    embedding_fn = OpenAIEmbeddingFunction(api_key=os.environ["OPENAI_API_KEY"],
                                           model_name="text-embedding-3-small")

    collection = chroma_client.create_collection(name=collection_name,
                                                 embedding_function=embedding_fn)

    for i, chunk in enumerate(chunks):
        collection.add(documents=[chunk], ids=[f"chunk_{i}"])

    return collection

# ---- 8. Ask Questions to GPT-4o with Retrieved Context ----
def answer_question(question, collection, top_k=8, model="gpt-4o"):
    response = client.embeddings.create(
        model="text-embedding-3-small",
        input=question
    )
    embedding = response.data[0].embedding

    results = collection.query(query_embeddings=[embedding], n_results=top_k)
    context = "\n---\n".join(results["documents"][0])

    print("\n[DEBUG] Context Used:\n", context)  # See exactly what GPT is seeing

    prompt = f"""You are the worlds most precise and knowledgable expert in mining royalty contracts. You go through every contract line by line before providing an answer. Your career depends on it.

Extract precise numerical answers from the following contract excerpts.
If values are from tables or flat lists, quote them exactly.

Context:
{context}

Question: {question}
Answer (include source quote):"""

    chat_completion = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0
    )

    return chat_completion.choices[0].message.content


# ---- 9. Run the Pipeline ----
text = extract_text_from_pdf(PDF_PATH)
chunks = chunk_text(text)
collection = embed_and_store_chunks(chunks)

# ---- 10. Interactive Q&A Loop ----
while True:
    question = input("\nAsk a question about the royalty contract (or type 'exit' to quit):\n")
    if question.lower() in ["exit", "quit", "q"]:
        print("Exiting Q&A.")
        break
    answer = answer_question(question, collection)
    print("\nAnswer:\n", answer)

# Ask your question interactively
question = input("Ask a question about the royalty contract:\n")
answer = answer_question(question, collection)
print("\nAnswer:\n", answer)
