In [1]:
# import packages
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

# Variables
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_30.infringement.json'
checkpoint_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_last_layer.pth'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define CustumMLP for internal states train
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

In [3]:
# Only extract hidden states
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = nn.DataParallel(model)
    hidden_states = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        hidden_states.append(outputs.hidden_states[3].mean(dim=1).cpu().numpy())
    return np.vstack(hidden_states)

In [4]:
# lode data for infringement & non infringement
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement

In [5]:
# Train for best model
def train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim, epochs=500, lr=0.001, checkpoint_path=checkpoint_file):
    custom_mlp = CustomMLP(input_dim, hidden_dim)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(custom_mlp.parameters(), lr=lr)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

    best_accuracy = -float('inf')  # Initialize the best accuracy to negative infinity
    best_model_state = None  # Store the state of the best model
    best_epoch = 0  # Track the epoch with the best accuracy
    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        custom_mlp.train()
        optimizer.zero_grad()
        outputs = custom_mlp(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            
            custom_mlp.eval()
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
            with torch.no_grad():
                y_pred_logits = custom_mlp(X_test_tensor)
                y_pred = (torch.sigmoid(y_pred_logits) > 0.5).float().numpy()
            
            accuracy = accuracy_score(y_test, y_pred)
            print(f"Test Accuracy at Epoch {epoch + 1}: {accuracy * 100:.2f}%")
            
            report = classification_report(y_test, y_pred, target_names=["infringement", "non_infringement"])
            print(f"Classification Report at Epoch {epoch + 1}:\n{report}")

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_state = custom_mlp.state_dict()
                best_epoch = epoch + 1
                torch.save(best_model_state, checkpoint_path)
                print(f"New best model saved with accuracy {best_accuracy * 100:.2f}% at epoch {best_epoch}")
                print(f"Best Classification Report at Epoch {best_epoch}:\n{report}")

    custom_mlp.load_state_dict(torch.load(checkpoint_path))

    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.show()

    print(f"Best Model was saved at epoch {best_epoch} with accuracy {best_accuracy * 100:.2f}%")
    return custom_mlp, losses, best_accuracy

In [6]:

tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token

non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement = load_data(non_infringement_file, infringement_file)

y_non_infringement = np.array(y_non_infringement)
y_infringement = np.array(y_infringement)


Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


In [7]:
print("Extracting hidden states for non_infringement texts...")
X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)

print("Extracting hidden states for infringement texts...")
X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)

Extracting hidden states for non_infringement texts...


Processing data batches:   0%|          | 0/232 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Processing data batches: 100%|██████████| 232/232 [01:06<00:00,  3.49it/s]


Extracting hidden states for infringement texts...


Processing data batches: 100%|██████████| 243/243 [01:25<00:00,  2.83it/s]


In [8]:
split_index_non_infringement = int(0.8 * len(X_non_infringement))
X_non_infringement_train = X_non_infringement[:split_index_non_infringement]
X_non_infringement_test = X_non_infringement[split_index_non_infringement:]
y_non_infringement_train = y_non_infringement[:split_index_non_infringement]
y_non_infringement_test = y_non_infringement[split_index_non_infringement:]

split_index_infringement = int(0.8 * len(X_infringement))
X_infringement_train = X_infringement[:split_index_infringement]
X_infringement_test = X_infringement[split_index_infringement:]
y_infringement_train = y_infringement[:split_index_infringement]
y_infringement_test = y_infringement[split_index_infringement:]

X_train = np.vstack((X_non_infringement_train, X_infringement_train))
X_test = np.vstack((X_non_infringement_test, X_infringement_test))
y_train = np.concatenate((y_non_infringement_train, y_infringement_train))
y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

print("Data successfully split into training and test sets.")

Data successfully split into training and test sets.


In [9]:
input_dim = X_train.shape[1]
hidden_dim = 256

custom_mlp, losses, best_accuracy = train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim)

Training Epochs:   2%|▏         | 9/500 [00:10<08:13,  1.00s/it]

Epoch 10/500, Loss: 0.6638


Training Epochs:   2%|▏         | 10/500 [00:11<08:53,  1.09s/it]

Test Accuracy at Epoch 10: 54.47%
Classification Report at Epoch 10:
                  precision    recall  f1-score   support

    infringement       0.55      0.62      0.58       194
