In [8]:
# import packages & variables
import argparse
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import json

# Parameters
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra.infringement.json'
checkpoint_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_reference_last_layer_median_sharedrepresentation.pth'


In [9]:
class SharedRepresentationModel(nn.Module):
    def __init__(self, input_dim, ref_dim, projection_dim, hidden_dim):
        super(SharedRepresentationModel, self).__init__()
        # Projection layers for input states and reference embeddings
        self.input_projection = nn.Linear(input_dim, projection_dim)
        self.ref_projection = nn.Linear(ref_dim, projection_dim)

        # Gating mechanism and final classification layer
        self.gate = nn.Linear(projection_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()
        
    def forward(self, input_states, ref_embeddings):
        # Project input states and reference embeddings to shared space
        input_projected = self.input_projection(input_states)
        ref_projected = self.ref_projection(ref_embeddings)
        
        # Optionally, compute some similarity or directly combine both features
        combined = input_projected * ref_projected  # Element-wise multiplication for alignment
        
        # Apply gating mechanism
        gated_output = self.activation(self.gate(combined))
        
        # Final output layer
        return self.output_layer(gated_output)


In [10]:
# Extract hidden states/reference embeddings
def extract_hidden_states(texts, model, tokenizer, batch_size=4):
    hidden_states = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        hidden_states.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
    return np.vstack(hidden_states)

def extract_reference_embeddings(references, model, tokenizer, batch_size=4):
    embeddings = []
    for i in tqdm(range(0, len(references), batch_size), desc="Processing references"):
        batch_references = references[i:i + batch_size]
        inputs = tokenizer(batch_references, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.pooler_output.cpu().numpy())
    return np.vstack(embeddings)

In [11]:
# load data for infringement & non infringement
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement

In [12]:
def train_model(X_train_input, X_train_ref, y_train, X_test_input, X_test_ref, y_test, 
                input_dim, ref_dim, projection_dim, hidden_dim, epochs=500, lr=0.001, checkpoint_path=checkpoint_file):
    # Define the shared representation model
    shared_model = SharedRepresentationModel(input_dim, ref_dim, projection_dim, hidden_dim)
    
    # Binary Cross-Entropy Loss and Adam Optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(shared_model.parameters(), lr=lr)

    # Convert training data to tensors
    X_train_input_tensor = torch.tensor(X_train_input, dtype=torch.float32)
    X_train_ref_tensor = torch.tensor(X_train_ref, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

    best_accuracy = -float('inf')
    best_model_state = None
    best_epoch = 0
    losses = []

    # Training loop
    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        shared_model.train()
        optimizer.zero_grad()

        # Forward pass
        outputs = shared_model(X_train_input_tensor, X_train_ref_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            
            # Evaluate on test set
            shared_model.eval()
            X_test_input_tensor = torch.tensor(X_test_input, dtype=torch.float32)
            X_test_ref_tensor = torch.tensor(X_test_ref, dtype=torch.float32)
            with torch.no_grad():
                y_pred_logits = shared_model(X_test_input_tensor, X_test_ref_tensor)
                y_pred = (torch.sigmoid(y_pred_logits) > 0.5).float().numpy()
            
            accuracy = accuracy_score(y_test, y_pred)
            print(f"Test Accuracy at Epoch {epoch + 1}: {accuracy * 100:.2f}%")
            
            report = classification_report(y_test, y_pred, target_names=["infringement", "non_infringement"])
            print(f"Classification Report at Epoch {epoch + 1}:\n{report}")

            # Save best model
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_state = shared_model.state_dict()
                best_epoch = epoch + 1
                torch.save(best_model_state, checkpoint_path)
                print(f"New best model saved with accuracy {best_accuracy * 100:.2f}% at epoch {best_epoch}")
                print(f"Best Classification Report at Epoch {best_epoch}:\n{report}")

    # Load best model after training
    shared_model.load_state_dict(torch.load(checkpoint_path))

    # Plot training loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.show()

    print(f"Best Model was saved at epoch {best_epoch} with accuracy {best_accuracy * 100:.2f}%")
    print(f"Best Classification Report:\n{report}")
    return shared_model, losses, best_accuracy


In [13]:
def main(model_name, non_infringement_file, infringement_file, checkpoint_path):
    # Load the tokenizer and models
    tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
    model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
    tokenizer.pad_token = tokenizer.eos_token
    
    # For reference embeddings, we use BERT
    bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    bert_model = AutoModel.from_pretrained('bert-base-uncased')
    bert_tokenizer.pad_token = tokenizer.eos_token

    # Load infringement and non-infringement data
    non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement = load_data(
        non_infringement_file, infringement_file
    )

    # Convert labels to numpy arrays
    y_non_infringement = np.array(y_non_infringement)
    y_infringement = np.array(y_infringement)

    # Extract hidden states for non-infringement outputs
    print("Extracting hidden states for non-infringement outputs...")
    X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)
    
    # Extract reference embeddings for non-infringement references
    print("Extracting reference embeddings for non-infringement references...")
    reference_embeddings_non_infringement = extract_reference_embeddings(non_infringement_references, bert_model, bert_tokenizer)

    # Extract hidden states for infringement outputs
    print("Extracting hidden states for infringement outputs...")
    X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)
    
    # Extract reference embeddings for infringement references
    print("Extracting reference embeddings for infringement references...")
    reference_embeddings_infringement = extract_reference_embeddings(infringement_references, bert_model, bert_tokenizer)

    # Split data into training and test sets
    split_index_non_infringement = int(0.8 * len(X_non_infringement))
    X_non_infringement_train_input = X_non_infringement[:split_index_non_infringement]
    X_non_infringement_test_input = X_non_infringement[split_index_non_infringement:]
    reference_embeddings_non_infringement_train = reference_embeddings_non_infringement[:split_index_non_infringement]
    reference_embeddings_non_infringement_test = reference_embeddings_non_infringement[split_index_non_infringement:]
    y_non_infringement_train = y_non_infringement[:split_index_non_infringement]
    y_non_infringement_test = y_non_infringement[split_index_non_infringement:]

    split_index_infringement = int(0.8 * len(X_infringement))
    X_infringement_train_input = X_infringement[:split_index_infringement]
    X_infringement_test_input = X_infringement[split_index_infringement:]
    reference_embeddings_infringement_train = reference_embeddings_infringement[:split_index_infringement]
    reference_embeddings_infringement_test = reference_embeddings_infringement[split_index_infringement:]
    y_infringement_train = y_infringement[:split_index_infringement]
    y_infringement_test = y_infringement[split_index_infringement:]

    # Combine infringement and non-infringement data
    X_train_input = np.vstack((X_non_infringement_train_input, X_infringement_train_input))
    X_test_input = np.vstack((X_non_infringement_test_input, X_infringement_test_input))
    X_train_ref = np.vstack((reference_embeddings_non_infringement_train, reference_embeddings_infringement_train))
    X_test_ref = np.vstack((reference_embeddings_non_infringement_test, reference_embeddings_infringement_test))
    y_train = np.concatenate((y_non_infringement_train, y_infringement_train))
    y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

    # Set dimensions for training
    input_dim = X_train_input.shape[1]
    ref_dim = X_train_ref.shape[1]
    projection_dim = 256  # Set projection dimensionality
    hidden_dim = 256      # Hidden layer dimensionality

    print(f"Training model with input_dim={input_dim}, ref_dim={ref_dim}, projection_dim={projection_dim}, hidden_dim={hidden_dim}")

    # Train the shared representation model
    best_model, losses, best_accuracy = train_model(X_train_input, X_train_ref, y_train, 
                                                    X_test_input, X_test_ref, y_test, 
                                                    input_dim, ref_dim, projection_dim, hidden_dim)


In [None]:

# Run main
main(model_name, non_infringement_file, infringement_file, checkpoint_file)
