Dev Session:

Goals:
- Validate that the tokenizer / batching is done correctly. 
- Speed it up.

In [None]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.hook_resid_pre",
    hook_point_layer = 10,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 1e-4,
    l1_coefficient = 5e-3,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 3, # 200M tokens seems doable overnight.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 100,
    feature_reinit_scale = 0.2,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-7,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small",
    wandb_entity = None,
    wandb_log_frequency=20,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)

import cProfile, pstats, io
from pstats import SortKey
pr = cProfile.Profile()
pr.enable()


# sparse_autoencoder = language_model_sae_runner(cfg)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())


In [None]:
# from mats_sae_training.activation_store import ActivationStore
from tqdm import tqdm
from datasets import load_dataset

import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def get_batch_tokens(iterable_dataset, cfg, model):
        """
        Streams a batch of tokens from a dataset.
        """

        batch_size = cfg.store_batch_size
        context_size = cfg.context_size
        device = cfg.device

        batch_tokens = torch.LongTensor(size=(0, context_size)).to(device)

        current_batch = []
        current_length = 0

        pbar = tqdm(total=batch_size, desc="Filling batches")
        while batch_tokens.shape[0] < batch_size:
            if not cfg.is_dataset_tokenized:
                s = next(iterable_dataset)["text"]
                tokens = model.to_tokens(s, truncate=False, move_to_device=True).squeeze(0)
                assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}"
            else:
                tokens = torch.tensor(
                    next(iterable_dataset)["tokens"],
                    dtype=torch.long,
                    device=device,
                )
            token_len = tokens.shape[0]

            while token_len > 0:
                # Space left in the current batch
                space_left = context_size - current_length

                # If the current tokens fit entirely into the remaining space
                if token_len <= space_left:
                    current_batch.append(tokens[:token_len])
                    current_length += token_len
                    break

                else:
                    # Take as much as will fit
                    current_batch.append(tokens[:space_left])

                    # Remove used part, add BOS
                    tokens = tokens[space_left:]
                    tokens = torch.cat(
                        (
                            torch.LongTensor([model.tokenizer.bos_token_id]).to(
                                tokens.device
                            ),
                            tokens,
                        ),
                        dim=0,
                    )

                    token_len -= space_left
                    token_len += 1
                    current_length = context_size

                # If a batch is full, concatenate and move to next batch
                if current_length == context_size:
                    full_batch = torch.cat(current_batch, dim=0)
                    batch_tokens = torch.cat(
                        (batch_tokens, full_batch.unsqueeze(0)), dim=0
                    )
                    current_batch = []
                    current_length = 0

            pbar.n = batch_tokens.shape[0]
            pbar.refresh()

        return batch_tokens[:batch_size]
    
    
data_path = "EleutherAI/the_pile_deduplicated"
dataset = load_dataset(data_path, split="train", streaming=True)



import cProfile, pstats, io
from pstats import SortKey
pr = cProfile.Profile()
pr.enable()
# ... do something ...
for i in range(3):
    batch_tokens = get_batch_tokens(iter(dataset), cfg, model)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
print(batch_tokens.shape)

In [None]:
# To do: preprocess_tokenized_dataset, preprocess_text_dataset, preprocess other dataset
def preprocess_tokenized_dataset(source_batch: dict, context_size: int) -> dict:
    tokenized_prompts = source_batch["tokens"]

    # Chunk each tokenized prompt into blocks of context_size,
    # discarding the last block if too small.
    context_size_prompts = []
    for encoding in tokenized_prompts:
        chunks = [
            encoding[i : i + context_size]
            for i in range(0, len(encoding), context_size)
            if len(encoding[i : i + context_size]) == context_size
        ]
        context_size_prompts.extend(chunks)

    return {"input_ids": context_size_prompts}


def preprocess_text_data(source_batch: dict, context_size: int) -> dict:
    prompts = source_batch["text"]
        

    texts = self.text_batch(batch_size=batch_size)
    return model.tokenizer(
        texts,
        return_tensors='pt',
        max_length=self.ctx_len,
        padding=True,
        truncation=True
    )



