In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Set HuggingFace cache directory to scratch to save space.
import os
os.environ['HUGGINGFACE_HUB_CACHE'] = '/scratch/' + os.environ['USER'] + '/huggingface_cache'

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "./llama-huggingface/llama-2-7b-chat",
).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(
    "./llama-huggingface/llama-2-7b-chat"
)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
import pickle

with open("wikitext103-v1-filtered.pkl", "rb") as f:
    texts = pickle.load(f)

In [5]:
add_padding = False

if add_padding:
    # https://discuss.huggingface.co/t/llama2-pad-token-for-batched-inference/48020
    # add_special_tokens doesn't work for some reason
    # tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    tokenizer.pad_token = "[PAD]"
    tokenizer.padding_side = "left"
    
    # Mini-test
    def mini_test():
        tokenization = tokenizer(texts[:1], padding=True, return_attention_mask=True, return_tensors='pt')
        sequence_lengths = tokenization.attention_mask.sum(dim=-1)
        print(sequence_lengths)
        # print(tokenization.input_ids[0, -sequence_lengths[0]:])
        # print(tokenizer.decode(tokenization.input_ids[0]))
        # assert (tokenization.input_ids[0, -sequence_lengths[0]:] > 0).all()
    
    mini_test()

In [10]:
# create a dictionary of models / optimizers

import autoencoder

cfg = {
    "act_size": 4096,
    "dict_size": 4096 * 16,
    "enc_dtype": "bf16",
    "l1_coeff": 3e-3,
    "seed": 0,
    "device": "cuda"
}

# might eventually have one model per layer
model = autoencoder.AutoEncoder(cfg=cfg)
optim = torch.optim.Adam(model.parameters())


In [6]:
import tqdm

def run():
    hidden_states_chunk = []
    index = 0
    # batch_size = 1
    chunk_id = 0
    counter = 0
    with tqdm.tqdm(desc='Running inference', total=len(texts)) as pbar:
        while index < len(texts):
            # batch = texts[index:index + batch_size]
            tokenization = tokenizer(
                texts[index],
                return_tensors='pt',
                return_attention_mask=True,
                padding=False,
            ).to('cuda')
            sequence_lengths = tokenization.attention_mask.sum(dim=-1)
            with torch.no_grad():
                outputs = model.forward(
                    **tokenization,
                    output_hidden_states=True,
                )
                hidden_states = outputs.hidden_states
                
                # print((len(hidden_states), *hidden_states[0].shape))
                # (33, 16, <max_sequence_length>, 4096)
                # hidden_states = [h.cpu() for h in hidden_states]
                # for text_i in range(len(batch)):
                #     hidden_state = []
                #     for layer_i in range(len(hidden_states)):
                #         hidden_state.append(hidden_states[layer_i][text_i, -sequence_lengths[text_i]:, :])
                #     hidden_states_chunk.append(tuple(hidden_state))
                hidden_states_chunk.append(torch.stack(hidden_states, dim=0).cpu())
            
            # index += batch_size
            # pbar.update(len(batch))
            index += 1
            pbar.update()
            
            # (1 + n_layers) tuple of tensors [batch_size, sequence_length, d_model=4096]
            # hidden_states_chunk.append(outputs.hidden_states.cpu())

            # if index % 256 == 0 and index > 0:
            #     # Store the chunk.
            #     torch.save(hidden_states_chunk, f"hidden_states_{chunk_id}.pt")
            #     hidden_states_chunk.clear()
            #     chunk_id += 1

            if index % 1024 == 0 and index > 0:
                torch.cuda.empty_cache()

run()


Running inference:   0%|          | 738/749962 [03:04<51:56:53,  4.01it/s]  


KeyboardInterrupt: 