In [2]:
import os
import sqlite3
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import torch
from tqdm import tqdm
import json
import zlib

# Set device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the Sentence Transformer model
model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
model = model.to(device)

# Load JSON data and prepare pairs
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    # Extract input and reference text for non-infringement
    non_infringement_inputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    # Extract input and reference text for infringement
    infringement_inputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]

    # Create structured matching pairs
    non_infringement_pairs = list(zip(non_infringement_inputs, non_infringement_references))
    infringement_pairs = list(zip(infringement_inputs, infringement_references))

    # Combine all pairs into a single list
    all_pairs = non_infringement_pairs + infringement_pairs
    return all_pairs

# Example usage
all_pairs = load_data(
    ' test_division/extra_30.non_infringement.json',
    ' test_division/extra_30.infringement.json'
)

# Extract `input` texts and `references` for storage
input_texts = [pair[0] for pair in all_pairs]
references = [pair[1] for pair in all_pairs]

# Encode `input` texts and `reference` texts in batches
def batch_encode_texts(model, texts, batch_size=8):
    all_vectors = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"):
        batch = texts[i:i + batch_size]
        batch_vectors = model.encode(batch, convert_to_tensor=True, device=device)
        all_vectors.append(batch_vectors.cpu().numpy())
    return np.vstack(all_vectors)

# Encode the `input` texts and `reference` texts
input_vectors = batch_encode_texts(model, input_texts)
reference_vectors = batch_encode_texts(model, references)

# Quantize embeddings to int8 (reduce memory usage)
def quantize_embeddings(embeddings, dtype=np.int8):
    min_val = embeddings.min()
    max_val = embeddings.max()
    scale = (max_val - min_val) / (np.iinfo(dtype).max - np.iinfo(dtype).min)
    quantized_embeddings = np.round((embeddings - min_val) / scale).astype(dtype)
    return quantized_embeddings, scale, min_val

# Compress the embeddings using zlib
def compress_embedding(embedding):
    # Convert to bytes and compress
    compressed = zlib.compress(embedding.tobytes())
    return compressed

# Quantize and compress embeddings
input_vectors_quantized, scale_input, min_input = quantize_embeddings(input_vectors, dtype=np.int8)
reference_vectors_quantized, scale_reference, min_reference = quantize_embeddings(reference_vectors, dtype=np.int8)

# Setting up FAISS Index
dimension = input_vectors_quantized.shape[1]
nlist = 3  # Number of clusters (for coarse quantization)
m = 8  # Number of sub-vector centroids (for PQ)
nbits = 8  # Bits per sub-vector

quantizer = faiss.IndexFlatL2(dimension)  # Use flat index for quantization
gpu_index = faiss.IndexIVFPQ(quantizer, dimension, nlist, m, nbits)

# Train the quantized index on input vectors
print("Training the quantized FAISS index on input vectors...")
gpu_index.train(input_vectors_quantized)
print("FAISS index training completed.")

# Add input vectors to the quantized FAISS index
print("Adding input vectors to the quantized FAISS index...")
gpu_index.add(input_vectors_quantized)
print("Input vectors added to the FAISS index.")

# Save the quantized FAISS index
faiss.write_index(gpu_index, 'faiss_index_quantized.index')
print("Quantized FAISS index has been saved.")

# Setting up SQLite database to store vectors (input and reference embeddings)
def setup_database():
    conn = sqlite3.connect('rag_db.sqlite')
    cursor = conn.cursor()
    cursor.execute(''' 
        CREATE TABLE IF NOT EXISTS documents (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            input_embedding BLOB NOT NULL,
            reference_embedding BLOB NOT NULL,
            input_scale FLOAT NOT NULL,
            input_min FLOAT NOT NULL,
            reference_scale FLOAT NOT NULL,
            reference_min FLOAT NOT NULL
        )
    ''')
    conn.commit()
    return conn

# Store embeddings in the database
def store_embeddings(conn, input_embeddings, reference_embeddings, scale_input, min_input, scale_reference, min_reference):
    cursor = conn.cursor()
    for input_emb, ref_emb in zip(input_embeddings, reference_embeddings):
        # Compress the embeddings
        compressed_input = compress_embedding(input_emb)
        compressed_ref = compress_embedding(ref_emb)

        # Insert compressed embeddings and quantization metadata
        cursor.execute('INSERT INTO documents (input_embedding, reference_embedding, input_scale, input_min, reference_scale, reference_min) VALUES (?, ?, ?, ?, ?, ?)', 
                       (compressed_input, compressed_ref, scale_input, min_input, scale_reference, min_reference))  # Store compressed embeddings and metadata
    conn.commit()

# Store embeddings in the database
conn = setup_database()
store_embeddings(conn, input_vectors_quantized, reference_vectors_quantized, scale_input, min_input, scale_reference, min_reference)

# Close the database connection
conn.close()


Encoding texts: 100%|██████████| 238/238 [00:08<00:00, 26.66it/s]
Encoding texts: 100%|██████████| 238/238 [00:09<00:00, 25.21it/s]


Training the quantized FAISS index on input vectors...




FAISS index training completed.
Adding input vectors to the quantized FAISS index...
Input vectors added to the FAISS index.
Quantized FAISS index has been saved.
