In [1]:
import sys
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

In [2]:
sys.path.append('../')

In [3]:
from interpretability.sae.sae import SAE
from utils.inference_utils.llm import LLM
from utils.train_utils.train_sae import train_sae

In [4]:
model_name_or_path = 'Qwen/Qwen2-1.5B-Instruct'
device = 'cuda'

In [5]:
llm = LLM(
    model_name_or_path=model_name_or_path,
    device=device
) # Initialize a model

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
for name, module in llm.model.named_modules():
    print(name) # choose which model's module activations you want to save into a train dataset


model
model.embed_tokens
model.layers
model.layers.0
model.layers.0.self_attn
model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.0.self_attn.o_proj
model.layers.0.self_attn.rotary_emb
model.layers.0.mlp
model.layers.0.mlp.gate_proj
model.layers.0.mlp.up_proj
model.layers.0.mlp.down_proj
model.layers.0.mlp.act_fn
model.layers.0.input_layernorm
model.layers.0.post_attention_layernorm
model.layers.1
model.layers.1.self_attn
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.1.self_attn.o_proj
model.layers.1.self_attn.rotary_emb
model.layers.1.mlp
model.layers.1.mlp.gate_proj
model.layers.1.mlp.up_proj
model.layers.1.mlp.down_proj
model.layers.1.mlp.act_fn
model.layers.1.input_layernorm
model.layers.1.post_attention_layernorm
model.layers.2
model.layers.2.self_attn
model.layers.2.self_attn.q_proj
model.layers.2.self_attn.k_proj
model.layers.2.self_attn.v_proj
model.layers.2.

In [7]:
# Let it be `model.layers.27` for 'Qwen/Qwen2-1.5B-Instruct'

layer_name = 'model.layers.27'

## Load data for SAE training

In [8]:
from datasets import load_dataset

In [9]:
dataset = load_dataset(
    'kendrivp/openthoughts_gpt5'
)

In [10]:
prompts = dataset['train'].to_pandas()['question'].tolist()

In [11]:
prompts[:2]

