In [1]:
# === Standard Library ===
import json
import os
import sys
from typing import Any, Dict, List, Optional, Union, TypedDict, Annotated

# === Data Handling ===
import pandas as pd

# === NLP Models ===
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel

# === LangGraph Core ===
from langgraph.graph import StateGraph, START, END
from langgraph.types import interrupt, Command
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

# === LangChain & LLM ===
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
import ollama  # used directly in custom LLM calls

# === Validation ===
from pydantic import BaseModel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path = "/Users/danielrubibreton/Desktop/PythonStuff/hface/all-MiniLM-L6-v2"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModel.from_pretrained(path)

In [17]:
HumanMessage(content="Hello").pretty_print()


Hello


In [4]:
with open("cfa2025.json", "r") as file:
    book_json = eval(file.read())

In [5]:
flat_sections = [f"{book} -> {chapter}" for book, chapters in book_json.items() for chapter in chapters]

flat_book = [f"{book} -> {chapter} -> {book_json[book][chapter] }" for book, chapters in book_json.items() for chapter in chapters]

In [6]:
def find_relevant_sections(
    query: str,
    top_k: int = 5,
    score_threshold1: float = 0.5,
    score_threshold2: float = 0.3,
    model_name: str = 'all-MiniLM-L6-v2',
    return_content: bool = False
) -> Union[List[str], Dict[str, str]]:
    """
    If return_content=False: return top_k section titles (e.g. ["Genesis -> 1", ...]).
    If return_content=True: return a dict mapping each section title to its full text.
    """
    model = SentenceTransformer(model_name, device='mps')
    query_emb = model.encode(query, convert_to_tensor=True)

    # 1) Title‐level matching
    sec_embs = model.encode(flat_sections, convert_to_tensor=True)
    scores = util.cos_sim(query_emb, sec_embs)[0]
    hits = sorted(
        [(i, s.item()) for i, s in enumerate(scores) if s.item() >= score_threshold1],
        key=lambda x: x[1], reverse=True
    )[:top_k]

    if not hits:
        # 2) Fallback: book‐level matching
        book_embs = model.encode(flat_book, convert_to_tensor=True)
        scores = util.cos_sim(query_emb, book_embs)[0]
        hits = sorted(
            [(i, s.item()) for i, s in enumerate(scores) if s.item() >= score_threshold2],
            key=lambda x: x[1], reverse=True
        )[: top_k + 2]
        # normalize to just "Book -> Chapter"
        section_keys = []
        for idx, _ in hits:
            book, chap, _ = flat_book[idx].split(" -> ", 2)
            section_keys.append(f"{book} -> {chap}")
    else:
        section_keys = [flat_sections[i] for i, _ in hits]

    if not return_content:
        return section_keys

    # if return_content=True, build the full dict
    return {
        sec: book_json[sec.split(" -> ", 1)[0]][sec.split(" -> ", 1)[1]]
        for sec in section_keys
    }

In [7]:
%%time
find_relevant_sections("What is interest rate (or yield)?")

CPU times: user 1.15 s, sys: 534 ms, total: 1.69 s
Wall time: 3.08 s


['Fixed Income -> 2.1.Maturity Structure of Interest Rates',
 'Quantitative Methods -> 2.Interest Rates and Time Value of Money',
 'Fixed Income -> 2.2.Yield-to-Maturity',
 'Quantitative Methods -> 2.1.Determinants of Interest Rates']

In [14]:
class LocalLLM:
    def __init__(
        self,
        model: str = "deepseek-r1:1.5b",
        temperature: float = 0,
        max_tokens: Optional[int] = None,
    ):
        self.model = model
        self.temperature = temperature


    def invoke(self, prompt: str) -> str:
        messages = [{"role": "user", "content": prompt}]
        response = ollama.chat(
            model=self.model,
            messages=messages,
            options={
                "temperature": self.temperature,
                "num_thread": 10,
                "low_vram": False,
            } 
        )
        return response["message"]["content"].split("</think>")[-1].strip("\n")


In [10]:
llm = LocalLLM()

In [None]:

# -------------------------------------------------------------------------
# 2. GraphState definition
# -------------------------------------------------------------------------
class GraphState(TypedDict):
    query: str
    retrieved_sections: Optional[Any]       # Will hold either List[str] or str
    response: Optional[str]
    messages: Annotated[List[Dict[str, str]], add_messages]
    goto: Optional[str]          # Used by the confirm node to drive conditional edges
    context: str


# -------------------------------------------------------------------------
# 3. Retrieval node (unchanged)
# -------------------------------------------------------------------------

def user_query_node(state: GraphState) -> GraphState:
    query = state["query"]
    # find_relevant_sections(query) returns List[str]

    
    return {
        "messages": [HumanMessage(content=query)]
    }



