In [1]:
# import packages & variables
import argparse
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification
import json
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

# Parameters
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra.infringement.json'
checkpoint_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_reference_last_layer.pth'


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define CustumMLP for internal states train
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

In [3]:
# Extract hidden states/reference embeddings
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = nn.DataParallel(model)
    hidden_states = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        hidden_states.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
    return np.vstack(hidden_states)

def extract_reference_embeddings(references, model, tokenizer, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = nn.DataParallel(model)
    embeddings = []
    for i in tqdm(range(0, len(references), batch_size), desc="Processing references"):
        batch_references = references[i:i + batch_size]
        inputs = tokenizer(batch_references, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.pooler_output.cpu().numpy())
    return np.vstack(embeddings)

In [4]:
# load data for infringement & non infringement
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement

In [5]:
# Train for best model
def train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim, epochs=500, lr=0.001, checkpoint_path=checkpoint_file):
    custom_mlp = CustomMLP(input_dim, hidden_dim)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(custom_mlp.parameters(), lr=lr)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

    best_accuracy = -float('inf')
    best_model_state = None
    best_epoch = 0
    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        custom_mlp.train()
        optimizer.zero_grad()
        outputs = custom_mlp(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            
            custom_mlp.eval()
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
            with torch.no_grad():
                y_pred_logits = custom_mlp(X_test_tensor)
                y_pred = (torch.sigmoid(y_pred_logits) > 0.5).float().numpy()
            
            accuracy = accuracy_score(y_test, y_pred)
            print(f"Test Accuracy at Epoch {epoch + 1}: {accuracy * 100:.2f}%")
            
            report = classification_report(y_test, y_pred, target_names=["infringement", "non_infringement"])
            print(f"Classification Report at Epoch {epoch + 1}:\n{report}")

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_state = custom_mlp.state_dict()
                best_epoch = epoch + 1
                torch.save(best_model_state, checkpoint_path)
                print(f"New best model saved with accuracy {best_accuracy * 100:.2f}% at epoch {best_epoch}")
                print(f"Best Classification Report at Epoch {best_epoch}:\n{report}")

    custom_mlp.load_state_dict(torch.load(checkpoint_path))

    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.show()

    print(f"Final Model Accuracy: {best_accuracy * 100:.2f}%")
    
    return custom_mlp, losses, best_accuracy

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token
bert_tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
bert_model = AutoModel.from_pretrained('google-bert/bert-base-uncased')
bert_tokenizer.pad_token = tokenizer.eos_token

non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement = load_data(
    non_infringement_file, infringement_file
)

y_non_infringement = np.array(y_non_infringement)
y_infringement = np.array(y_infringement)


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s]


In [7]:

# Function to add a separator between two sets of hidden states
def concatenate_with_separator(array1, array2, separator_value=0):
    # Create a separator with the same number of columns as the hidden states
    separator = np.full((array1.shape[0], 1), separator_value)  # Assuming you want a single-column separator
    return np.hstack([array1, separator, array2])

print("Extracting hidden states for non_infringement texts...")
X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)
print("Extracting reference embeddings for non_infringement texts...")
reference_embeddings_non_infringement = extract_hidden_states(non_infringement_references, model, tokenizer)

# Concatenate with a separator
X_non_infringement_combined = concatenate_with_separator(X_non_infringement, reference_embeddings_non_infringement)

print("Extracting hidden states for infringement texts...")
X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)
print("Extracting reference embeddings for infringement texts...")
reference_embeddings_infringement = extract_hidden_states(infringement_references, model, tokenizer)

# Concatenate with a separator
X_infringement_combined = concatenate_with_separator(X_infringement, reference_embeddings_infringement)


Extracting hidden states for non_infringement texts...


Processing data batches:   0%|          | 0/65 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Processing data batches: 100%|██████████| 65/65 [00:16<00:00,  3.97it/s]


Extracting reference embeddings for non_infringement texts...


Processing data batches: 100%|██████████| 65/65 [00:20<00:00,  3.16it/s]


Extracting hidden states for infringement texts...


Processing data batches: 100%|██████████| 81/81 [00:23<00:00,  3.44it/s]


Extracting reference embeddings for infringement texts...


Processing data batches: 100%|██████████| 81/81 [00:23<00:00,  3.42it/s]


In [8]:
split_index_non_infringement = int(0.8 * len(X_non_infringement_combined))
X_non_infringement_train = X_non_infringement_combined[:split_index_non_infringement]
X_non_infringement_test = X_non_infringement_combined[split_index_non_infringement:]
y_non_infringement_train = y_non_infringement[:split_index_non_infringement]
y_non_infringement_test = y_non_infringement[split_index_non_infringement:]