non_infringement       0.54      0.47      0.50       186

        accuracy                           0.54       380
       macro avg       0.54      0.54      0.54       380
    weighted avg       0.54      0.54      0.54       380

New best model saved with accuracy 54.47% at epoch 10
Best Classification Report at Epoch 10:
                  precision    recall  f1-score   support

    infringement       0.55      0.62      0.58       194
non_infringement       0.54      0.47      0.50       186

        accuracy                           0.54       380
       macro avg       0.54      0.54      0.54       380
    weighted avg       0.54      0.54      0.54       380



Training Epochs:   4%|▍         | 19/500 [00:22<09:21,  1.17s/it]

Epoch 20/500, Loss: 0.5929


Training Epochs:   4%|▍         | 20/500 [00:24<10:31,  1.32s/it]

Test Accuracy at Epoch 20: 65.53%
Classification Report at Epoch 20:
                  precision    recall  f1-score   support

    infringement       0.70      0.57      0.63       194
non_infringement       0.62      0.74      0.68       186

        accuracy                           0.66       380
       macro avg       0.66      0.66      0.65       380
    weighted avg       0.66      0.66      0.65       380

New best model saved with accuracy 65.53% at epoch 20
Best Classification Report at Epoch 20:
                  precision    recall  f1-score   support

    infringement       0.70      0.57      0.63       194
non_infringement       0.62      0.74      0.68       186

        accuracy                           0.66       380
       macro avg       0.66      0.66      0.65       380
    weighted avg       0.66      0.66      0.65       380



Training Epochs:   6%|▌         | 29/500 [00:34<09:26,  1.20s/it]

Epoch 30/500, Loss: 0.4940


Training Epochs:   6%|▌         | 30/500 [00:35<10:06,  1.29s/it]

Test Accuracy at Epoch 30: 67.37%
Classification Report at Epoch 30:
                  precision    recall  f1-score   support

    infringement       0.73      0.57      0.64       194
non_infringement       0.64      0.78      0.70       186

        accuracy                           0.67       380
       macro avg       0.68      0.68      0.67       380
    weighted avg       0.68      0.67      0.67       380

New best model saved with accuracy 67.37% at epoch 30
Best Classification Report at Epoch 30:
                  precision    recall  f1-score   support

    infringement       0.73      0.57      0.64       194
non_infringement       0.64      0.78      0.70       186

        accuracy                           0.67       380
       macro avg       0.68      0.68      0.67       380
    weighted avg       0.68      0.67      0.67       380



Training Epochs:   8%|▊         | 39/500 [00:47<09:25,  1.23s/it]

Epoch 40/500, Loss: 0.4182


Training Epochs:   8%|▊         | 40/500 [00:48<09:18,  1.21s/it]

Test Accuracy at Epoch 40: 72.89%
Classification Report at Epoch 40:
                  precision    recall  f1-score   support

    infringement       0.80      0.62      0.70       194
non_infringement       0.68      0.84      0.75       186

        accuracy                           0.73       380
       macro avg       0.74      0.73      0.73       380
    weighted avg       0.74      0.73      0.73       380

New best model saved with accuracy 72.89% at epoch 40
Best Classification Report at Epoch 40:
                  precision    recall  f1-score   support

    infringement       0.80      0.62      0.70       194
non_infringement       0.68      0.84      0.75       186

        accuracy                           0.73       380
       macro avg       0.74      0.73      0.73       380
    weighted avg       0.74      0.73      0.73       380



Training Epochs:  10%|▉         | 49/500 [00:58<08:49,  1.17s/it]

Epoch 50/500, Loss: 0.3830
Test Accuracy at Epoch 50: 75.79%
Classification Report at Epoch 50:
                  precision    recall  f1-score   support

    infringement       0.83      0.66      0.74       194
non_infringement       0.71      0.86      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  10%|█         | 50/500 [01:00<09:45,  1.30s/it]

New best model saved with accuracy 75.79% at epoch 50
Best Classification Report at Epoch 50:
                  precision    recall  f1-score   support

    infringement       0.83      0.66      0.74       194
non_infringement       0.71      0.86      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  12%|█▏        | 59/500 [01:10<07:53,  1.07s/it]

Epoch 60/500, Loss: 0.3636


Training Epochs:  12%|█▏        | 60/500 [01:11<08:56,  1.22s/it]

Test Accuracy at Epoch 60: 75.26%
Classification Report at Epoch 60:
                  precision    recall  f1-score   support

    infringement       0.80      0.68      0.74       194
