# Vision Transformer B On Sample Images

In [1]:
# Packages
from PIL import Image
import os
from torch import nn
import warnings
warnings.filterwarnings('ignore')
from transformers import ViTImageProcessor
from transformers import ViTModel
import torch

In [2]:
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load Images

Non-blurred

In [3]:
sample_images = [Image.open('../../../Images/single class samples/No Blur/' + img_path) for img_path in os.listdir('../../../Images/single class samples/No Blur')]

# for img in sample_images:
#     plt.imshow(img)
#     plt.show()

## Load Model

In [4]:
vision_transformer = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
vision_transformer.to(device)

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation(

## Preprocess Images

In [5]:
# Pre-process image
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

# Function to run images one at a time for convenience
def preprocess_image(image):
    return processor(images=image, return_tensors="pt")

# Sample images pre-processed
sample_images_preprocessed = [preprocess_image(img) for img in sample_images]

## Extract Embeddings

In [6]:
# Function to extract embedding
def get_embedding(image):
  # Run the image through the model
  with torch.no_grad():
    outputs = vision_transformer(**image)
  # Extract the pooler output
  return outputs.pooler_output

# Get embeddings
sample_images_embeddings = [get_embedding(img) for img in sample_images_preprocessed]

In [8]:
# Print first embedding
print(sample_images_embeddings[0])
print(sample_images_embeddings[0].shape)

tensor([[-6.4970e-02, -6.2431e-01, -3.5794e-01,  1.8799e-01, -3.5673e-01,
          4.8694e-01, -6.2571e-02,  3.0874e-01, -4.4544e-01,  3.7174e-01,
          2.7562e-01, -9.2562e-02, -9.1622e-02,  2.1673e-01, -4.4156e-01,
          2.0910e-02,  4.4827e-02, -2.3749e-01, -2.0859e-01,  6.7395e-02,
          1.4700e-02,  1.1039e-01, -7.1225e-01,  5.5117e-01, -1.2027e-01,
         -2.9632e-02, -3.1065e-01,  5.0272e-01,  2.3829e-01, -7.5610e-01,
          8.2920e-02,  5.3965e-01,  7.4124e-02, -5.2682e-01,  3.9484e-01,
         -1.3055e-01,  3.8558e-01,  1.3085e-01,  5.6282e-01,  1.2323e-02,
          6.9629e-02, -1.4153e-01,  6.5359e-01, -3.1777e-01,  4.8312e-02,
          3.2245e-01, -2.9875e-01, -1.3998e-01, -8.5723e-02, -3.3001e-01,
         -8.1837e-02,  6.5450e-02,  1.0227e-01,  6.5339e-01, -3.0757e-01,
          1.6409e-01,  1.2940e-01,  6.1121e-02,  9.4881e-03,  4.0508e-01,
          4.7414e-01, -3.3056e-02, -8.6009e-02, -2.4015e-02,  2.4787e-01,
         -5.4722e-01,  5.5199e-02,  8.