split_index_infringement = int(0.8 * len(X_infringement_combined))
X_infringement_train = X_infringement_combined[:split_index_infringement]
X_infringement_test = X_infringement_combined[split_index_infringement:]
y_infringement_train = y_infringement[:split_index_infringement]
y_infringement_test = y_infringement[split_index_infringement:]

X_train = np.vstack((X_non_infringement_train, X_infringement_train))
X_test = np.vstack((X_non_infringement_test, X_infringement_test))
y_train = np.concatenate((y_non_infringement_train, y_infringement_train))
y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

print("Data successfully split into training and test sets.")

Data successfully split into training and test sets.


In [9]:
input_dim = X_train.shape[1]
hidden_dim = 256 
print(f"Training MLP model with input_dim={input_dim} and hidden_dim={hidden_dim}")

custom_mlp, losses, best_accuracy = train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim)

Training MLP model with input_dim=8193 and hidden_dim=256


Training Epochs:   2%|▏         | 12/500 [00:00<00:21, 22.41it/s]

Epoch 10/500, Loss: 0.4412
Test Accuracy at Epoch 10: 86.32%
Classification Report at Epoch 10:
                  precision    recall  f1-score   support

    infringement       0.89      0.86      0.88        65
non_infringement       0.83      0.87      0.85        52

        accuracy                           0.86       117
       macro avg       0.86      0.86      0.86       117
    weighted avg       0.86      0.86      0.86       117

New best model saved with accuracy 86.32% at epoch 10
Best Classification Report at Epoch 10:
                  precision    recall  f1-score   support

    infringement       0.89      0.86      0.88        65
non_infringement       0.83      0.87      0.85        52

        accuracy                           0.86       117
       macro avg       0.86      0.86      0.86       117
    weighted avg       0.86      0.86      0.86       117



Training Epochs:   5%|▍         | 23/500 [00:00<00:18, 25.11it/s]

Epoch 20/500, Loss: 0.1721
Test Accuracy at Epoch 20: 90.60%
Classification Report at Epoch 20:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117

New best model saved with accuracy 90.60% at epoch 20
Best Classification Report at Epoch 20:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:   7%|▋         | 34/500 [00:01<00:16, 28.31it/s]

Epoch 30/500, Loss: 0.0724
Test Accuracy at Epoch 30: 90.60%
Classification Report at Epoch 30:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:   9%|▉         | 45/500 [00:01<00:16, 28.15it/s]

Epoch 40/500, Loss: 0.0301
Test Accuracy at Epoch 40: 90.60%
Classification Report at Epoch 40:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  11%|█         | 56/500 [00:02<00:14, 30.63it/s]

Epoch 50/500, Loss: 0.0161
Test Accuracy at Epoch 50: 90.60%
Classification Report at Epoch 50:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  13%|█▎        | 64/500 [00:02<00:15, 27.49it/s]

Epoch 60/500, Loss: 0.0114
Test Accuracy at Epoch 60: 90.60%
Classification Report at Epoch 60:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  15%|█▍        | 74/500 [00:02<00:14, 29.25it/s]

Epoch 70/500, Loss: 0.0077
Test Accuracy at Epoch 70: 90.60%
Classification Report at Epoch 70:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  17%|█▋        | 86/500 [00:03<00:13, 30.52it/s]

Epoch 80/500, Loss: 0.0055
Test Accuracy at Epoch 80: 90.60%
Classification Report at Epoch 80:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  19%|█▉        | 94/500 [00:03<00:13, 29.92it/s]

Epoch 90/500, Loss: 0.0043
Test Accuracy at Epoch 90: 90.60%
Classification Report at Epoch 90:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  20%|██        | 102/500 [00:03<00:12, 32.27it/s]

Epoch 100/500, Loss: 0.0038
Test Accuracy at Epoch 100: 90.60%
Classification Report at Epoch 100:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  23%|██▎       | 114/500 [00:03<00:11, 32.47it/s]

Epoch 110/500, Loss: 0.0037
Test Accuracy at Epoch 110: 90.60%
Classification Report at Epoch 110:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  25%|██▌       | 126/500 [00:04<00:11, 32.66it/s]

Epoch 120/500, Loss: 0.0035
Test Accuracy at Epoch 120: 90.60%
Classification Report at Epoch 120:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  27%|██▋       | 134/500 [00:04<00:11, 31.71it/s]

Epoch 130/500, Loss: 0.0034
Test Accuracy at Epoch 130: 90.60%
Classification Report at Epoch 130:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  28%|██▊       | 142/500 [00:04<00:12, 29.26it/s]

Epoch 140/500, Loss: 0.0033
Test Accuracy at Epoch 140: 90.60%
Classification Report at Epoch 140:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  31%|███       | 154/500 [00:05<00:11, 29.86it/s]

Epoch 150/500, Loss: 0.0033
Test Accuracy at Epoch 150: 90.60%
Classification Report at Epoch 150:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  32%|███▏      | 162/500 [00:05<00:11, 29.89it/s]

