In [1]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch import nn
import os
import time

# 使用 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

# 加载训练好的 CustomMLP
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

# 提取文本的内部状态，并计算时间
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)  # 确保模型在相同设备上
    model = nn.DataParallel(model)  # 多GPU支持
    hidden_states = []
    
    # 记录提取hidden states的开始时间
    start_time = time.time()
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        # 将inputs的所有tensor移动到相同的设备
        inputs = {key: value.to(device) for key, value in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 访问最后一个隐藏层的最后一个token的隐藏状态
        last_layer_hidden_states = outputs.hidden_states[-1]
        last_token_hidden_states = last_layer_hidden_states[:, -1, :]  # -1表示最后一个token
        hidden_states.append(last_token_hidden_states.cpu().numpy())  # 确保数据在CPU上
    
    # 记录提取hidden states的结束时间
    end_time = time.time()
    
    # 计算提取hidden states的时间
    extract_time = end_time - start_time
    print(f"Time taken to extract hidden states: {extract_time:.4f} seconds")
    
    return np.vstack(hidden_states), extract_time

# 预测方法
def predict_model(model, X_test, threshold=0.5):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)  # 确保数据在相同设备上
    
    # 记录开始时间
    start_time = time.time()
    
    with torch.no_grad():
        logits = model(X_test_tensor)
        probabilities = torch.sigmoid(logits)
        predictions = (probabilities > threshold).float().cpu().numpy()  # 先将结果移到 CPU 然后转换为 NumPy
    
    # 记录结束时间
    end_time = time.time()
    
    # 计算预测时间
    prediction_time = end_time - start_time
    return predictions, prediction_time

# 加载所有数据
def load_all_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)
    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)
    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement

# 主程序
def main(non_infringement_file, infringement_file, checkpoint_path, model_name, batch_size=4):
    # 加载 tokenizer 和模型
    tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
    model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
    
    # 解决 padding 问题
    if tokenizer.pad_token is None:
        # 手动添加 pad_token
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    # 使用 eos_token 作为 pad_token
    tokenizer.pad_token = tokenizer.eos_token  # 设置 eos_token 作为 pad_token
    
    # 加载训练好的 CustomMLP
    custom_mlp = CustomMLP(input_dim=4096, hidden_dim=256)  # 根据实际输入维度修改
    custom_mlp.load_state_dict(torch.load(checkpoint_path))
    
    # 设备设置，确保模型和数据都在同一设备上
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    custom_mlp.to(device)

    # 加载所有数据
    non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement = load_all_data(non_infringement_file, infringement_file)

    # 提取文本的内部状态
    print("Extracting hidden states for non_infringement texts...")
    X_non_infringement, extract_time_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer, batch_size)

    print("Extracting hidden states for infringement texts...")
    X_infringement, extract_time_infringement = extract_hidden_states(infringement_outputs, model, tokenizer, batch_size)

    # 组合数据
    X_test = np.vstack((X_non_infringement, X_infringement))
    y_test = np.concatenate((y_non_infringement, y_infringement))

    # 记录总时间
    total_prediction_time = 0
    total_samples = len(X_test)

    # 使用训练好的模型进行预测
    print("Predicting on test set...")
    predictions = []
    for i in tqdm(range(total_samples), desc="Predicting samples"):
        single_sample = X_test[i:i+1]  # 预测单条数据
        single_prediction, prediction_time = predict_model(custom_mlp, single_sample, threshold=0.5)
        predictions.append(single_prediction)
        total_prediction_time += prediction_time

    # 计算平均预测时间
    average_prediction_time = total_prediction_time / total_samples
    print(f"Average prediction time per sample: {average_prediction_time:.6f} seconds")

    # 计算总的平均时间
    total_time = extract_time_non_infringement + extract_time_infringement + total_prediction_time
    average_total_time = total_time / total_samples
    print(f"Average total time per sample (extraction + prediction): {average_total_time:.6f} seconds")

    # 打印结果
    predictions = np.concatenate(predictions)
    accuracy = accuracy_score(y_test, predictions)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    print("Classification Report:")
    print(classification_report(y_test, predictions, target_names=["Infringement", "Non-Infringement"]))

if __name__ == "__main__":
    # 设置路径
    non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.non_infringement.json'
    infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.infringement.json'
    checkpoint_path = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_last_token.pth'
    model_name = 'meta-llama/Meta-Llama-3.1-8B'

    main(non_infringement_file, infringement_file, checkpoint_path, model_name, batch_size=1)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]
  custom_mlp.load_state_dict(torch.load(checkpoint_path))


Extracting hidden states for non_infringement texts...


Processing data batches: 100%|██████████| 743/743 [00:43<00:00, 17.10it/s]


Time taken to extract hidden states: 43.4474 seconds
Extracting hidden states for infringement texts...


Processing data batches: 100%|██████████| 706/706 [00:37<00:00, 18.61it/s]


Time taken to extract hidden states: 37.9508 seconds
Predicting on test set...


Predicting samples: 100%|██████████| 1449/1449 [00:00<00:00, 2950.40it/s]


Average prediction time per sample: 0.000246 seconds
Average total time per sample (extraction + prediction): 0.056421 seconds
Test Accuracy: 66.32%
Classification Report:
                  precision    recall  f1-score   support

    Infringement       0.77      0.44      0.56       706
Non-Infringement       0.62      0.88      0.73       743

        accuracy                           0.66      1449
       macro avg       0.70      0.66      0.64      1449
    weighted avg       0.69      0.66      0.65      1449

