# First Version

In [1]:
# yt_rag_gpu_cached.py
# Drop-in for your original script: same QUESTION/LLM chain, faster indexing.

# --- Imports ---
import os, time, json, hashlib, shutil
from pathlib import Path
from typing import List, Dict, Any

from dotenv import load_dotenv
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled

import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer

from langchain.docstore.document import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq


# -------------------- Config --------------------
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

VIDEO_ID = "3qHkcs3kG44"
QUESTION = (
    "Is the topic of nuclear fusion discussed in this video? "
    "If yes, what exactly was discussed?"
)

CACHE_DIR = Path(".yt_rag_cache_0")         # A folder path for all cache files.
CACHE_DIR.mkdir(exist_ok=True)              # create folder if it doesn‚Äôt exist.
EMB_CACHE_ROOT = CACHE_DIR / "emb_cache_st" # subfolder under CACHE_DIR, i.e .yt_rag_cache_0
EMB_CACHE_ROOT.mkdir(exist_ok=True)         # make sure this exists, too.
TARGET_WINDOW_SECONDS = 110                 # ~90‚Äì120s is a sweet spot for long YT videos

# Retrieve N best-matching chunks for a query
TOP_K = 4                                   

# Sentence transformer model
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # 384-d, fast+solid
BATCH_SIZE = 1024 # how many texts to embed at once. Large batch uses more GPU RAM.


# -------------------- Helpers --------------------
def _vid_dir(video_id: str) -> Path:
    '''
    A function that takes a video_id string and returns a Path
    '''
    v = CACHE_DIR / video_id                 # create a subfolder inside .yt_rag_cache named after the video ID.
    v.mkdir(parents=True, exist_ok=True)     # ensure that folder exists.
    return v                                 # give back the folder path.

def _ts(sec: float) -> str:
    '''
    To convert seconds to H:MM:SS or M:SS string.
    '''
    sec = int(sec)                          # make sure its an integer
    h = sec // 3600                         # to convert to hours, divide it by 3600
    m = (sec % 3600) // 60                  # to convery to minutes, divide further by 60
    s =sec % 60                             # modulo of seconds
    return f"{h:d}:{m:02d}:{s:02d}" if h else f"{m:d}:{s:02d}"

def fetch_transcript(video_id: str, languages: List[str] = ["en"]) -> List[Dict[str, Any]]:
    '''
    Function to download the transcript for a given video.
    '''
    try:
        t0 = time.time()                                                                 # record the start time
        fetched = YouTubeTranscriptApi().fetch(video_id, languages=languages)            # get transcript object
        data = fetched.to_raw_data()                                                     # convert the transcript to a list of dicts
        print(f"‚úÖ Transcript fetched: {len(data)} segments in {time.time()-t0:.2f}s")   # print time taken 
        return data                                                                      # Return the list of dicts
    except TranscriptsDisabled:
        raise SystemExit("No captions available for this video.")                        # if no captions, exit with a message. 
    except Exception as e:
        raise SystemExit(f"Failed to fetch transcript: {e}")                             # any other error, then exit with explanation

def group_segments(segments: List[Dict[str, Any]], target_window_s: int = TARGET_WINDOW_SECONDS) -> List[Dict[str, Any]]:
    '''
    To combine many tiny caption segments into larger ‚Äúwindows‚Äù of ~target_window_s seconds. Please Google segment and window.
    '''
    out = []              # list of windows
    cur = []              # current window‚Äôs list of text lines.
    start = None          # start time of current window; start as None.
    end = None            # End time of current window; start as None.
    for row in segments:                                             # Loop over each row in the segments
        t0 = row["start"]                                            # Start time of the row
        t1 = row["start"] + row.get("duration", 0)                   # End time of the row
        if start is None:                                                      # if start is None
            start = t0                                                                # Set it to start time of the row (first segment in window).
        end = t1                                                     # Update End time of the row
        cur.append(row["text"].strip())                              # Add this segment‚Äôs text (row["text"].strip()) to cur
        if (end - start) >= target_window_s:                         # If the window duration (end - start) is >= target window seconds:
            out.append({"start": start, "end": end, "text": " ".join(cur).strip()}) # Add a new window dict to out with "start", "end", and concatenated "text"
            cur = []                                                                          # Reset cur
            start = None                                                                      # Reset start 
            end = None                                                                        # Reset end
    if cur:                         # After loop, if cur still has leftover text, append final window.
        out.append({"start": start, 
                    "end": end, 
                    "text": " ".join(cur).strip()})
        
    print(f"üß© Windows created: {len(out)} (‚âà{target_window_s}s each)") # Print out
    return out                                                           # Return out

