In [1]:
import argparse 
import torch

from datasets import load_dataset
from transformers import AutoImageProcessor

from lens import HookedVisionTransformer, plot_heatmaps

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name_or_path = "google/vit-base-patch16-224"
model = HookedVisionTransformer.from_pretrained(model_name_or_path, device)
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)
dataset = load_dataset(
    "cifar10",
    cache_dir="./../cache",
    task="image-classification"
)

with torch.no_grad():
    features = image_processor(dataset['test']['image'][0], return_tensors="pt")
    features.to(device)
    outputs, cache = model.run_with_cache(**features)
cache.keys()




dict_keys(['embeddings', 'encoder.0.ln1', 'encoder.0.attn.q', 'encoder.0.attn.k', 'encoder.0.attn.v', 'encoder.0.ln2', 'encoder.0.intermediate', 'encoder.1.ln1', 'encoder.1.attn.q', 'encoder.1.attn.k', 'encoder.1.attn.v', 'encoder.1.ln2', 'encoder.1.intermediate', 'encoder.2.ln1', 'encoder.2.attn.q', 'encoder.2.attn.k', 'encoder.2.attn.v', 'encoder.2.ln2', 'encoder.2.intermediate', 'encoder.3.ln1', 'encoder.3.attn.q', 'encoder.3.attn.k', 'encoder.3.attn.v', 'encoder.3.ln2', 'encoder.3.intermediate', 'encoder.4.ln1', 'encoder.4.attn.q', 'encoder.4.attn.k', 'encoder.4.attn.v', 'encoder.4.ln2', 'encoder.4.intermediate', 'encoder.5.ln1', 'encoder.5.attn.q', 'encoder.5.attn.k', 'encoder.5.attn.v', 'encoder.5.ln2', 'encoder.5.intermediate', 'encoder.6.ln1', 'encoder.6.attn.q', 'encoder.6.attn.k', 'encoder.6.attn.v', 'encoder.6.ln2', 'encoder.6.intermediate', 'encoder.7.ln1', 'encoder.7.attn.q', 'encoder.7.attn.k', 'encoder.7.attn.v', 'encoder.7.ln2', 'encoder.7.intermediate', 'encoder.8.ln1'

In [None]:
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_name_or_path',
        type=str, 
        default='google/vit-base-patch16-224', 
        help='model name from HuggingFace hub'
    )
    args = parser.parse_args()

    main(args)