<a href="https://colab.research.google.com/github/eric8he/SAE_ViTGPT/blob/main/TrainSAELayer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install sae-lens

In [2]:
from sae_lens import SAE
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from typing import List
import torchvision
from torchvision import transforms
from torch import optim
import torch.nn.functional as F

In [13]:
# Configuration
TRAIN_ALL_LAYERS = False  # Set to True to train the entire model
LEARNING_RATE = 1e-4
BATCH_SIZE = 1024  # Reduced batch size to accommodate training
NUM_EPOCHS = 3
TARGET_LAYER = 9
NUM_BATCHES_PER_EPOCH = 100

In [4]:
# Initialize models and processors
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Initialize SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{TARGET_LAYER}.hook_resid_pre",
    device=str(device),
)

config.json:   0%|          | 0.00/4.61k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/982M [00:00<?, ?B/s]

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/241 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]

blocks.9.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [14]:
class TrainableVisionEncoder(torch.nn.Module):
    def __init__(self, model, sae, target_layer):
        super().__init__()
        self.model = model
        self.sae = sae
        self.target_layer = target_layer
        self.target_act = None

        # Register hook to capture activations
        def gather_target_act_hook(mod, inputs, outputs):
            self.target_act = inputs[0]
            return outputs

        self.hook_handle = self.model.decoder.transformer.h[target_layer].register_forward_hook(
            gather_target_act_hook
        )

    def forward(self, pixel_values):
        # Get model outputs
        outputs = self.model.generate(pixel_values=pixel_values)

        # Get SAE reconstruction
        sae_encoded = self.sae.encode(self.target_act.to(torch.float32))
        sae_decoded = self.sae.decode(sae_encoded)

        return outputs, self.target_act, sae_decoded

    def remove_hook(self):
        self.hook_handle.remove()

In [15]:
def train_model():
    def process_batch_images(images):
        return feature_extractor(images=images, return_tensors="pt").pixel_values.to(device)

    # Set up training parameters
    if TRAIN_ALL_LAYERS:
        model.train()
        trainable_params = list(model.parameters()) + list(sae.parameters())
    else:
        model.eval()
        trainable_params = sae.parameters()

    optimizer = optim.Adam(trainable_params, lr=LEARNING_RATE)

    # Create dataset
    imgnet = load_dataset("imagenet-1k", split="train", streaming=True)
    ds = imgnet.shuffle(seed=42)
    batches = ds.batch(batch_size=BATCH_SIZE)

    # Create trainable model
    trainable_model = TrainableVisionEncoder(model, sae, TARGET_LAYER)
    trainable_model.to(device)

    print("Starting training...")
    for epoch in range(NUM_EPOCHS):
        total_loss = 0
        batch_count = 0

        for batch in batches:
            if batch_count >= NUM_BATCHES_PER_EPOCH:
                break

            # Process images
            images = [Image.fromarray(np.array(item)) for item in batch["image"]]
            images = [i.convert(mode="RGB") if i.mode != "RGB" else i for i in images]
            batch_pixel_values = process_batch_images(images)

            # Forward pass
            optimizer.zero_grad()
            outputs, target_act, sae_decoded = trainable_model(batch_pixel_values)

            # Calculate losses
            reconstruction_loss = F.mse_loss(sae_decoded, target_act)

            # Add L1 sparsity loss for SAE
            l1_loss = torch.mean(torch.abs(sae.encode(target_act.to(torch.float32))))

            # Combine losses
            loss = reconstruction_loss + 0.001 * l1_loss

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

            if batch_count % 10 == 0:
                print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Batch {batch_count}/{NUM_BATCHES_PER_EPOCH}, "
                      f"Loss: {total_loss/batch_count:.6f}")

        avg_loss = total_loss / batch_count
        print(f"Epoch {epoch+1} complete, Average Loss: {avg_loss:.6f}")

    # Clean up
    trainable_model.remove_hook()

    return trainable_model

In [16]:
trained_model = train_model()

# Save the trained SAE
save_path = f"trained_sae_layer_{TARGET_LAYER}"
sae.save_model(save_path)
print(f"Trained SAE saved to {save_path}")

Starting training...
Epoch 1/3, Batch 10/100, Loss: 2.272266
Epoch 1/3, Batch 20/100, Loss: 1.864353
Epoch 1/3, Batch 30/100, Loss: 1.603935
Epoch 1/3, Batch 40/100, Loss: 1.418858
Epoch 1/3, Batch 50/100, Loss: 1.281867
Epoch 1/3, Batch 60/100, Loss: 1.175657
Epoch 1/3, Batch 70/100, Loss: 1.088944
Epoch 1/3, Batch 80/100, Loss: 1.016665
Epoch 1/3, Batch 90/100, Loss: 0.954988
Epoch 1/3, Batch 100/100, Loss: 0.902394
Epoch 1 complete, Average Loss: 0.902394
Epoch 2/3, Batch 10/100, Loss: 0.395146
Epoch 2/3, Batch 20/100, Loss: 0.381537
Epoch 2/3, Batch 30/100, Loss: 0.370123
Epoch 2/3, Batch 40/100, Loss: 0.358287
Epoch 2/3, Batch 50/100, Loss: 0.347966
Epoch 2/3, Batch 60/100, Loss: 0.338764
Epoch 2/3, Batch 70/100, Loss: 0.330144
Epoch 2/3, Batch 80/100, Loss: 0.321874
Epoch 2/3, Batch 90/100, Loss: 0.313894
Epoch 2/3, Batch 100/100, Loss: 0.306618
Epoch 2 complete, Average Loss: 0.306618
Epoch 3/3, Batch 10/100, Loss: 0.230306
Epoch 3/3, Batch 20/100, Loss: 0.225386
Epoch 3/3, Batc