def make_docs(video_id: str, windows: List[Dict[str, Any]]) -> List[Document]:
    '''
    Converts each window into a LangChain Document.
    '''
    docs = []                        # An empty list
    for i, w in enumerate(windows):  # Enumerate in windows list
        meta = {                     # metadata about each window
            "video_id": video_id,
            "start": w["start"],
            "end": w["end"],
            "window_id": i,
            "time_range": f"{_ts(w['start'])}‚Äì{_ts(w['end'])}",
        }
        docs.append(Document(page_content=w["text"], metadata=meta)) # create Document with page_content as the text, and metadata
    print(f"üìÑ Document objects: {len(docs)}") # print number of docs 
    return docs # return the list of docs.

# -------------------- GPU embedding + cache --------------------
def make_st_model(model_name: str = MODEL_NAME) -> SentenceTransformer:
    '''
    Decide whether to use GPU (cuda) or CPU with the model.
    '''
    device = "cuda" if torch.cuda.is_available() else "cpu"          # If GPU is available, use it.
    model = SentenceTransformer(model_name, device=device)           # Create the model (SentenceTransformer) with GPU/CPU
    return model                                                     # Return the model.

def _hash_text(t: str) -> str:
    '''
    Used to generate unique cache filenames for each text chunk.
    '''
    return hashlib.md5(t.encode("utf-8")).hexdigest()

def embed_with_cache(model: SentenceTransformer, texts: List[str], cache_dir: Path, batch_size: int = BATCH_SIZE) -> np.ndarray:
    cache_dir.mkdir(parents=True, exist_ok=True)              # ensure cache directory exists.  
    vecs = []                                                 # will hold embeddings (or None for misses)
    misses = []                                               # texts we don‚Äôt have cached embeddings for yet.
    miss_idx = []                                             # positions of those texts.
    for i, t in enumerate(texts):                             # Loop over all texts
        fp = cache_dir / f"{_hash_text(t)}.npy"                     # fp is the file path for this text‚Äôs embedding, e.g., abcd1234.npy
        if fp.exists():                                             # If file exists 
            vecs.append(np.load(fp))                                     # then, ‚Üí np.load(fp) and append to vecs.
        else:                                                       # Else
            vecs.append(None)                                            # Put None in vecs (placeholder)
            misses.append(t)                                             # Append the text to misses.
            miss_idx.append(i)                                           # Append index i to miss_idx.
    if misses:                                                # only embed if there are texts not cached.
        with torch.inference_mode():                          # disable gradient calculations (faster, less memory)
            arr_new = model.encode(                                  # compute embeddings for:
                misses,                                              # the misses list
                batch_size=batch_size,                               # process in batches.
                convert_to_numpy=True,                               # get NumPy array.
                normalize_embeddings=True                            # L2-normalize, so cosine similarity becomes inner product
            ).astype("float32")                                      # Convert to float32, as FAISS expects float32 
        for j, row in enumerate(arr_new):                     # Loop over each new embedding row
            i = miss_idx[j]                                          # Get original index i from miss_idx
            vecs[i] = row                                            # Put row into vecs[i].
            np.save(cache_dir / f"{_hash_text(texts[i])}.npy", row)  # Save row to .npy file for future reuse.
    return np.vstack(vecs).astype("float32")                  # Stack vectors vertically into a single 2D array of shape (n_texts, embedding_dim).

