In [1]:
# import packages & variables
import argparse
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModel
import json


# Parameters
model_name = 'meta-llama/Meta-Llama-3.1-8B'
non_infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra.non_infringement.json'
infringement_file = '/home/guangwei/LLM-COPYRIGHT/copyright_newVersion/test_division/extra.infringement.json'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_data(non_infringement_file, infringement_file):
    non_infringement_texts = []
    infringement_texts = []

    with open(non_infringement_file) as f:
        non_infringement_data = json.load(f)
        non_infringement_texts = [item['input'] for item in non_infringement_data]

    with open(infringement_file) as f:
        infringement_data = json.load(f)
        infringement_texts = [item['input'] for item in infringement_data]

    return non_infringement_texts, infringement_texts

In [3]:
def extract_last_token_hidden_states(texts, model, tokenizer, batch_size=2048):
    hidden_states = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing data batches"):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        hidden_states_list = []
        for layer in outputs.hidden_states:
            # print(f"Layer shape: {layer.shape}")
            if len(layer.shape) == 3:
                last_hidden_state = layer[:, -1, :]
            elif len(layer.shape) == 2:
                last_hidden_state = layer.unsqueeze(1)
            else:
                raise ValueError("Unexpected layer shape: {}".format(layer.shape))
                
            hidden_states_list.append(last_hidden_state.cpu().numpy())

        hidden_states.append(hidden_states_list)

    return np.vstack(hidden_states)

In [4]:
def generate_ngrams(texts, n=2):
    ngram_texts = []
    for text in texts:
        tokens = text.split()
        ngrams = [' '.join(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
        ngram_texts.extend(ngrams)
    return ngram_texts

In [5]:
def visualize_trends(hidden_states):
    num_layers = len(hidden_states[0])
    for layer_idx in range(num_layers):
        layer_states = np.array([hidden_states[i][layer_idx] for i in range(len(hidden_states))])
        plt.figure(figsize=(12, 6))
        plt.title(f"Trends of Last Token Hidden States - Layer {layer_idx + 1}")
        
        for i in range(layer_states.shape[1]):
            plt.plot(layer_states[:, i], label=f'Unit {i + 1}')
        
        plt.xlabel('Sample Index')
        plt.ylabel('Hidden State Value')
        plt.legend()
        plt.show()

In [None]:
if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, output_hidden_states=True)
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    tokenizer.pad_token = tokenizer.eos_token
    non_infringement_texts, infringement_texts = load_data(non_infringement_file, infringement_file)
    all_texts = non_infringement_texts + infringement_texts
    ngram_texts = generate_ngrams(all_texts, n=2)
    hidden_states = extract_last_token_hidden_states(ngram_texts, model, tokenizer)
    visualize_trends(hidden_states)

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.28it/s]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3.1-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Processing data batches:   1%|▏         | 1/71 [03:03<3:34:18, 183.69s/it]