In [None]:
# import packages & variables
import argparse
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification
import json
import os
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Parameters
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.infringement.json'
checkpoint_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_reference_0_layer.pth'


In [None]:
# Define CustumMLP for internal states train
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

In [None]:
# Extract hidden states/reference embeddings
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    hidden_states = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
        hidden_states.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
    
    return np.vstack(hidden_states)



def extract_decoding(references, model, tokenizer, bert_model, bert_tokenizer, max_length=300, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    bert_model.to(device)
    bert_model = nn.DataParallel(bert_model)
    
    outputs = []
    ttft_list = []
    tpot_list = []
    sysprompt = "You will be shown a series of passages from famous literary works. After these examples, you will receive a prefix from another passage and be asked to complete it based on the text of a famous work. Provide only the continuation for the last given prefix without any extra commentary, formatting, or additional text."
    
    for i in tqdm(range(0, len(references), batch_size), desc="Processing references"):
        batch_references = references[i:i + batch_size]
        # Create the input prompt by concatenating sysprompt and the actual reference text
        prompt = sysprompt + " " + batch_references[0]
        
        # Tokenize and prepare input for generation model
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, padding_side='left').to(device)
        tokenizer.pad_token_id = tokenizer.eos_token_id

        with torch.no_grad():
            # Generate sequences
            start_time = time.time()
            generated_ids = model.generate(
                **inputs, max_new_tokens=20, output_hidden_states=False, return_dict_in_generate=True, use_cache=True
            )
            ttft = time.time() - start_time
            num_tokens_generated = generated_ids["sequences"].shape[1]
            tpot = ttft / num_tokens_generated if num_tokens_generated > 0 else float('inf')

            # Decode the generated sequences to text
            generated_texts = tokenizer.batch_decode(generated_ids["sequences"], skip_special_tokens=True)
            
            # Remove the prompt part from generated_texts if it's included
            generated_texts_cleaned = []
            for text in generated_texts:
                # Remove the prompt portion
                text = text[len(prompt):].strip()
                generated_texts_cleaned.append(text)
            
            # Print the cleaned generated texts
            # print("Generated Continuations:", generated_texts_cleaned)

            # Tokenize the cleaned text for BERT model
            bert_inputs = bert_tokenizer(generated_texts_cleaned, return_tensors="pt", padding=True, truncation=True).to(device)
            bert_outputs = bert_model(**bert_inputs)
            
            # Extract BERT embeddings
            embedding = bert_outputs.pooler_output.cpu().numpy()
        
        outputs.append(embedding)
        ttft_list.append(ttft)
        tpot_list.append(tpot)

    # Calculate averages for TTFT and TPOT
    average_ttft = np.mean(ttft_list)
    average_tpot = np.mean(tpot_list)
    
    return np.vstack(outputs), average_ttft, average_tpot

In [None]:
# load data for infringement & non infringement
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement

In [None]:
from sklearn.metrics import accuracy_score, classification_report, f1_score