non_infringement       0.71      0.83      0.77       186

        accuracy                           0.75       380
       macro avg       0.76      0.75      0.75       380
    weighted avg       0.76      0.75      0.75       380



Training Epochs:  14%|█▍        | 69/500 [01:24<09:11,  1.28s/it]

Epoch 70/500, Loss: 0.3519


Training Epochs:  14%|█▍        | 70/500 [01:25<09:58,  1.39s/it]

Test Accuracy at Epoch 70: 76.32%
Classification Report at Epoch 70:
                  precision    recall  f1-score   support

    infringement       0.81      0.70      0.75       194
non_infringement       0.72      0.83      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380

New best model saved with accuracy 76.32% at epoch 70
Best Classification Report at Epoch 70:
                  precision    recall  f1-score   support

    infringement       0.81      0.70      0.75       194
non_infringement       0.72      0.83      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  16%|█▌        | 79/500 [01:35<07:37,  1.09s/it]

Epoch 80/500, Loss: 0.3433


Training Epochs:  16%|█▌        | 80/500 [01:36<08:25,  1.20s/it]

Test Accuracy at Epoch 80: 75.79%
Classification Report at Epoch 80:
                  precision    recall  f1-score   support

    infringement       0.82      0.68      0.74       194
non_infringement       0.71      0.84      0.77       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  18%|█▊        | 89/500 [01:47<07:51,  1.15s/it]

Epoch 90/500, Loss: 0.3360


Training Epochs:  18%|█▊        | 90/500 [01:48<07:55,  1.16s/it]

Test Accuracy at Epoch 90: 76.05%
Classification Report at Epoch 90:
                  precision    recall  f1-score   support

    infringement       0.82      0.69      0.75       194
non_infringement       0.72      0.84      0.77       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  20%|█▉        | 99/500 [02:00<09:03,  1.35s/it]

Epoch 100/500, Loss: 0.3357


Training Epochs:  20%|██        | 100/500 [02:02<10:05,  1.51s/it]

Test Accuracy at Epoch 100: 77.37%
Classification Report at Epoch 100:
                  precision    recall  f1-score   support

    infringement       0.81      0.72      0.77       194
non_infringement       0.74      0.83      0.78       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.77       380
    weighted avg       0.78      0.77      0.77       380

New best model saved with accuracy 77.37% at epoch 100
Best Classification Report at Epoch 100:
                  precision    recall  f1-score   support

    infringement       0.81      0.72      0.77       194
non_infringement       0.74      0.83      0.78       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.77       380
    weighted avg       0.78      0.77      0.77       380



Training Epochs:  22%|██▏       | 109/500 [02:12<07:06,  1.09s/it]

Epoch 110/500, Loss: 0.3232


Training Epochs:  22%|██▏       | 110/500 [02:13<07:39,  1.18s/it]

Test Accuracy at Epoch 110: 76.32%
Classification Report at Epoch 110:
                  precision    recall  f1-score   support

    infringement       0.85      0.65      0.74       194
non_infringement       0.71      0.88      0.78       186

        accuracy                           0.76       380
       macro avg       0.78      0.77      0.76       380
    weighted avg       0.78      0.76      0.76       380



Training Epochs:  24%|██▍       | 119/500 [02:25<08:15,  1.30s/it]

Epoch 120/500, Loss: 0.3140


Training Epochs:  24%|██▍       | 120/500 [02:26<08:12,  1.30s/it]

Test Accuracy at Epoch 120: 75.79%
Classification Report at Epoch 120:
                  precision    recall  f1-score   support

    infringement       0.83      0.66      0.74       194
non_infringement       0.71      0.85      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  26%|██▌       | 129/500 [02:36<07:03,  1.14s/it]

Epoch 130/500, Loss: 0.3026


Training Epochs:  26%|██▌       | 130/500 [02:37<07:53,  1.28s/it]

Test Accuracy at Epoch 130: 76.58%
Classification Report at Epoch 130:
                  precision    recall  f1-score   support

    infringement       0.84      0.66      0.74       194
non_infringement       0.71      0.87      0.78       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.76       380
    weighted avg       0.78      0.77      0.76       380



Training Epochs:  28%|██▊       | 139/500 [02:49<07:03,  1.17s/it]

Epoch 140/500, Loss: 0.2890