def retrieval_node(state: GraphState) -> GraphState:
    query = state["query"]
    # find_relevant_sections(query) returns List[str]
    section_list = find_relevant_sections(query)

    
    return {
        "messages": [SystemMessage(content=section_list)],
        "retrieved_sections": section_list
    }


# -------------------------------------------------------------------------
# 4. Confirm node using interrupt()
# -------------------------------------------------------------------------

def confirm_node(state: GraphState) -> Dict[str, Any]:
    sections = state["retrieved_sections"]

    # 1) Exactly one → skip confirmation
    if len(sections) == 1:
        return {
            "messages": [
                SystemMessage(content=f"Only one section found: {sections[0]}. Skipping confirmation.")
            ],
            "retrieved_sections": sections,
            "goto": "full_retrieval"
        }

    # 2) None → ask user to rephrase query
    if not sections:
        new_q = interrupt({
            "prompt": "I couldn't find any sections matching your question. Please rephrase or clarify your query."
        })
        return {
            "messages": [SystemMessage(content=f"User provided new query: {new_q}")],
            "query": new_q,
            "goto": "retrieve"
        }

    # 3) Multiple → present numbered options
    opts = "\n".join(f"{i+1}. {sec}" for i, sec in enumerate(sections))
    choice = interrupt({
        "prompt": (
            "I found multiple relevant sections. Please select one by number:\n\n"
            f"{opts}\n\nReply with 1, 2, 3, etc."
        )
    })

    # validate
    try:
        idx = int(choice.strip()) - 1
        if idx < 0 or idx >= len(sections):
            raise ValueError()
    except ValueError:
        return {
            "messages": [SystemMessage(content=f"Invalid choice '{choice}'. Let me try again.")],
            "goto": "confirm"
        }

    # valid selection
    picked = sections[idx]
    return {
        "messages": [
            SystemMessage(content=f"User selected option {idx+1}: {picked}")
        ],
        "retrieved_sections": [picked],
        "goto": "full_retrieval"
    }

# -------------------------------------------------------------------------
# 5. Full retrieval node (unchanged)
# -------------------------------------------------------------------------
def full_retrieval_node(state: GraphState) -> GraphState:
    section_list = state["retrieved_sections"]  # This is still List[str]
    all_text = []
    for sec in section_list:
        book, chap = sec.split(" -> ", 1)
        text = book_json[book][chap]
        full_ctx.append(f"Book & Chapter: {sec}\n{text}")
        
    concatenated = "\n".join(all_text)

    
    return {
        "context": concatenated,
        "messages": [SystemMessage(content="Fetched full text for all candidate sections.")]
    }


# -------------------------------------------------------------------------
# 6. Response node (unchanged)
# -------------------------------------------------------------------------
def response_node(state: GraphState) -> GraphState:
    context_text = state["context"]  # Now a big string
    query = state["query"]
    prompt = (
        "Below are the relevant CFA curriculum sections:\n\n"
        f"{context_text}\n\n"
        f"Question: {query}\n\n"
        "Please answer in a concise, yet complete manner, citing the relevant sections as needed."
    )
    llm_answer = llm.invoke(prompt)

    return {
        "response": llm_answer,
        "messages": [AIMessage(content=llm_answer)]
    }


In [None]:
graph = (
    StateGraph(GraphState)
    .add_node("retrieve",       retrieval_node)
    .add_node("confirm",        confirm_node)
    .add_node("full_retrieval", full_retrieval_node)
    .add_node("respond",        response_node)
    
    .add_edge(START,            "retrieve")
    .add_edge("retrieve",       "confirm")
    .add_conditional_edges(
        "confirm",
        path=lambda out: out.get("goto"),
        path_map={"full_retrieval":"full_retrieval","retrieve":"retrieve"},
    )
    .add_edge("full_retrieval", "respond")
    .add_edge("respond",         END)
    .compile()
)

graph.get_graph().print_ascii()


In [None]:
# assume `graph` is already defined and compiled as in your notebook

while True:
    user_query = input("You: What is the difference between Venture capital and Private Equity?")
    if user_query.strip().lower() == "exit":
        print("Exiting chat. Goodbye!")
        break

    # build the LangGraph input
    state = {"messages": [{"role": "user", "content": user_query}]}
    
    # stream the response tokens as they arrive
    for msg_chunk, meta in graph.stream(
        state,
        stream_mode="messages"
    ):
        print(meta.get("langgraph_node"))
        # only print from your 'respond' node
        if meta.get("langgraph_node") == "respond" and msg_chunk.content:
            print(msg_chunk.content, end="", flush=True)
    print()  # newline after the full response