# -------------------- Index build/load --------------------
def build_or_load_index(video_id: str, force_rebuild: bool = False):
    vdir = _vid_dir(video_id)                                                         # Define directories and file paths for video cache folder.
    idx_dir = vdir / "faiss_st"                                                       # ... inside that, FAISS-specific folder 
    meta_fp = vdir / "meta_st.json"                                                   # ... general metadata JSON file.
    metas_path = idx_dir / "metas.jsonl"                                              # ... each document‚Äôs metadata in JSON Lines (.jsonl) format
    idx_path = idx_dir / "index.faiss"                                                # ... FAISS index file itself.
    cache_dir = EMB_CACHE_ROOT / video_id                                             # ... embeddings cache for this video

    if force_rebuild and vdir.exists():                                               # If force_rebuild is True 
        shutil.rmtree(vdir, ignore_errors=True)                                               # delete the whole video directory
    vdir.mkdir(parents=True, exist_ok=True)                                           # Then re-create it.

    # Fast path
    if idx_path.exists() and meta_fp.exists() and metas_path.exists():                # If the FAISS index and metadata files already exist
        index = faiss.read_index(str(idx_path))                                               # Load FAISS index from disk.
        metas = [json.loads(l) for l in metas_path.read_text(encoding="utf-8").splitlines()]  # Read metadata lines and json.loads each line.
        print(f"‚ö° Loaded cached index: {idx_dir}")                                            # Print message and return index, metas.
        return index, metas

    # Build fresh
    t0 = time.time()                                                                  # If fast-path doesn‚Äôt happen, build from scratch.  
    segs = fetch_transcript(video_id)                                                 # Record start time.
    wins = group_segments(segs, target_window_s=TARGET_WINDOW_SECONDS)                # Get transcript segments.
    docs = make_docs(video_id, wins)                                                  # Group into windows.

    texts = [d.page_content for d in docs]                                            # Extract just the text from each document.
    model = make_st_model(MODEL_NAME)                                                 # Create embedding model.

    t = time.time()                                                                   # Time embedding process:
    arr = embed_with_cache(model, texts, cache_dir=cache_dir, batch_size=BATCH_SIZE)  # NumPy array of embeddings (n, 384)
    t_embed = time.time() - t                                                         # how long embedding took.

    d = arr.shape[1]                                                                  # dimension of each vector.
    index = faiss.IndexFlatIP(d)   # an index that uses inner product (IP) for similarity embeddings are L2-normalized => IP == cosine
    index.add(arr)                                                                    # add all vectors to the index

    idx_dir.mkdir(parents=True, exist_ok=True)                                        # Ensure index directory exists.
    faiss.write_index(index, str(idx_path))                                           # Save FAISS index to disk.
    with metas_path.open("w", encoding="utf-8") as f:                                 # Open metas.jsonl for writing:
        for doc in docs:
            f.write(json.dumps(doc.metadata, ensure_ascii=False) + "\n")                   # For each doc, write its .metadata dict as a JSON line.
    with meta_fp.open("w", encoding="utf-8") as f:
        json.dump({
            "video_id": video_id,
            "num_windows": len(wins),
            "chunk_sec": TARGET_WINDOW_SECONDS,
            "model": MODEL_NAME,
            "dim": int(d),
            "embed_time_s": round(t_embed, 3),
        }, f, ensure_ascii=False, indent=2)                                          # Write a "summary" JSON file: video id, number of windows, etc.

    print(f"‚úÖ Index built in {time.time()-t0:.2f}s (embed {t_embed:.2f}s)‚Üí {idx_dir}")
    metas = [doc.metadata for doc in docs]                                           # Create metas list from documents.
    return index, metas, model                                                       # Return index, metas, and model.

