<a href="https://colab.research.google.com/github/eric8he/SAE_ViTGPT/blob/main/EvalSAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install sae-lens

In [2]:
from sae_lens import SAE
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from typing import List
import torchvision
from torchvision import transforms
from torch import optim
import torch.nn.functional as F

In [3]:
# Configuration
BATCH_SIZE = 512  # Reduced batch size to accommodate training
NUM_BATCHES_PER_EPOCH = 200

In [4]:
# Initialize models and processors
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Initialize SAE
sae = SAE.load_from_pretrained(path="./", device=str(device))

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

In [5]:
class TrainableVisionEncoder(torch.nn.Module):
    def __init__(self, model, sae, target_layer):
        super().__init__()
        self.model = model
        self.sae = sae
        self.target_layer = target_layer
        self.target_act = None

        # Register hook to capture activations
        def gather_target_act_hook(mod, inputs, outputs):
            self.target_act = inputs[0]
            return outputs

        self.hook_handle = self.model.decoder.transformer.h[target_layer].register_forward_hook(
            gather_target_act_hook
        )

    def forward(self, pixel_values):
        # Get model outputs
        outputs = self.model.generate(pixel_values=pixel_values)

        # Get SAE reconstruction
        sae_encoded = self.sae.encode(self.target_act.to(torch.float32))
        sae_decoded = self.sae.decode(sae_encoded)

        return outputs, self.target_act, sae_decoded, sae_encoded

    def remove_hook(self):
        self.hook_handle.remove()

In [6]:
def process_batch_images(images):
  return feature_extractor(images=images, return_tensors="pt").pixel_values.to(device)

# Create dataset
imgnet = load_dataset("imagenet-1k", split="train", streaming=True)
ds = imgnet.shuffle(seed=42)
batches = ds.batch(batch_size=BATCH_SIZE)

In [7]:
# Create trainable model
trainable_model = TrainableVisionEncoder(model, sae, 9)
trainable_model.to(device)

TrainableVisionEncoder(
  (model): VisionEncoderDecoderModel(
    (encoder): ViTModel(
      (embeddings): ViTEmbeddings(
        (patch_embeddings): ViTPatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): ViTEncoder(
        (layer): ModuleList(
          (0-11): 12 x ViTLayer(
            (attention): ViTSdpaAttention(
              (attention): ViTSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
       

In [11]:
pics = []
neurons = {}

In [17]:
print("Starting training...")
batch_count = 0
with torch.no_grad():
  for batch in batches:
    if batch_count >= NUM_BATCHES_PER_EPOCH:
      break

    # Process images
    images = [Image.fromarray(np.array(item)) for item in batch["image"]]
    images = [i.convert(mode="RGB") if i.mode != "RGB" else i for i in images]
    batch_pixel_values = process_batch_images(images)

    # Forward pass
    _, _, _, decoded_vec = trainable_model(batch_pixel_values)

    # Store results
    for image, acts in zip(images, decoded_vec):
      pics.append(image.resize([s // 4 for s in image.size]))
      for n_idx in torch.nonzero(acts[0]):
        if n_idx.item() not in neurons:
          neurons[n_idx.item()] = [len(pics) - 1]
        else:
          neurons[n_idx.item()].append(len(pics) - 1)

    batch_count += 1

    print(f"Batch {batch_count}/{NUM_BATCHES_PER_EPOCH}")


# Clean up
trainable_model.remove_hook()

Starting training...




Batch 1/200
Batch 2/200
Batch 3/200
Batch 4/200




Batch 5/200
Batch 6/200
Batch 7/200
Batch 8/200
Batch 9/200
Batch 10/200
Batch 11/200
Batch 12/200
Batch 13/200
Batch 14/200
Batch 15/200
Batch 16/200
Batch 17/200
Batch 18/200
Batch 19/200
Batch 20/200
Batch 21/200
Batch 22/200
Batch 23/200
Batch 24/200
Batch 25/200
Batch 26/200
Batch 27/200
Batch 28/200
Batch 29/200
Batch 30/200
Batch 31/200
Batch 32/200
Batch 33/200
Batch 34/200
Batch 35/200
Batch 36/200
Batch 37/200
Batch 38/200
Batch 39/200
Batch 40/200
Batch 41/200
Batch 42/200
Batch 43/200
Batch 44/200
Batch 45/200
Batch 46/200
Batch 47/200
Batch 48/200
Batch 49/200
Batch 50/200
Batch 51/200
Batch 52/200
Batch 53/200
Batch 54/200
Batch 55/200
Batch 56/200
Batch 57/200
Batch 58/200
Batch 59/200
Batch 60/200
Batch 61/200
Batch 62/200
Batch 63/200
Batch 64/200
Batch 65/200
Batch 66/200
Batch 67/200
Batch 68/200
Batch 69/200
Batch 70/200
Batch 71/200
Batch 72/200
Batch 73/200
Batch 74/200
Batch 75/200
Batch 76/200
Batch 77/200
Batch 78/200
Batch 79/200
Batch 80/200
Batch 81/200
Batc

In [19]:
print(len(output))
print(output[0])
print(output[0][1].shape)

102400
(<PIL.Image.Image image mode=RGB size=104x125 at 0x7AB0FFAEC4F0>, tensor([[0.0000, 0.0000, 0.6383,  ..., 0.0000, 0.0000, 0.0000]]))
torch.Size([1, 24576])


In [None]:
from PIL import Image  # Assuming PIL images are used
import torch  # Assuming PyTorch tensors
from tqdm import tqdm

input_list = output

# Separate images and activations
images, activations = zip(*input_list)

# Concatenate all activation tensors into a single tensor for efficient processing
activations = torch.cat(activations, dim=0)  # Shape: [100000, 10]
print(activations.shape)

# Find sorted indices for each column (activation vector index)
sorted_indices = torch.argsort(activations, dim=0, descending=True)
print(sorted_indices.shape)

# Prepare sorted lists
sorted_images = [[images[idx] for idx in sorted_indices[:, i]] for i in range(activations.size(1))]
sorted_activations = [activations[sorted_indices[:, i], i] for i in range(activations.size(1))]

print("done building sorted lists")

# Build the output dictionary, trimming based on the 8th index (index 7 in 0-based indexing)
output_dict = {
    i: [(sorted_images[i][j], sorted_activations[i][j].item()) for j in range(len(sorted_images[i]))]
    for i in range(activations.size(1))
    if activations[sorted_indices[7, i], i].item() != 0  # Check 8th index for non-zero
}

print(f"Processed {len(input_list)} items into a dictionary with {len(output_dict)} keys.")

torch.Size([102400, 24576])


In [22]:
import pickle
with open("drive/MyDrive/arr-final.pkl", "wb") as f:
    pickle.dump(output, f)