In [1]:
import torch
import torch.nn as nn
from transformers import (
    pipeline,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoConfig,
)
from IPython.display import Markdown
from huggingface_hub import notebook_login
from datasets import load_dataset
from langchain.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
import chromadb
import numpy as np

In [2]:
# Load PTB dataset (Penn Treebank)
dataset = load_dataset("ptb_text_only")

documents = []
for item in dataset["train"]:
    # For the "ptb_text_only" config, the text is typically in item["sentence"]
    text = item["sentence"]
    documents.append(text)

# Use the recursive character splitter
recur_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=60, separators=["\n\n", "\n", "\.", " ", ""]
)

# Perform the splits using the splitter
data_splits = recur_splitter.split_text("".join(documents))
print("Number of splits:", len(data_splits))

  chunk_size=1000, chunk_overlap=60, separators=["\n\n", "\n", "\.", " ", ""]


Number of splits: 5293


In [3]:
### Using embeddings by MPNET: https://huggingface.co/sentence-transformers/all-mpnet-base-v2
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
encode_kwargs = {"normalize_embeddings": False}
hf_embeddings = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs,
)

  hf_embeddings = HuggingFaceEmbeddings(


In [4]:
persist_directory = "./vector_store/"
original_client = chromadb.PersistentClient(path=persist_directory)
org_collection = original_client.get_collection(name="langchain")

org_data = original_data = org_collection.get(include=["embeddings", "documents"])

In [5]:
org_emb = torch.Tensor(original_data["embeddings"])
print(org_emb.shape)

torch.Size([5293, 768])


In [7]:
def topk_cosine_sim(x: torch.Tensor, k: int):
    x = torch.nn.functional.normalize(x, p=2, dim=1)  # normalize to unit vectors
    sim_matrix = x @ x.T  # cosine similarity
    topk_sim, topk_idx = torch.topk(
        sim_matrix, k=k, dim=1
    )  # +1 because self-similarity is 1
    return topk_sim[:, :], topk_idx[:, :]  # remove self-match

_, brute_force_idx = topk_cosine_sim(org_emb, k=10)
print(brute_force_idx[0])

tensor([   0,    3,    2,    1, 5082,   59, 4967, 3340,  496,  387])


In [6]:
def uniform_quantization(tensor: torch.Tensor, clip_val: torch.Tensor, bit):
    scale = (2 ** (bit - 1)) - 1
    tensor_q: torch.Tensor = tensor.clamp(-clip_val, clip_val) / clip_val * scale
    tensor_q = (tensor_q.round() - tensor_q).detach() + tensor_q  # STE 적용
    tensor_q_int = tensor_q.to(torch.int8)
    # print(tensor_q_int)
    msb_2_bits = tensor_q_int & 0xC0
    mid_2_bits = tensor_q_int & 0x30
    mid2_2_bits = tensor_q_int & 0x0C
    lsb_4_bits = tensor_q_int & 0x03
    # print(msb_2_bits, mid_2_bits, mid2_2_bits, lsb_4_bits)
    msb_2_bits_scaled = msb_2_bits / scale * clip_val
    mid_2_bits_scaled = mid_2_bits / scale * clip_val
    mid2_2_bits_scaled = mid2_2_bits / scale * clip_val
    lsb_4_bits_scaled = lsb_4_bits / scale * clip_val
    # print(msb_2_bits_scaled, mid_2_bits_scaled, mid2_2_bits_scaled, lsb_4_bits_scaled)
    # return msb_2_bits_scaled, mid_2_bits_scaled, mid2_2_bits_scaled, lsb_4_bits_scaled
    return tensor_q_int / scale * clip_val

In [8]:
def noise_inject_tensor(weight_tensor: torch.Tensor, std: torch.Tensor, typ: bool):
    device = weight_tensor.device
    std = std.to(device)
    if typ:
        std_reshaped = std.view(-1, 1) if std.dim() == 1 else std
        adjusted_noise = 1.0 + std_reshaped * torch.randn_like(weight_tensor)
    else:
        adjusted_noise = 1.0 + std * torch.randn_like(weight_tensor)
    return torch.mul(weight_tensor, adjusted_noise).to(device)

In [9]:
cliff_val = org_emb.std() * 3
print(cliff_val)
org_emb_q_4bit = uniform_quantization(org_emb, cliff_val, 4)

q_collection = original_client.get_or_create_collection(
    name=f"4_bit_q_cliff_{cliff_val}",
    # embedding_function=hf_embeddings,
    metadata={"hnsw:space": "cosine"},
)
quantized_data = q_collection.get(include=["embeddings"])
print(len(quantized_data["embeddings"]))

tensor(0.1083)
0


In [11]:
q_collection.add(
    ids=[str(i) for i in range(len(org_emb_q_4bit))],
    embeddings=org_emb_q_4bit.numpy(),
    documents=org_data["documents"],
),

(None,)

In [12]:
org_data["ids"]
id_idx_map = {id: idx for idx, id in enumerate(org_data["ids"])}

In [13]:
q_match = 0
match = 0
k = 10

for idx, query in enumerate(org_emb):
    q_query_result = q_collection.query(
        query_embeddings=query.numpy(),
        n_results=k,
        include=["documents", "embeddings"],
    )
    query_result = org_collection.query(
        query_embeddings=query.numpy(),
        n_results=k,
        include=["documents", "embeddings"],
    )
    query_idx = set(map(lambda x: id_idx_map[x], query_result["ids"][0]))
    q_query_idx = set(map(lambda x: int(x), q_query_result["ids"][0]))
    bf = set(map(lambda x: int(x), brute_force_idx[idx]))
    # print(query_idx, q_query_idx, bf)
    match += len(query_idx.intersection(bf))
    q_match += len(q_query_idx.intersection(bf))
    # break
print("recall:", match / (k * len(org_emb)))
print("q_recall:", q_match / (k * len(org_emb)))

recall: 0.9990553561307387
q_recall: 0.954146986586057
