In [1]:
import torch
from transformers import ViTForImageClassification
from tqdm.notebook import tqdm
import time
import matplotlib.pyplot as plt

import sys; sys.path.append("../src/")
from models import SmoothMaskedImageClassifier
from banal import sample_level_k_weights
from image_utils import load_images_from_directory

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
torch.manual_seed(1234)
images = load_images_from_directory("/home/antonxue/foo/imagenet-sample-images")
images = images[torch.randperm(len(images))[:10]].to(device)
print(images.shape)

torch.Size([10, 3, 224, 224])


In [3]:
class BinarizedImageClassifier(torch.nn.Module):
    def __init__(self, masked_image_classifier, image):
        super().__init__()
        self.masked_image_classifier = masked_image_classifier
        self.register_buffer("image", image); assert image.ndim == 3
        
    def forward(self, alpha):
        x = self.image.unsqueeze(0).repeat(alpha.size(0), 1, 1, 1)
        return self.masked_image_classifier(x, alpha)

In [4]:
n = 196
lambdas = [1.0, 0.9, 0.8]
ks = [14 * i for i in range(0, 15)]

In [5]:
lambda_masses, lambda_vars = [], []

with torch.no_grad():
    for lambda_ in lambdas:
        vit_model = SmoothMaskedImageClassifier(
            ViTForImageClassification.from_pretrained("google/vit-base-patch16-224"),
            lambda_ = lambda_,
            num_samples = 16,
            grid_size = (14, 14)
        )

        k_masses, k_vars = [], []
        pbar = tqdm(ks)
        for k in pbar:
            image_masses, image_vars = [], []
            for image in images:
                bin_model = BinarizedImageClassifier(vit_model, image).eval().to(device)
                out = sample_level_k_weights(
                    bin_model, n, k, num_subsets=8, input_samples=16
                )
                image_masses.append(out["average_mass"].sum().item())
                image_vars.append(out["average_variance"].sum().item())
            k_masses.append(torch.tensor(image_masses).mean().item())
            k_vars.append(torch.tensor(image_vars).mean().item())
            pbar.set_description(f"lambda {lambda_:.3f}, k {k}")

        lambda_masses.append(torch.tensor(k_masses))
        lambda_vars.append(torch.tensor(k_vars))

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
fs = 14
plt.clf()
fig, ax = plt.subplots(figsize=(6,4))

ax.plot(ks[1:], lambda_masses[0][1:].numpy(), label="$\lambda = 1.0$")
ax.plot(ks[1:], lambda_masses[1][1:].numpy(), label="$\lambda = 0.9$")
ax.plot(ks[1:], lambda_masses[2][1:].numpy(), label="$\lambda = 0.8$")

# ax.set_xticks([0, 28, 56, 84, 112, 140, 168, 196])
ax.set_xticks([28, 56, 84, 112, 140, 168, 196])

ax.set_xlabel("Degree", fontsize=fs)
ax.set_ylabel("Avg. Spectral Mass", fontsize=fs)
ax.tick_params(axis="x", labelsize=fs-2)
ax.tick_params(axis="y", labelsize=fs-2)

ax.legend(title="VIT", loc="upper right", title_fontsize=fs, fontsize=fs)
plt.savefig("../figures/vit_spectrum_vs_lambda.png", bbox_inches="tight")