In this colab we'll download two pre-trained image classification networks and dig into them a little.

In [None]:
!pip install transformers datasets
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow_datasets as tfds
import scipy.ndimage as nim
from transformers import AutoFeatureExtractor, ViTFeatureExtractor, ViTModel, ViTConfig
from PIL import Image
import requests
tfkl = tf.keras.layers

In [None]:
## Let's download some pics and resize them to match the 224x224 
## standard dimensions.  All grabbed from unsplash.com

## Feel free to change these with your own urls/uploaded images
urls = ["https://source.unsplash.com/lylCw4zcA7I",
        "https://source.unsplash.com/QJ2HGuSSQz0",
        "https://source.unsplash.com/p7tai9P7H-s",
        "https://source.unsplash.com/5U_28ojjgms",
        "https://source.unsplash.com/Gk8LG7dsHWA",
        "https://source.unsplash.com/1Fsb2C7hxQ0",
        "https://source.unsplash.com/CiUR8zISX60",
        "https://source.unsplash.com/uVnRa6mOLOM",
        "https://source.unsplash.com/DJ7bWa-Gwks"]

images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vits8')
images_resized = feature_extractor(images=images, return_tensors="pt")
for image in images:
  plt.imshow(image)
  plt.axis('off')
  plt.show()

# ResNet

In [None]:
# We can download one of the versions of ResNet through tensorflow
resnet = tf.keras.applications.ResNet50V2()
# Get ready for a bunch of output -- you can see all the layers, in all their glory
print(resnet.summary())

In [None]:
# There are 64 7x7 kernels in the first conv layer; let's look at them
first_conv_layer = resnet.layers[2]
weights = first_conv_layer.weights[0]
print("First conv layer weights shape:", weights.shape)
plt.figure(figsize=(10, 10))
for kernel_id in range(weights.shape[-1]):
  kernel = np.float32(weights[..., kernel_id])
  kernel = (kernel - kernel.min()) / (kernel.max() - kernel.min())
  plt.subplot(8, 8, kernel_id+1)
  plt.imshow(kernel)
  plt.axis('off')
plt.show()

In [None]:
images_pytorch = images_resized['pixel_values']
images_tensorflow = tf.transpose(images_pytorch.detach().numpy(), [0, 2, 3, 1])

In [None]:
# We can pass the images through the conv layer directly
conv_outputs = first_conv_layer(images_tensorflow)
print('Output from first conv layer:', conv_outputs.shape)

In [None]:
# We can visualize each of these channels, which are the convolution of the image with the corresponding [7, 7, 3] kernel
inches_per_subplot = 2.5
for kernel_id in range(16):
  kernel = np.float32(first_conv_layer.weights[0][..., kernel_id])
  kernel = (kernel - kernel.min()) / (kernel.max() - kernel.min())
  plt.figure(figsize=((len(images)+1)*inches_per_subplot, inches_per_subplot))
  plt.subplot(1, len(images)+1, 1)
  plt.imshow(kernel)
  plt.axis('off')
  plt.title(f'Kernel {kernel_id}', fontsize=16)
  for image_id in range(len(images)):
    plt.subplot(1, len(images)+1, image_id+2)
    plt.imshow(conv_outputs[image_id, ..., kernel_id])
    plt.axis('off')
  plt.tight_layout()
  plt.show()

In [None]:
# All right, here's the cool part, we're going to compute gradients of the 
# activation with respect to the input image

# To do so, we first need to turn the images into a tf.Variable so that it's
# automatically tracked by tensorflow's GradientTape
images_tensorflow_variable = tf.Variable(images_tensorflow)
for kernel_id in range(16):
  kernel = np.float32(first_conv_layer.weights[0][..., kernel_id])
  kernel = (kernel - kernel.min()) / (kernel.max() - kernel.min())
  plt.figure(figsize=((len(images)+1)*inches_per_subplot, inches_per_subplot))
  plt.subplot(1, len(images)+1, 1)
  plt.imshow(kernel)
  plt.title(f'Kernel {kernel_id}', fontsize=16)

  with tf.GradientTape() as tape:
    kernel_activation = tf.square(first_conv_layer(images_tensorflow_variable)[..., kernel_id])
    kernel_activation = tf.reduce_sum(kernel_activation)
  grads = np.float32(tape.gradient(kernel_activation, images_tensorflow_variable))
  # Note the gradient is the same size as the input image, and the color of the gradient is relevant
  grads = (grads - grads.min()) / (grads.max() - grads.min())
  for image_id in range(6):
    plt.subplot(1, len(images)+1, image_id+2)
    plt.imshow(grads[image_id])
    plt.axis('off')
  plt.show()

In [None]:
# We are going to dig into more than the first layer, so we'll make a new model
# which is easily defined using the input to the resnet and the output of one of
# the layers
# For simplicity, let's just consider the conv layers of the resnet
conv_inds = np.where([l.name[-4:]=='conv' for l in resnet.layers])[0]
print("Indices of conv layers:", conv_inds)

