In [1]:
from google.colab import drive, userdata
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install chromadb

Collecting chromadb
  Downloading chromadb-1.0.7-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting build>=1.0.3 (from chromadb)
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting chroma-hnswlib==0.7.6 (from chromadb)
  Downloading chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (252 bytes)
Collecting fastapi==0.115.9 (from chromadb)
  Downloading fastapi-0.115.9-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn>=0.18.3 (from uvicorn[standard]>=0.18.3->chromadb)
  Downloading uvicorn-0.34.2-py3-none-any.whl.metadata (6.5 kB)
Collecting posthog>=2.4.0 (from chromadb)
  Downloading posthog-4.0.1-py2.py3-none-any.whl.metadata (3.0 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelem

In [3]:
!pip install sentence_transformers

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence_transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence_transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence_transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence_transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence_transformers)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.11.0->sentence_transformers)
 

In [4]:
import os
import json
import nltk
import chromadb
import logging
import datetime
from sentence_transformers import SentenceTransformer
from tqdm import tqdm # Use notebook version for Colab
import uuid # Option for unique IDs, though combining page/chunk is better
import gc # Import garbage collector

nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)


True

In [None]:
# --- Configuration ---
BASE_DIR = '/content/drive/My Drive/SUNY_Poly_DSA598/' # Adjust as needed
WIKI_JSONL_DIR = os.path.join(BASE_DIR, 'datasets/FEVER/wiki-pages')
# Define where to store the persistent ChromaDB database
CHROMA_DB_PATH = os.path.join(BASE_DIR, 'chroma_db/fever_wiki_index_finetuned_debug') # Use a different path for debug runs if needed
# Path to your fine-tuned sBERT model
MODEL_PATH = os.path.join(BASE_DIR, 'models/sBERT/all-mpnet-base-v2_n1024_04-20_12:22_(ORCL_TEST)')
# If MODEL_PATH uses the specific name from your previous code:
# MODEL_PATH = f"{BASE_DIR}models/sBERT/all-mpnet-base-v2_n1024_04-20_12:22_(ORCL_TEST)"

COLLECTION_NAME = "fever_wiki_finetuned_sbert_debug"

# Chunking and Batching Parameters
SENTENCES_PER_CHUNK = 5       # How many sentences to group into one chunk
PROCESS_BATCH_SIZE = 64     # How many pages to process before embedding/adding to DB (adjust based on RAM)
MAX_FILES_TO_PROCESS = 1    # Set to a small number (e.g., 1) for debugging

# --- Setup Logging ---
# Logging setup remains useful even with debug prints for potential library/background messages
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Helper Function for Chunking ---
def chunk_text_by_sentences(text, sentences_per_chunk):
    """Splits text into chunks of specified number of sentences."""
    sentences = nltk.sent_tokenize(text)
    chunks = []
    current_chunk_sentences = []
    for i, sentence in enumerate(sentences):
        current_chunk_sentences.append(sentence)
        if (i + 1) % sentences_per_chunk == 0:
            chunks.append(" ".join(current_chunk_sentences))
            current_chunk_sentences = []
    # Add any remaining sentences as the last chunk
    if current_chunk_sentences:
        chunks.append(" ".join(current_chunk_sentences))
    return chunks