['I am in desperate need of some ear defenders, so I can program in peace. Unfortunately, I don\'t have ear defenders. But I do have a pair of headphones, a microphone and a microcontroller, so I thought I\'d make some noise-cancelling headphones.\nHowever, there is one tiny problem. I can\'t program in noisy environments! So you\'ll need to write the code for me.\nTask\nGiven a constant noisy input, output the input\'s "complement"; a value that will completely cancel out the input value (0x00 cancels 0xff, 0x7f cancels 0x80). Because the headphones block some noise, also take a value as input that you will divide the "complement" by, using the following algorithm: if (c < 0x80) ceil ( 0x7f - ( (0x7f - c) / n ) ) else floor ( ( (c - 0x80) / n ) + 0x80 ) where c is the "complement" and n is the value to divide by. (If the value exceeds the byte range, it should be clipped to 0 or 255.)\nThe input will always be a sequence of bytes, via stdin or function parameter (see bonuses) but the 

In [12]:
activations_dataset = llm.get_hidden_state(
    input_text=prompts[:4],
    layer_name='model.layers.27'
)
torch.cuda.empty_cache()

In [13]:
activations_dataset

tensor([[-1.7109,  1.3359, -0.2852,  ..., -1.0312, -0.0045, -0.2891],
        [ 1.8516, -0.9180, -1.7109,  ..., -0.9336,  1.5703, -1.6953],
        [-0.9922,  0.3359, -1.3672,  ..., -0.2598, -0.9531,  0.6523],
        [-1.3516, -1.3672,  0.0070,  ..., -0.0571,  0.1367,  1.1250]],
       dtype=torch.bfloat16)

In [14]:
activations_dataset.shape

torch.Size([4, 1536])

In [15]:
activations_dataset = TensorDataset(activations_dataset)

In [16]:
dataloader = DataLoader(
    dataset=activations_dataset,
    batch_size=2,
    shuffle=True
)

In [17]:
for batch in dataloader:
    print(batch)

[tensor([[ 1.8516, -0.9180, -1.7109,  ..., -0.9336,  1.5703, -1.6953],
        [-0.9922,  0.3359, -1.3672,  ..., -0.2598, -0.9531,  0.6523]],
       dtype=torch.bfloat16)]
[tensor([[-1.3516, -1.3672,  0.0070,  ..., -0.0571,  0.1367,  1.1250],
        [-1.7109,  1.3359, -0.2852,  ..., -1.0312, -0.0045, -0.2891]],
       dtype=torch.bfloat16)]


In [18]:
sae = SAE(
    in_hidden_state_size=activations_dataset.tensors[0].shape[1],
    sparse_hidden_state_size=4096,
    device='cuda'
)

In [19]:
train_sae(
    sae=sae,
    dataloader=dataloader,
    epochs=20,
    path_to_save=f'./models/{model_name_or_path.split('/')[1]}_{layer_name}_SAE.pt'
)

Epoch 1/20: 100%|██████████| 2/2 [00:00<00:00, 26.71it/s]


Epoch 1/20, Loss: 10.593750


Epoch 2/20: 100%|██████████| 2/2 [00:00<00:00, 183.49it/s]


Epoch 2/20, Loss: 7.796875


Epoch 3/20: 100%|██████████| 2/2 [00:00<00:00, 171.21it/s]


Epoch 3/20, Loss: 5.875000


Epoch 4/20: 100%|██████████| 2/2 [00:00<00:00, 132.40it/s]


Epoch 4/20, Loss: 4.765625


Epoch 5/20: 100%|██████████| 2/2 [00:00<00:00, 133.34it/s]


Epoch 5/20, Loss: 3.375000


Epoch 6/20: 100%|██████████| 2/2 [00:00<00:00, 150.30it/s]


Epoch 6/20, Loss: 2.710938


Epoch 7/20: 100%|██████████| 2/2 [00:00<00:00, 150.74it/s]


Epoch 7/20, Loss: 2.328125


Epoch 8/20: 100%|██████████| 2/2 [00:00<00:00, 154.65it/s]


Epoch 8/20, Loss: 1.804688


Epoch 9/20: 100%|██████████| 2/2 [00:00<00:00, 163.68it/s]


Epoch 9/20, Loss: 1.621094


Epoch 10/20: 100%|██████████| 2/2 [00:00<00:00, 160.41it/s]


Epoch 10/20, Loss: 1.640625


Epoch 11/20: 100%|██████████| 2/2 [00:00<00:00, 165.15it/s]


Epoch 11/20, Loss: 1.039062


Epoch 12/20: 100%|██████████| 2/2 [00:00<00:00, 159.14it/s]


Epoch 12/20, Loss: 1.234375


Epoch 13/20: 100%|██████████| 2/2 [00:00<00:00, 167.85it/s]


Epoch 13/20, Loss: 0.839844


Epoch 14/20: 100%|██████████| 2/2 [00:00<00:00, 158.04it/s]


Epoch 14/20, Loss: 0.958984


Epoch 15/20: 100%|██████████| 2/2 [00:00<00:00, 158.90it/s]


Epoch 15/20, Loss: 0.835938


Epoch 16/20: 100%|██████████| 2/2 [00:00<00:00, 158.94it/s]


Epoch 16/20, Loss: 0.718750


Epoch 17/20: 100%|██████████| 2/2 [00:00<00:00, 136.62it/s]


Epoch 17/20, Loss: 0.767578


Epoch 18/20: 100%|██████████| 2/2 [00:00<00:00, 127.36it/s]


Epoch 18/20, Loss: 0.689453


Epoch 19/20: 100%|██████████| 2/2 [00:00<00:00, 150.69it/s]


Epoch 19/20, Loss: 0.726562


Epoch 20/20: 100%|██████████| 2/2 [00:00<00:00, 170.03it/s]


Epoch 20/20, Loss: 0.794922
Saved SAE model to ./models/Qwen2-1.5B-Instruct_model.layers.27_SAE.pt


'./models/Qwen2-1.5B-Instruct_model.layers.27_SAE.pt'