Epoch 160/500, Loss: 0.0032
Test Accuracy at Epoch 160: 90.60%
Classification Report at Epoch 160:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  35%|███▍      | 174/500 [00:05<00:10, 30.81it/s]

Epoch 170/500, Loss: 0.0032
Test Accuracy at Epoch 170: 90.60%
Classification Report at Epoch 170:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  37%|███▋      | 185/500 [00:06<00:10, 30.43it/s]

Epoch 180/500, Loss: 0.0032
Test Accuracy at Epoch 180: 90.60%
Classification Report at Epoch 180:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  39%|███▊      | 193/500 [00:06<00:10, 28.64it/s]

Epoch 190/500, Loss: 0.0032
Test Accuracy at Epoch 190: 90.60%
Classification Report at Epoch 190:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  40%|████      | 201/500 [00:06<00:10, 27.71it/s]

Epoch 200/500, Loss: 0.0034
Test Accuracy at Epoch 200: 90.60%
Classification Report at Epoch 200:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  43%|████▎     | 213/500 [00:07<00:09, 29.00it/s]

Epoch 210/500, Loss: 0.0032
Test Accuracy at Epoch 210: 90.60%
Classification Report at Epoch 210:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  44%|████▍     | 221/500 [00:07<00:09, 30.27it/s]

Epoch 220/500, Loss: 0.0031
Test Accuracy at Epoch 220: 90.60%
Classification Report at Epoch 220:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  47%|████▋     | 233/500 [00:07<00:08, 30.87it/s]

Epoch 230/500, Loss: 0.0031
Test Accuracy at Epoch 230: 90.60%
Classification Report at Epoch 230:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  49%|████▉     | 245/500 [00:08<00:08, 31.36it/s]

Epoch 240/500, Loss: 0.0031
Test Accuracy at Epoch 240: 90.60%
Classification Report at Epoch 240:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  51%|█████     | 253/500 [00:08<00:08, 28.93it/s]

Epoch 250/500, Loss: 0.0031
Test Accuracy at Epoch 250: 90.60%
Classification Report at Epoch 250:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  52%|█████▏    | 261/500 [00:08<00:08, 27.24it/s]

Epoch 260/500, Loss: 0.0031
Test Accuracy at Epoch 260: 90.60%
Classification Report at Epoch 260:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  55%|█████▍    | 273/500 [00:09<00:08, 27.97it/s]

Epoch 270/500, Loss: 0.0031
Test Accuracy at Epoch 270: 90.60%
Classification Report at Epoch 270:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  57%|█████▋    | 285/500 [00:09<00:07, 28.52it/s]

Epoch 280/500, Loss: 0.0031
Test Accuracy at Epoch 280: 90.60%
Classification Report at Epoch 280:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  59%|█████▊    | 293/500 [00:09<00:06, 30.58it/s]

Epoch 290/500, Loss: 0.0031
Test Accuracy at Epoch 290: 90.60%
Classification Report at Epoch 290:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  61%|██████    | 304/500 [00:10<00:06, 29.55it/s]

Epoch 300/500, Loss: 0.0031
Test Accuracy at Epoch 300: 90.60%
Classification Report at Epoch 300:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  63%|██████▎   | 315/500 [00:10<00:06, 28.86it/s]

Epoch 310/500, Loss: 0.0031
Test Accuracy at Epoch 310: 90.60%
Classification Report at Epoch 310:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  65%|██████▌   | 327/500 [00:11<00:05, 31.87it/s]

Epoch 320/500, Loss: 0.0031
Test Accuracy at Epoch 320: 90.60%
Classification Report at Epoch 320:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  67%|██████▋   | 335/500 [00:11<00:05, 31.63it/s]

Epoch 330/500, Loss: 0.0031
Test Accuracy at Epoch 330: 90.60%
Classification Report at Epoch 330:
                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



Training Epochs:  69%|██████▊   | 343/500 [00:11<00:04, 33.62it/s]

In [21]:
def save_checkpoint(model, optimizer, epoch, loss, filepath):
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to '{filepath}'.")

save_checkpoint(custom_mlp, torch.optim.Adam(custom_mlp.parameters()), len(losses), losses[-1], checkpoint_file)

Checkpoint saved to '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_reference_last_layer.pth'.


In [22]:
y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()
print(classification_report(y_test, y_pred_final, target_names=["infringement", "non_infringement"]))

                  precision    recall  f1-score   support

    infringement       0.98      0.85      0.91        65
non_infringement       0.84      0.98      0.90        52

        accuracy                           0.91       117
       macro avg       0.91      0.91      0.91       117
    weighted avg       0.92      0.91      0.91       117



  y_pred_final = (torch.sigmoid(torch.tensor(custom_mlp(torch.tensor(X_test, dtype=torch.float32)))) > 0.5).float().numpy()
