In [1]:
import os
import sqlite3
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import torch
from tqdm import tqdm
import json

# Set device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the Sentence Transformer model
model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
model = model.to(device)

# Load JSON data and prepare pairs
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    # Extract input and reference text for non-infringement
    non_infringement_inputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    # Extract input and reference text for infringement
    infringement_inputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]

    # Create structured matching pairs
    non_infringement_pairs = list(zip(non_infringement_inputs, non_infringement_references))
    infringement_pairs = list(zip(infringement_inputs, infringement_references))

    # Combine all pairs into a single list
    all_pairs = non_infringement_pairs + infringement_pairs
    return all_pairs

# Example usage
all_pairs = load_data(
    '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.non_infringement.json',
    '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.infringement.json'
)

# Extract `input` texts and `references` for storage
input_texts = [pair[0] for pair in all_pairs]
references = [pair[1] for pair in all_pairs]

# Encode `input` texts in batches
def batch_encode_texts(model, texts, batch_size=8):
    all_vectors = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding inputs"):
        batch = texts[i:i + batch_size]
        batch_vectors = model.encode(batch, convert_to_tensor=True, device=device)
        all_vectors.append(batch_vectors.cpu().numpy())
    return np.vstack(all_vectors)

# Encode the `input` texts
input_vectors = batch_encode_texts(model, input_texts)

# Initialize FAISS index for `input` vectors
dimension = input_vectors.shape[1]
nlist = 3  # Example number of clusters
quantizer = faiss.IndexFlatL2(dimension)
gpu_index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)

# Train and add vectors to the FAISS index
print("Training the index on input vectors...")
gpu_index.train(input_vectors)
print("Index training completed.")
print("Adding input vectors to the index...")
gpu_index.add(input_vectors)
print("Input vectors added to the index.")

# 保存 FAISS 索引
faiss.write_index(gpu_index, 'faiss_index.index')
print("FAISS index has been saved.")

# Setting up SQLite database to store documents and embeddings
def setup_database():
    conn = sqlite3.connect('rag_db.sqlite')
    cursor = conn.cursor()
    cursor.execute(''' 
        CREATE TABLE IF NOT EXISTS documents (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            document_text TEXT NOT NULL,
            reference_text TEXT NOT NULL  -- Add reference_text column to store references
            -- embedding BLOB NOT NULL
        )
    ''')
    conn.commit()
    return conn

# Store documents and references in the database
def store_documents(conn, inputs, references, embeddings):
    cursor = conn.cursor()
    for inp, ref, emb in zip(inputs, references, embeddings):
        cursor.execute('INSERT INTO documents (document_text, reference_text) VALUES (?, ?)', 
                       (inp, ref))  # Store input, reference, and embedding
    conn.commit()

# Store documents and references in the database
conn = setup_database()
store_documents(conn, input_texts, references, input_vectors)

# Close the database connection
conn.close()



  from .autonotebook import tqdm as notebook_tqdm
Encoding inputs: 100%|██████████| 238/238 [00:09<00:00, 26.09it/s]


Training the index on input vectors...
Index training completed.
Adding input vectors to the index...
Input vectors added to the index.
FAISS index has been saved.
