In [None]:
"""
csv_chunk_embed_qdrant.py

- Reads CSV files in `data_dir`.
- Uses a specified text column to create chunked documents.
- Creates embeddings using a HuggingFace sentence-transformers model.
- Stores vectors in Qdrant with metadata (filename, row_index, chunk_index, text).
- Allows interactive similarity search.
"""

import os
import math
import uuid
from typing import List, Dict, Iterable
import pandas as pd
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http.models import VectorParams, Distance, PointStruct
from tqdm import tqdm


In [None]:
# ---------- USER CONFIG ----------
DATA_DIR = "data"      # directory containing *.csv
TEXT_COLUMN = "text"       # column name to read from CSVs (change if needed)
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # HF sentence-transformers model
COLLECTION_NAME = "csv_chunks"
CHUNK_MAX_WORDS = 120      # max words per chunk
CHUNK_OVERLAP_WORDS = 20   # overlap between adjacent chunks
BATCH_SIZE = 64            # embedding batch size
QDRANT_HOST = "localhost"  # change if remote
QDRANT_PORT = 6333
# ---------------------------------

In [None]:

def chunk_text(text: str, max_words: int = 120, overlap: int = 20) -> List[str]:
    """Split text into chunks by words with overlap."""
    if not isinstance(text, str) or not text.strip():
        return []
    words = text.split()
    if len(words) <= max_words:
        return [" ".join(words)]
    chunks = []
    start = 0
    while start < len(words):
        end = start + max_words
        chunk_words = words[start:end]
        chunks.append(" ".join(chunk_words))
        if end >= len(words):
            break
        start = end - overlap  # overlap
    return chunks


In [None]:

def iter_csv_rows(data_dir: str, text_column: str):
    """Yield tuples (filename, row_index, text) for every row in CSVs found."""
    for fname in os.listdir(data_dir):
        if not fname.lower().endswith(".csv"):
            continue
        path = os.path.join(data_dir, fname)
        try:
            df = pd.read_csv(path, dtype=str, keep_default_na=False)  # read as str to avoid NaNs
        except Exception as e:
            print(f"Failed to read {path}: {e}")
            continue
        if text_column not in df.columns:
            print(f"Warning: {fname} missing column '{text_column}', skipping.")
            continue
        for idx, row in df.iterrows():
            text = row.get(text_column, "")
            yield fname, idx, text


In [None]:

def embed_batches(model: SentenceTransformer, texts: Iterable[str], batch_size: int = 64):
    """Yield embeddings for batches of texts."""
    texts = list(texts)
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        emb = model.encode(batch, show_progress_bar=False, batch_size=len(batch), convert_to_numpy=True)
        yield i, batch, emb

In [None]:
# 1) load model
print("Loading embedding model:", MODEL_NAME)
model = SentenceTransformer(MODEL_NAME)
emb_dim = model.get_sentence_embedding_dimension()
print("Embedding dimension:", emb_dim)

# 2) connect to Qdrant
print("Connecting to Qdrant at", f"{QDRANT_HOST}:{QDRANT_PORT}")
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, prefer_grpc=False)

# 3) create collection (if not exists)
existing_collections = [c.name for c in client.get_collections().collections]

if COLLECTION_NAME not in existing_collections:
    print("Creating collection:", COLLECTION_NAME)
    client.recreate_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=emb_dim, distance=Distance.COSINE),
    )
else:
    print("Collection exists:", COLLECTION_NAME)

# 4) iterate CSV rows, chunk and prepare records
texts_for_embedding = []
mapping = []

print("Reading CSVs and chunking texts...")

for fname, row_idx, text in iter_csv_rows(DATA_DIR, TEXT_COLUMN):
    chunks = chunk_text(text, CHUNK_MAX_WORDS, CHUNK_OVERLAP_WORDS)

    for chunk_idx, chunk in enumerate(chunks):
        meta = {
            "source_file": fname,
            "row_index": row_idx,
            "chunk_index": chunk_idx,
            "text": chunk[:1000],
        }

        texts_for_embedding.append(chunk)
        mapping.append(meta)

if not texts_for_embedding:
    print("No text chunks found. Check DATA_DIR and TEXT_COLUMN.")
else:
    print(f"Total chunks to embed: {len(texts_for_embedding)}")

    # 5) embed in batches and upsert to Qdrant
    print("Embedding and uploading vectors in batches...")

    next_id = 0

    for start_idx, batch_texts, batch_emb in embed_batches(
        model, texts_for_embedding, BATCH_SIZE
    ):
        points_batch = []

        for j, emb in enumerate(batch_emb):
            meta = mapping[start_idx + j]
            point_id = str(uuid.uuid4())

            p = PointStruct(
                id=point_id,
                vector=emb.tolist(),
                payload=meta,
            )

            points_batch.append(p)

        client.upsert(
            collection_name=COLLECTION_NAME,
            points=points_batch,
        )

        next_id += len(points_batch)
        print(f"Upserted {next_id} vectors...", end="\r")

    print(f"\nFinished upserting {next_id} vectors into '{COLLECTION_NAME}'.")


In [None]:

# 6) interactive search
print("\nIndexing complete. You can now enter queries. Type 'exit' to quit.")

while True:
    query = input("\nEnter query: ").strip()

    if not query:
        continue

    if query.lower() in ("exit", "quit"):
        break

    q_emb = model.encode([query], convert_to_numpy=True)[0]

    response = client.query_points(
        collection_name=COLLECTION_NAME,
        query=q_emb.tolist(),
        limit=10,
    )

    hits = response.points

    if not hits:
        print("No results.")
        continue

    for rank, hit in enumerate(hits, start=1):
        payload = hit.payload or {}
        score = hit.score
        preview = payload.get("text", "")[:500]

        print(
            f"\n{rank}. score={score:.4f} â€” "
            f"file={payload.get('source_file')} "
            f"row={payload.get('row_index')} "
            f"chunk={payload.get('chunk_index')}"
        )

        print(preview)
