In [1]:
%env CUDA_VISIBLE_DEVICES=0

import random

import torch
from torch import nn
from torch.nn import functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from datasets import load_dataset

from tqdm.notebook import tqdm, trange

env: CUDA_VISIBLE_DEVICES=0


In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    token="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", device_map="cuda", torch_dtype="auto",
)

In [None]:
SEED = 0

def get_wikitext2(seed, seqlen, nsamples=128):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=True, token="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")

    train_input_ids = tokenizer("\n\n".join(traindata['text']), return_tensors='pt').input_ids
    random.seed(seed)
    train_batch = []
    for _ in range(nsamples):
        i = random.randint(0, train_input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = train_input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        train_batch.append(inp[0])

    test_input_ids = tokenizer("\n\n".join(testdata['text']), return_tensors='pt').input_ids
    test_input_ids = test_input_ids[:, :(test_input_ids.shape[1] // seqlen) *  seqlen]
    test_input_ids = test_input_ids.reshape(test_input_ids.shape[1] // seqlen, seqlen)

    return torch.stack(train_batch), test_input_ids

train_batch, test_input_ids = get_wikitext2(SEED, 4096)

In [None]:
LAYER_ID = 10
LAYER = model.model.layers[LAYER_ID].mlp.down_proj

In [None]:
HESSIAN = None
NUM_SAMPLES = 0
INPUTS = []

@torch.no_grad()
def update_hessian(_, inp, out):
    global HESSIAN
    global NUM_SAMPLES
    inp = inp[0].data # ... x hidden_size
    INPUTS.append(inp.clone().cpu())
    inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
    inp = inp.t().float() # hidden_size x inputs
    NUM_SAMPLES += 1
    if HESSIAN is None:
        HESSIAN = inp.matmul(inp.t())
    else:
        HESSIAN += inp.matmul(inp.t())
    

hook = LAYER.register_forward_hook(
    update_hessian
)

In [None]:
with torch.no_grad():
    for i in trange(train_batch.shape[0]):
        input = train_batch[[i]].clone().cuda()
        model(input)

In [None]:
INPUTS = torch.cat(INPUTS, dim=0)

In [None]:
HESSIAN = HESSIAN / NUM_SAMPLES

In [None]:
torch.save(HESSIAN, "hessian.pt")
torch.save(LAYER.weight.data, "weight.pt")
torch.save(INPUTS, "inputs.pt")