In [14]:
import os
import sqlite3
import json
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import torch
from tqdm import tqdm

# Set device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the Sentence Transformer model
model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
model = model.to(device)

# Function to load data from JSON filesimport json
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    # Extract input and reference text for non-infringement
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    # Extract input and reference text for infringement
    infringement_references = [entry['reference'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return (non_infringement_references, y_non_infringement, infringement_references, y_infringement)

# Example usage
non_infringement_references, y_non_infringement, \
infringement_references, y_infringement = load_data('/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.non_infringement.json', '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.infringement.json')


# Encode references in batches
def batch_encode_references(model, references, batch_size=8):
    all_vectors = []
    for i in tqdm(range(0, len(references), batch_size), desc="Encoding references"):
        batch = references[i:i + batch_size]
        batch_vectors = model.encode(batch, convert_to_tensor=True, device=device)
        all_vectors.append(batch_vectors.cpu().numpy())
    return np.vstack(all_vectors)

# Encode the complete references
reference_vectors = batch_encode_references(model, references)

# Initialize FAISS index
dimension = reference_vectors.shape[1]
nlist = 3
quantizer = faiss.IndexFlatL2(dimension)
gpu_index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)

# Train and add vectors to the FAISS index
print("Training the index...")
gpu_index.train(reference_vectors)
print("Index training completed.")
print("Adding vectors to the index...")
gpu_index.add(reference_vectors)
print("Vectors added to the index.")

# Setting up SQLite database to store text and embeddings
def setup_database():
    conn = sqlite3.connect('reference_db.sqlite')
    cursor = conn.cursor()
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS reference_data (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            text TEXT NOT NULL,
            embedding BLOB NOT NULL
        )
    ''')
    conn.commit()
    return conn

# Function to store references and their embeddings directly in the database
def store_references(conn, references, embeddings):
    cursor = conn.cursor()
    for ref, emb in zip(references, embeddings):
        # Store each embedding as binary
        cursor.execute('INSERT INTO reference_data (text, embedding) VALUES (?, ?)', 
                       (ref, emb.tobytes()))
    conn.commit()

# Store references in the database
conn = setup_database()
store_references(conn, references, reference_vectors)

# Function to find the most relevant reference based on input text
def search_next_sentence(input_text, top_k=1):
    print(f"Searching for next sentence for input: '{input_text}'...")
    input_vector = model.encode([input_text], convert_to_tensor=True, device=device).cpu().numpy()
    _, indices = gpu_index.search(input_vector, top_k)
    return [references[i] for i in indices[0]]

# Example input sentence
input_sentence = "foes right and left. Ser Rodrik hammered at the big man in the shadowskin cloak, their horses dancing round each other as they traded blow for blow."
next_sentence = search_next_sentence(input_sentence)
print("Recommended next sentence:", next_sentence)

# Close the database connection
conn.close()


Encoding references: 100%|██████████| 1/1 [00:00<00:00, 20.82it/s]


Training the index...
Index training completed.
Adding vectors to the index...
Vectors added to the index.
Searching for next sentence for input: 'foes right and left. Ser Rodrik hammered at the big man in the shadowskin cloak, their horses dancing round each other as they traded blow for blow.'...
Recommended next sentence: ['. Tyrion danced back in while the brigand\'s leg was still pinned beneath his fallen mount, and buried the axe in the man\'s neck, just above the shoulder blades. As he struggled to yank the blade loose, he heard Marillion moaning under the bodies. "Someone help me," the singer']
