In [16]:
import os, json, uuid, pathlib  #step 1
import pyarrow as pa
import pyarrow.csv as pacsv
import pandas as pd
from sentence_transformers import SentenceTransformer
import chromadb

# ---- Paths ----
DATA_CSV = "synthetic_retail_data.csv"          # change to your real CSV if different
CHROMA_PATH = "chromadb_store"                  # persistent store folder
COLLECTION_NAME = "retail_chunks_arrow"
JSON_INDEX_PATH = "arrow_index.json"

# If you want to regenerate synthetic data quickly (set to True to generate)
GENERATE_SYNTHETIC = False
N_ROWS_SYN = 200


In [17]:
# Read CSV using PyArrow (fast and type-aware) #step2
read_options = pacsv.ReadOptions(autogenerate_column_names=False)
parse_options = pacsv.ParseOptions(delimiter=",")
convert_options = pacsv.ConvertOptions()  # we’ll inspect schema after load

table: pa.Table = pacsv.read_csv(DATA_CSV, read_options=read_options,
                                 parse_options=parse_options,
                                 convert_options=convert_options)

print("Schema:", table.schema)
print("Rows:", table.num_rows)

# ---- Assign roles ----
# Treat text-ish columns as "categorical" (go into the chunk),
# and numeric columns as "metadata".
# You can modify these lists to fit your real CSV.

all_cols = table.schema.names
# Heuristic: numeric arrow types
numeric_types = (pa.int8(), pa.int16(), pa.int32(), pa.int64(),
                 pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64(),
                 pa.float16(), pa.float32(), pa.float64())

numeric_cols = [c for c in all_cols if pa.types.is_integer(table.schema.field(c).type) or pa.types.is_floating(table.schema.field(c).type)]
# Everything else considered categorical (including strings/timestamps)
categorical_cols = [c for c in all_cols if c not in numeric_cols]

# If you want a specific split (recommended for your dataset), override:
preferred_categorical = ["InvoiceNo", "StockCode", "Description", "InvoiceDate", "Country"]
preferred_numeric = ["Quantity", "UnitPrice", "CustomerID"]

# Use preferred if they exist; fallback to detected
categorical_cols = [c for c in preferred_categorical if c in all_cols] or categorical_cols
numeric_cols = [c for c in preferred_numeric if c in all_cols] or numeric_cols

print("Categorical (chunk text):", categorical_cols)
print("Numeric (metadata):", numeric_cols)


Schema: InvoiceNo: string
StockCode: string
Description: string
Quantity: int64
InvoiceDate: date32[day]
UnitPrice: double
CustomerID: int64
Country: string
Rows: 200
Categorical (chunk text): ['InvoiceNo', 'StockCode', 'Description', 'InvoiceDate', 'Country']
Numeric (metadata): ['Quantity', 'UnitPrice', 'CustomerID']


In [18]:
def to_python_value(arr, i):
    # Convert a scalar at row i in a PyArrow array to a Python type
    return arr[i].as_py()

def build_row_dict(batch, row_idx_in_batch):
    out = {}
    for col_name in batch.schema.names:
        out[col_name] = to_python_value(batch.column(col_name), row_idx_in_batch)
    return out

def make_chunk_text(row_dict, cat_cols):
    # Compose only categorical fields into the text chunk
    parts = [f"{c}: {row_dict.get(c, '')}" for c in cat_cols]
    return ", ".join(parts)

chunks = []
metadatas = []
id_list = []
row_counter = 0