# -------------------- Glue: build retriever & run chain --------------------
# We‚Äôll store page_content alongside metadata so retrieval returns real docs.
def build_index_and_retriever(video_id: str, force_rebuild: bool = False, k: int = TOP_K):
    '''
    Another function that builds or loads an index, but now stores full info (content + metadata) in metas_full.jsonl
    '''
    vdir = _vid_dir(video_id)
    idx_dir = vdir / "faiss_st"
    metas_path = idx_dir / "metas_full.jsonl"                      # store content + meta here
    idx_path = idx_dir / "index.faiss"

    if force_rebuild and vdir.exists():                            # Same logic: if force_rebuild, delete everything for this video; then recreate.
        shutil.rmtree(vdir, ignore_errors=True)
    vdir.mkdir(parents=True, exist_ok=True)

    # If not present, build fresh with full metas
    if not (idx_path.exists() and metas_path.exists()):            # If either index or metas_full.jsonl is missing, build from scratch.
        # Fresh build (repeats a bit of code for clarity)
        segs = fetch_transcript(video_id)                                   
        wins = group_segments(segs, 
                              target_window_s=TARGET_WINDOW_SECONDS)
        docs = make_docs(video_id, wins)                                    

        texts = [d.page_content for d in docs]
        model = make_st_model(MODEL_NAME)
        arr = embed_with_cache(model,                              # Same embedding logic as above.
                               texts, 
                               cache_dir=EMB_CACHE_ROOT / video_id, 
                               batch_size=BATCH_SIZE) 

        d = arr.shape[1]
        index = faiss.IndexFlatIP(d)             
        index.add(arr)                                             # Create FAISS index and add vectors.

        idx_dir.mkdir(parents=True, exist_ok=True)
        faiss.write_index(index, str(idx_path))
        with metas_path.open("w", encoding="utf-8") as f:
            for ddoc in docs:
                # keep content for LC chain
                payload = {"page_content": ddoc.page_content, 
                           "metadata": ddoc.metadata}
                f.write(json.dumps(payload,                      # Write metas_full.jsonl where each line is: {"page_content": "...", "metadata": {...}}
                                   ensure_ascii=False) + "\n")    
        # Return index, metas_full, and model.
        return index, [json.loads(l) for l in metas_path.read_text(encoding="utf-8").splitlines()], model

    # If files already exist, Load path
    index = faiss.read_index(str(idx_path))                                                    # Load FAISS index
    metas_full = [json.loads(l) for l in metas_path.read_text(encoding="utf-8").splitlines()]  # Load metas_full.jsonl into a list of dicts.
    model = make_st_model(MODEL_NAME)                                                          # Build model
    return index, metas_full, model                                                            # Return them.

class LCStyleRetriever:
    """A tiny retriever that returns langchain.Document objects."""
    def __init__(self, index: faiss.Index, metas_full: List[Dict[str, Any]], model: SentenceTransformer, k: int = TOP_K):
        '''
        __init__ stores the index, metadata, model, and k on self.
        '''
        self.index = index
        self.metas_full = metas_full  # list of {"page_content":..., "metadata":...}
        self.model = model
        self.k = k

    def get_relevant_documents(self, query: str) -> List[Document]:
        with torch.inference_mode():
            q = self.model.encode([query],                               # Embeds the query into q.
                                  batch_size=1, 
                                  convert_to_numpy=True, 
                                  normalize_embeddings=True).astype("float32")
        D, I = self.index.search(q, self.k)                              # Performs search in FAISS index: D are distances/similarities; I are indices.
        docs = []                                                        # empty list.
        for idx in I[0]:                                                 # Loop over indices in I[0]:
            if idx == -1:                                                           # Skip invalid index -1.
                continue
            payload = self.metas_full[idx]                                          # get the content + metadata dict
            docs.append(Document(page_content=payload["page_content"], # Create a Document with page_content and metadata.
                                 metadata=payload["metadata"]))
        return docs                # Return the list of documents.


# -------------------- Build retriever + LLM chain (matches your original shape) --------------------

