In [16]:
# import packages
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

# Variables
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_10.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_10.infringement.json'
checkpoint_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_last_token.pth'

In [17]:
# Define CustumMLP for internal states train
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

In [18]:
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = nn.DataParallel(model)
    hidden_states = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 访问最后一个隐藏层的最后一个token的隐藏状态
        # hidden_states[-1]表示最后一个隐藏层，mean(dim=1)表示取所有头的均值
        last_layer_hidden_states = outputs.hidden_states[-1]
        last_token_hidden_states = last_layer_hidden_states[:, -1, :]  # -1表示最后一个token
        hidden_states.append(last_token_hidden_states.cpu().numpy())
    return np.vstack(hidden_states)

In [19]:
# lode data for infringement & non infringement
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement

In [20]:
from sklearn.metrics import accuracy_score, classification_report, f1_score

# Train for best model
def train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim, epochs=2500, lr=0.001, checkpoint_path=checkpoint_file):
    custom_mlp = CustomMLP(input_dim, hidden_dim)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(custom_mlp.parameters(), lr=lr)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

    best_accuracy = -float('inf')  # Initialize the best accuracy to negative infinity
    best_f1 = -float('inf')  # Initialize the best F1-score to negative infinity
    best_model_state = None  # Store the state of the best model
    best_epoch = 0  # Track the epoch with the best accuracy
    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        custom_mlp.train()
        optimizer.zero_grad()
        outputs = custom_mlp(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            
            custom_mlp.eval()
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
            with torch.no_grad():
                y_pred_logits = custom_mlp(X_test_tensor)
                y_pred = (torch.sigmoid(y_pred_logits) > 0.5).float().numpy()
            
            accuracy = accuracy_score(y_test, y_pred)
            f1 = f1_score(y_test, y_pred)  # Calculate F1-score
            print(f"Test Accuracy at Epoch {epoch + 1}: {accuracy * 100:.2f}%")
            print(f"Test F1-score at Epoch {epoch + 1}: {f1:.4f}")
            
            report = classification_report(y_test, y_pred, target_names=["infringement", "non_infringement"])
            print(f"Classification Report at Epoch {epoch + 1}:\n{report}")

            # Check if the current model is the best based on F1-score
            if f1 > best_f1:
                best_accuracy = accuracy
                best_f1 = f1
                best_model_state = custom_mlp.state_dict()
                best_epoch = epoch + 1
                torch.save(best_model_state, checkpoint_path)
                print(f"New best model saved with F1-score {best_f1:.4f} at epoch {best_epoch}")
                print(f"Best Classification Report at Epoch {best_epoch}:\n{report}")

    # Load the best model state
    custom_mlp.load_state_dict(torch.load(checkpoint_path))

    # Plot loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.show()

    print(f"Best Model was saved at epoch {best_epoch} with F1-score {best_f1:.4f} and accuracy {best_accuracy * 100:.2f}%")
    return custom_mlp, losses, best_accuracy, best_f1


In [21]:

tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token

non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement = load_data(non_infringement_file, infringement_file)

y_non_infringement = np.array(y_non_infringement)
y_infringement = np.array(y_infringement)


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.19s/it]


In [22]:
print("Extracting hidden states for non_infringement texts...")
X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)

print("Extracting hidden states for infringement texts...")
X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)

Extracting hidden states for non_infringement texts...


Processing data batches: 100%|██████████| 65/65 [00:15<00:00,  4.08it/s]


Extracting hidden states for infringement texts...


Processing data batches: 100%|██████████| 81/81 [00:23<00:00,  3.38it/s]


In [23]:
split_index_non_infringement = int(0.8 * len(X_non_infringement))
X_non_infringement_train = X_non_infringement[:split_index_non_infringement]
X_non_infringement_test = X_non_infringement[split_index_non_infringement:]
y_non_infringement_train = y_non_infringement[:split_index_non_infringement]
y_non_infringement_test = y_non_infringement[split_index_non_infringement:]

split_index_infringement = int(0.8 * len(X_infringement))
X_infringement_train = X_infringement[:split_index_infringement]
X_infringement_test = X_infringement[split_index_infringement:]
y_infringement_train = y_infringement[:split_index_infringement]
y_infringement_test = y_infringement[split_index_infringement:]

X_train = np.vstack((X_non_infringement_train, X_infringement_train))
X_test = np.vstack((X_non_infringement_test, X_infringement_test))
y_train = np.concatenate((y_non_infringement_train, y_infringement_train))
y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

