In [1]:
%load_ext autoreload
%autoreload 2
# Set HuggingFace cache directory to scratch to save space.
import os
os.environ['HUGGINGFACE_HUB_CACHE'] = '/scratch/' + os.environ['USER'] + '/huggingface_cache'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [2]:
import pickle

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "./llama-huggingface/llama-2-7b-chat", torch_dtype=torch.bfloat16
).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(
    "./llama-huggingface/llama-2-7b-chat"
)

with open("wikitext103-v1-filtered.pkl", "rb") as f:
    texts = pickle.load(f)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# add_padding = False

# if add_padding:
#     # https://discuss.huggingface.co/t/llama2-pad-token-for-batched-inference/48020
#     # add_special_tokens doesn't work for some reason
#     # tokenizer.add_special_tokens({"pad_token": "[PAD]"})
#     tokenizer.pad_token = "[PAD]"
#     tokenizer.padding_side = "left"
    
#     # Mini-test
#     def mini_test():
#         tokenization = tokenizer(texts[:1], padding=True, return_attention_mask=True, return_tensors='pt')
#         sequence_lengths = tokenization.attention_mask.sum(dim=-1)
#         print(sequence_lengths)
#         # print(tokenization.input_ids[0, -sequence_lengths[0]:])
#         # print(tokenizer.decode(tokenization.input_ids[0]))
#         # assert (tokenization.input_ids[0, -sequence_lengths[0]:] > 0).all()
    
#     mini_test()

# shuffled_index = 0
# tokenization = tokenizer(
#     texts[shuffled_index],
#     return_tensors='pt',
#     return_attention_mask=True,
#     padding=False,
# ).to('cuda')
# sequence_lengths = tokenization.attention_mask.sum(dim=-1)
# with torch.no_grad():
#     outputs = model.forward(
#         **tokenization,
#         output_hidden_states=True,
#     )
#     # (1 + n_layers) tuple of tensors [batch_size, sequence_length, d_model=4096]
#     hidden_states = outputs.hidden_states

# hidden_states[-1].view(-1, 4096)[0].abs().sum()
# (encoder.encode.weight @ hidden_states[-1].view(-1, 4096)[[1], :].T)
# hidden_states[-1].view(-1, 4096).T

# torch.save(encoder.state_dict(), "encoder.pt")
# torch.save(encoder_optim.state_dict(), "encoder_optim.pt")

In [12]:
import tqdm
import wandb
import quantized_autoencoder
import gc
gc.collect()
torch.cuda.empty_cache()

cfg = {
    "llm_dims": 4096,
    "sparse_dims": 4096 * 64,
    "topk": 4096,
    "enc_dtype": "bf16",
    "seed": 0,
}

# at some point: create a dictionary of models / optimizers
# might eventually have one model per layer
encoder = quantized_autoencoder.QSAE(cfg=cfg).to(device='cuda')
encoder_optim = torch.optim.Adam(encoder.parameters())

