In [1]:
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from transformers import ViTForImageClassification
from pathlib import Path

import sys; sys.path.append("../src/")
from stability import soft_stability_rate
from models import MaskedImageClassifier, CertifiedMuSImageClassifier
from image_utils import load_images_from_directory
from explanations import \
    get_lime_for_image, get_shap_for_image, get_intgrad_for_image, get_mfaba_for_image

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# Load images; they will be ordered the same every time.
images = load_images_from_directory("/shared_data0/helenjin/imagenet-sample-images/")
# images = images[:20]
print(images.shape)

torch.Size([1000, 3, 224, 224])


In [3]:
# Load the model. This should be our custom fine-tuned models later!
raw_vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
raw_vit.eval().to(device);

In [7]:
model = raw_vit

# Compute the total number of parameters
total_params = sum(p.numel() for p in model.parameters())

# Compute the number of trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

Total parameters: 86567656
Trainable parameters: 86567656


In [4]:
wrapped_vit = MaskedImageClassifier(raw_vit)
wrapped_vit.eval().to(device);

In [5]:
img = images[0]

In [6]:
intgrad_expln_alpha, intgrad_expln, intgrad_expln_attrs = get_intgrad_for_image(wrapped_vit, img, return_verbose=True)


x.grad tensor([[[[-0.0264, -0.0197, -0.0167,  ...,  0.0073,  0.0001,  0.0037],
          [-0.0112, -0.0039, -0.0033,  ...,  0.0015,  0.0024,  0.0009],
          [-0.0096, -0.0039, -0.0066,  ..., -0.0015,  0.0024,  0.0076],
          ...,
          [-0.0053,  0.0056,  0.0202,  ...,  0.0246,  0.0134,  0.0100],
          [-0.0027, -0.0091,  0.0019,  ...,  0.0150,  0.0028,  0.0079],
          [ 0.0244, -0.0001,  0.0035,  ...,  0.0247,  0.0204,  0.0373]],

         [[ 0.0046,  0.0009, -0.0002,  ...,  0.0282,  0.0232,  0.0338],
          [ 0.0109,  0.0081,  0.0068,  ...,  0.0148,  0.0154,  0.0195],
          [ 0.0115,  0.0085,  0.0065,  ...,  0.0121,  0.0128,  0.0219],
          ...,
          [ 0.0191,  0.0169,  0.0227,  ...,  0.0053,  0.0006,  0.0148],
          [ 0.0238,  0.0177,  0.0189,  ...,  0.0012, -0.0026,  0.0120],
          [ 0.0270,  0.0147,  0.0145,  ...,  0.0119,  0.0124,  0.0287]],

         [[-0.0013, -0.0069, -0.0090,  ...,  0.0016,  0.0015,  0.0141],
          [-0.0049, -0.

KeyboardInterrupt: 