In [1]:
%load_ext autoreload
%autoreload 2
import os
import copy
import numpy as np
import json
import argparse
import random
import scipy
import config
from LLAMA import LLAMA
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, pipeline
import utils_llama.activation as ana
import scipy
import math
import time
import pickle
import datasets
from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
from itertools import chain


# 设置随机种子以便结果可重复
torch.manual_seed(0)
torch.cuda.empty_cache()
torch.set_grad_enabled(True)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7f5772f4c0a0>

In [2]:
class ARGS:
    def __init__(self):
        self.subject = 'S1'
        self.gpt = 'perceived'
        self.sessions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]
        self.layer = 17
        self.layer2 = 18
        self.act_name = 'ffn_gate'
        self.window = 15
        self.chunk = 4

args = ARGS()

model_dir = '/ossfs/workspace/nas/gzhch/data/models/Llama-2-7b-hf'
model = AutoModelForCausalLM.from_pretrained(
    model_dir, 
    device_map='auto',
    torch_dtype=torch.float16,
).eval()

# model = None

tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

## load cached llm act if possible
cache_dir = '/ossfs/workspace/nas/gzhch/data/cache'
llama = LLAMA(model, tokenizer, cache_dir)

2024-01-31 21:22:42,823 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 2/2 [05:48<00:00, 174.38s/it]


