In [2]:
import sqlite3
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import torch

# Set device for CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the Sentence Transformer model
model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
model = model.to(device)

# Function to retrieve the most similar record's reference embedding using input text
def get_similar_reference(input_text, db_name='rag_db.sqlite', faiss_index_path='faiss_index_quantized.index'):
    """
    Retrieve the reference embedding corresponding to the most similar input embedding from the database.
    
    Parameters:
    - input_text: The input text to search for in the database.
    - db_name: The name of the SQLite database to retrieve embeddings from.
    - faiss_index_path: Path to the saved FAISS index for fast similarity search.
    
    Returns:
    - reference_embedding: The reference embedding corresponding to the most similar input embedding.
    """
    # Step 1: Convert the input text to its embedding
    input_vector = model.encode([input_text], convert_to_tensor=True, device=device)
    input_vector = input_vector.cpu().numpy()  # Convert to numpy array
    
    # Step 2: Load the FAISS index (ensure it's the quantized index)
    faiss_index = faiss.read_index(faiss_index_path)
    
    # Ensure the index is trained
    if not faiss_index.is_trained:
        print("FAISS index is not trained yet. Training the index with input vectors...")
        faiss_index.train(input_vector)  # Train the index if not trained
        print("Index trained.")

    # Step 3: Search the FAISS index for the most similar input embedding
    D, I = faiss_index.search(input_vector, k=1)  # k=1 to get the top 1 most similar vector
    print(f"Distances: {D}, Indices: {I}")
    
    # Step 4: Retrieve the index of the most similar vector (convert numpy.int64 to int)
    most_similar_index = int(I[0][0])  # Convert to normal int type to avoid datatype mismatch
    
    # Step 5: Retrieve the corresponding reference embedding from the database
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()
    cursor.execute("SELECT reference_embedding FROM documents LIMIT 1 OFFSET ?", (most_similar_index,))
    result = cursor.fetchone()
    
    if result:
        # Convert the BLOB data (reference embedding) back to numpy array
        reference_embedding = np.frombuffer(result[0], dtype=np.float32)
        conn.close()
        return reference_embedding
    else:
        conn.close()
        return None  # Return None if no reference found

# Example usage
input_text = "wall.' 'O'Brien!' said Winston, making an effort to control his voice. 'You know this is not necessary. What is it that you want me to do?' O'Brien made no direct answer. When he spoke it was in the schoolmasterish manner that he sometimes affected. He looked thoughtfully into the distance, as though he were addressing an audience somewhere behind Winston's back. 'By itself,' he said, 'pain is not always enough. There are occasions when a human being will stand out against pain, even to the point of death. But for everyone there is something unendurable--something that cannot be contemplated. Courage and cowardice are not involved. If you are falling from a height it is not cowardly to clutch at a rope. If you have come up from deep water it is not cowardly to fill your lungs with air. It is merely an instinct which cannot be destroyed. It is the same with the rats. For you, they are unendurable. They are a form"
reference_embedding = get_similar_reference(input_text)

# Print the corresponding reference embedding
if reference_embedding is not None:
    print("Reference embedding:", reference_embedding)
else:
    print("No matching reference found.")


Distances: [[0.40222144]], Indices: [[928]]
Reference embedding: [-0.0121234   0.05992422 -0.01780817 ...  0.0265995   0.00911881
  0.00684574]