Training Epochs:  28%|██▊       | 140/500 [02:50<07:07,  1.19s/it]

Test Accuracy at Epoch 140: 76.32%
Classification Report at Epoch 140:
                  precision    recall  f1-score   support

    infringement       0.83      0.67      0.74       194
non_infringement       0.71      0.86      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.77      0.76       380
    weighted avg       0.78      0.76      0.76       380



Training Epochs:  30%|██▉       | 149/500 [03:02<07:52,  1.35s/it]

Epoch 150/500, Loss: 0.2772


Training Epochs:  30%|███       | 150/500 [03:04<08:40,  1.49s/it]

Test Accuracy at Epoch 150: 76.05%
Classification Report at Epoch 150:
                  precision    recall  f1-score   support

    infringement       0.82      0.69      0.75       194
non_infringement       0.72      0.84      0.77       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  32%|███▏      | 159/500 [03:15<07:39,  1.35s/it]

Epoch 160/500, Loss: 0.2586


Training Epochs:  32%|███▏      | 160/500 [03:17<07:57,  1.41s/it]

Test Accuracy at Epoch 160: 76.05%
Classification Report at Epoch 160:
                  precision    recall  f1-score   support

    infringement       0.82      0.68      0.74       194
non_infringement       0.72      0.84      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  34%|███▍      | 169/500 [03:27<06:02,  1.10s/it]

Epoch 170/500, Loss: 0.2446
Test Accuracy at Epoch 170: 75.26%


Training Epochs:  34%|███▍      | 170/500 [03:29<06:11,  1.13s/it]

Classification Report at Epoch 170:
                  precision    recall  f1-score   support

    infringement       0.83      0.65      0.73       194
non_infringement       0.70      0.86      0.77       186

        accuracy                           0.75       380
       macro avg       0.77      0.75      0.75       380
    weighted avg       0.77      0.75      0.75       380



Training Epochs:  36%|███▌      | 179/500 [03:40<06:49,  1.28s/it]

Epoch 180/500, Loss: 0.2302


Training Epochs:  36%|███▌      | 180/500 [03:42<07:24,  1.39s/it]

Test Accuracy at Epoch 180: 78.68%
Classification Report at Epoch 180:
                  precision    recall  f1-score   support

    infringement       0.83      0.73      0.78       194
non_infringement       0.75      0.84      0.79       186

        accuracy                           0.79       380
       macro avg       0.79      0.79      0.79       380
    weighted avg       0.79      0.79      0.79       380

New best model saved with accuracy 78.68% at epoch 180
Best Classification Report at Epoch 180:
                  precision    recall  f1-score   support

    infringement       0.83      0.73      0.78       194
non_infringement       0.75      0.84      0.79       186

        accuracy                           0.79       380
       macro avg       0.79      0.79      0.79       380
    weighted avg       0.79      0.79      0.79       380



Training Epochs:  38%|███▊      | 189/500 [03:52<06:36,  1.27s/it]

Epoch 190/500, Loss: 0.2135


Training Epochs:  38%|███▊      | 190/500 [03:54<07:00,  1.36s/it]

Test Accuracy at Epoch 190: 75.53%
Classification Report at Epoch 190:
                  precision    recall  f1-score   support

    infringement       0.83      0.65      0.73       194
non_infringement       0.70      0.86      0.77       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.75       380
    weighted avg       0.77      0.76      0.75       380



Training Epochs:  40%|███▉      | 199/500 [04:04<05:05,  1.02s/it]

Epoch 200/500, Loss: 0.2025


Training Epochs:  40%|████      | 200/500 [04:05<06:02,  1.21s/it]

Test Accuracy at Epoch 200: 75.79%
Classification Report at Epoch 200:
                  precision    recall  f1-score   support

    infringement       0.84      0.65      0.73       194
non_infringement       0.71      0.87      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  42%|████▏     | 209/500 [04:18<06:43,  1.39s/it]

Epoch 210/500, Loss: 0.1903


Training Epochs:  42%|████▏     | 210/500 [04:19<06:35,  1.37s/it]

Test Accuracy at Epoch 210: 76.32%
Classification Report at Epoch 210:
                  precision    recall  f1-score   support

    infringement       0.85      0.65      0.74       194
non_infringement       0.71      0.88      0.78       186

        accuracy                           0.76       380
       macro avg       0.78      0.77      0.76       380
    weighted avg       0.78      0.76      0.76       380