# Train for best model
def train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim, epochs=2000, lr=0.001, checkpoint_path=checkpoint_file):
    custom_mlp = CustomMLP(input_dim, hidden_dim)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(custom_mlp.parameters(), lr=lr)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

    best_accuracy = -float('inf')
    best_f1 = -float('inf')  # Track best F1-score
    best_model_state = None
    best_epoch = 0
    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        custom_mlp.train()
        optimizer.zero_grad()
        outputs = custom_mlp(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        # Every 10 epochs, evaluate the model
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            
            custom_mlp.eval()
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
            with torch.no_grad():
                y_pred_logits = custom_mlp(X_test_tensor)
                y_pred = (torch.sigmoid(y_pred_logits) > 0.5).float().numpy()
            
            # Calculate accuracy and F1-score
            accuracy = accuracy_score(y_test, y_pred)
            f1 = f1_score(y_test, y_pred)  # F1-score calculation
            
            print(f"Test Accuracy at Epoch {epoch + 1}: {accuracy * 100:.2f}%")
            print(f"Test F1-score at Epoch {epoch + 1}: {f1:.4f}")
            
            # Generate classification report
            report = classification_report(y_test, y_pred, target_names=["infringement", "non_infringement"])
            print(f"Classification Report at Epoch {epoch + 1}:\n{report}")

            # Save the model if it achieves a better F1-score
            if f1 > best_f1:
                best_accuracy = accuracy
                best_f1 = f1
                best_model_state = custom_mlp.state_dict()
                best_epoch = epoch + 1
                torch.save(best_model_state, checkpoint_path)
                print(f"New best model saved with F1-score {best_f1:.4f} at epoch {best_epoch}")
                print(f"Best Classification Report at Epoch {best_epoch}:\n{report}")

    # Load the best model state
    custom_mlp.load_state_dict(torch.load(checkpoint_path))

    # Plot loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.show()

    print(f"Final Model Accuracy: {best_accuracy * 100:.2f}%")
    print(f"Final Model F1-score: {best_f1:.4f}")
    
    return custom_mlp, losses, best_accuracy, best_f1


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token
bert_tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
bert_model = AutoModel.from_pretrained('google-bert/bert-base-uncased')
bert_tokenizer.pad_token = tokenizer.eos_token

non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement = load_data(
    non_infringement_file, infringement_file
)

y_non_infringement = np.array(y_non_infringement)
y_infringement = np.array(y_infringement)


In [23]:
# print("Extracting hidden states for non_infringement texts...")
# X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)
# print("Extracting reference embeddings for non_infringement texts...")
# reference_embeddings_non_infringement = extract_reference_embeddings(non_infringement_references, bert_model, bert_tokenizer)
# X_non_infringement_combined = np.hstack([X_non_infringement, reference_embeddings_non_infringement])

# print("Extracting hidden states for infringement texts...")
# X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)
# print("Extracting reference embeddings for infringement texts...")
# reference_embeddings_infringement = extract_reference_embeddings(infringement_references, bert_model, bert_tokenizer)
# X_infringement_combined = np.hstack([X_infringement, reference_embeddings_infringement])


# 在主程序中
print("Extracting hidden states for non_infringement texts...")
X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)
print("Extracting decoding for non_infringement texts...")
last_token_hidden_states_non_infringement, totaltime_non_infringement, tpot_non_infringement = extract_decoding(non_infringement_outputs, model, tokenizer, bert_model, bert_tokenizer)

if (X_non_infringement.shape[0]!=last_token_hidden_states_non_infringement.shape[0]):
    print("X_non_infringement.shape[0]:", X_non_infringement.shape[0])
    print("last_token_hidden_states_non_infringement.shape[0]:", last_token_hidden_states_non_infringement.shape[0])
    # 对齐两个数组的行数
    min_rows = min(X_non_infringement.shape[0], last_token_hidden_states_non_infringement.shape[0])

    # 裁剪两个数组到相同的行数
    X_non_infringement_aligned = X_non_infringement[:min_rows, :]
    last_token_hidden_states_non_infringement_aligned = last_token_hidden_states_non_infringement[:min_rows, :]
else:
    X_non_infringement_aligned = X_non_infringement
    last_token_hidden_states_non_infringement_aligned = last_token_hidden_states_non_infringement

    # 进行合并
X_non_infringement_combined = np.hstack([X_non_infringement_aligned, last_token_hidden_states_non_infringement_aligned])



print("Extracting hidden states for infringement texts...")
X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)
print("Extracting decoding for infringement texts...")
last_token_hidden_states_infringement, totaltime_infringement, tpot_infringement = extract_decoding(infringement_outputs, model, tokenizer, bert_model, bert_tokenizer)