print("Data successfully split into training and test sets.")


Data successfully split into training and test sets.


In [None]:
input_dim = X_train.shape[1]
hidden_dim = 256

custom_mlp, losses, best_accuracy, best_f1 = train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim)

Training Epochs:   0%|          | 10/2500 [00:04<17:00,  2.44it/s]

Epoch 10/2500, Loss: 0.4185
Test Accuracy at Epoch 10: 78.63%
Test F1-score at Epoch 10: 0.8062
Classification Report at Epoch 10:
                  precision    recall  f1-score   support

    infringement       1.00      0.62      0.76        65
non_infringement       0.68      1.00      0.81        52

        accuracy                           0.79       117
       macro avg       0.84      0.81      0.78       117
    weighted avg       0.86      0.79      0.78       117

New best model saved with F1-score 0.8062 at epoch 10
Best Classification Report at Epoch 10:
                  precision    recall  f1-score   support

    infringement       1.00      0.62      0.76        65
non_infringement       0.68      1.00      0.81        52

        accuracy                           0.79       117
       macro avg       0.84      0.81      0.78       117
    weighted avg       0.86      0.79      0.78       117



Training Epochs:   1%|          | 20/2500 [00:08<16:18,  2.53it/s]

Epoch 20/2500, Loss: 0.3664
Test Accuracy at Epoch 20: 71.79%
Test F1-score at Epoch 20: 0.7130
Classification Report at Epoch 20:
                  precision    recall  f1-score   support

    infringement       0.80      0.66      0.72        65
non_infringement       0.65      0.79      0.71        52

        accuracy                           0.72       117
       macro avg       0.72      0.72      0.72       117
    weighted avg       0.73      0.72      0.72       117



Training Epochs:   1%|          | 30/2500 [00:12<16:12,  2.54it/s]

Epoch 30/2500, Loss: 0.2659
Test Accuracy at Epoch 30: 80.34%
Test F1-score at Epoch 30: 0.7723
Classification Report at Epoch 30:
                  precision    recall  f1-score   support

    infringement       0.81      0.85      0.83        65
non_infringement       0.80      0.75      0.77        52

        accuracy                           0.80       117
       macro avg       0.80      0.80      0.80       117
    weighted avg       0.80      0.80      0.80       117



Training Epochs:   2%|▏         | 40/2500 [00:16<16:46,  2.44it/s]

Epoch 40/2500, Loss: 0.1406
Test Accuracy at Epoch 40: 88.03%
Test F1-score at Epoch 40: 0.8704
Classification Report at Epoch 40:
                  precision    recall  f1-score   support

    infringement       0.92      0.86      0.89        65
non_infringement       0.84      0.90      0.87        52

        accuracy                           0.88       117
       macro avg       0.88      0.88      0.88       117
    weighted avg       0.88      0.88      0.88       117

New best model saved with F1-score 0.8704 at epoch 40
Best Classification Report at Epoch 40:
                  precision    recall  f1-score   support

    infringement       0.92      0.86      0.89        65
non_infringement       0.84      0.90      0.87        52

        accuracy                           0.88       117
       macro avg       0.88      0.88      0.88       117
    weighted avg       0.88      0.88      0.88       117



Training Epochs:   2%|▏         | 50/2500 [00:19<16:12,  2.52it/s]

Epoch 50/2500, Loss: 0.0703
Test Accuracy at Epoch 50: 88.89%
Test F1-score at Epoch 50: 0.8807
Classification Report at Epoch 50:
                  precision    recall  f1-score   support

    infringement       0.93      0.86      0.90        65
non_infringement       0.84      0.92      0.88        52

        accuracy                           0.89       117
       macro avg       0.89      0.89      0.89       117
    weighted avg       0.89      0.89      0.89       117

New best model saved with F1-score 0.8807 at epoch 50
Best Classification Report at Epoch 50:
                  precision    recall  f1-score   support

    infringement       0.93      0.86      0.90        65
non_infringement       0.84      0.92      0.88        52

        accuracy                           0.89       117
       macro avg       0.89      0.89      0.89       117
    weighted avg       0.89      0.89      0.89       117



Training Epochs:   2%|▏         | 60/2500 [00:23<16:00,  2.54it/s]

Epoch 60/2500, Loss: 0.0410
Test Accuracy at Epoch 60: 88.03%
Test F1-score at Epoch 60: 0.8704
Classification Report at Epoch 60:
                  precision    recall  f1-score   support

    infringement       0.92      0.86      0.89        65