# --- Core Processing Function (with Debug) ---
def process_page_batch(pages_batch, collection, sbert_model, sentences_per_chunk, debug=False):
    """
    Helper to chunk, embed, and add a batch of pages to ChromaDB.
    Includes debug printing.
    """
    if debug: print(f"[DEBUG] process_page_batch: Processing batch of {len(pages_batch)} pages.")
    batch_chunk_texts = []
    batch_chunk_ids = []
    batch_metadatas = []

    pages_processed_in_batch = 0
    for page_data in pages_batch:
        page_id = page_data['id']
        page_text = page_data['text']

        # Chunk the text
        chunks = chunk_text_by_sentences(page_text, sentences_per_chunk)
        if not chunks:
            logger.warning(f"Page {page_id} produced no text chunks, skipping.")
            if debug: print(f"[DEBUG] Page {page_id} produced no text chunks, skipping.")
            continue

        # Prepare data for this page's chunks
        for i, chunk_text in enumerate(chunks):
            chunk_id = f"{page_id}_chunk_{i}" # Unique ID for the chunk
            batch_chunk_texts.append(chunk_text)
            batch_chunk_ids.append(chunk_id)
            batch_metadatas.append({
                'page_id': page_id,       # Store original page title/ID
                'chunk_index': i,         # Store chunk sequence number
                'sentence_count': chunk_text.count('.') + chunk_text.count('?') + chunk_text.count('!') # Approx sentences
            })
        pages_processed_in_batch += 1

    # Embed all chunks in the batch together for efficiency
    if not batch_chunk_texts:
        if debug: print("[DEBUG] process_page_batch: No chunks generated in this batch.")
        return pages_processed_in_batch, 0 # No chunks generated in this batch

    try:
        # logger.debug(f"Embedding {len(batch_chunk_texts)} chunks...") # logger.debug might not show by default
        if debug: print(f"[DEBUG] process_page_batch: Embedding {len(batch_chunk_texts)} chunks...")
        # Convert to list of floats for ChromaDB compatibility if needed, depending on version/backend
        embeddings = sbert_model.encode(batch_chunk_texts, show_progress_bar=False).tolist()
        if debug: print(f"[DEBUG] process_page_batch: Embedding complete. Shape: {len(embeddings)} x {len(embeddings[0]) if embeddings else 0}")

        # Add batch to ChromaDB collection
        # Use upsert=True if you might re-run and want to overwrite existing IDs
        # logger.debug(f"Adding {len(batch_chunk_ids)} items to Chroma collection...")
        if debug: print(f"[DEBUG] process_page_batch: Adding {len(batch_chunk_ids)} items to Chroma collection '{collection.name}'...")
        collection.add(
            ids=batch_chunk_ids,
            embeddings=embeddings,
            metadatas=batch_metadatas,
            documents=batch_chunk_texts # Store the chunk text itself
        )
        if debug: print(f"[DEBUG] process_page_batch: Addition complete.")

        # Explicitly delete large variables to suggest garbage collection
        del embeddings
        del batch_chunk_texts
        del batch_chunk_ids
        del batch_metadatas
        gc.collect() # Suggest garbage collection
        if debug: print(f"[DEBUG] process_page_batch: Cleaned up batch variables.")

        return pages_processed_in_batch, len(chunks) * pages_processed_in_batch # Return estimate based on pages processed

    except Exception as e:
        logger.error(f"Error during embedding or adding batch to ChromaDB: {e}")
        if debug: print(f"[ERROR] process_page_batch: Error during embedding or adding batch: {e}")
        # Decide how to handle partial failures - skip batch, retry?
        return pages_processed_in_batch, 0 # Report 0 chunks added for this failed batch