Training Epochs:  44%|████▍     | 219/500 [04:30<05:56,  1.27s/it]

Epoch 220/500, Loss: 0.1798


Training Epochs:  44%|████▍     | 220/500 [04:32<06:41,  1.43s/it]

Test Accuracy at Epoch 220: 75.79%
Classification Report at Epoch 220:
                  precision    recall  f1-score   support

    infringement       0.84      0.65      0.73       194
non_infringement       0.71      0.87      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  46%|████▌     | 229/500 [04:43<05:19,  1.18s/it]

Epoch 230/500, Loss: 0.1772


Training Epochs:  46%|████▌     | 230/500 [04:44<05:37,  1.25s/it]

Test Accuracy at Epoch 230: 75.79%
Classification Report at Epoch 230:
                  precision    recall  f1-score   support

    infringement       0.84      0.65      0.73       194
non_infringement       0.71      0.87      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  48%|████▊     | 239/500 [04:56<06:01,  1.39s/it]

Epoch 240/500, Loss: 0.1678
Test Accuracy at Epoch 240: 77.11%


Training Epochs:  48%|████▊     | 240/500 [04:58<05:59,  1.38s/it]

Classification Report at Epoch 240:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  50%|████▉     | 249/500 [05:08<04:45,  1.14s/it]

Epoch 250/500, Loss: 0.1595


Training Epochs:  50%|█████     | 250/500 [05:09<05:16,  1.27s/it]

Test Accuracy at Epoch 250: 77.37%
Classification Report at Epoch 250:
                  precision    recall  f1-score   support

    infringement       0.83      0.70      0.76       194
non_infringement       0.73      0.85      0.79       186

        accuracy                           0.77       380
       macro avg       0.78      0.78      0.77       380
    weighted avg       0.78      0.77      0.77       380



Training Epochs:  52%|█████▏    | 259/500 [05:22<05:23,  1.34s/it]

Epoch 260/500, Loss: 0.1550


Training Epochs:  52%|█████▏    | 260/500 [05:23<05:28,  1.37s/it]

Test Accuracy at Epoch 260: 75.79%
Classification Report at Epoch 260:
                  precision    recall  f1-score   support

    infringement       0.83      0.66      0.74       194
non_infringement       0.71      0.86      0.78       186

        accuracy                           0.76       380
       macro avg       0.77      0.76      0.76       380
    weighted avg       0.77      0.76      0.76       380



Training Epochs:  54%|█████▍    | 269/500 [05:35<05:10,  1.34s/it]

Epoch 270/500, Loss: 0.1505


Training Epochs:  54%|█████▍    | 270/500 [05:37<05:18,  1.38s/it]

Test Accuracy at Epoch 270: 76.84%
Classification Report at Epoch 270:
                  precision    recall  f1-score   support

    infringement       0.85      0.66      0.74       194
non_infringement       0.71      0.88      0.79       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.77       380
    weighted avg       0.78      0.77      0.77       380



Training Epochs:  56%|█████▌    | 279/500 [05:49<05:16,  1.43s/it]

Epoch 280/500, Loss: 0.1466


Training Epochs:  56%|█████▌    | 280/500 [05:51<05:43,  1.56s/it]

Test Accuracy at Epoch 280: 77.11%
Classification Report at Epoch 280:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  58%|█████▊    | 289/500 [06:02<04:27,  1.27s/it]

Epoch 290/500, Loss: 0.1434


Training Epochs:  58%|█████▊    | 290/500 [06:04<04:52,  1.39s/it]

Test Accuracy at Epoch 290: 76.84%
Classification Report at Epoch 290:
                  precision    recall  f1-score   support

    infringement       0.85      0.66      0.74       194
non_infringement       0.71      0.88      0.79       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.77       380
    weighted avg       0.78      0.77      0.77       380



Training Epochs:  60%|█████▉    | 299/500 [06:16<04:30,  1.35s/it]

Epoch 300/500, Loss: 0.1407


Training Epochs:  60%|██████    | 300/500 [06:18<04:30,  1.35s/it]

Test Accuracy at Epoch 300: 77.11%
Classification Report at Epoch 300:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  62%|██████▏   | 309/500 [06:30<04:18,  1.35s/it]

Epoch 310/500, Loss: 0.1414


Training Epochs:  62%|██████▏   | 310/500 [06:31<04:31,  1.43s/it]

