In [None]:
import numpy as np
import os
from torch.utils.data import DataLoader, TensorDataset
import gc
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import torch

### Code to collect LLM activations compatible with Python 3.9's latest Transformers library


def preprocess(paragraphs, tokenizer, seq_len):
    token_ids = []
    for paragraph in paragraphs:
        ids = tokenizer(paragraph, add_special_tokens=False)["input_ids"]
        token_ids.extend(ids)

    tokenized = torch.tensor(token_ids, dtype=torch.long)

    length = tokenized.size(0)
    if length >= seq_len:
        new_length = (length // seq_len) * seq_len
        tokenized = tokenized[:new_length]
        tokenized = tokenized.view(-1, seq_len)

        bos = torch.full(
            (tokenized.size(0), 1), tokenizer.bos_token_id, dtype=torch.long
        )
        tokenized = torch.cat([bos, tokenized], dim=1)

        target = tokenized.clone()
        target[:, :-1] = -100
        return tokenized, target
    else:
        raise ValueError("Dataset too small")


def get_loader(dataset, tokenizer, batch_size, name="train", seq_len=512, val=False):
    paragraphs = dataset[name]["text"]
    tokenized, target = preprocess(paragraphs, tokenizer, seq_len=seq_len)
    dataset = TensorDataset(tokenized, target)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader


class HookManager:
    def __init__(self, model, model_name):
        self.hook_handles = []
        self.value_residuals = {}
        self.model = model
        self.model_name = model_name

    def make_hook(self, name):
        def valres_hook(module, input, output):
            self.value_residuals[name] = output

        return valres_hook

    def register_hooks(self):
        if self.model_name == "llama":
            self.register_hooks_llama()

    def register_hooks_llama(self):
        for i in range(0, len(self.model.model.layers)):
            if i % 5 == 0 or i == len(self.model.model.layers) - 1:
                handle_v = self.model.model.layers[
                    i
                ].self_attn.v_proj.register_forward_hook(self.make_hook("v" + str(i)))
                handle_k = self.model.model.layers[
                    i
                ].self_attn.k_proj.register_forward_hook(self.make_hook("k" + str(i)))
                handle_q = self.model.model.layers[
                    i
                ].self_attn.q_proj.register_forward_hook(self.make_hook("q" + str(i)))
                self.hook_handles.append(handle_q)
                self.hook_handles.append(handle_v)
                self.hook_handles.append(handle_k)
        print("hooks registered")

    def clear_hooks(self):
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles.clear()

    def clear_residuals(self):
        self.value_residuals.clear()
        del self.value_residuals
        torch.cuda.empty_cache()
        self.value_residuals = {}
        gc.collect()

    def clear(self):
        self.clear_hooks()
        self.clear_residuals()


def archive_residuals(value_residuals, attentions, n_layers, id_batch):
    dump_dics = [{} for i in range(0, n_layers)]
    if attentions is not None:
        attentions = torch.stack(attentions).transpose(
            2, 1
        )  # output shape : (n_layers,n_heads,batch_size,seq_len,seq_len)
        for k in range(n_layers):
            dump_dics[k]["attentions"] = attentions[k]
    for i in range(0, n_layers):
        if i % 5 == 0 or i == n_layers - 1:
            v = value_residuals["v" + str(i)].cpu().detach().contiguous()
            k = value_residuals["k" + str(i)].cpu().detach().contiguous()
            q = value_residuals["q" + str(i)].cpu().detach().contiguous()
            seqlen = v.shape[1]
            batchsize = v.shape[0]
            dump_dics[i]["V"] = (
                v.reshape(batchsize, seqlen, 32, 128).transpose(2, 1).transpose(1, 0)
            )
            dump_dics[i]["K"] = (
                k.reshape(batchsize, seqlen, 32, 128).transpose(2, 1).transpose(1, 0)
            )
            dump_dics[i]["Q"] = (
                q.reshape(batchsize, seqlen, 32, 128).transpose(2, 1).transpose(1, 0)
            )
            dump_dics[i]["id_batch"] = id_batch
    return dump_dics  # dim (n_layers,batch_size,seq_len,d)


def collect_activations(model, loader, output_attentions=False, stop_indice=5):
    model.eval()
    hook_manager = HookManager(model, "llama")
    n_layers = len(model.model.layers)
    for k in range(n_layers):
        d = "layer" + str(k)
        if not os.path.exists("layer" + str(k)):
            os.makedirs(d)
    with torch.no_grad():
        for i, (data, target) in enumerate(loader):
            hook_manager.register_hooks()
            data, target = data.to(device), target.to(device)
            try:
                outputs = model(
                    input_ids=data, labels=target, output_attentions=output_attentions
                )
                try:
                    if output_attentions:
                        dump_dics = archive_residuals(
                            hook_manager.value_residuals,
                            outputs["attentions"],
                            len(model.model.layers),
                            i,
                        )
                    else:
                        dump_dics = archive_residuals(
                            hook_manager.value_residuals, None, n_layers, i
                        )
                    for k in range(0, n_layers):
                        torch.save(
                            dump_dics[k], "layer" + str(k) + "/dump" + str(i) + ".pt"
                        )
                    del dump_dics
                except Exception as e:
                    print(e)
                    print("Error with tensor saving")
                del outputs
            except Exception as e:
                print("Error with model output")
            hook_manager.clear()
            gc.collect()
            torch.cuda.empty_cache()
            if i == stop_indice:
                break


In [None]:
if __name__ == "__main__":
    os.chdir("/data/mgiles/shil6478/activations")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_id = f"togethercomputer/LLaMA-2-7B-32K"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model_8bit = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="auto", load_in_8bit=True
    )
    ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")
    train_loader = get_loader(ds, tokenizer, 2, seq_len=1024)
    collect_activations(
        model_8bit, train_loader, output_attentions=False, stop_indice=350
    )
