In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import json
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

os.environ["CUDA_VISIBLE_DEVICES"] = ""

model_name = ''
non_infringement_file = 'literal.non_infringement.json'
infringement_file = 'literal.infringement.json'
checkpoint_file = ''

class CustomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomMLP, self).__init__()
        self.down = nn.Linear(input_dim, hidden_dim)
        self.gate = nn.Linear(input_dim, hidden_dim)
        self.up = nn.Linear(hidden_dim, 1)
        self.activation = nn.SiLU()

    def forward(self, x):
        down_output = self.down(x)
        gate_output = self.gate(x)
        gated_output = down_output * self.activation(gate_output)
        return self.up(gated_output)

def load_data(non_infringement_file, infringement_file):
    with open(non_infringement_file, 'r', encoding='utf-8') as file:
        non_infringement_data = json.load(file)
    non_infringement_outputs = [entry['input'] for entry in non_infringement_data]
    y_non_infringement = [1] * len(non_infringement_outputs)

    with open(infringement_file, 'r', encoding='utf-8') as file:
        infringement_data = json.load(file)
    infringement_outputs = [entry['input'] for entry in infringement_data]
    y_infringement = [0] * len(infringement_outputs)

    return non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement

def generate_random_hidden_states(num_samples, hidden_dim):
    return np.random.randn(num_samples, hidden_dim)

def load_large_model(model_name):
    print("Loading model across multiple GPUs...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="balanced", 
        offload_folder="offload",
        offload_state_dict=True,
        output_hidden_states=True
    )
    print("Model loaded successfully.")
    return model

def train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim, epochs=250, lr=0.001, checkpoint_path=checkpoint_file):
    custom_mlp = CustomMLP(input_dim, hidden_dim).to("cuda")
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(custom_mlp.parameters(), lr=lr)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to("cuda")
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to("cuda")

    best_f1 = -float('inf') 
    best_accuracy = -float('inf') 
    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        custom_mlp.train()
        optimizer.zero_grad()
        outputs = custom_mlp(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")
            custom_mlp.eval()
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to("cuda")
            with torch.no_grad():
                y_pred_logits = custom_mlp(X_test_tensor)
                y_pred = (torch.sigmoid(y_pred_logits) > 0.5).float().cpu().numpy()

            accuracy = accuracy_score(y_test, y_pred)
            f1 = f1_score(y_test, y_pred)
            print(f"Test Accuracy: {accuracy * 100:.2f}%, F1-score: {f1:.4f}")
            
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                torch.save(custom_mlp.state_dict(), checkpoint_path) 
                print(f"New best model saved with Accuracy {best_accuracy * 100:.2f}%.")
            if f1 > best_f1:
                best_f1 = f1
                torch.save(custom_mlp.state_dict(), checkpoint_path)
                print(f"New best model saved with F1-score {best_f1:.4f}.")

    plt.plot(losses)
    plt.title("Training Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    return custom_mlp, best_accuracy, best_f1

tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
model = load_large_model(model_name)
tokenizer.pad_token = tokenizer.eos_token

non_infringement_outputs, y_non_infringement, infringement_outputs, y_infringement = load_data(non_infringement_file, infringement_file)

print("Generating random hidden states for non-infringement texts...")
X_non_infringement = generate_random_hidden_states(len(non_infringement_outputs), 768) 

print("Generating random hidden states for infringement texts...")
X_infringement = generate_random_hidden_states(len(infringement_outputs), 768)

split_index_non_infringement = int(0.8 * len(X_non_infringement))
split_index_infringement = int(0.8 * len(X_infringement))

X_train = np.vstack((X_non_infringement[:split_index_non_infringement], X_infringement[:split_index_infringement]))
X_test = np.vstack((X_non_infringement[split_index_non_infringement:], X_infringement[split_index_infringement:]))
y_train = np.concatenate((y_non_infringement[:split_index_non_infringement], y_infringement[:split_index_infringement]))
y_test = np.concatenate((y_non_infringement[split_index_non_infringement:], y_infringement[split_index_infringement:]))

print("Data successfully split into training and test sets.")

input_dim = X_train.shape[1]
hidden_dim = 256

trained_model, best_accuracy, best_f1 = train_model(X_train, y_train, X_test, y_test, input_dim, hidden_dim)
print(f"Training complete. Best Accuracy: {best_accuracy * 100:.2f}%, Best F1-score: {best_f1:.4f}")