# Call build_index_and_retriever for the chosen VIDEO_ID and Get back index, metas_full, and model.
index, metas_full, model = build_index_and_retriever(VIDEO_ID, force_rebuild=False, k=TOP_K)

# Create an instance of LCStyleRetriever using these.
retriever = LCStyleRetriever(index, metas_full, model, k=TOP_K)

# PromptTemplate defines how the final prompt to the LLM will look.
prompt = PromptTemplate(
    input_variables=["context", "question"], # input_variables ‚Äì the placeholders you‚Äôll fill in ({context} and {question}).
    template=( # template ‚Äì formatted string with those placeholders.
        "You are a concise, helpful assistant. Use ONLY the context to answer.\n\n"
        "Context:\n{context}\n\n"
        "Question: {question}\n\n"
        "If the answer isn't in the context, say you can't find it."
    ),
)

llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0.2) # Creates the ChatGroq LLM client, with model and temperature.

def format_docs(retrieved_docs: List[Document]) -> str:
    '''
    Given a list of Documents, join their page_content into one big string, separated by two newlines.
    '''
    return "\n\n".join(doc.page_content for doc in retrieved_docs)

parallel = RunnableParallel({
    "context": RunnableLambda(lambda q: retriever.get_relevant_documents(q)) | RunnableLambda(format_docs),
    "question": RunnablePassthrough(),
})
parser = StrOutputParser()                               # ensures we end with a plain string.
main_chain = parallel | prompt | llm | parser            # The pipeline

# --- Run ---
answer = main_chain.invoke(QUESTION)
print("\n=== Answer ===\n")
print(answer)



=== Answer ===

Yes, the topic of nuclear fusion is discussed in this video. 

The discussion about nuclear fusion starts with the speaker suggesting that we should be building nuclear fusion test plants on the moon. They mention that the problem with fission nuclear fission is that it was built with a bomb, and there are issues with dirty nukes, Fukushima, Three Mile Island, and Chernobyl. They also mention that we need a way to iterate on nuclear fission and eventually fusion to get them working safely, cleanly, and passively. 

The speaker also mentions that there are Gen 4 nuclear reactors that are passive fail-safe, meaning that when they fail, they fail into a safe state.


# Other versions

In [18]:
# -------------------- Imports --------------------
import os
import time
import json
import hashlib
import shutil
from pathlib import Path

from dotenv import load_dotenv
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled

import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer

from langchain.docstore.document import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq


# -------------------- Configuration --------------------
# Load environment variables from .env (this should contain GROQ_API_KEY)
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

# The YouTube video you want to query
VIDEO_ID = "3qHkcs3kG44"

# The question you want to ask about this video
QUESTION = (
    "Is the topic of nuclear fusion discussed in this video? "
    "If yes, what exactly was discussed?"
)

# Root directory to store all cached data (transcript/embeddings/index)
CACHE_DIR = Path(".yt_rag_cache_1")
CACHE_DIR.mkdir(exist_ok=True)

# Root directory for embedding cache (per video)
EMB_CACHE_ROOT = CACHE_DIR / "emb_cache_st"
EMB_CACHE_ROOT.mkdir(exist_ok=True)

# How long each transcript window (chunk) should be, in seconds
TARGET_WINDOW_SECONDS = 110   # around 90‚Äì120 seconds is good for long videos

# How many top matching chunks to retrieve for each query
TOP_K = 4

# SentenceTransformer model (text ‚Üí vector)
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # 384-dim, fast and solid

# Batch size for embedding (reduce if GPU memory is small)
BATCH_SIZE = 512


# -------------------- Helper functions --------------------
def get_video_dir(video_id):
    """
    Return the cache directory for a specific video.
    Example: .yt_rag_cache/3qHkcs3kG44
    """
    vdir = CACHE_DIR / video_id
    vdir.mkdir(parents=True, exist_ok=True)
    return vdir


