In [2]:
import numpy as np
from sklearn.cluster import KMeans
import numpy as np
from nanopq import PQ
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from tqdm import tqdm



  from tqdm.autonotebook import tqdm, trange


In [3]:
def residual_quantization(doc_embeddings, num_levels, cluster_num):
    """
    Args:
        doc_embeddings: numpy array [num_docs, embedding_dim]
        num_levels: how many quantization levels (tokens per docID)
        cluster_num: number of clusters per level (codebook size)
    Returns:
        docid_to_rqcode: list of token lists, one per doc
    """
    residuals = doc_embeddings.copy()
    codebooks = []
    docid_to_rqcode = [[] for _ in range(len(doc_embeddings))]

    for level in range(num_levels):
        print(f"Training level {level+1}/{num_levels}...")

        # Step 1: Train codebook using KMeans on the current residuals
        kmeans = KMeans(n_clusters=cluster_num, random_state=42)
        kmeans.fit(residuals)
        codebook = kmeans.cluster_centers_
        codebooks.append(codebook)

        # Step 2: For each doc, assign nearest codeword and compute new residual
        new_residuals = []

        for doc_idx, doc_vec in enumerate(residuals):
            # Compute distances to all centroids in the codebook
            distances = np.linalg.norm(codebook - doc_vec, axis=1)
            nearest_codeword_idx = np.argmin(distances)
            nearest_codeword = codebook[nearest_codeword_idx]

            # Save token index for this level
            docid_to_rqcode[doc_idx].append(int(nearest_codeword_idx))

            # Compute residual for next level
            residual = doc_vec - nearest_codeword
            new_residuals.append(residual)

        residuals = np.vstack(new_residuals)

    return docid_to_rqcode, codebooks


In [15]:
# 1. Define text passages
passages = {
    "doc1": "Artificial intelligence is transforming industries.",
    "doc2": "Machine learning helps computers learn from data.",
    "doc3": "Quantum computing is a new paradigm.",
    "doc4": "The future of AI includes ethical challenges."
}



In [16]:
# 2. Load GTR-T5-Base sentence encoder
model = SentenceTransformer("sentence-transformers/gtr-t5-base")
docids = list(passages.keys())
texts = list(passages.values())
embeddings = model.encode(texts, convert_to_numpy=True)

In [17]:
embeddings.shape, embeddings.dtype, embeddings

((4, 768),
 dtype('float32'),
 array([[ 0.00273099, -0.0302944 ,  0.05884   , ...,  0.03470925,
         -0.01520736,  0.03500362],
        [-0.02251787, -0.05412124,  0.05713512, ...,  0.00543885,
          0.00965429, -0.01286693],
        [ 0.01178704, -0.01564313,  0.01263544, ..., -0.0136919 ,
         -0.00466294,  0.01595624],
        [ 0.01346409, -0.00766769,  0.01716973, ..., -0.00420209,
         -0.02281265,  0.05193746]], dtype=float32))

In [20]:
num_levels = 3  # Let's keep it small for this example
cluster_num = 2 # Tiny number since we only have 4 documents  # 256 Typical size for codebooks

rq_codes, codebooks = residual_quantization(embeddings, num_levels, cluster_num)

# Show Semantic IDs for each doc
for docid, rq_code in zip(docids, rq_codes):
    print(f"{docid}: Semantic ID = {rq_code}")


Training level 1/3...
Training level 2/3...
Training level 3/3...
doc1: Semantic ID = [1, 0, 0]
doc2: Semantic ID = [0, 0, 0]
doc3: Semantic ID = [1, 0, 1]
doc4: Semantic ID = [1, 1, 0]


In [21]:
mock_vocab_size = 32128

print("Quantized RQ DocIDs:")
encoded_ids = []

for idx, code in enumerate(rq_codes):
    # Offset each token by (level * 256) to avoid collisions between levels
    new_doc_code = [int(x) + i * 256 for i, x in enumerate(code)]
    # Shift into the vocabulary space (avoid overlap with normal tokens)
    encoded = ','.join(str(x + mock_vocab_size) for x in new_doc_code)
    encoded_ids.append(encoded)
    print(f"{docids[idx]}\t{encoded}")


Quantized RQ DocIDs:
doc1	32129,32384,32640
doc2	32128,32384,32640
doc3	32129,32384,32641
doc4	32129,32385,32640


In [22]:
reconstructed_embeddings = []

for code in rq_codes:
    # Start with zero vector
    recon = np.zeros_like(embeddings[0])
    for level, token in enumerate(code):
        recon += codebooks[level][token]  # Sum codewords from each level
    reconstructed_embeddings.append(recon)

reconstructed_embeddings = np.vstack(reconstructed_embeddings)

# Normalize
from sklearn.preprocessing import normalize
embeddings_norm = normalize(embeddings, axis=1)
reconstructed_embeddings_norm = normalize(reconstructed_embeddings, axis=1)

# Compute cosine similarities
from sklearn.metrics.pairwise import cosine_similarity
similarities = cosine_similarity(reconstructed_embeddings_norm, embeddings_norm)

# Print results
print("\nDecoded IDs and their closest texts:")
for i, sim in enumerate(similarities):
    closest_idx = np.argmax(sim)
    print(f"Decoded text for doc {i + 1}: {texts[closest_idx]}")
    print(f"Similarity scores for doc {i + 1}:")
    for j, score in enumerate(sim):
        print(f"  Text {j + 1}: {texts[j]} (Similarity: {score:.4f})")



Decoded IDs and their closest texts:
Decoded text for doc 1: Artificial intelligence is transforming industries.
Similarity scores for doc 1:
  Text 1: Artificial intelligence is transforming industries. (Similarity: 0.9545)
  Text 2: Machine learning helps computers learn from data. (Similarity: 0.6810)
  Text 3: Quantum computing is a new paradigm. (Similarity: 0.7936)
  Text 4: The future of AI includes ethical challenges. (Similarity: 0.7724)
Decoded text for doc 2: Machine learning helps computers learn from data.
Similarity scores for doc 2:
  Text 1: Artificial intelligence is transforming industries. (Similarity: 0.7010)
  Text 2: Machine learning helps computers learn from data. (Similarity: 0.9828)
  Text 3: Quantum computing is a new paradigm. (Similarity: 0.5172)
  Text 4: The future of AI includes ethical challenges. (Similarity: 0.4770)
Decoded text for doc 3: Quantum computing is a new paradigm.
Similarity scores for doc 3:
  Text 1: Artificial intelligence is transform