# --- Main Indexing Function (with Debug) ---
def build_chroma_index_incrementally(wiki_dir, db_path, collection_name, model_path,
                                     sentences_per_chunk=5, process_batch_size=32,
                                     max_files=None, debug=False):
    """
    Loads wiki pages, chunks text, embeds, and builds a persistent ChromaDB index incrementally.
    Includes debug printing.
    """
    if debug: print(f"[DEBUG] Starting build_chroma_index_incrementally with debug mode ENABLED.")

    if not os.path.exists(wiki_dir):
        logger.error(f"Wikipedia JSONL directory not found: {wiki_dir}")
        if debug: print(f"[ERROR] Wikipedia JSONL directory not found: {wiki_dir}")
        return

    # --- Initialize ChromaDB Client (Persistent) ---
    logger.info(f"Initializing persistent ChromaDB client at: {db_path}")
    if debug: print(f"[DEBUG] Initializing persistent ChromaDB client at: {db_path}")
    os.makedirs(db_path, exist_ok=True) # Ensure directory exists
    try:
        client = chromadb.PersistentClient(path=db_path)
        if debug: print(f"[DEBUG] ChromaDB client initialized.")
    except Exception as e:
        logger.error(f"Failed to initialize ChromaDB client: {e}")
        if debug: print(f"[ERROR] Failed to initialize ChromaDB client: {e}")
        return

    # --- Load Sentence Transformer Model ---
    logger.info(f"Loading sBERT model from: {model_path}")
    if debug: print(f"[DEBUG] Loading sBERT model from: {model_path}")
    try:
        # Ensure model runs on GPU if available (implicitly handled by SentenceTransformer usually)
        sbert_model = SentenceTransformer(model_path) # Specify device if needed
        if debug: print(f"[DEBUG] sBERT model loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load Sentence Transformer model: {e}")
        if debug: print(f"[ERROR] Failed to load Sentence Transformer model: {e}")
        return

    # --- Get or Create Chroma Collection ---
    logger.info(f"Getting or creating ChromaDB collection: {collection_name}")
    if debug: print(f"[DEBUG] Getting or creating ChromaDB collection: {collection_name}")
    try:
        collection = client.get_or_create_collection(
            name=collection_name,
            # metadata={"hnsw:space": "cosine"} # Optional: Specify distance metric
        )
        initial_count = collection.count()
        logger.info(f"Collection '{collection_name}' ready. Current count: {initial_count}")
        if debug: print(f"[DEBUG] Collection '{collection_name}' ready. Current count: {initial_count}")
    except Exception as e:
        logger.error(f"Failed to get or create ChromaDB collection: {e}")
        if debug: print(f"[ERROR] Failed to get or create ChromaDB collection: {e}")
        return


    # --- Iterate through wiki files ---
    wiki_files = sorted([f for f in os.listdir(wiki_dir) if f.startswith('wiki-') and f.endswith('.jsonl')])
    if max_files is not None:
        wiki_files = wiki_files[:max_files]
    logger.info(f"Found {len(wiki_files)} wiki files to process.")
    if debug: print(f"[DEBUG] Found {len(wiki_files)} wiki files to process.")

    total_pages_processed = 0
    total_chunks_added = 0

    # Use tqdm for progress bar over files
    for filename in tqdm(wiki_files, desc="Processing Wiki Files"):
        filepath = os.path.join(wiki_dir, filename)
        logger.info(f"Processing file: {filename}")
        if debug: print(f"\n[DEBUG] Processing file: {filename} at path: {filepath}")

        pages_batch = [] # Accumulate pages before processing
        file_pages_processed = 0
        file_chunks_added = 0

        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                # Use tqdm for progress within a file (can be verbose)
                # for line_num, line in tqdm(enumerate(f), desc=f"Lines in {filename}", leave=False): # Optionally add inner progress
                f_len = sum(1 for _ in f) # Get file length
                f.seek(0) # Reset file pointer

                for line_num, line in enumerate(f):
                    try:
                        data = json.loads(line)
                        page_id = str(data.get('id', '')) # Ensure ID is string
                        page_text = data.get('text', '')

                        if not page_id or not page_text:
                            # logger.warning(f"Skipping line {line_num+1} in {filename}: Missing 'id' or 'text'.")
                            # Optional: Add debug print for skipped lines if needed
                            if debug: print(f"[DEBUG] Skipping line {line_num+1} in {filename}: Missing 'id' or 'text'.")
                            continue

                        if debug: print(f"[DEBUG] Processing line {line_num+1} of {f_len} in {filename}.")
                        pages_batch.append({'id': page_id, 'text': page_text})

                        # When batch is full, process it
                        if len(pages_batch) >= process_batch_size:
                            if debug: print(f"[DEBUG] Reached batch size ({process_batch_size}), processing batch...")
                            processed_count, chunks_count = process_page_batch(pages_batch, collection, sbert_model, sentences_per_chunk, debug=debug)
                            file_pages_processed += processed_count
                            file_chunks_added += chunks_count
                            pages_batch = [] # Reset batch
                            if debug: print(f"[DEBUG] Batch processed. File pages so far: {file_pages_processed}, File chunks so far: {file_chunks_added}")


                    except json.JSONDecodeError:
                        logger.warning(f"Skipping invalid JSON line {line_num+1} in {filename}.")
                        if debug: print(f"[WARN] Skipping invalid JSON line {line_num+1} in {filename}.")
                    except Exception as e:
                        logger.error(f"Error processing line {line_num+1} in {filename}: {e}")
                        if debug: print(f"[ERROR] Error processing line {line_num+1} in {filename}: {e}")

                # Process any remaining pages in the last batch for this file
                if pages_batch:
                    if debug: print(f"[DEBUG] Processing final batch for file {filename} (size: {len(pages_batch)})...")
                    processed_count, chunks_count = process_page_batch(pages_batch, collection, sbert_model, sentences_per_chunk, debug=debug)
                    file_pages_processed += processed_count
                    file_chunks_added += chunks_count
                    if debug: print(f"[DEBUG] Final batch processed.")


        except Exception as e:
            logger.error(f"Error reading or processing file {filename}: {e}")
            if debug: print(f"[ERROR] Error reading or processing file {filename}: {e}")

        total_pages_processed += file_pages_processed
        total_chunks_added += file_chunks_added
        logger.info(f"Finished processing {filename}. Pages in file: {file_pages_processed}, Chunks added: {file_chunks_added}. Total pages overall: {total_pages_processed}, Total chunks overall: {total_chunks_added}")
        if debug: print(f"[DEBUG] Finished processing {filename}. Pages in file: {file_pages_processed}, Chunks added: {file_chunks_added}. Cumulative pages: {total_pages_processed}, Cumulative chunks: {total_chunks_added}")
        # Optional: Persist aggressively after each file if memory is extremely tight, though PersistentClient should handle it.
        # client.persist()
        # if debug: print(f"[DEBUG] Explicitly persisted DB state after file {filename}")


    logger.info(f"--- Indexing complete ---")
    if debug: print(f"\n--- [DEBUG] Indexing complete ---")
    logger.info(f"Total Wikipedia pages processed: {total_pages_processed}")
    if debug: print(f"[DEBUG] Total Wikipedia pages processed: {total_pages_processed}")
    logger.info(f"Total text chunks added to collection '{collection_name}': {total_chunks_added}")
    if debug: print(f"[DEBUG] Total text chunks added to collection '{collection_name}': {total_chunks_added}")
    final_count = collection.count()
    logger.info(f"Final collection count: {final_count}")
    if debug: print(f"[DEBUG] Final collection count: {final_count}")
    logger.info(f"Database saved to: {db_path}")
    if debug: print(f"[DEBUG] Database saved to: {db_path}")


