In [1]:
import torch

from sparse_models import SparseMLP, SimpleSparseMLP
from transformers import AutoModelForCausalLM, AutoTokenizer
from interp_utils import register_hook, remove_hooks

device = 'cuda' if torch.cuda.is_available() else 'cpu'


sparsity_levels = torch.arange(-5,5)
sparsity = 3
fname = f'mlp_F20000_S{sparsity}_R2.pt'

# mlp = SparseMLP(n_features=6000, d_model=768, disable_comet=True)
# mlp.load_state_dict(torch.load(f'../sparse-mlps/{fname}'))

mlp = SimpleSparseMLP(n_features=20000, d_model=768)
mlp.load_state_dict(torch.load(f'../../ts-autoencoder/mlps/{fname}'))

mlp.to(device)

def mlp_replacement_hook(module, inp, out):
    with torch.no_grad():
        acts =  mlp.get_acts(inp[0])
        pred = mlp.decoder(acts) + mlp.output_bias[None]
        return pred

model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M').to(device)
tokenizer = AutoTokenizer.from_pretrained('roneneldan/TinyStories-33M')



In [2]:
remove_hooks(model)
batch_toks = tokenizer.encode(' Once upon a time', return_tensors='pt').to(device)
register_hook(model.transformer.h[0].mlp, mlp_replacement_hook)
with torch.no_grad():
    x = model(batch_toks)
    remove_hooks(model)
    y = model(batch_toks)

Removing hook:  GPTNeoMLP mlp_replacement_hook


In [3]:
import torch.distributions as dists
def enc(s):
    return tokenizer.encode(s, return_tensors='pt').to(device)

def dec(tok_ids):
    return tokenizer.batch_decode(tok_ids)
remove_hooks(model)

register_hook(model.transformer.h[0].mlp, mlp_replacement_hook)

prompt = 'Once upon a time'
# for _ in range(20):
#     prompt = dec(model.generate(enc(prompt)), temperature=1.0)[0]
# print(prompt)

for _ in range(300):
    logits = model(enc(prompt)).logits

    sampled_tok = tokenizer.decode(dists.Categorical(logits=(1/0.9)*logits[0,-1]).sample())

    prompt += sampled_tok

print(prompt)

Once upon a time there was a red cat named Fluffy. Fluffy loved to eat milk. One day, Fluffy�! She wanted to play with a crystal.


Maggie asked everyone, "Will you play with me?" But no one said yes.


So Flappy didn't give up. She thought and thought, and then she decided to try and steal the red crystal. She crept up to the door and called, "No, don't do it!"


The voice said, "Why not? I'm scared and didn't want to play with you."


Maggie thought for a moment. Then she said, "It's ok Fluffy. I don't want you to have fun with me."


The voice said, "Thank you for being kind. I'll leave soon. Just be careful when you're eating."


Fluffy was embarrassed. Did she know that she could eat the red crystal.


<|endoftext|>


Max said, "Can I have a cookie?"


Max replied, "No, it's too expensive."
Max felt sad. He thought it was too silly. He thought, "Yeah! I'm going to play here anyway!"


Max walked around the park, with Max running and laughing.


Suddenly he spotted something. It wa

: 