def seconds_to_timestamp(sec):
    """
    Convert seconds (float) to a H:MM:SS or M:SS string.
    Example: 75.3 -> "1:15"
    """
    sec = int(sec)
    h = sec // 3600
    m = (sec % 3600) // 60
    s = sec % 60
    if h:
        return f"{h:d}:{m:02d}:{s:02d}"
    else:
        return f"{m:d}:{s:02d}"


def fetch_transcript(video_id, languages=["en"]):
    """
    Download the transcript (subtitles) for a given YouTube video.
    Returns a list of segments, where each segment is a dict containing:
    - 'start': when the segment starts (in seconds)
    - 'duration': how long it lasts
    - 'text': the subtitle text
    """
    try:
        t0 = time.time()
        fetched = YouTubeTranscriptApi().fetch(video_id, languages=languages)
        data = fetched.to_raw_data()
        print(f"‚úÖ Transcript fetched: {len(data)} segments in {time.time() - t0:.2f}s")
        return data
    except TranscriptsDisabled:
        raise SystemExit("No captions available for this video.")
    except Exception as e:
        raise SystemExit(f"Failed to fetch transcript: {e}")


def group_segments_into_windows(segments, target_window_s=TARGET_WINDOW_SECONDS):
    """
    Combine many small transcript segments into larger "windows" of around
    target_window_s seconds each. 

    Returns a list of dicts like:
    {
        "start": ...,
        "end": ...,
        "text": "combined transcript text"
    }
    """
    windows = []
    current_texts = []
    current_start = None
    current_end = None

    for row in segments:
        start = row["start"]
        end = row["start"] + row.get("duration", 0)

        # If this is the first segment in the current window
        if current_start is None:
            current_start = start

        current_end = end
        current_texts.append(row["text"].strip())

        # If we've reached or exceeded the target window length, close this window
        if (current_end - current_start) >= target_window_s:
            window_text = " ".join(current_texts).strip()
            windows.append({"start": current_start, "end": current_end, "text": window_text})
            # Reset for the next window
            current_texts = []
            current_start = None
            current_end = None

    # If any leftover segments remain, make a final window
    if current_texts:
        window_text = " ".join(current_texts).strip()
        windows.append({"start": current_start, "end": current_end, "text": window_text})

    print(f"üß© Windows created: {len(windows)} (‚âà{target_window_s}s each)")
    return windows


def make_documents_from_windows(video_id, windows):
    """
    Convert each window into a LangChain Document, storing helpful metadata
    (like start/end times and a human-readable time range).
    """
    docs = []
    for i, w in enumerate(windows):
        meta = {
            "video_id": video_id,
            "start": w["start"],
            "end": w["end"],
            "window_id": i,
            "time_range": f"{seconds_to_timestamp(w['start'])}‚Äì{seconds_to_timestamp(w['end'])}",
        }
        docs.append(Document(page_content=w["text"], metadata=meta))

    print(f"üìÑ Document objects: {len(docs)}")
    return docs