Test Accuracy at Epoch 310: 76.84%
Classification Report at Epoch 310:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  64%|██████▍   | 319/500 [06:44<04:22,  1.45s/it]

Epoch 320/500, Loss: 0.1375


Training Epochs:  64%|██████▍   | 320/500 [06:46<04:44,  1.58s/it]

Test Accuracy at Epoch 320: 77.11%
Classification Report at Epoch 320:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  66%|██████▌   | 329/500 [06:57<03:27,  1.21s/it]

Epoch 330/500, Loss: 0.1345


Training Epochs:  66%|██████▌   | 330/500 [06:58<03:40,  1.29s/it]

Test Accuracy at Epoch 330: 77.37%
Classification Report at Epoch 330:
                  precision    recall  f1-score   support

    infringement       0.86      0.66      0.75       194
non_infringement       0.72      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.78      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  68%|██████▊   | 339/500 [07:10<03:25,  1.28s/it]

Epoch 340/500, Loss: 0.1329


Training Epochs:  68%|██████▊   | 340/500 [07:11<03:43,  1.40s/it]

Test Accuracy at Epoch 340: 77.37%
Classification Report at Epoch 340:
                  precision    recall  f1-score   support

    infringement       0.86      0.66      0.75       194
non_infringement       0.72      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.78      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  70%|██████▉   | 349/500 [07:23<03:25,  1.36s/it]

Epoch 350/500, Loss: 0.1313


Training Epochs:  70%|███████   | 350/500 [07:24<03:33,  1.43s/it]

Test Accuracy at Epoch 350: 77.11%
Classification Report at Epoch 350:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  72%|███████▏  | 359/500 [07:34<02:21,  1.00s/it]

Epoch 360/500, Loss: 0.1299


Training Epochs:  72%|███████▏  | 360/500 [07:35<02:37,  1.12s/it]

Test Accuracy at Epoch 360: 77.37%
Classification Report at Epoch 360:
                  precision    recall  f1-score   support

    infringement       0.86      0.66      0.75       194
non_infringement       0.72      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.78      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  74%|███████▍  | 370/500 [07:44<01:26,  1.51it/s]

Epoch 370/500, Loss: 0.1299
Test Accuracy at Epoch 370: 79.21%
Classification Report at Epoch 370:
                  precision    recall  f1-score   support

    infringement       0.85      0.72      0.78       194
non_infringement       0.75      0.87      0.80       186

        accuracy                           0.79       380
       macro avg       0.80      0.79      0.79       380
    weighted avg       0.80      0.79      0.79       380

New best model saved with accuracy 79.21% at epoch 370
Best Classification Report at Epoch 370:
                  precision    recall  f1-score   support

    infringement       0.85      0.72      0.78       194
non_infringement       0.75      0.87      0.80       186

        accuracy                           0.79       380
       macro avg       0.80      0.79      0.79       380
    weighted avg       0.80      0.79      0.79       380



Training Epochs:  76%|███████▌  | 380/500 [07:49<01:04,  1.86it/s]

Epoch 380/500, Loss: 0.1294
Test Accuracy at Epoch 380: 77.11%
Classification Report at Epoch 380:
                  precision    recall  f1-score   support

    infringement       0.86      0.65      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.77      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  78%|███████▊  | 390/500 [07:55<01:14,  1.48it/s]

Epoch 390/500, Loss: 0.1276
Test Accuracy at Epoch 390: 77.89%
Classification Report at Epoch 390:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



Training Epochs:  80%|████████  | 400/500 [08:00<00:42,  2.38it/s]

Epoch 400/500, Loss: 0.1256
Test Accuracy at Epoch 400: 77.37%
Classification Report at Epoch 400:
                  precision    recall  f1-score   support

    infringement       0.86      0.66      0.75       194
non_infringement       0.72      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.79      0.78      0.77       380
    weighted avg       0.79      0.77      0.77       380



Training Epochs:  82%|████████▏ | 410/500 [08:05<00:37,  2.40it/s]

Epoch 410/500, Loss: 0.1252
Test Accuracy at Epoch 410: 77.89%
Classification Report at Epoch 410:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



Training Epochs:  84%|████████▍ | 421/500 [08:10<00:26,  2.99it/s]

Epoch 420/500, Loss: 0.1234
Test Accuracy at Epoch 420: 77.89%
Classification Report at Epoch 420:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