non_infringement       0.84      0.90      0.87        52

        accuracy                           0.88       117
       macro avg       0.88      0.88      0.88       117
    weighted avg       0.88      0.88      0.88       117



Training Epochs:   3%|▎         | 70/2500 [00:27<16:08,  2.51it/s]

Epoch 70/2500, Loss: 0.0296
Test Accuracy at Epoch 70: 88.03%
Test F1-score at Epoch 70: 0.8704
Classification Report at Epoch 70:
                  precision    recall  f1-score   support

    infringement       0.92      0.86      0.89        65
non_infringement       0.84      0.90      0.87        52

        accuracy                           0.88       117
       macro avg       0.88      0.88      0.88       117
    weighted avg       0.88      0.88      0.88       117



Training Epochs:   3%|▎         | 80/2500 [00:31<15:53,  2.54it/s]

Epoch 80/2500, Loss: 0.0260
Test Accuracy at Epoch 80: 88.89%
Test F1-score at Epoch 80: 0.8807
Classification Report at Epoch 80:
                  precision    recall  f1-score   support

    infringement       0.93      0.86      0.90        65
non_infringement       0.84      0.92      0.88        52

        accuracy                           0.89       117
       macro avg       0.89      0.89      0.89       117
    weighted avg       0.89      0.89      0.89       117



Training Epochs:   4%|▎         | 90/2500 [00:35<15:58,  2.52it/s]

Epoch 90/2500, Loss: 0.0247
Test Accuracy at Epoch 90: 87.18%
Test F1-score at Epoch 90: 0.8598
Classification Report at Epoch 90:
                  precision    recall  f1-score   support

    infringement       0.90      0.86      0.88        65
non_infringement       0.84      0.88      0.86        52

        accuracy                           0.87       117
       macro avg       0.87      0.87      0.87       117
    weighted avg       0.87      0.87      0.87       117



Training Epochs:   4%|▍         | 100/2500 [00:39<15:46,  2.54it/s]

Epoch 100/2500, Loss: 0.0242
Test Accuracy at Epoch 100: 87.18%
Test F1-score at Epoch 100: 0.8598
Classification Report at Epoch 100:
                  precision    recall  f1-score   support

    infringement       0.90      0.86      0.88        65
non_infringement       0.84      0.88      0.86        52

        accuracy                           0.87       117
       macro avg       0.87      0.87      0.87       117
    weighted avg       0.87      0.87      0.87       117



Training Epochs:   4%|▍         | 110/2500 [00:43<17:33,  2.27it/s]

Epoch 110/2500, Loss: 0.0239
Test Accuracy at Epoch 110: 87.18%
Test F1-score at Epoch 110: 0.8598
Classification Report at Epoch 110:
                  precision    recall  f1-score   support

    infringement       0.90      0.86      0.88        65
non_infringement       0.84      0.88      0.86        52

        accuracy                           0.87       117
       macro avg       0.87      0.87      0.87       117
    weighted avg       0.87      0.87      0.87       117



Training Epochs:   5%|▍         | 120/2500 [00:47<16:08,  2.46it/s]

Epoch 120/2500, Loss: 0.0237
Test Accuracy at Epoch 120: 87.18%
Test F1-score at Epoch 120: 0.8598
Classification Report at Epoch 120:
                  precision    recall  f1-score   support

    infringement       0.90      0.86      0.88        65
non_infringement       0.84      0.88      0.86        52

        accuracy                           0.87       117
       macro avg       0.87      0.87      0.87       117
    weighted avg       0.87      0.87      0.87       117



Training Epochs:   5%|▌         | 130/2500 [00:51<15:31,  2.54it/s]

Epoch 130/2500, Loss: 0.0236
Test Accuracy at Epoch 130: 87.18%
Test F1-score at Epoch 130: 0.8598
Classification Report at Epoch 130:
                  precision    recall  f1-score   support

    infringement       0.90      0.86      0.88        65
non_infringement       0.84      0.88      0.86        52

        accuracy                           0.87       117
       macro avg       0.87      0.87      0.87       117
    weighted avg       0.87      0.87      0.87       117



Training Epochs:   6%|▌         | 140/2500 [00:54<15:35,  2.52it/s]

Epoch 140/2500, Loss: 0.0235
Test Accuracy at Epoch 140: 87.18%
Test F1-score at Epoch 140: 0.8598
Classification Report at Epoch 140:
                  precision    recall  f1-score   support

    infringement       0.90      0.86      0.88        65
