In [1]:
import torch 
import argparse
from utils import dotdict
from activation_dataset import setup_token_data
import wandb
import json
from datetime import datetime
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict, load_dataset
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
import multiprocessing as mp
from transformers import GPT2Tokenizer, PreTrainedTokenizerBase, AutoModelForCausalLM, AutoTokenizer


cfg = dotdict()
# models: "EleutherAI/pythia-70m-deduped", "usvsnsp/pythia-6.9b-ppo", "lomahony/eleuther-pythia6.9b-hh-sft"
# cfg.model_name="lomahony/eleuther-pythia6.9b-hh-sft"
cfg.model_name="EleutherAI/pythia-6.9b-deduped"
cfg.layers=[10]
cfg.setting="residual"
cfg.tensor_name="gpt_neox.layers.{layer}"
original_l1_alpha = 8e-4
cfg.l1_alpha=original_l1_alpha
cfg.sparsity=None
cfg.num_epochs=10
cfg.model_batch_size=8
cfg.max_length = 256
cfg.lr=1e-3
cfg.kl=False
cfg.reconstruction=False
# cfg.dataset_name="NeelNanda/pile-10k"
cfg.dataset_name="Dahoas/rm-static"
cfg.device="cuda:0"
cfg.ratio = 4
cfg.seed = 0
# cfg.device="cpu"
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

In [2]:
dset = load_dataset(cfg.dataset_name, split="train")
alt_dset = load_dataset("Elriggs/openwebtext-100k", split="train")

Downloading metadata:   0%|          | 0.00/926 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/530 [00:00<?, ?B/s]

Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/Dahoas___parquet/default-b9d2c4937d617106/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/68.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/76256 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5103 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/Dahoas___parquet/default-b9d2c4937d617106/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


Downloading readme:   0%|          | 0.00/366 [00:00<?, ?B/s]

Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/303M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/Elriggs___parquet/Elriggs--openwebtext-100k-79076ecafee8a6d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


In [3]:
dset.column_names

['prompt', 'response', 'chosen', 'rejected']

In [6]:
string_list = dset['chosen'] + dset['rejected']

In [7]:
def create_tokenized_dataset(texts, tokenizer, max_length):
    """
    Convert a list of strings into a tokenized datasets.Dataset.

    Args:
    - texts (list of str): The list of strings to be tokenized.
    - tokenizer: A tokenizer from the HuggingFace Transformers library.
    - max_length (int): Maximum length to which the tokenized strings should be truncated/padded.

    Returns:
    - A datasets.Dataset with tokenized inputs and attention masks.
    """

    # Create a datasets.Dataset from the list of texts
    dset = Dataset.from_dict({'text': texts})

    # Define a function to tokenize the texts
    def tokenize_function(examples):
        return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)

    # Tokenize the texts
    tokenized_dset = dset.map(tokenize_function, batched=True)

    return tokenized_dset

In [8]:
create_tokenized_dataset(string_list, tokenizer, 256)

Map:   0%|          | 0/152512 [00:00<?, ? examples/s]

Using pad_token, but it is not set yet.


ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.

In [None]:
x = alt_dset
text_key = "text"
max_length = cfg.max_length


chunk_size = min(tokenizer.model_max_length, max_length)  # tokenizer max length is 1024 for gpt2
sep = tokenizer.eos_token or "<|endoftext|>"
joined_text = sep.join([""] + x[text_key])
output = tokenizer(
    # Concatenate all the samples together, separated by the EOS token.
    joined_text,  # start with an eos token
    max_length=chunk_size,
    return_attention_mask=False,
    return_overflowing_tokens=True,
    truncation=True,
)

if overflow := output.pop("overflowing_tokens", None):
    # Slow Tokenizers return unnested lists of ints
    assert isinstance(output["input_ids"][0], int)

    # Chunk the overflow into batches of size `chunk_size`
    chunks = [output["input_ids"]] + [
        overflow[i * chunk_size : (i + 1) * chunk_size] for i in range(math.ceil(len(overflow) / chunk_size))
    ]
    output = {"input_ids": chunks}

