In [4]:
import torch

from sparse_models import SparseMLP, SimpleSparseMLP
from transformers import AutoModelForCausalLM, AutoTokenizer
from interp_utils import register_hook, remove_hooks

device = 'cuda' if torch.cuda.is_available() else 'cpu'


sparsity_levels = torch.arange(-5,5)
sparsity = 3

# fname = f'mlp_F6000_S{sparsity}_R1'
# mlp = SparseMLP(n_features=6000, d_model=768, disable_comet=True)

fname = f'mlp_F20000_S{sparsity}_R2'
mlp = SimpleSparseMLP(n_features=20000, d_model=768)


mlp.load_state_dict(torch.load(f'../../ts-autoencoder/mlps/{fname}.pt'))
mlp.to(device)

# mlp.decoder.weight.data[:,12000:] = torch.tensor(0.)

def mlp_replacement_hook(module, inp, out):
    with torch.no_grad():
        acts =  mlp.get_acts(inp[0])
        pred = mlp.decoder(acts) + mlp.output_bias[None]
        return pred

def mlp_ablation_hook(module, inp, out):
    with torch.no_grad():
        # out = torch.zeros_like(out)
        out = out+inp[0]
        return out
    
def get_sparse_neuron_ablation_hook(neuron_idx):
    def sparse_neuron_ablation_hook(module, inp, out):
        with torch.no_grad():
            acts = mlp.get_acts(inp[0])
            acts[:,:,neuron_idx] = torch.tensor(0.)
            pred = mlp.decoder(acts) + mlp.output_bias[None]
            return pred
    return sparse_neuron_ablation_hook

def get_neuron_tail_ablation_hook(neuron_idx):
    def tail_ablation_hook(module, inp, out):
        with torch.no_grad():
            acts = mlp.get_acts(inp[0])
            acts[:,:,neuron_idx:] = torch.tensor(0.)
            pred = mlp.decoder(acts) + mlp.output_bias[None]
            return pred
    return tail_ablation_hook



model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M').to(device)
tokenizer = AutoTokenizer.from_pretrained('roneneldan/TinyStories-33M')

tokenizer.pad_token = tokenizer.eos_token
tokenizer.max_len = 128



In [2]:
import datasets

BATCH_SIZE=20

dataset = datasets.load_dataset('roneneldan/TinyStories', split='validation')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)


def encode_batch(batch):
    out = tokenizer(batch['text'], truncation=True, padding='max_length', max_length=128, return_tensors='pt')
    input_ids, attention_mask = out['input_ids'], out['attention_mask']
    
    return input_ids, attention_mask


Found cached dataset parquet (/home/noa/.cache/huggingface/datasets/roneneldan___parquet/roneneldan--TinyStories-6ac769f186d7da53/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [19]:
from tqdm import tqdm
import torch.nn.functional as F

remove_hooks(model)

# register_hook(model.transformer.h[0].mlp, mlp_replacement_hook)
# register_hook(model.transformer.h[0].mlp, mlp_ablation_hook)



def get_val_loss(model):
    total_losses = []
    for i, batch in tqdm(enumerate(dataloader)):
        input_ids, attention_mask = encode_batch(batch)
        input_ids = input_ids[attention_mask[:,-1] ==1].to(device)
        with torch.no_grad():
            logits = model(input_ids).logits
            next_token = input_ids[:,1:]
            logprobs = F.log_softmax(logits[:,:-1], dim=-1)
            batch_loss = -logprobs.gather(index=next_token.unsqueeze(-1), dim=-1).mean()
            total_losses.append(batch_loss)
    loss = torch.mean(torch.tensor(total_losses))
    return loss

remove_hooks(model)
og_loss = get_val_loss(model)

remove_hooks(model)
register_hook(model.transformer.h[0].mlp, mlp_replacement_hook)
mlp_replacement_loss = get_val_loss(model)

remove_hooks(model)
register_hook(model.transformer.h[0].mlp, mlp_ablation_hook)
zero_ablation_loss = get_val_loss(model)

# neuron_losses = []
tail_losses = []
for neuron_idx in range(8000, 10000, 2000):
    # remove_hooks(model)
    # neuron_ablation_hook = get_sparse_neuron_ablation_hook(neuron_idx)
    # register_hook(model.transformer.h[0].mlp, neuron_ablation_hook)
    # neuron_loss = get_val_loss(model)
    # neuron_losses.append(neuron_loss)

    remove_hooks(model)
    tail_ablation_hook = get_neuron_tail_ablation_hook(neuron_idx)
    register_hook(model.transformer.h[0].mlp, tail_ablation_hook)
    tail_loss = get_val_loss(model)
    tail_losses.append(tail_loss)
  
    
# neuron_losses = torch.stack(neuron_losses)
tail_losses = torch.stack(tail_losses)

all_losses = {
    'og_loss': og_loss,
    'mlp_replacement_loss': mlp_replacement_loss,
    'zero_ablation_loss': zero_ablation_loss,
    # 'neuron_losses': neuron_losses,
    'tail_losses': tail_losses
}

torch.save(all_losses, f'{fname}_losses.pt')

Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:19, 55.94it/s]
1100it [00:25, 42.83it/s]


