In [1]:
import numpy as np
from PIL import Image
from datasets import load_dataset
from transformers import ViTForImageClassification
from torchvision import transforms
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam import run_dff_on_image

from modules import reshape_vit_huggingface, print_top_categories, run_grad_cam_on_image

In [None]:
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
image

In [None]:
# input
model_pretrain = 'google/vit-large-patch32-384'
image_size = (384, 384)

# model config
model = ViTForImageClassification.from_pretrained(model_pretrain)
category_dict = dict((v, k) for k, v in model.config.id2label.items())

# image input
image_resized = image.resize(image_size)
tensor_resized = transforms.ToTensor()(image_resized)

# Deep Feature Factorization, and those sneaky LayerNorm layers
- that does Non Negative Matrix Factorization on the features to cluster them into several feature concepts
- Every concept then gets a feature representation. We can associate every concept with the categories, by running the classifier on each of these representations, and displaying that in a legend next to the image

Ref: Tutorials [DFF](https://jacobgil.github.io/pytorch-gradcam-book/Deep%20Feature%20Factorizations.html)



In [None]:
# dff config
target_layer_dff = model.vit.layernorm
image_dff = run_dff_on_image(model=model,
                             target_layer=target_layer_dff,
                             classifier=model.classifier,
                             img_pil=image_resized,
                             img_tensor=tensor_resized,
                             reshape_transform=reshape_vit_huggingface,
                             n_components=4,
                             top_k=2)
display(Image.fromarray(image_dff))

# Grad-cam

In [None]:
# grad_cam config
targets_for_gradcam = [ClassifierOutputTarget(category_dict["Egyptian cat"]),
                       ClassifierOutputTarget(category_dict["remote control, remote"])]
target_layer_gradcam = model.vit.encoder.layer[-2].output
image_grad_cam = run_grad_cam_on_image(model=model,
                                       target_layer=target_layer_gradcam,
                                       targets_for_gradcam=targets_for_gradcam,
                                       input_tensor=tensor_resized,
                                       input_image=image_resized,
                                       reshape_transform=reshape_vit_huggingface)
display(Image.fromarray(image_grad_cam))

In [None]:
print_top_categories(model, tensor_resized)

In [None]:
# input
model_pretrain = 'google/vit-base-patch16-224-in21k'
image_size = (224, 224)
# model_pretrain = 'google/vit-large-patch32-384'
# image_size = (384, 384)

# model config
model = ViTForImageClassification.from_pretrained(model_pretrain)
category_dict = dict((v, k) for k, v in model.config.id2label.items())

# image input
image_resized = image.resize(image_size)
tensor_resized = transforms.ToTensor()(image_resized)

In [None]:
# dff config
target_layer_dff = model.vit.layernorm
image_dff = run_dff_on_image(model=model,
                             target_layer=target_layer_dff,
                             classifier=model.classifier,
                             img_pil=image_resized,
                             img_tensor=tensor_resized,
                             reshape_transform=reshape_vit_huggingface,
                             n_components=4,
                             top_k=2)
display(Image.fromarray(image_dff))

In [None]:
# grad_cam config
targets_for_gradcam = [ClassifierOutputTarget(category_dict["Egyptian cat"]),
                       ClassifierOutputTarget(category_dict["remote control, remote"])]
target_layer_gradcam = model.vit.encoder.layer[-2].output
image_grad_cam = run_grad_cam_on_image(model=model,
                                       target_layer=target_layer_gradcam,
                                       targets_for_gradcam=targets_for_gradcam,
                                       input_tensor=tensor_resized,
                                       input_image=image_resized,
                                       reshape_transform=reshape_vit_huggingface)
display(Image.fromarray(image_grad_cam))