for batch in table.to_batches(max_chunksize=2048):
    batch = pa.Table.from_batches([batch])  # ensure schema consistency
    n = batch.num_rows
    for i in range(n):
        row = build_row_dict(batch, i)

        # Build chunk text from categorical fields
        text = make_chunk_text(row, categorical_cols)

        # Metadata: numeric fields
        meta = {k: row[k] for k in numeric_cols}

        # Keep a few categorical keys as metadata (converted safely)
        for keep in ["Country", "Description", "InvoiceDate", "InvoiceNo", "StockCode"]:
            if keep in row:
                val = row[keep]
                # Convert unsupported types (dates/timestamps) → string
                if hasattr(val, "isoformat"):
                    val = str(val)
                elif val is None:
                    val = None
                else:
                    # Non-numeric categorical values → string
                    val = str(val) if not isinstance(val, (int, float, bool)) else val
                meta[keep] = val

        # Create a stable ID
        _id = str(row_counter)

        id_list.append(_id)
        chunks.append(text)
        metadatas.append(meta)
        row_counter += 1

print(f"Built {len(chunks)} chunks. Example:\n", chunks[:2], "\nMetadata example:\n", metadatas[:2])



Built 200 chunks. Example:
 ['InvoiceNo: INV1000, StockCode: STK588, Description: Shoes, InvoiceDate: 2023-01-01, Country: Germany', 'InvoiceNo: INV1001, StockCode: STK335, Description: Laptop, InvoiceDate: 2023-01-02, Country: USA'] 
Metadata example:
 [{'Quantity': 3, 'UnitPrice': 306.04, 'CustomerID': 11745, 'Country': 'Germany', 'Description': 'Shoes', 'InvoiceDate': '2023-01-01', 'InvoiceNo': 'INV1000', 'StockCode': 'STK588'}, {'Quantity': 10, 'UnitPrice': 63.06, 'CustomerID': 12621, 'Country': 'USA', 'Description': 'Laptop', 'InvoiceDate': '2023-01-02', 'InvoiceNo': 'INV1001', 'StockCode': 'STK335'}]


In [19]:
# Load embedding model (fast & small)
model = SentenceTransformer("all-MiniLM-L6-v2")

# Encode in batches
embeddings = model.encode(chunks, batch_size=128, show_progress_bar=True)
print("Embeddings shape:", len(embeddings), "x", len(embeddings[0]))

# Connect to persistent ChromaDB
client = chromadb.PersistentClient(path=CHROMA_PATH)
collection = client.get_or_create_collection(COLLECTION_NAME)

# Optional: clear previous data safely (avoid where={} errors)
if collection.count() > 0:
    existing = collection.get()
    all_ids = existing.get("ids", [])
    if all_ids:
        collection.delete(ids=all_ids)

# Add in manageable batches to avoid RAM spikes
BATCH = 1000
for start in range(0, len(chunks), BATCH):
    end = start + BATCH
    collection.add(
        ids=id_list[start:end],
        documents=chunks[start:end],
        embeddings=embeddings[start:end],
        metadatas=metadatas[start:end]
    )

print("✅ Stored in ChromaDB:", COLLECTION_NAME, "Count:", collection.count())


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Embeddings shape: 200 x 384
✅ Stored in ChromaDB: retail_chunks_arrow Count: 200


In [20]:
# Build a fast lookup index
index = {
    "version": "1.0",
    "source_csv": str(DATA_CSV),
    "collection": COLLECTION_NAME,
    "field_roles": {
        "categorical": categorical_cols,
        "numeric_meta": numeric_cols
    },
    "id_mapping": {},  # id -> minimal pointer info
    "stats": {
        "rows": len(id_list),
        "embedding_model": "all-MiniLM-L6-v2",
        "embedding_dim": int(len(embeddings[0])) if len(embeddings) else 0
    }
}

# Minimal mapping: id to a subset you might want to fetch quickly without reading Chroma
# We’ll include a primary key if available (InvoiceNo) and a couple of common fields.
for _id, meta in zip(id_list, metadatas):
    index["id_mapping"][_id] = {
        "InvoiceNo": meta.get("InvoiceNo"),
        "Description": meta.get("Description"),
        "Country": meta.get("Country")
    }

with open(JSON_INDEX_PATH, "w", encoding="utf-8") as f:
    json.dump(index, f, ensure_ascii=False, indent=2)

print(f"✅ Wrote JSON index → {JSON_INDEX_PATH}")