Training Epochs:  86%|████████▌ | 430/500 [08:14<00:33,  2.11it/s]

Epoch 430/500, Loss: 0.1266
Test Accuracy at Epoch 430: 76.32%
Classification Report at Epoch 430:
                  precision    recall  f1-score   support

    infringement       0.86      0.64      0.73       194
non_infringement       0.70      0.89      0.79       186

        accuracy                           0.76       380
       macro avg       0.78      0.77      0.76       380
    weighted avg       0.78      0.76      0.76       380



Training Epochs:  88%|████████▊ | 439/500 [08:20<00:34,  1.75it/s]

Epoch 440/500, Loss: 0.1220


Training Epochs:  88%|████████▊ | 440/500 [08:21<00:41,  1.45it/s]

Test Accuracy at Epoch 440: 76.58%
Classification Report at Epoch 440:
                  precision    recall  f1-score   support

    infringement       0.86      0.64      0.74       194
non_infringement       0.71      0.89      0.79       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.76       380
    weighted avg       0.79      0.77      0.76       380



Training Epochs:  90%|█████████ | 450/500 [08:27<00:39,  1.28it/s]

Epoch 450/500, Loss: 0.1237
Test Accuracy at Epoch 450: 76.05%
Classification Report at Epoch 450:
                  precision    recall  f1-score   support

    infringement       0.86      0.63      0.73       194
non_infringement       0.70      0.89      0.78       186

        accuracy                           0.76       380
       macro avg       0.78      0.76      0.76       380
    weighted avg       0.78      0.76      0.76       380



Training Epochs:  92%|█████████▏| 460/500 [08:34<00:26,  1.53it/s]

Epoch 460/500, Loss: 0.1233
Test Accuracy at Epoch 460: 76.58%
Classification Report at Epoch 460:
                  precision    recall  f1-score   support

    infringement       0.85      0.65      0.74       194
non_infringement       0.71      0.88      0.79       186

        accuracy                           0.77       380
       macro avg       0.78      0.77      0.76       380
    weighted avg       0.78      0.77      0.76       380



Training Epochs:  94%|█████████▍| 470/500 [08:38<00:12,  2.31it/s]

Epoch 470/500, Loss: 0.1200
Test Accuracy at Epoch 470: 77.89%
Classification Report at Epoch 470:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



Training Epochs:  96%|█████████▌| 480/500 [08:44<00:10,  1.87it/s]

Epoch 480/500, Loss: 0.1197
Test Accuracy at Epoch 480: 77.89%
Classification Report at Epoch 480:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



Training Epochs:  98%|█████████▊| 490/500 [08:48<00:02,  3.62it/s]

Epoch 490/500, Loss: 0.1189
Test Accuracy at Epoch 490: 77.89%
Classification Report at Epoch 490:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



Training Epochs: 100%|██████████| 500/500 [08:49<00:00,  1.06s/it]
  custom_mlp.load_state_dict(torch.load(checkpoint_path))


Epoch 500/500, Loss: 0.1182
Test Accuracy at Epoch 500: 77.89%
Classification Report at Epoch 500:
                  precision    recall  f1-score   support

    infringement       0.85      0.69      0.76       194
non_infringement       0.73      0.88      0.80       186

        accuracy                           0.78       380
       macro avg       0.79      0.78      0.78       380
    weighted avg       0.79      0.78      0.78       380



RuntimeError: Error(s) in loading state_dict for CustomMLP:
	Missing key(s) in state_dict: "down.weight", "down.bias", "gate.weight", "gate.bias", "up.weight", "up.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "loss". 

In [None]:
def save_checkpoint(model, optimizer, epoch, loss, filepath):
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to '{filepath}'.")

save_checkpoint(custom_mlp, torch.optim.Adam(custom_mlp.parameters()), len(losses), losses[-1], checkpoint_file)

Checkpoint saved to '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_last_layer.pth'.


In [None]:
print(f"Final Model Accuracy: {best_accuracy * 100:.2f}%")
y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()
print(classification_report(y_test, y_pred_final, target_names=["infringement", "non_infringement"]))

  y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()


Final Model Accuracy: 81.84%
                  precision    recall  f1-score   support

    infringement       0.87      0.76      0.81       194
non_infringement       0.78      0.88      0.83       186

        accuracy                           0.82       380
       macro avg       0.82      0.82      0.82       380
    weighted avg       0.82      0.82      0.82       380

