In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import json
import os
import time
import torch.nn as nn

os.environ["CUDA_VISIBLE_DEVICES"] = ""

model_name = 'meta-llama/Meta-Llama-3.1-8B'
bert_model_name = 'google-bert/bert-base-uncased'
checkpoint_file = 'train_input_reference_last_token.pth'
non_infringement_file = 'extra_30.non_infringement.json'
infringement_file = 'extra_30.infringement.json'

# Load models and tokenizers
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token  # use eos_token as pad_token

bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_tokenizer.pad_token = tokenizer.eos_token  # use eos_token as pad_token

# Define Custom MLP
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

# Extract hidden states from the causal LM model (Meta-Llama)
def extract_hidden_states(texts, model, tokenizer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = nn.DataParallel(model)
    hidden_states = []
    data_times = []  # List to record time taken for each data point

    # Process each text and calculate time for each data point
    for text in tqdm(texts, desc="Processing data points"):
        start_time = time.time()  # Start time for the data point
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        last_layer_hidden_states = outputs.hidden_states[-1]
        last_token_hidden_states = last_layer_hidden_states[:, -1, :]  # Get last token's hidden state
        hidden_states.append(last_token_hidden_states.cpu().numpy())
        end_time = time.time()  # End time for the data point
        data_times.append(end_time - start_time)  # Record time for each data point
    
    # Calculate average time per sample
    avg_time_per_sample = np.mean(data_times)
    print(f"Average time per sample for extracting hidden states: {avg_time_per_sample:.4f} seconds")
    
    return np.vstack(hidden_states), avg_time_per_sample

# Extract reference embeddings from the BERT model
def extract_reference_embeddings(references, model, tokenizer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = nn.DataParallel(model)
    embeddings = []
    data_times = []  # List to record time taken for each data point

    # Process each reference and calculate time for each data point
    for reference in tqdm(references, desc="Processing references"):
        start_time = time.time()  # Start time for the data point
        inputs = tokenizer(reference, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.pooler_output.cpu().numpy())
        end_time = time.time()  # End time for the data point
        data_times.append(end_time - start_time)  # Record time for each data point
    
    # Calculate average time per sample
    avg_time_per_sample = np.mean(data_times)
    print(f"Average time per sample for extracting reference embeddings: {avg_time_per_sample:.4f} seconds")
    
    return np.vstack(embeddings), avg_time_per_sample

# Load data
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement

# Predict method (similar to training, but no backpropagation)
def predict_model(model, X_test, batch_size=4, threshold=0.5):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    predictions = []
    data_times = []  # List to record time taken for each data point

    # Process each data point and calculate time for each data point
    for i in tqdm(range(len(X_test)), desc="Predicting data points"):
        start_time = time.time()  # Start time for the data point
        batch_data = X_test_tensor[i:i+1]  # Single data point at a time
        with torch.no_grad():
            logits = model(batch_data)
            probabilities = torch.sigmoid(logits)
            predictions.append((probabilities > threshold).cpu().numpy())
        end_time = time.time()  # End time for the data point
        data_times.append(end_time - start_time)  # Record time for each data point
    
    # Calculate average time per sample
    avg_time_per_sample = np.mean(data_times)
    print(f"Average time per sample for prediction: {avg_time_per_sample:.4f} seconds")
    
    return np.vstack(predictions), avg_time_per_sample

# Main prediction process
def main_predict(non_infringement_file, infringement_file, checkpoint_file):
    # Load and process data
    non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement = load_data(non_infringement_file, infringement_file)

    # Extract hidden states and reference embeddings
    print("Extracting hidden states for non_infringement texts...")
    X_non_infringement, hidden_state_avg_time_per_sample = extract_hidden_states(non_infringement_outputs, model, tokenizer)
    print("Extracting reference embeddings for non_infringement texts...")
    reference_embeddings_non_infringement, reference_embedding_avg_time_per_sample = extract_reference_embeddings(non_infringement_references, bert_model, bert_tokenizer)
    X_non_infringement_combined = np.hstack([X_non_infringement, reference_embeddings_non_infringement])

    print("Extracting hidden states for infringement texts...")
    X_infringement, infringement_hidden_state_avg_time_per_sample = extract_hidden_states(infringement_outputs, model, tokenizer)
    print("Extracting reference embeddings for infringement texts...")
    reference_embeddings_infringement, infringement_reference_embedding_avg_time_per_sample = extract_reference_embeddings(infringement_references, bert_model, bert_tokenizer)
    X_infringement_combined = np.hstack([X_infringement, reference_embeddings_infringement])

    # Combine data
    X_test = np.vstack((X_non_infringement_combined, X_infringement_combined))
    y_test = np.concatenate((y_non_infringement, y_infringement))

    # Load trained model
    custom_mlp = CustomMLP(input_dim=X_test.shape[1], hidden_dim=256)
    custom_mlp.load_state_dict(torch.load(checkpoint_file))
    custom_mlp.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # Make predictions
    print("Making predictions on the test set...")
    predictions, prediction_avg_time_per_sample = predict_model(custom_mlp, X_test)
    
    # Evaluate the model
    accuracy = accuracy_score(y_test, predictions)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    print("Classification Report:")
    print(classification_report(y_test, predictions, target_names=["infringement", "non_infringement"]))
    
    # Calculate total time for each sample
    total_times = []
    for i in range(len(X_test)):
        # For each sample, compute the total time
        total_time = hidden_state_avg_time_per_sample + reference_embedding_avg_time_per_sample + prediction_avg_time_per_sample
        total_times.append(total_time)
    
    # Compute average total time for all samples
    avg_total_time = np.mean(total_times)
    print(f"Average total time per sample: {avg_total_time:.4f} seconds")

if __name__ == "__main__":
    main_predict(non_infringement_file, infringement_file, checkpoint_file)
