In [1]:
"""
end_to_end_weaviate_bge_m3.py

- Reads 3 JSON files: schemes (list), acts (dict or list).
- Chunks schemes and acts differently.
- Generates embeddings using BAAI/bge-m3 (Hugging Face).
- Inserts into Weaviate with metadata (vector store + BM25).
- Query example: dense retrieval (top_k) + rerank with bge-reranker-v2-m3.
- Uses tqdm for progress feedback.
"""

'\nend_to_end_weaviate_bge_m3.py\n\n- Reads 3 JSON files: schemes (list), acts (dict or list).\n- Chunks schemes and acts differently.\n- Generates embeddings using BAAI/bge-m3 (Hugging Face).\n- Inserts into Weaviate with metadata (vector store + BM25).\n- Query example: dense retrieval (top_k) + rerank with bge-reranker-v2-m3.\n- Uses tqdm for progress feedback.\n'

In [2]:
import json, math, time, os, re
from pathlib import Path
from typing import List, Dict
from tqdm.auto import tqdm
import numpy as np

# HuggingFace transformers / torch
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

# Weaviate client
import weaviate


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Enforce minimum torch version to avoid CVE-2025-32434 when calling torch.load
def _torch_at_least(major_required: int = 2, minor_required: int = 6) -> bool:
    version_match = re.match(r"(\d+)\.(\d+)", torch.__version__)
    if not version_match:
        return False
    major_num, minor_num = map(int, version_match.groups())
    return (major_num, minor_num) >= (major_required, minor_required)

if not _torch_at_least():
    raise RuntimeError(
        "Detected torch %s. Due to CVE-2025-32434, torch >= 2.6 is required "
        "for safe torch.load calls. Upgrade torch (e.g. `pip install --upgrade \"torch>=2.6.0\"`) "
        "or switch to safetensors weights." % torch.__version__
    )

In [5]:
from preprocessing import (
    build_docs_from_schemes_enhanced,
    build_docs_from_acts_enhanced
)


In [None]:
# --- Config ---
SCHEMES_FILES = ["schemes_agriculture.json", "schemes_finance.json", "schemes_education.json"]  # adapt
ACTS_FILES = ["acts_finance.json", "acts_agriculture.json"]  # adapt

# WEAVIATE_URL = "http://localhost:8080"   # update if using cloud
# WEAVIATE_API_KEY = None  # if needed

BGE_EMBED_MODEL = "BAAI/bge-m3"
BGE_RERANKER = "BAAI/bge-reranker-v2-m3"

# Embedding & batching
BATCH_SIZE = 32
EMBED_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RERANK_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Chunking parameters
SCHEME_MAX_CHARS = 1200
SCHEME_OVERLAP = 350
ACT_SECTION_THRESHOLD = 3000
ACT_MAX_CHARS = 1800
ACT_OVERLAP = 350

# Weaviate class name
WEAVIATE_CLASS = "GovDoc" # Write your cluster class name here


In [11]:
# ---------- STEP 2: Load BGE M3 model + tokenizer (embedding) ----------
def load_embedding_model(model_name=BGE_EMBED_MODEL, device=EMBED_DEVICE):
    print(f"Loading embedding model {model_name} -> {device}")
    t0 = time.time()
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
    model.eval()
    t = time.time()-t0
    print(f"Loaded embedding model in {t:.1f}s")
    return tokenizer, model, t