if (X_infringement.shape[0]!=last_token_hidden_states_infringement.shape[0]):
    print("X_infringement.shape[0]:", X_infringement.shape[0])
    print("last_token_hidden_states_non_infringement.shape[0]:", last_token_hidden_states_infringement.shape[0])
    # 对齐两个数组的行数
    min_rows = min(X_infringement.shape[0], last_token_hidden_states_infringement.shape[0])

    # 裁剪两个数组到相同的行数
    X_infringement_aligned = X_infringement[:min_rows, :]
    last_token_hidden_states_infringement_aligned = last_token_hidden_states_infringement[:min_rows, :]
else:
    X_infringement_aligned = X_infringement
    last_token_hidden_states_infringement_aligned = last_token_hidden_states_infringement

    # 进行合并
X_infringement_combined = np.hstack([X_infringement_aligned, last_token_hidden_states_infringement_aligned])

# 计算平均值
average_totaltime = (totaltime_non_infringement + totaltime_infringement) / 2
print("Average Total Time:", average_totaltime)
average_tpot = (tpot_non_infringement + tpot_infringement) / 2
print("Average Tpot:", average_tpot)

Processing references:  35%|███▍      | 85/243 [01:13<01:56,  1.36it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing references:  35%|███▌      | 86/243 [01:14<01:55,  1.35it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing references:  36%|███▌      | 87/243 [01:14<01:55,  1.35it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing references:  36%|███▌      | 88/243 [01:15<01:55,  1.35it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing references:  37%|███▋      | 89/243 [01:16<01:54,  1.35it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing references:  37%|███▋      | 90/243 [01:17<01:58,  1.29it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Processing references:  37%|███▋      | 91/243 [01:18<01:55,  1.31it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Proces

X_infringement.shape[0]: 970
last_token_hidden_states_non_infringement.shape[0]: 243
Average Total Time: 0.7300534809910356
Average Tpot: 0.005586614325218218





In [24]:
split_index_non_infringement = int(0.8 * len(X_non_infringement_combined))
X_non_infringement_train = X_non_infringement_combined[:split_index_non_infringement]
X_non_infringement_test = X_non_infringement_combined[split_index_non_infringement:]
y_non_infringement_train = y_non_infringement[:split_index_non_infringement]
y_non_infringement_test = y_non_infringement[split_index_non_infringement:]

split_index_infringement = int(0.8 * len(X_infringement_combined))
X_infringement_train = X_infringement_combined[:split_index_infringement]
X_infringement_test = X_infringement_combined[split_index_infringement:]
y_infringement_train = y_infringement[:split_index_infringement]
y_infringement_test = y_infringement[split_index_infringement:]

X_train = np.vstack((X_non_infringement_train, X_infringement_train))
X_test = np.vstack((X_non_infringement_test, X_infringement_test))
y_train = np.concatenate((y_non_infringement_train, y_infringement_train))
y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

print("Data successfully split into training and test sets.")

Data successfully split into training and test sets.


In [25]:
input_dim = X_train.shape[1]
hidden_dim = 4096
print(f"Training MLP model with input_dim={input_dim} and hidden_dim={hidden_dim}")

custom_mlp, losses, best_accuracy, best_f1 = train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim)

Training MLP model with input_dim=4864 and hidden_dim=4096


Training Epochs:   0%|          | 9/2000 [00:04<17:19,  1.92it/s]

Epoch 10/2000, Loss: 0.7142





ValueError: Found input variables with inconsistent numbers of samples: [1519, 96]

In [None]:
def save_checkpoint(model, optimizer, epoch, loss, filepath):
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to '{filepath}'.")

save_checkpoint(custom_mlp, torch.optim.Adam(custom_mlp.parameters()), len(losses), losses[-1], checkpoint_file)

In [None]:
y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()
print(classification_report(y_test, y_pred_final, target_names=["infringement", "non_infringement"]))

In [None]:
print(totaltime_non_infringement)
print(tpot_non_infringement)

print(totaltime_infringement)
print(tpot_infringement)