<a href="https://colab.research.google.com/github/eric8he/SAE_ViTGPT/blob/main/EvalSAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install sae-lens

In [2]:
from sae_lens import SAE
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from typing import List
import torchvision
from torchvision import transforms
from torch import optim
import torch.nn.functional as F

In [3]:
# Configuration
TRAIN_ALL_LAYERS = False  # Set to True to train the entire model
LEARNING_RATE = 1e-4
BATCH_SIZE = 128  # Reduced batch size to accommodate training
NUM_EPOCHS = 3
TARGET_LAYER = 9
NUM_BATCHES_PER_EPOCH = 200

In [5]:
# Initialize models and processors
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Initialize SAE
sae = SAE.load_from_pretrained(path="./", device=str(device))

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

In [15]:
class TrainableVisionEncoder(torch.nn.Module):
    def __init__(self, model, sae, target_layer):
        super().__init__()
        self.model = model
        self.sae = sae
        self.target_layer = target_layer
        self.target_act = None

        # Register hook to capture activations
        def gather_target_act_hook(mod, inputs, outputs):
            self.target_act = inputs[0]
            return outputs

        self.hook_handle = self.model.decoder.transformer.h[target_layer].register_forward_hook(
            gather_target_act_hook
        )

    def forward(self, pixel_values):
        # Get model outputs
        outputs = self.model.generate(pixel_values=pixel_values)

        # Get SAE reconstruction
        sae_encoded = self.sae.encode(self.target_act.to(torch.float32))
        sae_decoded = self.sae.decode(sae_encoded)

        return outputs, self.target_act, sae_decoded, sae_encoded

    def remove_hook(self):
        self.hook_handle.remove()

In [16]:
def process_batch_images(images):
  return feature_extractor(images=images, return_tensors="pt").pixel_values.to(device)

# Create dataset
imgnet = load_dataset("imagenet-1k", split="train", streaming=True)
ds = imgnet.shuffle(seed=42)
batches = ds.batch(batch_size=BATCH_SIZE)

In [17]:
# Create trainable model
trainable_model = TrainableVisionEncoder(model, sae, TARGET_LAYER)
trainable_model.to(device)

TrainableVisionEncoder(
  (model): VisionEncoderDecoderModel(
    (encoder): ViTModel(
      (embeddings): ViTEmbeddings(
        (patch_embeddings): ViTPatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): ViTEncoder(
        (layer): ModuleList(
          (0-11): 12 x ViTLayer(
            (attention): ViTSdpaAttention(
              (attention): ViTSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
       

In [18]:
output = []

In [19]:
print("Starting training...")
batch_count = 0
with torch.no_grad():
  for batch in batches:
    if batch_count >= NUM_BATCHES_PER_EPOCH:
      break

    # Process images
    images = [Image.fromarray(np.array(item)) for item in batch["image"]]
    images = [i.convert(mode="RGB") if i.mode != "RGB" else i for i in images]
    batch_pixel_values = process_batch_images(images)

    # Forward pass
    _, _, _, decoded_vec = trainable_model(batch_pixel_values)

    # Store results
    for image, acts in zip(images, decoded_vec):
      output.append((image.resize([s // 4 for s in image.size]), acts.cpu()))

    batch_count += 1

    if batch_count % 10 == 0:
      print(f"Batch {batch_count}/{NUM_BATCHES_PER_EPOCH}")


# Clean up
trainable_model.remove_hook()

Starting training...




Batch 10/200




Batch 20/200
Batch 30/200
Batch 40/200
Batch 50/200
Batch 60/200
Batch 70/200
Batch 80/200
Batch 90/200
Batch 100/200
Batch 110/200
Batch 120/200
Batch 130/200
Batch 140/200
Batch 150/200
Batch 160/200
Batch 170/200
Batch 180/200
Batch 190/200
Batch 200/200


In [21]:
print(len(output))
print(output[0])
print(output[0][1].shape)

25600
(<PIL.Image.Image image mode=RGB size=104x125 at 0x7FF38A419030>, tensor([[0.0000, 0.0000, 0.6383,  ..., 0.0000, 0.0000, 0.0000]]))
torch.Size([1, 24576])


In [22]:
import pickle
with open("drive/MyDrive/arr-final.pkl", "wb") as f:
    pickle.dump(output, f)