total_tokens = sum(len(ids) for ids in output["input_ids"])
total_bytes = len(joined_text.encode("utf-8"))

if not return_final_batch:
    # We know that the last sample will almost always be less than the max
    # number of tokens, and we don't want to pad, so we just drop it.
    output = {k: v[:-1] for k, v in output.items()}

output_batch_size = len(output["input_ids"])

if output_batch_size == 0:
    raise ValueError(
        "Not enough data to create a single batch complete batch."
        " Either allow the final batch to be returned,"
        " or supply more data."
    )

In [None]:
sep = tokenizer.eos_token or "<|endoftext|>"
def concat_prompt_response(example):
    example['text_chosen'] = example['prompt'] + example['chosen']
    example['text_rejected'] = example['prompt'] + example['rejected']
    return example

dset = dset.map(concat_prompt_response)

In [None]:
chunk_size = min(tokenizer.model_max_length, cfg.max_length)
output = tokenizer(
            dset["text_chosen"],  # start with an eos token
            max_length=chunk_size,
            return_attention_mask=False,
            return_overflowing_tokens=True,
            truncation=True,
        )

In [None]:
del alt_dset

In [None]:
def chunk_and_tokenize_rm(
    data: Dataset,
    tokenizer: PreTrainedTokenizerBase,
    *,
    format: str = "torch",
    num_proc: int = min(mp.cpu_count() // 2, 8),
    text_key: str = "text_chosen",
    max_length: int = 2048,
    return_final_batch: bool = False,
    load_from_cache_file: bool = True,
):
    """Perform GPT-style chunking and tokenization on a dataset.

    The resulting dataset will consist entirely of chunks exactly `max_length` tokens
    long. Long sequences will be split into multiple chunks, and short sequences will
    be merged with their neighbors, using `eos_token` as a separator. The fist token
    will also always be an `eos_token`.

    Args:
        data: The dataset to chunk and tokenize.
        tokenizer: The tokenizer to use.
        format: The format to return the dataset in, passed to `Dataset.with_format`.
        num_proc: The number of processes to use for tokenization.
        text_key: The key in the dataset to use as the text to tokenize.
        max_length: The maximum length of a batch of input ids.
        return_final_batch: Whether to return the final batch, which may be smaller
            than the others.
        load_from_cache_file: Whether to load from the cache file.

    Returns:
        * The chunked and tokenized dataset.
        * The ratio of nats to bits per byte see https://arxiv.org/pdf/2101.00027.pdf,
            section 3.1.
    """

    def _tokenize_fn(x: Dict[str, list]):
        chunk_size = min(tokenizer.model_max_length, max_length)  # tokenizer max length is 1024 for gpt2
        print
        sep = tokenizer.eos_token or "<|endoftext|>"
        joined_text = sep.join([""] + x[text_key])
        output = tokenizer(
            # Concatenate all the samples together, separated by the EOS token.
            joined_text,  # start with an eos token
            max_length=chunk_size,
            return_attention_mask=False,
            return_overflowing_tokens=True,
            truncation=True,
        )

        if overflow := output.pop("overflowing_tokens", None):
            # Slow Tokenizers return unnested lists of ints
            assert isinstance(output["input_ids"][0], int)

            # Chunk the overflow into batches of size `chunk_size`
            chunks = [output["input_ids"]] + [
                overflow[i * chunk_size : (i + 1) * chunk_size] for i in range(math.ceil(len(overflow) / chunk_size))
            ]
            output = {"input_ids": chunks}


        if not return_final_batch:
            # We know that the last sample will almost always be less than the max
            # number of tokens, and we don't want to pad, so we just drop it.
            output = {k: v[:-1] for k, v in output.items()}

        output_batch_size = len(output["input_ids"])

        if output_batch_size == 0:
            raise ValueError(
                "Not enough data to create a single batch complete batch."
                " Either allow the final batch to be returned,"
                " or supply more data."
            )

        return output

    data = data.map(
        _tokenize_fn,
        # Batching is important for ensuring that we don't waste tokens
        # since we always throw away the last element of the batch we
        # want to keep the batch size as large as possible
        batched=True,
        batch_size=2048,
        num_proc=num_proc,
        remove_columns=data.column_names,
        load_from_cache_file=load_from_cache_file,
    )
    return data.with_format(format, columns=["input_ids"])

In [None]:
chunked = chunk_and_tokenize_rm(dset, tokenizer, max_length=cfg.max_length)

In [None]:
# Download the dataset
# TODO iteratively grab dataset?
cfg.max_length = 256
token_loader = setup_token_data(cfg, tokenizer, model, seed=cfg.seed)
num_tokens = cfg.max_length*cfg.model_batch_size*len(token_loader)
print(f"Number of tokens: {num_tokens}")

In [None]:
# Run 1 datapoint on model to get the activation size
from baukit import Trace

text = "1"
tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
# Your activation name will be different. In the next cells, we will show you how to find it.
with torch.no_grad():
    with Trace(model, tensor_names[0]) as ret:
        _ = model(tokens)
        representation = ret.output
        # check if instance tuple
        if(isinstance(representation, tuple)):
            representation = representation[0]
        activation_size = representation.shape[-1]
print(f"Activation size: {activation_size}")

In [None]:
# Initialize New autoencoder
from autoencoders.learned_dict import TiedSAE, UntiedSAE, AnthropicSAE
from torch import nn
params = dict()
n_dict_components = activation_size*cfg.ratio
params["encoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["encoder"])

params["decoder"] = torch.empty((n_dict_components, activation_size), device=cfg.device)
nn.init.xavier_uniform_(params["decoder"])

params["encoder_bias"] = torch.empty((n_dict_components,), device=cfg.device)
nn.init.zeros_(params["encoder_bias"])

params["shift_bias"] = torch.empty((activation_size,), device=cfg.device)
nn.init.zeros_(params["shift_bias"])

autoencoder = AnthropicSAE(  # TiedSAE, UntiedSAE, AnthropicSAE
    # n_feats = n_dict_components, 
    # activation_size=activation_size,
    encoder=params["encoder"],
    encoder_bias=params["encoder_bias"],
    decoder=params["decoder"],
    shift_bias=params["shift_bias"],
)
autoencoder.to_device(cfg.device)
autoencoder.set_grad()
# autoencoder.encoder.requires_grad = True
# autoencoder.encoder_bias.requires_grad = True
# autoencoder.decoder.requires_grad = True
# autoencoder.shift_bias.requires_grad = True
optimizer = torch.optim.Adam(
    [
        autoencoder.encoder, 
        autoencoder.encoder_bias,
        autoencoder.decoder,
        autoencoder.shift_bias,
    ], lr=cfg.lr)

In [None]:
# Set target sparsity to 10% of activation_size if not set
if cfg.sparsity is None:
    cfg.sparsity = int(activation_size*0.05)
    print(f"Target sparsity: {cfg.sparsity}")

target_lower_sparsity = cfg.sparsity * 0.9
target_upper_sparsity = cfg.sparsity * 1.1
adjustment_factor = 0.1  # You can set this to whatever you like

In [None]:
original_bias = autoencoder.encoder_bias.clone().detach()
# Wandb setup
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
wandb_run_name = f"{cfg.model_name}_{start_time[4:]}_{cfg.sparsity}"  # trim year
print(f"wandb_run_name: {wandb_run_name}")

In [None]:
wandb.init(project="sparse coding", config=dict(cfg), name=wandb_run_name)

In [None]:
time_since_activation = torch.zeros(autoencoder.encoder.shape[0])
total_activations = torch.zeros(autoencoder.encoder.shape[0])
max_num_tokens = 30_000_000
save_every = 5_000_000
num_saved_so_far = 0
# Freeze model parameters 
model.eval()
model.requires_grad_(False)
model.to(cfg.device)
last_encoder = autoencoder.encoder.clone().detach()
for i, batch in enumerate(tqdm(token_loader)):
    tokens = batch["input_ids"].to(cfg.device)
    with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
        with Trace(model, tensor_names[0]) as ret:
            _ = model(tokens)
            representation = ret.output
            if(isinstance(representation, tuple)):
                representation = representation[0]
    layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
    # activation_saver.save_batch(layer_activations.clone().cpu().detach())

    c = autoencoder.encode(layer_activations)
    x_hat = autoencoder.decode(c)
    
    reconstruction_loss = (x_hat - layer_activations).pow(2).mean()
    l1_loss = torch.norm(c, 1, dim=-1).mean()
    total_loss = reconstruction_loss + cfg.l1_alpha*l1_loss

    time_since_activation += 1
    time_since_activation = time_since_activation * (c.sum(dim=0).cpu()==0)
    total_activations += c.sum(dim=0).cpu()
    if ((i+1) % 10 == 0): # Check here so first check is model w/o change
        # self_similarity = torch.cosine_similarity(c, last_encoder, dim=-1).mean().cpu().item()
        # Above is wrong, should be similarity between encoder and last encoder
        self_similarity = torch.cosine_similarity(autoencoder.encoder, last_encoder, dim=-1).mean().cpu().item()
        last_encoder = autoencoder.encoder.clone().detach()
        num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
        with torch.no_grad():
            sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            # Count number of dead_features are zero
            num_dead_features = (time_since_activation >= min(i, 200)).sum().item()
        print(f"Sparsity: {sparsity:.1f} | Dead Features: {num_dead_features} | Total Loss: {total_loss:.2f} | Reconstruction Loss: {reconstruction_loss:.2f} | L1 Loss: {cfg.l1_alpha*l1_loss:.2f} | l1_alpha: {cfg.l1_alpha:.2e} | Tokens: {num_tokens_so_far} | Self Similarity: {self_similarity:.2f}")
        wandb.log({
            'Sparsity': sparsity,
            'Dead Features': num_dead_features,
            'Total Loss': total_loss.item(),
            'Reconstruction Loss': reconstruction_loss.item(),
            'L1 Loss': (cfg.l1_alpha*l1_loss).item(),
            'l1_alpha': cfg.l1_alpha,
            'Tokens': num_tokens_so_far,
            'Self Similarity': self_similarity
        })
        
        dead_features = torch.zeros(autoencoder.encoder.shape[0])
        
        if(num_tokens_so_far > max_num_tokens):
            print(f"Reached max number of tokens: {max_num_tokens}")
            break
        
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # resample_period = 10000
    # if (i % resample_period == 0):
    #     # RESAMPLING
    #     with torch.no_grad():
    #         # Count number of dead_features are zero
    #         num_dead_features = (total_activations == 0).sum().item()
    #         print(f"Dead Features: {num_dead_features}")
            
    #     if num_dead_features > 0:
    #         print("Resampling!")
    #         # hyperparams:
    #         max_resample_tokens = 1000 # the number of token activations that we consider for inserting into the dictionary
    #         # compute loss of model on random subset of inputs
    #         resample_loader = setup_token_data(cfg, tokenizer, model, seed=i)
    #         num_resample_data = 0

    #         resample_activations = torch.empty(0, activation_size)
    #         resample_losses = torch.empty(0)

    #         for resample_batch in resample_loader:
    #             resample_tokens = resample_batch["input_ids"].to(cfg.device)
    #             with torch.no_grad(): # As long as not doing KL divergence, don't need gradients for model
    #                 with Trace(model, tensor_names[0]) as ret:
    #                     _ = model(resample_tokens)
    #                     representation = ret.output
    #                     if(isinstance(representation, tuple)):
    #                         representation = representation[0]
    #             layer_activations = rearrange(representation, "b seq d_model -> (b seq) d_model")
    #             resample_activations = torch.cat((resample_activations, layer_activations.detach().cpu()), dim=0)

    #             c = autoencoder.encode(layer_activations)
    #             x_hat = autoencoder.decode(c)
                
    #             reconstruction_loss = (x_hat - layer_activations).pow(2).mean(dim=-1)
    #             l1_loss = torch.norm(c, 1, dim=-1)
    #             temp_loss = reconstruction_loss + cfg.l1_alpha*l1_loss
                
    #             resample_losses = torch.cat((resample_losses, temp_loss.detach().cpu()), dim=0)
                
    #             num_resample_data +=layer_activations.shape[0]
    #             if num_resample_data > max_resample_tokens:
    #                 break

                
    #         # sample num_dead_features vectors of input activations
    #         probabilities = resample_losses**2
    #         probabilities /= probabilities.sum()
    #         sampled_indices = torch.multinomial(probabilities, num_dead_features, replacement=True)
    #         new_vectors = resample_activations[sampled_indices]

    #         # calculate average encoder norm of alive neurons
    #         alive_neurons = list((total_activations!=0))
    #         modified_columns = total_activations==0
    #         avg_norm = autoencoder.encoder.data[alive_neurons].norm(dim=-1).mean()

    #         # replace dictionary and encoder weights with vectors
    #         new_vectors = new_vectors / new_vectors.norm(dim=1, keepdim=True)
            
    #         params_to_modify = [autoencoder.encoder, autoencoder.encoder_bias]

    #         current_weights = autoencoder.encoder.data
    #         current_weights[modified_columns] = (new_vectors.to(cfg.device) * avg_norm * 0.02)
    #         autoencoder.encoder.data = current_weights

    #         current_weights = autoencoder.encoder_bias.data
    #         current_weights[modified_columns] = 0
    #         autoencoder.encoder_bias.data = current_weights
            
    #         if hasattr(autoencoder, 'decoder'):
    #             current_weights = autoencoder.decoder.data
    #             current_weights[modified_columns] = new_vectors.to(cfg.device)
    #             autoencoder.decoder.data = current_weights
    #             params_to_modify += [autoencoder.decoder]

    #         for param_group in optimizer.param_groups:
    #             for param in param_group['params']:
    #                 if any(param is d_ for d_ in params_to_modify):
    #                     # Extract the corresponding rows from m and v
    #                     m = optimizer.state[param]['exp_avg']
    #                     v = optimizer.state[param]['exp_avg_sq']
                        
    #                     # Update the m and v values for the modified columns
    #                     m[modified_columns] = 0  # Reset moving average for modified columns
    #                     v[modified_columns] = 0  # Reset squared moving average for modified columns
        
    #     total_activations = torch.zeros(autoencoder.encoder.shape[0])

    if ((i+2) % 10_000==0): # save periodically but before big changes
        model_save_name = cfg.model_name.split("/")[-1]
        save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}_ckpt{num_saved_so_far}"  # trim year

        # Make directory traiend_models if it doesn't exist
        import os
        if not os.path.exists("trained_models"):
            os.makedirs("trained_models")
        # Save model
        torch.save(autoencoder, f"trained_models/{save_name}.pt")
        
        num_saved_so_far += 1

    # Running sparsity check
    num_tokens_so_far = i*cfg.max_length*cfg.model_batch_size
    if(num_tokens_so_far > 200000):
        if(i % 100 == 0):
            with torch.no_grad():
                sparsity = (c != 0).float().mean(dim=0).sum().cpu().item()
            if sparsity > target_upper_sparsity:
                cfg.l1_alpha *= (1 + adjustment_factor)
            elif sparsity < target_lower_sparsity:
                cfg.l1_alpha *= (1 - adjustment_factor)
            # print(f"Sparsity: {sparsity:.1f} | l1_alpha: {cfg.l1_alpha:.2e}")

In [None]:
model_save_name = cfg.model_name.split("/")[-1]
save_name = f"{model_save_name}_sp{cfg.sparsity}_r{cfg.ratio}_{tensor_names[0]}"  # trim year

# Make directory traiend_models if it doesn't exist
import os
if not os.path.exists("trained_models"):
    os.makedirs("trained_models")
# Save model
torch.save(autoencoder, f"trained_models/{save_name}.pt")

In [None]:
wandb.finish()