# Pooling function (mean pooling)
def mean_pooling(last_hidden, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
    sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def embed_texts(tokenizer, model, texts:List[str], batch_size=BATCH_SIZE, device=EMBED_DEVICE):
    embeddings = []
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Embedding batches"):
            batch = texts[i:i+batch_size]
            enc = tokenizer(batch, truncation=True, padding=True, return_tensors="pt")
            input_ids = enc['input_ids'].to(device)
            attention_mask = enc['attention_mask'].to(device)
            out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            # pooling strategy: mean pooling over token embeddings
            last_hidden = out.last_hidden_state
            pooled = mean_pooling(last_hidden, attention_mask)  # (B, D)
            pooled = pooled.cpu().numpy()
            embeddings.append(pooled)
    return np.vstack(embeddings)

In [12]:
# ---------- STEP 3: Create Weaviate schema and upsert ----------
def create_weaviate_client(url=WEAVIATE_URL, api_key=WEAVIATE_API_KEY):
    if api_key:
        client = weaviate.Client(url=url, additional_headers={"X-API-KEY": api_key})
    else:
        client = weaviate.Client(url=url)
    return client

def ensure_schema(client: weaviate.Client, class_name=WEAVIATE_CLASS, vector_dim=None):
    # Create class with BM25 & vectorizer none (we'll pass vectors explicitly)
    schema = {
        "class": class_name,
        "vectorizer": "none",  # we send vectors explicitly
        "properties": [
            {"name":"text", "dataType":["text"]},
            {"name":"doc_type", "dataType":["string"]},
            {"name":"source_file","dataType":["string"]},
            {"name":"doc_id","dataType":["string"]},
            {"name":"chunk_id","dataType":["string"]},
            {"name":"metadata","dataType":["text"]}
        ]
    }
    existing = client.schema.get()
    classes = [c['class'] for c in existing.get('classes', [])]
    if class_name in classes:
        print("Weaviate class exists - deleting and recreating (update as needed)")
        client.schema.delete_class(class_name)
    client.schema.create_class(schema)
    print("Weaviate schema created.")

def upsert_weaviate(client, class_name, docs, embeddings, batch_size=64):
    # docs: list of docs with .text and .metadata
    assert len(docs) == embeddings.shape[0]
    with client.batch as batch:
        batch.batch_size = batch_size
        for i in tqdm(range(len(docs)), desc="Weaviate upsert"):
            d = docs[i]
            ev = embeddings[i].tolist()
            meta_json = json.dumps(d["metadata"], ensure_ascii=False)
            obj = {
                "text": d["text"],
                "doc_type": d["doc_type"],
                "source_file": d["source_file"],
                "doc_id": d["doc_id"],
                "chunk_id": d["chunk_id"],
                "metadata": meta_json
            }
            batch.add_data_object(obj, class_name, vector=ev)

In [13]:
# ---------- STEP 4: Retrieval & rerank ----------
def weaviate_hybrid_search(client, class_name, query, top_k=50, alpha=0.5):
    """
    alpha: 0-> pure BM25, 1-> pure vector. Weaviate uses 'hybrid' operator with alpha weighting.
    """
    near_text = {
        "concepts": [query]
    }
    # python client query
    res = client.query.get(class_name, ["text", "doc_id", "chunk_id", "metadata"]).with_hybrid(query, alpha).with_limit(top_k).do()
    # parse results
    objs = []
    try:
        hits = res['data']['Get'][class_name]
        for h in hits:
            vec_score = h.get('_additional', {}).get('score')  # vector score or fused score
            objs.append({
                "text": h['text'],
                "doc_id": h['doc_id'],
                "chunk_id": h['chunk_id'],
                "metadata": json.loads(h['metadata']),
                "score": vec_score
            })
    except Exception as e:
        print("Query returned unexpected shape:", e)
        return []
    return objs

def load_reranker(model_name=BGE_RERANKER, device=RERANK_DEVICE):
    print(f"Loading reranker {model_name} -> {device}")
    t0 = time.time()
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    model.to(device)
    model.eval()
    t = time.time() - t0
    print(f"Loaded reranker in {t:.1f}s")
    return tok, model, t

def rerank_with_model(tokenizer, model, query, candidates, device=RERANK_DEVICE, batch_size=32):
    """
    candidates: list[str] texts
    returns scores aligned with candidates
    """
    scores = []
    with torch.no_grad():
        for i in range(0, len(candidates), batch_size):
            batch = candidates[i:i+batch_size]
            enc = tokenizer([query]*len(batch), batch, truncation=True, padding=True, return_tensors="pt")
            enc = {k:v.to(device) for k,v in enc.items()}
            out = model(**enc)
            logits = out.logits.squeeze(-1).cpu().numpy()  # shape (B,)
            # If logits are unbounded, optionally pass through sigmoid to get 0-1 score
            scores.extend(logits.tolist())
    return scores

In [None]:
print("Building docs from schemes and acts...")
schemes_docs = build_docs_from_schemes_enhanced(SCHEMES_FILES)
acts_docs = build_docs_from_acts_enhanced(ACTS_FILES)
docs = schemes_docs + acts_docs
print(f"Built {len(docs)} chunks (schemes: {len(schemes_docs)}, acts: {len(acts_docs)})")

In [None]:
# 2. load embedding model
tok, emb_model, emb_load_time = load_embedding_model(BGE_EMBED_MODEL, device=EMBED_DEVICE)

# 3. embed texts
texts = [d['text'] for d in docs]
print("Generating embeddings (this may take a while)...")
t0 = time.time()
embeddings = embed_texts(tok, emb_model, texts, batch_size=BATCH_SIZE, device=EMBED_DEVICE)
t_emb = time.time() - t0
print(f"Embedding completed in {t_emb:.1f}s for {len(texts)} chunks -> {embeddings.shape}")

In [18]:
import weaviate
from weaviate.connect import ConnectionParams
from weaviate import AuthApiKey



In [None]:
# 4-5. Connect to Weaviate v4 and upsert precomputed embeddings

import os
import json
import uuid
import numpy as np
from tqdm.auto import tqdm

import weaviate
from weaviate.classes.init import Auth
from weaviate.classes.config import Property, DataType, Configure

# Ensure these variables are set
WEAVIATE_API_KEY = "keep your api key"
WEAVIATE_REST_URL = "keep your rest url"
WEAVIATE_GRPC_URL = "keep your grpc url"
WEAVIATE_URL = "keep your weaviate url"
WEAVIATE_COLLECTION = "GovDocs"  # your class / collection name
VECTOR_DIM = 1024   # BGE-m3 output




# ---------------------------------------------------
# 1. CONNECT TO WEAVIATE CLOUD
# ---------------------------------------------------
client = weaviate.connect_to_weaviate_cloud(
    cluster_url=WEAVIATE_URL,
    auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)

print("Connected:", client.is_ready())


# ---------------------------------------------------
# 2. CREATE COLLECTION IF NOT EXISTS
# ---------------------------------------------------
existing = client.collections.list_all()

if WEAVIATE_COLLECTION not in [c.name for c in existing]:
    print(f"Creating collection '{WEAVIATE_COLLECTION}' ...")

    client.collections.create(
        name=WEAVIATE_COLLECTION,
        vectorizer_config=Configure.Vectorizer.none(),  # we supply vectors
        properties=[
            Property(name="text", data_type=DataType.TEXT),
            Property(name="preview", data_type=DataType.TEXT),
            Property(name="doc_type", data_type=DataType.TEXT),
            Property(name="doc_id", data_type=DataType.TEXT),
            Property(name="chunk_id", data_type=DataType.TEXT),
            Property(name="metadata_json", data_type=DataType.TEXT),
        ]
    )
else:
    print(f"Collection '{WEAVIATE_COLLECTION}' already exists.")


collection = client.collections.get(WEAVIATE_COLLECTION)


# ---------------------------------------------------
# 3. PREP EMBEDDINGS
# ---------------------------------------------------
embeddings = np.asarray(embeddings).astype(float)
assert len(docs) == embeddings.shape[0], "Mismatch between docs and embedding count"


# ---------------------------------------------------
# 4. UPSERT (INSERT) IN BATCHES WITH TQDM
# ---------------------------------------------------
BATCH_SIZE = 64
print(f"Uploading {len(docs)} items to Weaviate…")

with collection.batch.fixed_size(BATCH_SIZE) as batch:
    for i in tqdm(range(len(docs)), desc="Weaviate insert"):
        d = docs[i]
        vec = embeddings[i].tolist()

        properties = {
            "text": d["text"],
            "preview": d.get("preview", ""),
            "doc_type": d.get("doc_type", ""),
            "doc_id": d.get("doc_id", ""),
            "chunk_id": d.get("chunk_id", ""),
            "metadata_json": json.dumps(d.get("metadata", {}), ensure_ascii=False),
        }

        batch.add_object(
            properties=properties,
            uuid=str(uuid.uuid4()),
            vector=vec,
        )

        if batch.number_errors > 0:
            print("Errors:", batch.get_errors())
            raise RuntimeError("Batch insert failed")

print("✅ Upload complete!")

Connected: True
Creating collection 'GovDocs' ...


            Use the `vector_config` argument instead.
            


Uploading 13321 items to Weaviate…


Weaviate insert: 100%|██████████| 13321/13321 [00:21<00:00, 609.38it/s]



✅ Upload complete!


In [27]:
# 6. load reranker
rer_tok, rer_model, rer_load_time = load_reranker(BGE_RERANKER, device=RERANK_DEVICE)

Loading reranker BAAI/bge-reranker-v2-m3 -> cuda
Loaded reranker in 24.8s
Loaded reranker in 24.8s


In [None]:
def retrieve_hybrid_v4(client, collection_name, query, query_embedding, top_k=50, alpha=0.3):
    collection = client.collections.get(collection_name)

    result = collection.query.hybrid(
        query=query,
        vector=query_embedding,
        alpha=alpha,
        limit=top_k,
        return_properties=["text", "doc_id", "chunk_id", "preview", "metadata_json", "doc_type" ],
        include_vector=False
    )

    docs = []
    for obj in result.objects:
        score = obj.metadata.score
        if score is None:
            score = 0.0

        docs.append({
            "text": obj.properties.get("text", ""),
            "doc_id": obj.properties.get("doc_id", ""),
            "chunk_id": obj.properties.get("chunk_id", ""),
            "preview": obj.properties.get("preview", ""),
            "doc_type": obj.properties.get("doc_type", ""),
            "metadata": json.loads(obj.properties.get("metadata_json", "{}")),
            "hybrid_score": float(score)
        })
    return docs


In [56]:
def rerank_candidates(query, candidates, rer_tok, rer_model, device="cuda", batch_size=16):
    """
    Given retrieved candidates, reranks them using BGE reranker.
    Adds 'rerank_score' to each candidate.
    """
    texts = [c["text"] for c in candidates]

    # your existing function
    scores = rerank_with_model(
        rer_tok,
        rer_model,
        query,
        texts,
        device=device,
        batch_size=batch_size
    )

    for i in range(len(candidates)):
        candidates[i]["rerank_score"] = float(scores[i])

    return candidates


In [62]:
def retrieve_and_rerank(query, client, collection_name, embed_tok, embed_model, rer_tok, rer_model, device="cuda"):
    # 1) Embed query
    q_emb = embed_texts(embed_tok, embed_model, [query], batch_size=1, device=device)[0]
    q_emb = q_emb.astype(np.float32)

    # 2) Hybrid retrieval
    retrieved = retrieve_hybrid_v4(
        client,
        collection_name,
        query,
        q_emb,
        top_k=50,
        alpha=0.5
    )
    print(f"Hybrid retrieved: {len(retrieved)} docs")

    # 3) Rerank with BGE
    reranked = rerank_candidates(query, retrieved, rer_tok, rer_model, device=device)

    # 4) Sort by rerank score (descending)
    reranked_sorted = sorted(reranked, key=lambda x: x["rerank_score"], reverse=True)

    return reranked_sorted


In [63]:
query = "How does the IGST Act treat input tax credit on goods brought into one State but used in another?"

results = retrieve_and_rerank(
    query=query,
    client=client,
    collection_name=WEAVIATE_COLLECTION,
    embed_tok=tok,
    embed_model=emb_model,
    rer_tok=rer_tok,
    rer_model=rer_model,
    device=EMBED_DEVICE
)

print("\nTop 10 results:\n")

for r in results[:10]:
    hybrid = r.get("hybrid_score", 0.0) or 0.0
    rerank = r.get("rerank_score", 0.0) or 0.0

    print("Hybrid:", round(hybrid, 4),
          "| Rerank:", round(rerank, 4),
          "| Doc:", r["doc_id"], "Chunk:", r["chunk_id"])
    print(r["text"][:300].replace("\n", " "), "...\n---\n")



Embedding batches: 100%|██████████| 1/1 [00:00<00:00, 61.05it/s]



Hybrid retrieved: 50 docs

Top 10 results:

Hybrid: 0.0 | Rerank: 0.8097 | Doc: act_3 Chunk: Definitions_10
l tax, State tax, integrated tax or Union territory tax charged on any supply of goods or services or both made to him and includes—\n(a) the integrated goods and services tax charged on import of goods;\n(b) the tax payable under the provisions of sub-sections (3) and (4) of section 9;\n(c) the tax ...
---

Hybrid: 0.0 | Rerank: 0.5155 | Doc: act_3 Chunk: CHAPTER X - PAYMENT OF TAX_3
ee or any other amount available in the electronic cash ledger under this Act, to the electronic cash ledger for,–—\n(a) integrated tax, central tax, State tax, Union territory tax or cess; or\n(b) integrated tax or central tax of a distinct person as specified in sub-section (4) or, as the case may ...
---

Hybrid: 0.0 | Rerank: 0.382 | Doc: act_5 Chunk: CHAPTER IV - PAYMENT OF TAX_1
nner of utilisation of the input tax credit** on account of integrated tax, Central tax, State tax or Union territor

In [64]:
def inspect_results(results, top_k=35):
    print("\n===== UNIQUE DOCS RETRIEVED =====")
    unique = {}
    for r in results:
        uid = f"{r['doc_id']}::{r['chunk_id']}"
        if uid not in unique:
            unique[uid] = r

    print(f"Total retrieved: {len(results)}")
    print(f"Unique chunks: {len(unique)}\n")

    print("===== TOP RESULTS (sorted by rerank_score) =====")
    sorted_items = sorted(unique.values(), key=lambda x: x['rerank_score'], reverse=True)

    for r in sorted_items[:top_k]:
        print("\n---")
        print("Doc:", r["doc_id"], "| Chunk:", r["chunk_id"])
        print("Hybrid Score:", r["hybrid_score"])
        print("Rerank Score:", r["rerank_score"])
        print("Preview:", r["text"])
        print("---")
inspect_results(results, top_k=20)



===== UNIQUE DOCS RETRIEVED =====
Total retrieved: 50
Unique chunks: 50

===== TOP RESULTS (sorted by rerank_score) =====

---
Doc: act_3 | Chunk: Definitions_10
Hybrid Score: 0.0
Rerank Score: 0.8096749782562256
Preview: l tax, State tax, integrated tax or Union territory tax charged on any supply of goods or services or both made to him and includes—\n(a) the integrated goods and services tax charged on import of goods;\n(b) the tax payable under the provisions of sub-sections (3) and (4) of section 9;\n(c) the tax payable under the provisions of sub-sections (3) and (4) of section 5 of the Integrated Goods and Services Tax Act;\n(d) the tax payable under the provisions of sub-sections (3) and (4) of section 9 of the respective State Goods and Services Tax Act; or\n(e) the tax payable under the provisions of sub-sections (3) and (4) of section 7 of the Union Territory Goods and Services Tax Act,\nbut does not include the tax paid under the composition levy;', 'input tax credit': '(6