In [1]:
# import packages & variables
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification
import json

# Parameters
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_10.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra_10.infringement.json'
checkpoint_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/models/train_input_reference_last_layer_Siamese.pth'


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define CustumMLP for internal states train
class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

In [3]:
class SiameseNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SiameseNetwork, self).__init__()
        self.encoder = CustomMLP(input_dim, hidden_dim)

    def forward(self, x1, x2):
        # Pass both inputs through the shared encoder
        encoded_x1 = self.encoder(x1)
        encoded_x2 = self.encoder(x2)
        
        # Compute the absolute difference between the two encoded outputs
        diff = torch.abs(encoded_x1 - encoded_x2)
        
        # Optionally, add more layers here if needed
        similarity = F.sigmoid(diff)  # Similarity score can be based on this
        
        return similarity

In [4]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss function.
    Takes embeddings of two samples and a label: 1 if similar, 0 if dissimilar.
    """
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss = (label) * torch.pow(euclidean_distance, 2) + \
               (1 - label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        return loss.mean()

In [5]:
# Extract hidden states/reference embeddings
def extract_hidden_states(texts, model, tokenizer, batch_size=128):
    hidden_states = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        hidden_states.append(outputs.hidden_states[-1].mean(dim=1).cpu().numpy())
    return np.vstack(hidden_states)

def extract_reference_embeddings(references, model, tokenizer, batch_size=128):
    embeddings = []
    for i in tqdm(range(0, len(references), batch_size), desc="Processing references"):
        batch_references = references[i:i + batch_size]
        inputs = tokenizer(batch_references, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.pooler_output.cpu().numpy())
    return np.vstack(embeddings)

In [6]:
# load data for infringement & non infringement
def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_json_data = json.load(file)

    non_infringement_outputs = [entry['input'] for entry in non_infringement_json_data]
    non_infringement_references = [entry['reference'] for entry in non_infringement_json_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_json_data = json.load(file)

    infringement_outputs = [entry['input'] for entry in infringement_json_data]
    infringement_references = [entry['reference'] for entry in infringement_json_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement

In [7]:
def train_siamese_model(X1_train, X2_train, y_train, X1_test, X2_test, y_test, input_dim, hidden_dim, epochs=500, lr=0.001, checkpoint_path="siamese_model.pth"):
    model = SiameseNetwork(input_dim, hidden_dim)
    criterion = ContrastiveLoss(margin=1.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    X1_train_tensor = torch.tensor(X1_train, dtype=torch.float32)
    X2_train_tensor = torch.tensor(X2_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

    best_loss = float('inf')
    best_model_state = None
    best_epoch = 0
    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        model.train()
        optimizer.zero_grad()

        # Forward pass for training data pairs
        outputs = model(X1_train_tensor, X2_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            
            # Validation step
            model.eval()
            X1_test_tensor = torch.tensor(X1_test, dtype=torch.float32)
            X2_test_tensor = torch.tensor(X2_test, dtype=torch.float32)
            with torch.no_grad():
                y_pred = model(X1_test_tensor, X2_test_tensor)
                y_pred = (y_pred > 0.5).float().numpy()
            
            accuracy = accuracy_score(y_test, y_pred)
            print(f"Test Accuracy at Epoch {epoch + 1}: {accuracy * 100:.2f}%")

            if loss < best_loss:
                best_loss = loss
                best_model_state = model.state_dict()
                best_epoch = epoch + 1
                torch.save(best_model_state, checkpoint_path)
                print(f"New best model saved at epoch {best_epoch} with loss {best_loss:.4f}")

    model.load_state_dict(torch.load(checkpoint_path))

    print(f"Best Model was saved at epoch {best_epoch} with best loss {best_loss:.4f}")
    return model, losses, best_loss


In [8]:
# def main for reference embedding/train
def main(model_name, non_infringement_file, infringement_file, checkpoint_path):
    tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
    model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
    tokenizer.pad_token = tokenizer.eos_token
    bert_tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
    bert_model = AutoModel.from_pretrained('google-bert/bert-base-uncased')
    bert_tokenizer.pad_token = tokenizer.eos_token

    non_infringement_outputs, non_infringement_references, y_non_infringement, infringement_outputs, infringement_references, y_infringement = load_data(
        non_infringement_file, infringement_file
    )

    y_non_infringement = np.array(y_non_infringement)
    y_infringement = np.array(y_infringement)

    print("Extracting hidden states for non_infringement texts...")
    X_non_infringement = extract_hidden_states(non_infringement_outputs, model, tokenizer)
    print("Extracting reference embeddings for non_infringement texts...")
    reference_embeddings_non_infringement = extract_reference_embeddings(non_infringement_references, bert_model, bert_tokenizer)
    X_non_infringement_combined = np.hstack([X_non_infringement, reference_embeddings_non_infringement])

    print("Extracting hidden states for infringement texts...")
    X_infringement = extract_hidden_states(infringement_outputs, model, tokenizer)
    print("Extracting reference embeddings for infringement texts...")
    reference_embeddings_infringement = extract_reference_embeddings(infringement_references, bert_model, bert_tokenizer)
    X_infringement_combined = np.hstack([X_infringement, reference_embeddings_infringement])

    split_index_non_infringement = int(0.8 * len(X_non_infringement_combined))
    X_non_infringement_train = X_non_infringement_combined[:split_index_non_infringement]
    X_non_infringement_test = X_non_infringement_combined[split_index_non_infringement:]
    y_non_infringement_train = y_non_infringement[:split_index_non_infringement]
    y_non_infringement_test = y_non_infringement[split_index_non_infringement:]

    split_index_infringement = int(0.8 * len(X_infringement_combined))
    X_infringement_train = X_infringement_combined[:split_index_infringement]
    X_infringement_test = X_infringement_combined[split_index_infringement:]
    y_infringement_train = y_infringement[:split_index_infringement]
    y_infringement_test = y_infringement[split_index_infringement:]

    # X_train = np.vstack((X_non_infringement_train, X_infringement_train))
    # X_test = np.vstack((X_non_infringement_test, X_infringement_test))
    # y_train = np.concatenate((y_non_infringement_train, y_infringement_train))
    # y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

    
    # Prepare paired data for Siamese Network
    X_train_hidden_states = np.vstack((X_non_infringement_train, X_infringement_train))
    X_train_reference_embeddings = np.vstack((reference_embeddings_non_infringement, reference_embeddings_infringement))
    y_train = np.concatenate((y_non_infringement_train, y_infringement_train))

    X_test_hidden_states = np.vstack((X_non_infringement_test, X_infringement_test))
    X_test_reference_embeddings = np.vstack((reference_embeddings_non_infringement, reference_embeddings_infringement))
    y_test = np.concatenate((y_non_infringement_test, y_infringement_test))

    input_dim = X_train_hidden_states.shape[1]
    hidden_dim = 256
    print(f"Training MLP model with input_dim={input_dim} and hidden_dim={hidden_dim}")
    
    best_model, losses, best_loss = train_siamese_model(
        X_train_hidden_states, X_train_reference_embeddings, y_train,
        X_test_hidden_states, X_test_reference_embeddings, y_test,
        input_dim, hidden_dim
    )


In [9]:

# Run main
main(model_name, non_infringement_file, infringement_file, checkpoint_file)


Loading checkpoint shards: 100%|██████████| 4/4 [00:09<00:00,  2.30s/it]


Extracting hidden states for non_infringement texts...


Processing data batches:   0%|          | 0/3 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Processing data batches: 100%|██████████| 3/3 [11:51<00:00, 237.31s/it]


Extracting reference embeddings for non_infringement texts...


Processing references: 100%|██████████| 3/3 [00:34<00:00, 11.39s/it]


Extracting hidden states for infringement texts...


Processing data batches: 100%|██████████| 3/3 [11:06<00:00, 222.17s/it]


Extracting reference embeddings for infringement texts...


Processing references: 100%|██████████| 3/3 [00:31<00:00, 10.42s/it]


Training MLP model with input_dim=4864 and hidden_dim=256


Training Epochs:   0%|          | 0/500 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (580x768 and 4864x256)