Removing hook:  GPTNeoMLP mlp_replacement_hook


1100it [00:20, 54.41it/s]


Removing hook:  GPTNeoMLP mlp_ablation_hook


1100it [00:26, 41.53it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 42.81it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 43.08it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 43.33it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 43.57it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 42.98it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 43.12it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:26, 41.91it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:26, 41.09it/s]


Removing hook:  GPTNeoMLP tail_ablation_hook


1100it [00:25, 42.64it/s]


In [20]:
all_losses

{'og_loss': tensor(1.2556),
 'mlp_replacement_loss': tensor(1.3804),
 'zero_ablation_loss': tensor(2.9380),
 'tail_losses': tensor([8.2413, 1.4679, 1.3925, 1.3789, 1.3791, 1.3804, 1.3811, 1.3807, 1.3800,
         1.3806])}

In [16]:
from interp_utils import hist
# # neuron_losses = torch.stack(neuron_losses)
# # tail_losses = torch.stack(tail_losses)

# all_losses = {
#     'og_loss': og_loss,
#     'mlp_replacement_loss': mlp_replacement_loss,
#     'zero_ablation_loss': zero_ablation_loss,
#     'neuron_losses': neuron_losses,
#     'tail_losses': tail_losses
# }

# torch.save(all_losses, f'{fname}_losses.pt')

hist(tail_losses)


In [17]:
og_loss

tensor(1.2556)

In [139]:
# neuron_losses = torch.stack(neuron_losses)


# results = {
#     'neuron_losses': neuron_losses,nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn

# }

'mlp_F20000_S3_R2.pt'

In [None]:
MLP replacement: 1.3811
Full MLP: 1.2556
Zero-ablated MLP: 6.4111
Identity-ablated MLP: 3.9379
OG mlp + identity: 2.938


In [114]:
(1.2471-0.9087)/(6.3950-0.9087)

0.06168091427738187

In [107]:
loss

tensor(1.2471)

In [100]:
loss

tensor(1.2471)

tensor(0.9087, device='cuda:0', grad_fn=<NegBackward0>)

In [85]:
logprobs.shape

torch.Size([9, 127, 50257])

In [94]:
next_token



tensor(0.9087, device='cuda:0', grad_fn=<NegBackward0>)

In [18]:
remove_hooks(model)
batch_toks = tokenizer.encode(' Once upon a time', return_tensors='pt').to(device)
register_hook(model.transformer.h[0].mlp, mlp_replacement_hook)
with torch.no_grad():
    x = model(batch_toks)
    remove_hooks(model)
    y = model(batch_toks)

Removing hook:  GPTNeoMLP mlp_replacement_hook


In [50]:
x.logits

tensor([[[ 4.8164,  4.4057, -2.6297,  ..., -6.2258, -1.1136,  2.4448],
         [ 3.0801,  3.6401, -6.6639,  ..., -5.3223, -2.6751,  4.5713],
         [ 3.4398, -0.7651, -5.8565,  ..., -6.8624, -0.2880, -0.1177],
         [ 9.7025,  3.3765, -3.7329,  ..., -6.9522, -2.0069, -0.7559]]],
       device='cuda:0')

In [59]:
# import torch.distributions as dists
# def enc(s):
#     return tokenizer.encode(s, return_tensors='pt').to(device)

# def dec(tok_ids):
#     return tokenizer.batch_decode(tok_ids)
# remove_hooks(model)

# register_hook(model.transformer.h[0].mlp, mlp_replacement_hook)


# # for _ in range(20):
# #     prompt = dec(model.generate(enc(prompt)), temperature=1.0)[0]
# # print(prompt)

# for _ in range(5):
#     prompt = 'Jack was a'
#     for _ in range(300):
#         logits = model(enc(prompt)).logits

#         sampled_tok = tokenizer.decode(dists.Categorical(logits=(1/0.9)*logits[0,-1]).sample())

        
#         prompt += sampled_tok

#         if sampled_tok == '<|endoftext|>':
#             break

#     print(prompt)

Categorical(logits: torch.Size([50257]))