<a href="https://colab.research.google.com/github/eric8he/SAE_ViTGPT/blob/main/SAE_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install sae_lens

In [2]:
from sae_lens import SAE
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, default_data_collator
import torch
from PIL import Image
import numpy as np
from datasets import load_dataset
import sys
from torch.utils.data import DataLoader
from typing import List
import torchvision
from torchvision import transforms

In [3]:
%%capture
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

batch_size = 512  # Adjust based on your GPU memory
gen_kwargs = {"max_length": 16, "num_beams": 4}

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

In [4]:
def process_batch_images(images):
    return feature_extractor(images=images, return_tensors="pt").pixel_values.to(device)

def gather_residual_activations(model, target_layer, batch_inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]
        return outputs

    handle = model.decoder.transformer.h[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        a = model.generate(pixel_values=batch_inputs, **gen_kwargs)
    handle.remove()
    return target_act

In [5]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.9.hook_resid_pre",
    device="cuda:0",
)

blocks.9.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [6]:
# Create dataset
imgnet = load_dataset("imagenet-1k", split="validation", streaming=True)
num_samples_to_take = 100000
ds = imgnet.take(num_samples_to_take)
batches = ds.batch(batch_size=batch_size)

Downloading builder script:   0%|          | 0.00/4.58k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/85.4k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

The repository for imagenet-1k contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/imagenet-1k.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


In [7]:
data = []
print(f"Processing in batches of {batch_size}")

batch_idx = 0
num_batches = 500

for batch in batches:
    images = [Image.fromarray(np.array(item)) for item in batch["image"]]
    images = [i.convert(mode="RGB") if i.mode != "RGB" else i for i in images]
    batch_pixel_values = process_batch_images(images)

    # Get activations for the batch
    target_act = gather_residual_activations(model, 9, batch_pixel_values)
    sae_acts = sae.encode(target_act.to(torch.float32))

    # Store results
    for idx, (image, acts) in enumerate(zip(images, sae_acts)):
        data.append((image.resize([s // 2 for s in image.size]), acts.cpu()))

    batch_idx += 1
    if batch_idx % 10 == 0:
        print(f"Batch {batch_idx} done")
    if batch_idx > num_batches:
        print("Done!")
        break

    if (batch_idx + 1) * batch_size % 100 == 0:
        print(f"{(batch_idx + 1) * batch_size} examples done")
        print("filesize:", sys.getsizeof(data))

Processing in batches of 512


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Batch 10 done
Batch 20 done


KeyboardInterrupt: 

In [8]:
# store data file as pickle
import pickle
with open("drive/MyDrive/arr-big.pkl", "wb") as f:
    pickle.dump(data, f)

In [9]:
print(len(data))

10240


In [40]:
!ls -l

total 2510888
-rw-r--r-- 1 root root 2571140508 Dec  5 10:15 arr.pkl
drwxr-xr-x 1 root root       4096 Dec  3 19:31 sample_data