In [3]:
def load_data(task_name, n_shot=1, seed=42):
    data_dirs = {
        'xsum' : '/ossfs/workspace/nas/gzhch/data/datasets/xsum',
        'gsm8k' : '/ossfs/workspace/nas/gzhch/data/datasets/gsm8k',
        'alpaca' : '/ossfs/workspace/nas/gzhch/data/datasets/alpaca',
        'wmt' : '/ossfs/workspace/nas/gzhch/data/datasets/wmt14_de-en_test',
        'wikitext2' : '/ossfs/workspace/nas/gzhch/data/datasets/wikitext-2-v1',
        'wikitext_dense' : '/ossfs/workspace/nas/gzhch/data/datasets/wikitext-2-v1',
        'wikitext_eval' : '/ossfs/workspace/nas/gzhch/data/datasets/wikitext-2-v1',
        'cross_language' : '/ossfs/workspace/nas/gzhch/data/datasets/wmt14_de-en_test',
    }
    if task_name == 'gsm8k':
        dataset = datasets.load_dataset(data_dirs[task_name])
    elif task_name == 'wikitext2':
        dataset = datasets.load_from_disk(data_dirs[task_name])
        dataset = dataset['train'].filter(lambda x: len(x['text'])>100) 
        dataset = dataset.select(random.sample(range(len(dataset)), 1000))

    elif task_name == 'wikitext_eval':
        dataset = datasets.load_from_disk(data_dirs[task_name])
        dataset = dataset['test'].filter(lambda x: len(x['text'])>100) 

    elif task_name == 'cross_language':
        dataset = datasets.load_from_disk(data_dirs[task_name])
        de_data = dataset.map(lambda e: dict(text=e['translation']['de']))
        en_data = dataset.map(lambda e: dict(text=e['translation']['en']))
        return en_data, de_data

    elif task_name == 'wikitext_dense':
        def tokenize_texts(examples):
            tokenized_inputs = tokenizer(examples["text"])
            return tokenized_inputs

        def group_texts(examples):
            # Concatenate all texts.
            max_length = 1024
            concatenated_examples = {k: list(chain(*examples[k])) for k in ['input_ids']}
            total_length = len(concatenated_examples['input_ids'])
            # print(len(concatenated_examples['input_ids']), '\n')
            # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
            if total_length >= max_length:
                total_length = (total_length // max_length) * max_length
            # else:
                # print('aaa')
            # Split by chunks of max_len.
            # result = {
            #     k: [t[i : i + max_length] for i in range(0, total_length, max_length)]
            #     for k, t in concatenated_examples.items()
            # }
            result = {'input_ids': [concatenated_examples['input_ids'][i : i + max_length] for i in range(0, total_length, max_length)]}
            return result

        dataset = datasets.load_from_disk(data_dirs[task_name])
        dataset = dataset.map(tokenize_texts, batched=True, num_proc=4)
        dataset = dataset.map(group_texts, batched=True, num_proc=4, remove_columns=['text', 'attention_mask'])
        dataset['train'] = dataset['train'].shuffle(seed=seed)

    return dataset

# 创建一个简单的两层全连接神经网络
class Projector(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Projector, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        # self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.act = nn.SiLU()
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.act(out)
        out = self.fc2(out)
        out = self.act(out)
        return out

class LinearProjector(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearProjector, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        # self.act = nn.SiLU()
    
    def forward(self, x):
        out = self.fc(x)
        # out = self.act(out)
        return out


@torch.no_grad()
def eval(x, y, net):
    output = net(x)
    loss = criterion(output, y)
    return loss

def deduplication(data):
    tokens = data['context'][:, 5]
    unique_tokens = []
    unique_token_ids = []
    for idx in range(len(tokens)):
        if tokens[idx] not in unique_tokens:
            unique_tokens.append(tokens[idx])
            unique_token_ids.append(idx)
    random.shuffle(unique_token_ids)
    ids = unique_token_ids

    return {k : v[ids] for k, v in data.items()}

def train(net, train_set, stim_neurons=None, resp_neurons=None, max_step=100000):
    logs = []
    # layer1, layer2 = 10, 15
    total_batch = len(train_set) // batch_size

    total_batch = min(total_batch, max_step)
    for b in range(total_batch):
        input_ids = train_set[b * batch_size: (b + 1) * batch_size]['input_ids']
        input_ids = torch.tensor(input_ids)
        input = dict(input_ids=input_ids, attention_mask=torch.ones(input_ids.shape))
        with torch.no_grad():
            res = llama.get_neuron_activation_and_loss(input)

        if stim_neurons is not None:
            X = res['ffn_gate'][:, layer1, stim_neurons].cuda().float()
        else:
            X = res['ffn_gate'][:, layer1, :].cuda().float()
        if resp_neurons is not None:
            Y = res['ffn_gate'][:, layer2, resp_neurons].cuda().float()
        else:
            Y = res['ffn_gate'][:, layer2, :].cuda().float()

        output = net(X)
        loss = criterion(output, Y)
        
        optimizer.zero_grad() 
        (loss * output.shape[1]).backward()        
        optimizer.step()       
        
        if (b+1) % 1 == 0:
            eval_loss = eval(test_X.cuda(), test_Y.cuda(), net).item()
            print(f'Epoch [{b+1}/{total_batch}], Train Loss: {loss.item():.6f}, Eval Loss: {eval_loss:.6f}')
            logs.append(f'Epoch [{b+1}/{total_batch}], Train Loss: {loss.item():.6f}, Eval Loss: {eval_loss:.6f}')
    return logs

def evaluate_ppl(eval_data, model, fake_ffn=None, num_of_batch=3, **forwrd_args):
    ppls = []
    batch_size = 100
    for b in range(num_of_batch):
        input = tokenizer(eval_data['text'][b * batch_size: (b + 1) * batch_size], padding='longest', return_tensors='pt')
        result = ana.custom_forward(model, input['input_ids'].cuda(), inspect_acts=['ffn_gate'], fake_ffn=fake_ffn, **forwrd_args)
        logits = result['logits']
        labels = input['input_ids']
        input_ids = input['input_ids'][:, :-1]

        # calculate loss
        shift_logits = logits[..., :-1, :].contiguous().view(-1, 32000)
        shift_labels = labels[..., 1:].contiguous().view(-1)
        loss_fct = torch.nn.CrossEntropyLoss(reduce=False)
        loss = loss_fct(shift_logits, shift_labels).view(labels.shape[0], -1)
        t = (loss * input['attention_mask'][:, :-1]).sum(dim=1)/input['attention_mask'].sum(dim=1)
        ppls += torch.exp(t).tolist()
    ppl = torch.nan_to_num(torch.tensor(ppls)).mean().tolist()
    return ppl

In [5]:
def get_log(layer1, layer2):

    with open(f'/ossfs/workspace/cache_v2/{layer1}-{layer2}.txt', 'r') as f:
        logs = f.readlines()
    return logs


def f(layer1, layer2):

    with open(f'/ossfs/workspace/cache_v2/{layer1}-{layer2}.txt', 'r') as f:
        logs = f.readlines()
    # return logs

    stim_neurons = None
    resp_neurons = None

    ### get text set
    test_X, test_Y = [], []
    if stim_neurons is not None:
        test_X = test_data['ffn_gate'][:, layer1, stim_neurons].cuda().half()
    else:
        test_X = test_data['ffn_gate'][:, layer1, :].cuda().half()
    if resp_neurons is not None:
        test_Y = test_data['ffn_gate'][:, layer2, resp_neurons].cuda().half()
    else:
        test_Y = test_data['ffn_gate'][:, layer2, :].cuda().half()

    save_path = f'/ossfs/workspace/cache_v2/net_{layer1}_{layer2}.pt'
    # save_path = f'/ossfs/workspace/nas/gzhch/data/cache/llama-7b/net_{layer1}_{layer2}.pt'
    # if not os.path.exists(save_path):
    #     save_path = f'/ossfs/workspace/nas/gzhch/data/cache/llama-7b/net_{layer1}_{layer2}.pt'

    net = torch.load(save_path).half()

    pred = net(test_X)

    th = 0.6
    pred = net(test_X)
    delta = pred - test_Y
    ids = torch.nonzero(((delta.std(dim=0)) / test_Y.std(dim=0)).abs() < th).squeeze()
    return pred, test_Y, logs

In [None]:
wiki_data = load_data('wikitext_dense')

# get test data once and for all
batch_size = 10
test_data = []
for b in range(5):
    input_ids = wiki_data['validation'][b * batch_size: (b + 1) * batch_size]['input_ids']
    input_ids = torch.tensor(input_ids)
    input = dict(input_ids=input_ids, attention_mask=torch.ones(input_ids.shape))
    with torch.no_grad():
        res = llama.get_neuron_activation_and_loss(input)
        test_data.append(res)
test_data = {k: torch.cat([i[k] for i in test_data]) for k in test_data[0].keys()}

In [12]:
## get log
results = [[] for _ in range(0, 32, 2)]
for i, layer1 in enumerate(range(0, 32, 2)):
    for j, layer2 in enumerate(range(0, 32, 2)):
        logs = get_log(layer1, layer2)
        results[i].append(float(logs[-1].split()[-1]))

In [190]:
layer1 = 8
layer2 = 10
pred, test_Y, _ = f(layer1, layer2)

neuron_pearson = []
for i in range(pred.shape[1]):
    stat = scipy.stats.pearsonr(pred[:, i].cpu().detach(), test_Y[:, i].cpu().detach())
    neuron_pearson.append(stat.statistic)
neuron_pearson = torch.tensor(neuron_pearson)

neuron_std = (pred - test_Y).std(dim=0).cpu()

In [303]:
# indices = neuron_pearson.topk(100, largest=True).indices.cpu()
indices = neuron_std.topk(10, largest=True).indices.cpu()
neuron_id = indices
neuron_weight = model.model.layers[layer2].mlp.down_proj.weight[:, neuron_id]
lm_head = model.lm_head.weight
logit_contribution = torch.matmul(lm_head, neuron_weight.to(lm_head.device)).transpose(0, 1)
logits = logit_contribution.topk(100, dim=1).indices
# tokenizer.convert_ids_to_tokens(logit_contribution.topk(10, dim=0).indices.view(-1))

In [6]:
layer1 = 8
layer2 = 10
pred, test_Y, _ = f(layer1, layer2)


FileNotFoundError: [Errno 2] No such file or directory: '/ossfs/workspace/cache_v2/8-10.txt'

In [333]:
layer = 20
neurons = test_data['ffn_gate'][:, layer, :]
neurons.min()

tensor(-28.7031, dtype=torch.float16)

In [334]:
with open('/ossfs/workspace/test_data.pkl', 'wb') as f:
    pickle.dump(test_data, f)

In [261]:
layer1 = 4
layer2 = 6
en_data, de_data = load_data('cross_language')
eval_data = en_data
fake_ffn = ana.FFNProjector(layer1, layer2, torch.load(f'/ossfs/workspace/cache_v2/net_{layer1}_{layer2}.pt'))

batch = [9]
ppls = []
batch_size = 100
for b in range(10):
    input = tokenizer(eval_data['text'][b * batch_size: (b + 1) * batch_size], padding='longest', return_tensors='pt')
    result = ana.custom_forward(model, input['input_ids'].cuda(), inspect_acts=['ffn_gate'], fake_ffn=fake_ffn)
    logits = result['logits']
    labels = input['input_ids']
    input_ids = input['input_ids'][:, :-1]

    # calculate loss
    shift_logits = logits[..., :-1, :].contiguous().view(-1, 32000)
    shift_labels = labels[..., 1:].contiguous().view(-1)
    loss_fct = torch.nn.CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits, shift_labels).view(labels.shape[0], -1)
    # print(loss)
    t = (loss * input['attention_mask'][:, :-1]).sum(dim=1)/input['attention_mask'].sum(dim=1)
    ppls += torch.exp(t).tolist()
ppl = torch.nan_to_num(torch.tensor(ppls)).mean().tolist()




In [263]:
torch.tensor(ppls).mean()

tensor(42.5913)