# --- Code to Run in Notebook Cell ---

# Make sure MODEL_PATH is correctly set before running
if "YYYY-MM-DD" in MODEL_PATH: # Basic check if placeholder is still there
     print("ERROR: Please update MODEL_PATH with the correct path to your fine-tuned model.")
     # Or raise an error: raise ValueError("MODEL_PATH needs to be updated.")
else:
    # Set debug=True for the main call
    build_chroma_index_incrementally(
        wiki_dir=WIKI_JSONL_DIR,
        db_path=CHROMA_DB_PATH,
        collection_name=COLLECTION_NAME,
        model_path=MODEL_PATH,
        sentences_per_chunk=SENTENCES_PER_CHUNK,
        process_batch_size=PROCESS_BATCH_SIZE,
        max_files=MAX_FILES_TO_PROCESS,
        debug=True # <--- Set Debug Mode Here
    )

[DEBUG] Starting build_chroma_index_incrementally with debug mode ENABLED.
[DEBUG] Initializing persistent ChromaDB client at: /content/drive/My Drive/SUNY_Poly_DSA598/chroma_db/fever_wiki_index_finetuned_debug
[DEBUG] ChromaDB client initialized.
[DEBUG] Loading sBERT model from: /content/drive/My Drive/SUNY_Poly_DSA598/models/sBERT/all-mpnet-base-v2_n1024_04-20_12:22_(ORCL_TEST)
[DEBUG] sBERT model loaded successfully.
[DEBUG] Getting or creating ChromaDB collection: fever_wiki_finetuned_sbert_debug
[DEBUG] Collection 'fever_wiki_finetuned_sbert_debug' ready. Current count: 0
[DEBUG] Found 1 wiki files to process.


Processing Wiki Files:   0%|          | 0/1 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[DEBUG] Processing line 44157 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44158 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44159 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44160 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44161 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44162 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44163 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44164 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44165 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44166 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44167 of 50000 in wiki-001.jsonl.
[DEBUG] Processing line 44168 of 50000 in wiki-001.jsonl.
[DEBUG] Reached batch size (64), processing batch...
[DEBUG] process_page_batch: Processing batch of 64 pages.
[DEBUG] process_page_batch: Embedding 71 chunks...
[DEBUG] process_page_batch: Embedding complete. Shape: 71 x 768
[DEBUG] proce