✅ Wrote JSON index → arrow_index.json


In [21]:
def search(query_text, k=5):
    q_emb = model.encode([query_text])
    res = collection.query(
        query_embeddings=q_emb,
        n_results=k,
        include=["documents", "metadatas", "distances"]  # no "ids"
    )
    return res

demo = search("cheap laptop", k=5)

for doc, meta, dist in zip(demo["documents"][0], demo["metadatas"][0], demo["distances"][0]):
    print(f"\n(distance: {dist:.4f})")
    print("Doc:", doc)
    print("Meta:", meta)

# Use JSON index to quickly map back by our own ID list
with open(JSON_INDEX_PATH, "r", encoding="utf-8") as f:
    jindex = json.load(f)

# Example: pick the first returned doc and lookup our JSON index
if demo["documents"][0]:
    first_doc = demo["documents"][0][0]
    # JSON index maps IDs → fields
    print("\n🔎 From JSON index:", jindex["id_mapping"].get("0"))



(distance: 1.2907)
Doc: InvoiceNo: INV1077, StockCode: STK462, Description: Laptop, InvoiceDate: 2023-03-19, Country: India
Meta: {'UnitPrice': 157.66, 'StockCode': 'STK462', 'Quantity': 5, 'InvoiceDate': '2023-03-19', 'Country': 'India', 'CustomerID': 11308, 'InvoiceNo': 'INV1077', 'Description': 'Laptop'}

(distance: 1.3074)
Doc: InvoiceNo: INV1062, StockCode: STK749, Description: Laptop, InvoiceDate: 2023-03-04, Country: India
Meta: {'CustomerID': 15547, 'StockCode': 'STK749', 'Description': 'Laptop', 'Quantity': 5, 'InvoiceDate': '2023-03-04', 'UnitPrice': 157.69, 'InvoiceNo': 'INV1062', 'Country': 'India'}

(distance: 1.3128)
Doc: InvoiceNo: INV1008, StockCode: STK147, Description: Laptop, InvoiceDate: 2023-01-09, Country: Germany
Meta: {'CustomerID': 14914, 'StockCode': 'STK147', 'InvoiceNo': 'INV1008', 'Description': 'Laptop', 'InvoiceDate': '2023-01-09', 'Quantity': 1, 'UnitPrice': 389.69, 'Country': 'Germany'}

(distance: 1.3145)
Doc: InvoiceNo: INV1196, StockCode: STK952, De

In [None]:
# Evaluate recall@k for a few keywords based on Description membership
def recall_at_k(query, keyword, k=10):
    # Convert PyArrow table → Pandas DataFrame
    df_all = table.to_pandas()

    # ground truth: rows whose Description contains keyword (case-insensitive)
    gt_ids = [str(i) for i, val in enumerate(df_all["Description"].astype(str)) if keyword.lower() in val.lower()]

    res = search(query, k=k)
    retrieved = set(res["ids"][0]) if "ids" in res else set()

    if not gt_ids:
        return 0.0, [], []

    overlap = retrieved & set(gt_ids)
    recall = len(overlap) / len(gt_ids)
    return recall, list(overlap), gt_ids


# Try it
for q, kw in [("laptop", "laptop"), ("cheap shoes", "shoes"), ("book", "book")]:
    r, hit, gt = recall_at_k(q, kw, k=10)
    print(f"\nQuery='{q}'  keyword='{kw}'  Recall@10={r:.2f}  Hits={hit[:10]}  GT_count={len(gt)}")




Query='laptop'  keyword='laptop'  Recall@10=0.36  Hits=['127', '33', '19', '62', '77', '116', '196', '166', '21', '190']  GT_count=28

Query='cheap shoes'  keyword='shoes'  Recall@10=0.30  Hits=['92', '55', '199', '78', '152', '173', '142', '90', '144', '167']  GT_count=33

Query='book'  keyword='book'  Recall@10=0.32  Hits=['82', '74', '66', '194', '102', '157', '27', '162', '26', '137']  GT_count=31
