# Vision Transformer On Sample Images

In [40]:
# Packages
from PIL import Image
import os
from torch import nn
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from transformers import ViTForImageClassification
from transformers import ViTImageProcessor
import torch

In [41]:
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load Images

Non-blurred

In [42]:
sample_images = [Image.open('../../../Images/single class samples/No Blur/' + img_path) for img_path in os.listdir('../../../Images/single class samples/No Blur')]
# for img in sample_images:
#     plt.imshow(img)
#     plt.show()

## Load Model

In [43]:
vision_transformer = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
vision_transformer.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

## Get all items except for last layer

In [44]:
# get layers between certain indices
def slice_model(original_model, from_layer=None, to_layer=None):
    return nn.Sequential(*list(original_model.children())[from_layer:to_layer])

# get all layers except the last one
model_vit_feature = slice_model(vision_transformer, to_layer=-1).to('cpu')
# print summary of layers
#print(vision_transformer)
print(model_vit_feature)

Sequential(
  (0): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=

## Preprocess Images

In [45]:
# Pre-process image
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
def preprocess_image(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    pixel_values = inputs.pixel_values
    return pixel_values

# Sample images pre-processed
sample_images_preprocessed = [preprocess_image(img) for img in sample_images]

## Extract Embeddings

In [46]:
def get_embedding(pixel_values):
  # Run the image through the model
  with torch.no_grad():
    outputs = model_vit_feature(pixel_values)
  # Extract the CLS token embedding
  embedding = outputs.last_hidden_state[:, 0, :]
  return embedding

# Get embeddings
sample_images_embeddings = [get_embedding(img) for img in sample_images_preprocessed]

In [47]:
# Check embedding extraction code
print(model_vit_feature(sample_images_preprocessed[0]))
print(get_embedding(sample_images_preprocessed[0]))

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.6706, -1.0077,  1.2121,  ...,  0.4465, -0.0515, -1.0279],
         [-0.2915,  0.1586,  1.3501,  ...,  0.7603, -0.2999,  0.1151],
         [ 0.5654, -0.2566,  1.0829,  ...,  1.4887, -0.6268, -0.5772],
         ...,
         [-0.0767, -1.2364,  0.8187,  ...,  0.3521, -0.8355, -0.7715],
         [-0.0942, -0.9965,  0.8878,  ...,  0.5728, -1.1754, -0.5289],
         [ 0.9328, -0.5469,  0.8430,  ...,  0.8277,  0.0075, -1.0804]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=None, hidden_states=None, attentions=None)
tensor([[ 6.7058e-01, -1.0077e+00,  1.2121e+00, -1.5262e-01, -1.1599e+00,
          3.5820e-01, -9.2690e-02,  1.6586e-01,  7.7243e-01, -4.1376e-02,
          8.8223e-01, -8.7197e-01,  7.6799e-01, -1.5431e+00, -1.4675e+00,
          6.4721e-01,  5.8765e-01,  9.7706e-01, -3.2920e-01,  1.7110e+00,
          1.3971e+00,  1.1511e+00,  1.2677e+00,  2.4742e+00, -6.5813e-01,
         -5.0540e-01, -7.7757e-01, -1.7593

In [49]:
from transformers import ViTImageProcessor, ViTModel

new_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
new_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

inputs = new_processor(images=sample_images[0], return_tensors="pt")
outputs = new_model(**inputs)
last_hidden_state = outputs.last_hidden_state

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

torch.Size([1, 197, 768])


In [52]:
print(outputs)
print(last_hidden_state)
print(last_hidden_state.shape)
print(outputs.pooler_output.shape)

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.0307, -0.2018,  0.1547,  ...,  0.0699,  0.1101, -0.2558],
         [-0.1421, -0.0840,  0.2028,  ...,  0.0760, -0.0321, -0.1833],
         [ 0.0235, -0.0660,  0.1798,  ...,  0.2299,  0.0648, -0.3490],
         ...,
         [-0.0284, -0.3854,  0.0872,  ...,  0.0911, -0.2026, -0.3441],
         [-0.0162, -0.3857,  0.1141,  ...,  0.1286, -0.2539, -0.3561],
         [ 0.1285, -0.0645,  0.1626,  ...,  0.0578,  0.0265, -0.2455]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-6.4970e-02, -6.2431e-01, -3.5794e-01,  1.8799e-01, -3.5673e-01,
          4.8694e-01, -6.2571e-02,  3.0874e-01, -4.4544e-01,  3.7174e-01,
          2.7562e-01, -9.2562e-02, -9.1622e-02,  2.1673e-01, -4.4156e-01,
          2.0910e-02,  4.4827e-02, -2.3749e-01, -2.0859e-01,  6.7395e-02,
          1.4700e-02,  1.1039e-01, -7.1225e-01,  5.5117e-01, -1.2027e-01,
         -2.9632e-02, -3.1065e-01,  5.0272e-01,  2.3829e-01, -7.5610e-01,
          8