## API Keys

In [None]:
# set your API keys in the .env file..
import os
from dotenv import load_dotenv

load_dotenv()

assert os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY is not set"
assert os.getenv("LANGCHAIN_API_KEY"), "LANGCHAIN_API_KEY is not set"
assert os.getenv("YOUTUBE_API_KEY"), "YOUTUBE_API_KEY is not set"

## Traceability

In [None]:
from langchain.callbacks.tracers.langchain import LangChainTracer

tracer = LangChainTracer()

## System Prompts

In [None]:
from langchain_core.messages import SystemMessage
from langchain_core.prompts import SystemMessagePromptTemplate

SysPrompt1=SystemMessage(content="""Return a comma separated list of exactly 5 valid YouTube video IDs 
that are most relevant to the user's query.
For example: 'id1,id2,id3,id4,id5'""")

SysPrompt2=SystemMessagePromptTemplate.from_template(
    template="Return the youtube transcript for the video with ID {video_id}"
)

SysPrompt3=SystemMessage(content="Use the context below to answer the question.")

## Tools

In [None]:
from langchain_core.tools import Tool
from youtube_transcript_api import YouTubeTranscriptApi
from googleapiclient.discovery import build

YOUTUBE_API_SERVICE_NAME = "youtube"
YOUTUBE_API_VERSION = "v3"

def youtube_search(query, max_results=5):
    youtube = build(YOUTUBE_API_SERVICE_NAME, YOUTUBE_API_VERSION, developerKey=os.getenv("YOUTUBE_API_KEY"))
    search_response = youtube.search().list(
        q=query,
        type="video",
        part="id,snippet",
        maxResults=max_results
    ).execute()

    results = []
    for item in search_response["items"]:
        video_id = item["id"]["videoId"]
        title = item["snippet"]["title"]
        results.append({"title": title, "video_id": video_id})

    return results

youtube_search_tool = Tool.from_function(
    func=youtube_search,
    name="youtube_search",
    description="Search for ID's of youtube videos that are most relevant to a user's query"
)

def fetch_youtube_transcript(video_id):
    try:
        transcript = YouTubeTranscriptApi.get_transcript(video_id)
        return "\n".join([t['text'] for t in transcript])
    except Exception as e:
        return f"⚠️ Error fetching transcript: {e}"

youtube_transcript_tool = Tool.from_function(
    func=fetch_youtube_transcript,
    name="fetch_youtube_transcript",
    description="Returns the full transcript of a YouTube video given its ID (e.g., 'xZX4KHrqwhM')."
)

## LLM

In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate,HumanMessagePromptTemplate
from langchain_core.messages import HumanMessage
from langchain.agents import initialize_agent, AgentType

# create the LLM
llm = ChatOpenAI(
    model="gpt-4o", 
    temperature=0,
    callbacks=[tracer]
    )

# create the Agent with tools
agent = initialize_agent(
    tools=[youtube_search_tool, youtube_transcript_tool],
    llm=llm,
    agent=AgentType.OPENAI_FUNCTIONS,  # or AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION
    verbose=True
)

prompt_videos = ChatPromptTemplate.from_messages([
    SysPrompt1,
    HumanMessagePromptTemplate.from_template(template="{query}")]
    )

prompt_video_transcript = ChatPromptTemplate.from_messages([
    SysPrompt2, 
    HumanMessagePromptTemplate.from_template(template="{query}")])

prompt_answer = ChatPromptTemplate.from_messages([
    SysPrompt3, 
    HumanMessagePromptTemplate.from_template(template="{query}\n\nContext:\n{context}")])


## Knowledgebase

In [None]:
# set up any Embeddings/Vector DB's here
from langchain_openai import OpenAIEmbeddings

embedding_model = OpenAIEmbeddings()

vectorstore = None  # Global FAISS store (lazy init)

## Graph Nodes/Lambdas

In [None]:
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter

def step_get_video_ids(state):
    video_ids=[]
    messages = prompt_videos.invoke({"query": state["query"]})
    response = agent.invoke(messages)
    video_ids = response['output'].split(',')
    print(f"✅ Retrieved and filtered Video IDs: {video_ids}")
    return {"query": state["query"], "video_ids": video_ids}

def step_get_transcripts(state):
    docs = []
    for vid in state["video_ids"]:
        try:
            messages = prompt_video_transcript.invoke({"query": state["query"], "video_id": vid})
            response = agent.invoke(messages)
            docs.append(Document(page_content=response['output'], metadata={"video_id": vid}))
        except Exception:
            print("⚠️ Warning: No transcript was retrieved for video {vid}.")
            continue
    if not docs:
        print("⚠️ Warning: No transcripts were retrieved for the given video IDs.")
    return {"query": state["query"], "documents": docs}

def step_embed_docs(state):
    if state["documents"]:
        splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        chunks = splitter.split_documents(state["documents"])
        global vectorstore
        vectorstore = FAISS.from_documents(chunks, embedding_model)
    return {"query": state["query"]}

def step_rag_answer(state):
    context=''
    if vectorstore:
        retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
        results = retriever.invoke(state["query"])
        context = "\n\n".join(doc.page_content for doc in results)
    messages = prompt_answer.invoke({"query": state["query"], "context": context})
    response = agent.invoke(messages)
    return {"answer": response['output']}

## Orchestrator

In [None]:
# set up LangChain/LangGraph here
from typing import TypedDict, Annotated, List
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph
import operator

class State(TypedDict):
    query: str
    video_ids: List[str]
    documents: List[Document]
    answer: str

graph = StateGraph(State)
graph.add_node("get_video_ids", RunnableLambda(step_get_video_ids))
graph.add_node("get_transcripts", RunnableLambda(step_get_transcripts))
graph.add_node("embed_docs", RunnableLambda(step_embed_docs))
graph.add_node("rag_answer", RunnableLambda(step_rag_answer))

graph.set_entry_point("get_video_ids")
graph.add_edge("get_video_ids", "get_transcripts")
graph.add_edge("get_transcripts", "embed_docs")
graph.add_edge("embed_docs", "rag_answer")
graph.set_finish_point("rag_answer")

app = graph.compile()


## The Graph

In [None]:
from IPython.display import Image, display

display(Image(app.get_graph().draw_mermaid_png()))

## Kick off the Query

In [None]:
query = "What are quantum effects in biology?"
result = app.invoke({"query": query})
print("🧠 Final Agent Answer:\n")
print(result["answer"])