non_infringement       0.84      0.88      0.86        52

        accuracy                           0.87       117
       macro avg       0.87      0.87      0.87       117
    weighted avg       0.87      0.87      0.87       117



Training Epochs:   6%|▌         | 150/2500 [00:58<15:25,  2.54it/s]

Epoch 150/2500, Loss: 0.0235
Test Accuracy at Epoch 150: 85.47%
Test F1-score at Epoch 150: 0.8381
Classification Report at Epoch 150:
                  precision    recall  f1-score   support

    infringement       0.88      0.86      0.87        65
non_infringement       0.83      0.85      0.84        52

        accuracy                           0.85       117
       macro avg       0.85      0.85      0.85       117
    weighted avg       0.86      0.85      0.85       117



Training Epochs:   6%|▋         | 160/2500 [01:02<15:22,  2.54it/s]

Epoch 160/2500, Loss: 0.0234
Test Accuracy at Epoch 160: 85.47%
Test F1-score at Epoch 160: 0.8381
Classification Report at Epoch 160:
                  precision    recall  f1-score   support

    infringement       0.88      0.86      0.87        65
non_infringement       0.83      0.85      0.84        52

        accuracy                           0.85       117
       macro avg       0.85      0.85      0.85       117
    weighted avg       0.86      0.85      0.85       117



Training Epochs:   7%|▋         | 170/2500 [01:06<15:34,  2.49it/s]

Epoch 170/2500, Loss: 0.0234
Test Accuracy at Epoch 170: 85.47%
Test F1-score at Epoch 170: 0.8381
Classification Report at Epoch 170:
                  precision    recall  f1-score   support

    infringement       0.88      0.86      0.87        65
non_infringement       0.83      0.85      0.84        52

        accuracy                           0.85       117
       macro avg       0.85      0.85      0.85       117
    weighted avg       0.86      0.85      0.85       117



Training Epochs:   7%|▋         | 180/2500 [01:10<15:16,  2.53it/s]

Epoch 180/2500, Loss: 0.0234
Test Accuracy at Epoch 180: 84.62%
Test F1-score at Epoch 180: 0.8302
Classification Report at Epoch 180:
                  precision    recall  f1-score   support

    infringement       0.87      0.85      0.86        65
non_infringement       0.81      0.85      0.83        52

        accuracy                           0.85       117
       macro avg       0.84      0.85      0.84       117
    weighted avg       0.85      0.85      0.85       117



Training Epochs:   8%|▊         | 190/2500 [01:14<15:08,  2.54it/s]

Epoch 190/2500, Loss: 0.0233
Test Accuracy at Epoch 190: 84.62%
Test F1-score at Epoch 190: 0.8302
Classification Report at Epoch 190:
                  precision    recall  f1-score   support

    infringement       0.87      0.85      0.86        65
non_infringement       0.81      0.85      0.83        52

        accuracy                           0.85       117
       macro avg       0.84      0.85      0.84       117
    weighted avg       0.85      0.85      0.85       117



Training Epochs:   8%|▊         | 200/2500 [01:18<15:05,  2.54it/s]

Epoch 200/2500, Loss: 0.0233
Test Accuracy at Epoch 200: 84.62%
Test F1-score at Epoch 200: 0.8302
Classification Report at Epoch 200:
                  precision    recall  f1-score   support

    infringement       0.87      0.85      0.86        65
non_infringement       0.81      0.85      0.83        52

        accuracy                           0.85       117
       macro avg       0.84      0.85      0.84       117
    weighted avg       0.85      0.85      0.85       117



Training Epochs:   8%|▊         | 205/2500 [01:19<14:34,  2.62it/s]

In [None]:
def save_checkpoint(model, optimizer, epoch, loss, filepath):
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to '{filepath}'.")

save_checkpoint(custom_mlp, torch.optim.Adam(custom_mlp.parameters()), len(losses), losses[-1], checkpoint_file)

Checkpoint saved to '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_last_token.pth'.


In [None]:
print(f"Final Model Accuracy: {best_accuracy * 100:.2f}%")
y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()
print(classification_report(y_test, y_pred_final, target_names=["infringement", "non_infringement"]))

Final Model Accuracy: 91.53%
                  precision    recall  f1-score   support

    infringement       0.93      0.86      0.89        49
non_infringement       0.90      0.96      0.93        69

        accuracy                           0.92       118
       macro avg       0.92      0.91      0.91       118
    weighted avg       0.92      0.92      0.91       118



  y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()


: 