In [None]:

import torch
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch import nn
import os
import time

# Use GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"

# Define the trained CustomMLP model
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

# Load a large model and distribute it across multiple GPUs
def load_large_model(model_name):
    """Load a large model and distribute it across multiple GPUs."""
    print("Loading model across multiple GPUs...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="balanced",  # Automatically balance across multiple GPUs
        offload_folder="offload",  # Offload parts of the model to disk if memory is insufficient
        offload_state_dict=True,
        output_hidden_states=True
    )
    print("Model loaded successfully.")
    return model

# Extract the hidden states of texts and measure execution time
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    hidden_states = []
    
    # Record the start time for hidden state extraction
    start_time = time.time()
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Access the hidden state of the last token in the last hidden layer
        last_layer_hidden_states = outputs.hidden_states[-1]
        last_token_hidden_states = last_layer_hidden_states[:, -1, :]  # -1 means the last token
        hidden_states.append(last_token_hidden_states.cpu().numpy())  # Ensure the data is on CPU
    
    # Record the end time for hidden state extraction
    end_time = time.time()
    
    # Calculate the time taken for hidden state extraction
    extract_time = end_time - start_time
    print(f"Time taken to extract hidden states: {extract_time:.4f} seconds")
    
    return np.vstack(hidden_states), extract_time

# Prediction function
def predict_model(model, X_test, threshold=0.5):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)  # Ensure the data is on the same device
    
    # Record the start time
    start_time = time.time()
    
    with torch.no_grad():
        logits = model(X_test_tensor)
        probabilities = torch.sigmoid(logits)
        predictions = (probabilities > threshold).float().cpu().numpy()  # Move results to CPU and convert to NumPy
    
    # Record the end time
    end_time = time.time()
    
    # Calculate prediction time
    prediction_time = end_time - start_time
    return predictions, prediction_time

# Load all data
def load_all_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)
    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)
    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement

# Main function
def main(non_infringement_file, infringement_file, checkpoint_path, model_name, batch_size=4):
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
    model = load_large_model(model_name)
    
    # Handle padding issues
    if tokenizer.pad_token is None:
        # Manually add pad_token
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    # Use eos_token as pad_token
    tokenizer.pad_token = tokenizer.eos_token  # Set eos_token as pad_token
    
    # Load the trained CustomMLP
    custom_mlp = CustomMLP(input_dim=8192, hidden_dim=256)  # Modify input dimensions as needed
    custom_mlp.load_state_dict(torch.load(checkpoint_path))
    
    # Set device to ensure both model and data are on the same device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    custom_mlp.to(device)

    # Load all data
    non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement = load_all_data(non_infringement_file, infringement_file)

    # Extract hidden states of the texts
    print("Extracting hidden states for non_infringement texts...")
    X_non_infringement, extract_time_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer, batch_size)

    print("Extracting hidden states for infringement texts...")
    X_infringement, extract_time_infringement = extract_hidden_states(infringement_outputs, model, tokenizer, batch_size)

    # Combine the data
    X_test = np.vstack((X_non_infringement, X_infringement))
    y_test = np.concatenate((y_non_infringement, y_infringement))

    # Track total prediction time
    total_prediction_time = 0
    total_samples = len(X_test)

    # Predict using the trained model
    print("Predicting on test set...")
    predictions = []
    for i in tqdm(range(total_samples), desc="Predicting samples"):
        single_sample = X_test[i:i+1]  # Predict one sample at a time
        single_prediction, prediction_time = predict_model(custom_mlp, single_sample, threshold=0.5)
        predictions.append(single_prediction)
        total_prediction_time += prediction_time

    # Calculate average prediction time per sample
    average_prediction_time = total_prediction_time / total_samples
    print(f"Average prediction time per sample: {average_prediction_time:.6f} seconds")

    # Calculate total average time (extraction + prediction)
    total_time = extract_time_non_infringement + extract_time_infringement + total_prediction_time
    average_total_time = total_time / total_samples
    print(f"Average total time per sample (extraction + prediction): {average_total_time:.6f} seconds")

    # Print results
    predictions = np.concatenate(predictions)
    accuracy = accuracy_score(y_test, predictions)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    print("Classification Report:")
    print(classification_report(y_test, predictions, target_names=["Infringement", "Non-Infringement"]))

if __name__ == "__main__":
    # Define paths
    non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/literal.non_infringement.json'
    infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/literal.infringement.json'
    checkpoint_path = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/0.pth'
    model_name = '/raid/data/guangwei/huggingface/hub/models--meta-llama--Llama-3.1-70B/snapshots/349b2ddb53ce8f2849a6c168a81980ab25258dac/'

    main(non_infringement_file, infringement_file, checkpoint_path, model_name, batch_size=1)


  from .autonotebook import tqdm as notebook_tqdm


Loading model across multiple GPUs...


Loading checkpoint shards: 100%|██████████| 30/30 [01:11<00:00,  2.38s/it]
  custom_mlp.load_state_dict(torch.load(checkpoint_path))


Model loaded successfully.
Extracting hidden states for non_infringement texts...


Processing data batches: 100%|██████████| 590/590 [21:37<00:00,  2.20s/it]


Time taken to extract hidden states: 1297.2193 seconds
Extracting hidden states for infringement texts...


Processing data batches: 100%|██████████| 168/168 [06:12<00:00,  2.22s/it]


Time taken to extract hidden states: 372.8114 seconds
Predicting on test set...


Predicting samples: 100%|██████████| 758/758 [00:00<00:00, 3391.09it/s]

Average prediction time per sample: 0.000209 seconds
Average total time per sample (extraction + prediction): 2.203416 seconds
Test Accuracy: 79.82%
Classification Report:
                  precision    recall  f1-score   support

    Infringement       0.56      0.42      0.48       168
Non-Infringement       0.85      0.91      0.87       590

        accuracy                           0.80       758
       macro avg       0.70      0.66      0.68       758
    weighted avg       0.78      0.80      0.79       758






: 