# -------------------- Embedding model + caching --------------------
def make_sentence_transformer_model(model_name=MODEL_NAME):
    """
    Create a SentenceTransformer model, placing it on GPU if available,
    otherwise on CPU.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"üß† Loading SentenceTransformer on: {device}")
    model = SentenceTransformer(model_name, device=device)
    return model


def hash_text_to_filename(text):
    """
    Create a stable hash from text so we can use it as a cache filename.
    This avoids illegal characters and super long filenames.
    """
    return hashlib.md5(text.encode("utf-8")).hexdigest()


def embed_texts_with_cache(model, texts, cache_dir, batch_size=BATCH_SIZE):
    """
    Embed a list of texts using SentenceTransformer, but use a local cache
    directory so that repeated runs do NOT have to embed the same text again.

    Steps:
    - For each text, check if its embedding .npy file exists in the cache.
    - If yes ‚Üí load it.
    - If no ‚Üí remember it as a "miss" and embed later.
    - After embedding the misses, save them to disk and place them in the
      right positions in the output array.

    Returns a NumPy array of shape (num_texts, embedding_dim).
    """
    cache_dir.mkdir(parents=True, exist_ok=True)

    vectors = []      # will hold embeddings or None
    misses = []       # texts that are not cached yet
    miss_indices = [] # positions of those texts in the original list

    # First pass: find which texts already have cached embeddings
    for i, text in enumerate(texts):
        filename = hash_text_to_filename(text) + ".npy"
        filepath = cache_dir / filename

        if filepath.exists():
            # Cache hit: load embedding from disk
            vectors.append(np.load(filepath))
        else:
            # Cache miss: we will embed this later
            vectors.append(None)
            misses.append(text)
            miss_indices.append(i)

    # Second pass: embed all "missing" texts in one or more batches
    if misses:
        print(f"üßÆ Embedding {len(misses)} new texts (not in cache yet)...")
        with torch.inference_mode():
            arr_new = model.encode(
                misses,
                batch_size=batch_size,
                convert_to_numpy=True,
                normalize_embeddings=True,  # L2-normalize for cosine similarity
            ).astype("float32")

        # Store each new embedding both in memory and on disk
        for j, vec in enumerate(arr_new):
            original_index = miss_indices[j]
            vectors[original_index] = vec
            filename = hash_text_to_filename(texts[original_index]) + ".npy"
            filepath = cache_dir / filename
            np.save(filepath, vec)

    # Stack all vectors into a single 2D array (num_texts, dim)
    embeddings = np.vstack(vectors).astype("float32")
    return embeddings


# -------------------- Build or load FAISS index + metadata --------------------
def build_index_and_retriever(video_id, force_rebuild=False, k=TOP_K):
    """
    Main function that:
    - Downloads transcript (if needed)
    - Groups into windows
    - Builds or loads embeddings
    - Builds or loads a FAISS index
    - Returns:
        - index: FAISS index
        - metas_full: list of dictionaries with "page_content" and "metadata"
        - model: SentenceTransformer model
    """
    video_dir = get_video_dir(video_id)
    index_dir = video_dir / "faiss_st"
    metas_path = index_dir / "metas_full.jsonl"  # stores content + metadata
    index_path = index_dir / "index.faiss"
    cache_dir = EMB_CACHE_ROOT / video_id

    if force_rebuild and video_dir.exists():
        print("‚ôªÔ∏è force_rebuild=True ‚Üí deleting old cache for this video...")
        shutil.rmtree(video_dir, ignore_errors=True)
        video_dir.mkdir(parents=True, exist_ok=True)

    index_dir.mkdir(parents=True, exist_ok=True)

    # If we do NOT already have an index and metadata, build them from scratch
    if not (index_path.exists() and metas_path.exists()):
        print("üì¶ No existing index found. Building a new one...")

        # 1. Fetch transcript
        segments = fetch_transcript(video_id)

        # 2. Group into windows
        windows = group_segments_into_windows(segments, target_window_s=TARGET_WINDOW_SECONDS)

        # 3. Turn windows into Documents
        docs = make_documents_from_windows(video_id, windows)

        # 4. Extract raw texts and embed them (with caching)
        texts = [doc.page_content for doc in docs]
        model = make_sentence_transformer_model(MODEL_NAME)
        embeddings = embed_texts_with_cache(model, texts, cache_dir=cache_dir, batch_size=BATCH_SIZE)

        # 5. Build FAISS index (inner product == cosine because we normalized embeddings)
        dim = embeddings.shape[1]
        index = faiss.IndexFlatIP(dim)
        index.add(embeddings)

        # 6. Save FAISS index to disk
        faiss.write_index(index, str(index_path))

        # 7. Save "full" metadata (page_content + metadata) to JSON Lines file
        metas_full = []
        with metas_path.open("w", encoding="utf-8") as f:
            for doc in docs:
                payload = {"page_content": doc.page_content, "metadata": doc.metadata}
                metas_full.append(payload)
                f.write(json.dumps(payload, ensure_ascii=False) + "\n")

        print("‚úÖ New index built and saved.")
        return index, metas_full, model

    # If we DO have an index and metadata, just load them
    print("‚ö° Found existing index. Loading from disk...")
    index = faiss.read_index(str(index_path))

    metas_full = []
    for line in metas_path.read_text(encoding="utf-8").splitlines():
        metas_full.append(json.loads(line))

    model = make_sentence_transformer_model(MODEL_NAME)

    return index, metas_full, model


# -------------------- Simple retriever class --------------------
class LCStyleRetriever:
    """
    A tiny retriever that:
    - stores the FAISS index
    - stores the SentenceTransformer model
    - stores content + metadata
    and returns LangChain Document objects for a given query.
    """

    def __init__(self, index, metas_full, model, k=TOP_K):
        self.index = index
        self.metas_full = metas_full  # list of {"page_content": ..., "metadata": ...}
        self.model = model
        self.k = k

    def get_relevant_documents(self, query):
        """
        Embed the query, search the FAISS index for top-k nearest neighbors,
        then return a list of LangChain Document objects.
        """
        with torch.inference_mode():
            query_vec = self.model.encode(
                [query],
                batch_size=1,
                convert_to_numpy=True,
                normalize_embeddings=True,
            ).astype("float32")

        # FAISS search: returns (distances, indices)
        distances, indices = self.index.search(query_vec, self.k)

        docs = []
        for idx in indices[0]:
            if idx == -1:  # just a safety check; usually not needed
                continue
            payload = self.metas_full[idx]
            doc = Document(page_content=payload["page_content"], metadata=payload["metadata"])
            docs.append(doc)

        return docs


# -------------------- Build retriever + LLM chain --------------------
# Build or load the FAISS index and create a retriever
index, metas_full, st_model = build_index_and_retriever(
    VIDEO_ID,
    force_rebuild=False,  # set True if you change chunking/model and want a fresh index
    k=TOP_K,
)

retriever = LCStyleRetriever(index, metas_full, st_model, k=TOP_K)

# Define the prompt template that the LLM will see
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=(
        "You are a concise, helpful assistant. Use ONLY the context to answer.\n\n"
        "Context:\n{context}\n\n"
        "Question: {question}\n\n"
        "If the answer isn't in the context, say you can't find it."
    ),
)

# Create the Groq LLM client
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0.2)


def format_docs(retrieved_docs):
    """
    Take a list of LangChain Document objects and turn them into a single
    string that will be passed as 'context' to the LLM.
    """
    return "\n\n".join(doc.page_content for doc in retrieved_docs)


# Create a small pipeline that:
# 1. Takes the question as input.
# 2. In parallel:
#    - For "context": runs the retriever, then formats the docs.
#    - For "question": just passes the question through untouched.
parallel = RunnableParallel(
    {
        "context": RunnableLambda(lambda q: retriever.get_relevant_documents(q))
        | RunnableLambda(format_docs),
        "question": RunnablePassthrough(),
    }
)

# Parse the LLM's output as a plain string
parser = StrOutputParser()

# Combine everything:
# question ‚Üí parallel ‚Üí prompt ‚Üí llm ‚Üí parser ‚Üí final answer string
main_chain = parallel | prompt | llm | parser


# -------------------- Run the chain --------------------
if __name__ == "__main__":
    # Run the full RAG pipeline on the QUESTION and print the answer
    answer = main_chain.invoke(QUESTION)
    print("\n=== Answer ===\n")
    print(answer)


üì¶ No existing index found. Building a new one...
‚úÖ Transcript fetched: 4076 segments in 2.86s
üß© Windows created: 73 (‚âà110s each)
üìÑ Document objects: 73
üß† Loading SentenceTransformer on: cuda
üßÆ Embedding 73 new texts (not in cache yet)...
‚úÖ New index built and saved.

=== Answer ===

Yes, the topic of nuclear fusion is discussed in this video. Specifically, the conversation mentions that nuclear fusion is a technology that is not far from working, and that building nuclear fusion test plants on the moon could be a possibility.