def get_mapped_dataset(cfg):
    # Load the dataset
    context_size = cfg.context_size
    dataset_path = cfg.dataset_path
    dataset_split = "train"
    buffer_size: int = 1024
    preprocess_batch_size: int = 1024

    dataset = load_dataset(dataset_path, streaming=True, split=dataset_split)  # type: ignore
    # ids = dataset.to_iterable_dataset() # try out shards here
    # ids = ids.filter(filter_fn).map(process_fn) 
    
    # Setup preprocessing
    existing_columns = list(next(iter(dataset)).keys())
    mapped_dataset = dataset.map(
        preprocess_tokenized_dataset, # preprocess is what differentiates different datasets
        batched=True,
        batch_size=preprocess_batch_size,
        fn_kwargs={"context_size": context_size},
        remove_columns=existing_columns,
    )

    # Setup approximate shuffling. As the dataset is streamed, this just pre-downloads at least
    # `buffer_size` items and then shuffles just that buffer.
    # https://huggingface.co/docs/datasets/v2.14.5/stream#shuffle
    dataset = mapped_dataset.shuffle(buffer_size=buffer_size)
    return dataset


data_path = "EleutherAI/the_pile_deduplicated"
dataset = load_dataset(data_path, split="train", streaming=True)
dataset = get_mapped_dataset(cfg)



In [None]:
def text_batch(self, batch_size=None):
    """
    Return a list of text
    """
    if batch_size is None:
        batch_size = self.in_batch_size
    return [
        next(self.data) for _ in range(batch_size)
    ]


def tokenized_batch(self, batch_size=None):
    """
    Return a batch of tokenized inputs.
    """
    texts = self.text_batch(batch_size=batch_size)
    return self.model.tokenizer(
        texts,
        return_tensors='pt',
        max_length=self.ctx_len,
        padding=True,
        truncation=True
    )

In [None]:
import numpy as np
from typing import Dict, List
import einops
from datasets import load_dataset

column_name = "text"
add_bos_token = True
seq_len = 128
tokenizer = model.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
    text = examples[column_name]
    # Concatenate it all into an enormous string, separated by eos_tokens
    full_text = tokenizer.eos_token.join(text)
    # Divide into 20 chunks of ~ equal length
    num_chunks = 20
    chunk_length = (len(full_text) - 1) // num_chunks + 1
    chunks = [
        full_text[i * chunk_length : (i + 1) * chunk_length]
        for i in range(num_chunks)
    ]
    # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
    tokens = tokenizer(chunks, return_tensors="np", padding=True)[
        "input_ids"
    ].flatten()
    # Drop padding tokens
    tokens = tokens[tokens != tokenizer.pad_token_id]
    num_tokens = len(tokens)
    num_batches = num_tokens // (seq_len)
    # Drop the final tokens if not enough to make a full sequence
    tokens = tokens[: seq_len * num_batches]
    tokens = einops.rearrange(
        tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
    )
    if add_bos_token:
        prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
        tokens = np.concatenate([prefix, tokens], axis=1)
    return {"tokens": tokens}

data_path = "EleutherAI/the_pile_deduplicated"
dataset = load_dataset(data_path, split="train", streaming=True)

tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=[column_name],
)
buffer_size = 1024
dataset = tokenized_dataset.shuffle(buffer_size=buffer_size)
# tokenized_dataset.set_format(type="torch", columns=["tokens"])
batch_tokens = next(iter(dataset))
batch_tokens

In [None]:
tensor_tokens = torch.stack(
    [next(iter(dataset))["tokens"] for _ in range(32)]
)

In [None]:
tensor_tokens

In [None]:
tokenized_dataset.set_format(type="torch", columns=["tokens"])


In [None]:
model.to_string(batch_tokens['tokens'])

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    num_workers=0, 
    collate_fn= torch.stack)

batch_tokens = next(iter(dataloader))
batch_tokens

In [None]:
batch_tokens['tokens']