In [13]:
def run():
    # wandb.init(mode='dryrun')
    wandb.init(project='llm-mechanics', config={'sparse_dims': 4096 * 128, 'topk': 4096, 'layer': 16})

    shuffle = True
    if shuffle:
        torch.manual_seed(0)
        order = torch.randperm(len(texts), device='cpu')
    else:
        order = torch.arange(len(texts), device='cpu')
    
    index = 0
    # batch_size = 1
    chunk_id = 0
    counter = 0
    with tqdm.tqdm(desc='Running inference', total=len(texts)) as pbar:
        while index < len(texts):
            # batch = texts[index:index + batch_size]
            shuffled_index = int(order[index])
            tokenization = tokenizer(
                texts[shuffled_index],
                return_tensors='pt',
                return_attention_mask=True,
                padding=False,
            ).to('cuda')
            sequence_lengths = tokenization.attention_mask.sum(dim=-1)
            with torch.no_grad():
                outputs = model.forward(
                    **tokenization,
                    output_hidden_states=True,
                )
                # (1 + n_layers) tuple of tensors [batch_size, sequence_length, d_model=4096]
                hidden_states = outputs.hidden_states

            # calculate reconstruction loss for autoencoder and train weights
            # acts = torch.stack(hidden_states, dim=0).view(-1, 4096)

            # constrain to only the nth layer
            # hidden_states[0] corresponds to embeddings.
            # hidden_states[16] is after the 16th transformer layer.
            acts = hidden_states[16].view(-1, 4096)
            
            # calculate gradient in batches to enable higher sparsity
            use_qsae = True
            if use_qsae:
                act_batch_i = 0
                act_bs = 1024
                q_tot = 0
                r_tot = 0
                unquantized_tot = 0
                
                encoder_optim.zero_grad()
                
                std_dev = torch.std(acts).item()
                # mean = torch.mean(acts, dim=0)
                while act_batch_i < len(acts):
                    act_batch = acts[act_batch_i:act_batch_i + act_bs]
    
                    _x_reconstructed, _scores, _quantized_activations, quantization_error, reconstruction_error = encoder(act_batch)

                    loss = quantization_error + reconstruction_error
                    
                    loss_scaled = (loss * len(act_batch) / len(acts))
                    loss_scaled.backward()
    
                    q_tot += (quantization_error * len(act_batch)).item()
                    r_tot += (reconstruction_error * len(act_batch)).item()
                    unquantized_tot += (_scores * len(act_batch)).sum().item()
                    
                    act_batch_i += act_bs
                
                encoder_optim.step()
                
                wandb.log({
                    'reconstruction_loss': r_tot/len(acts),
                    'quantization_loss': q_tot/len(acts),
                    'unquantized_total': unquantized_tot/len(acts),
                    'loss': (r_tot + q_tot)/len(acts),
                    'std_dev': std_dev,
                    # 'mean': mean.mean(),
                })
            else:
                act_batch_i = 0
                act_bs = 1024
                l1_tot = 0
                l2_tot = 0
                while act_batch_i < len(acts):
                    act_batch = acts[act_batch_i:act_batch_i + act_bs]
    
                    loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(act_batch)
                    loss_scaled = (loss * len(act_batch) / len(acts))
                    loss_scaled.backward()
    
                    l1_tot += (l1_loss * len(act_batch)).item()
                    l2_tot += (l2_loss * len(act_batch)).item()
                    act_batch_i += act_bs
                
                encoder.make_decoder_weights_and_grad_unit_norm()
                encoder_optim.step()
                encoder_optim.zero_grad()
            
                wandb.log({'l2_loss': l2_tot/len(acts), 'l1_loss': l1_tot/len(acts), 'loss': (l1_tot+l2_tot)/len(acts)})

            # if index % 256 == 0 and index > 0:
            #     # Store the chunk.
            #     torch.save(hidden_states_chunk, f"hidden_states_{chunk_id}.pt")
            #     hidden_states_chunk.clear()
            #     chunk_id += 1

            if index % 1024 == 0 and index > 0:
                torch.cuda.empty_cache()

            if index % 4096 == 0 and index > 0:
                torch.save(encoder.state_dict(), "encoder.pt")
                torch.save(encoder_optim.state_dict(), "encoder_optim.pt")

            index += 1
            pbar.update()

run()


VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,▂▂▁▅▁▂▁▁▁▃▃▅▃▆▁▁▂▁▁▄▁▂█▁▂▁▂▂▂▃▁▂▁▃▃▂▁▁▂▁
quantization_loss,▄▄▅▅▁▁▄▃▄▆▅▃▄█▃▂▄▁▅▂▄▆▃▂▄▁▂▃▃▄▂▂▄▄▂▆▂▄▁▃
reconstruction_loss,▂▂▁▅▁▂▁▁▁▃▃▅▃▆▁▁▂▁▁▄▁▂█▁▂▁▂▂▂▃▁▂▁▃▃▂▁▁▂▁
std_dev,▂▃▂▆▂▃▁▂▁▅▄▆▄▇▁▁▃▁▂▅▁▃█▂▂▁▂▃▃▄▂▃▂▄▄▄▁▁▃▂
unquantized_total,▃▂▃▂▄▅▅▃▄▁▂▁▃▁▅▄▂▅▃▂▄▂▁▄▃█▃▄▃▂▄▂▄▂▃▂▄▆▆▃

0,1
loss,3.85006
quantization_loss,0.62829
reconstruction_loss,3.22177
std_dev,1.6875
unquantized_total,17218721.68421


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112046766922706, max=1.0…

Running inference:   1%|          | 3942/749962 [06:47<21:26:25,  9.67it/s]


KeyboardInterrupt: 