In [3]:
# get the data
from datasets import load_dataset
data = load_dataset('ola13/small-the_pile-dedup')

In [4]:
len(data['train']['text'])

100000

In [5]:
len(data['train']['text'][0])

13276

In [1]:
# load llm
from transformers import GPTNeoXForCausalLM, AutoTokenizer

model = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  cache_dir="./pythia-70m-deduped/step3000",
)

tokenizer = AutoTokenizer.from_pretrained(
  "EleutherAI/pythia-70m-deduped",
  cache_dir="./pythia-70m-deduped/step3000",
)

inputs = tokenizer("Hello, I am", return_tensors="pt")
tokens = model.generate(**inputs)
tokenizer.decode(tokens[0])


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


'Hello, I am looking for a way to get my name in the mail. I am looking for'

In [29]:
100000 * 128

12800000

In [30]:
128*100

12800

In [32]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (a

In [61]:
# define the sae 
import torch
import torch.nn as nn

class SAE(nn.Module):
    def __init__(self, model_dim, sae_dim):
        super(SAE, self).__init__()
        self.encoder = nn.Linear(model_dim, sae_dim, bias=True)
        self.decoder = nn.Linear(sae_dim, model_dim, bias=True)
        self.model_bias = nn.Parameter(torch.randn(model_dim))
        self.relu = nn.ReLU()

    def forward(self, x):
        f = self.relu(self.encoder(x - self.model_bias))
        xhat = self.decoder(f)
        return xhat

model_dim = 512
sae_dim = model_dim * 4
sae = SAE(model_dim, sae_dim)

In [64]:
x = torch.randn(512)
y = sae(x)
y.shape

torch.Size([512])

In [35]:
# get the activations, train sae

layer = 3
context_len = 128
n_train_tokens = context_len * 20
n_trained = 0
max_cache_size = context_len * 10
cache = torch.empty(0, model_dim)
context_idx = 0  # TODO: need to randomize documents? 
sae_batch_size = context_len * 1

# sample from the cache
def sample_and_remove(cache, n):
    indices = torch.randperm(cache.size(0))[:n]
    sampled_acts = cache[indices]
    mask = torch.ones(cache.size(0), dtype=torch.bool)
    mask[indices] = False
    new_cache = cache[mask]
    return sampled_acts, new_cache

# until trained on X tokens:
while n_trained < n_train_tokens:
    # if the cache is at least half empty:
    if cache.shape[0] < max_cache_size // 2:
        while cache.shape[0] < max_cache_size:
            # TODO: change this to batches
            print(context_idx)
            text = data['train']['text'][context_idx]
            inputs = tokenizer(text, return_tensors="pt")
            inputs_mod = {'input_ids': inputs['input_ids'][0, :context_len].view(1, -1), 'attention_mask': inputs['attention_mask'][0, :context_len].view(1, -1)}
            output = model(**inputs_mod, output_hidden_states=True)
            acts = output['hidden_states'][layer+1][0,:, :]
            cache = torch.cat((cache, acts))
            context_idx += 1

    # sample from the cache
    sae_acts_batch, cache = sample_and_remove(cache, sae_batch_size)

    # train another step in the autoencoder
    

    # increment trained
    n_trained += sae_batch_size
    break


n_trained  0
cache  0
0
cache  128
1
cache  256
2
cache  384
3
cache  464
4
cache  592
5
cache  720
6
cache  848
7
cache  976
8
cache  1104
9
cache  1232
10