In [None]:
layer_id = 23
mini_model = tf.keras.Model(resnet.input, resnet.layers[layer_id].output)

intermed_outp = mini_model(images_tensorflow_variable)
print('Intermediate output shape:', intermed_outp.shape[1:])
for kernel_id in range(16):  # Just look at the first 16 features of this layer
  plt.figure(figsize=((len(images)+1)*inches_per_subplot, inches_per_subplot))
  for image_id in range(len(images)):
    plt.subplot(1, len(images), image_id+1)
    image = np.float32(intermed_outp[image_id, ..., kernel_id])
    image = (image - image.min()) / (image.max() - image.min())
    plt.imshow(image)
    plt.axis('off')
  plt.show()

In [None]:
# What are the gradients with respect to these later activations?
for kernel_id in range(16):
  with tf.GradientTape() as tape:
    kernel_activation = tf.square(mini_model(images_tensorflow_variable)[..., kernel_id])
    kernel_activation = tf.reduce_sum(kernel_activation)
  grads = np.float32(tape.gradient(kernel_activation, images_tensorflow_variable))
  plt.figure(figsize=(len(images)*inches_per_subplot, inches_per_subplot))
  for image_id in range(len(images)):
    plt.subplot(1, len(images), image_id+1)
    grad_image = grads[image_id]
    grad_image = (grad_image - grad_image.min()) / (grad_image.max() - grad_image.min())
    plt.imshow(grad_image)
    plt.axis('off')
  plt.show()

In [None]:
# What if we perturbed the original pixels in the direction of these gradients?
# And then, recompute the gradients and perturb the new image to maximize a specific activation

# We'll do a little better and use an adaptive optimizer
# Also rather than just boosting the activation values, we want the spread to be
# boosted so that it's clear where the boosting is happening
# (but definitely mess around with all of this, try different things)

update_step_size = 1e-3
num_perturb_steps = 500
moment_power = 2
inches_per_subplot = 5
for kernel_id in range(8):
  images_tensorflow_adj = tf.Variable(images_tensorflow)
  opt = tf.keras.optimizers.Adam(update_step_size)
  for update_step in range(num_perturb_steps):
    with tf.GradientTape() as tape:
      kernel_activation = mini_model(images_tensorflow_adj)[..., kernel_id]
      kernel_activation = tf.reshape(kernel_activation, [len(images), -1])
      kernel_activation_exp = tf.pow(kernel_activation, moment_power)
      loss = -tf.reduce_sum(tf.reduce_mean(kernel_activation_exp, axis=-1) - tf.pow(tf.reduce_mean(kernel_activation, axis=-1), moment_power))
    grads = np.float32(tape.gradient(loss, images_tensorflow_adj))
    opt.apply_gradients(zip([grads], [images_tensorflow_adj]))
    images_tensorflow_adj.assign(tf.clip_by_value(images_tensorflow_adj, -2.5, 2.5))
  plt.figure(figsize=(len(images)*inches_per_subplot, inches_per_subplot))
  for image_id in range(len(images)):
    plt.subplot(1, len(images), image_id+1)
    image = np.float32(images_tensorflow_adj[image_id])
    image = (image-image.min())/(image.max()-image.min())
    plt.imshow(image)
    plt.axis('off')
  plt.tight_layout()
  plt.show()

# Vision transformer

In [None]:
## We're downloading a self supervised vision transformer, meaning it didn't actually use any labels, just prior knowledge about augmentations
## Let's visualize the attention maps for the images
config = ViTConfig.from_pretrained('facebook/dino-vits8', output_hidden_states=True, output_attentions=True)
model = ViTModel.from_pretrained('facebook/dino-vits8', config=config)

outputs = model(**images_resized)

In [None]:
# Let's look at the attention maps.  There are 12 layers, with 6 attention heads 
# each. Then there are 28x28 = 784 patches per image, plus one placeholder token
# out front, and you have attention for every token to every other token (785x785)
[thing.shape for thing in outputs['attentions']]

In [None]:
inches_per_subplot = 2
num_attention_heads = 6
num_layers = 12
for image_ind in range(len(images)):
  plt.figure(figsize=(4, 4))
  plt.imshow(images[image_ind])
  plt.axis('off')
  plt.show()
  plt.figure(figsize=(num_layers*inches_per_subplot, num_attention_heads*inches_per_subplot))
  for i in range(num_layers):
    for j in range(num_attention_heads):
      plt.subplot(num_attention_heads, num_layers, i+j*num_layers+1)
      img = outputs['attentions'][i][image_ind, j].detach().numpy()
      img = img[0][1:].reshape([28, 28])
      ## upsample
      img = tf.image.resize(tf.reshape(img, [1, 28, 28, 1]), [224, 224], method='nearest')
      plt.imshow(img[0, ..., 0])
      
      if j==0:
        plt.title(f'Layer {i+1}', fontsize=18)
      if i==0:
        plt.ylabel(f'Attn head {j+1}', fontsize=18)
      plt.xticks([])
      plt.yticks([])
  plt.tight_layout()
  plt.show()
  print()