In [1]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

inputs = processor(images=image, return_tensors="pt")

In [3]:
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: Egyptian cat


In [4]:
# Extract embeddings from the ViT model (pooled, CLS, and patch-level)
import torch

# Ensure model is in eval mode
model.eval()

with torch.no_grad():
    # Use the existing `inputs` created earlier by the processor
    # Option A: use the encoder + pooler output (global image embedding)
    vit_outputs = model.vit(**inputs)

    # Some ViT implementations provide a `pooler_output`; otherwise use CLS token
    pooled = getattr(vit_outputs, "pooler_output", None)
    if pooled is None:
        pooled = vit_outputs.last_hidden_state[:, 0, :]
    print("Pooled embedding shape:", pooled.shape)  # (batch, hidden_dim)

    # Option B: full sequence embeddings (CLS token + patch tokens)
    seq_emb = vit_outputs.last_hidden_state
    print("Sequence embeddings shape (batch, seq_len, hidden_dim):", seq_emb.shape)

    # Option C: raw patch embeddings BEFORE the transformer encoder
    # inputs.pixel_values has shape (batch, channels, height, width)
    patch_embeds = model.vit.embeddings.patch_embeddings(inputs.pixel_values)
    # Depending on the HF ViT class, patch_embeds may already be (batch, seq_len, hidden_dim)
    print("Patch embeddings shape (before adding cls token/pos):", patch_embeds.shape)

    # Example: convert pooled embedding to numpy for downstream use
    pooled_np = pooled.cpu().numpy()
    print("Pooled embedding (first example, first 5 values):", pooled_np[0, :5])


Pooled embedding shape: torch.Size([1, 768])
Sequence embeddings shape (batch, seq_len, hidden_dim): torch.Size([1, 197, 768])
Patch embeddings shape (before adding cls token/pos): torch.Size([1, 196, 768])
Pooled embedding (first example, first 5 values): [0.29419428 0.8350226  1.903902   0.081404   1.0390357 ]
