In [1]:
%matplotlib inline

import torch
import torchvision
from torch import nn
import os
from transformer import *
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from PIL.ImageOps import invert

%config InlineBackend.figure_format = 'retina'
os.environ["CUDA_VISIBLE_DEVICES"]=""

import math

from PIL import Image
import requests

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from matplotlib import rc

# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')

/media/nicola/Data/Workspace/CORnet-SamedifferentTask


In [2]:
class SVRTVal(object):

    def __init__(self, model, dataset_path, angle=None, resize=False, normalize=False, image_set='val'):
        self.resize = resize
        self.normalize = normalize
        self.angle = angle
        self.image_set = image_set
        self.dataset_path = dataset_path
        self.name = image_set
        self.model = model
        self.data_loader = self.data()
        self.shape = (self.model.vit.patch_res, self.model.vit.patch_res)

    def data(self):
        transforms = []
        '''if self.angle is not None:
            transforms.append(torchvision.transforms.Pad(80, fill=(255, 255, 255)))
            transforms.append(torchvision.transforms.Lambda(
                lambda img: torchvision.transforms.functional.rotate(img, self.angle, resample=Image.BILINEAR)
            ))
            transforms.append(torchvision.transforms.CenterCrop(192))'''
        transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])
        ])

        if self.dataset_path is not None:
            val_dataset = self.dataset_path
            print('Using validation from {}'.format(val_dataset))
        else:
            val_dataset = FLAGS.data_path

        dataset = torchvision.datasets.ImageFolder(
            os.path.join(val_dataset, self.image_set),
            transforms)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  num_workers=0,
                                                  pin_memory=True)

        return data_loader

    def __call__(self, inp):
        self.model.eval()
        with torch.no_grad():
            #inp = inp.cuda()
            attn = self.model.get_last_selfattention(inp)
            
        return attn    

In [44]:
# load the model
checkpoint_path = "runs/from_scratch_uncertainty/problem_5/econvviut-hires-medium_adam_lr0.0001_28000-training-set_depth0_udepth9_dropout_data-augmentation_multiloss_epochs160/best_checkpoint.pth.tar"
dataset_path = "/media/nicola/SSD/Datasets/svrt-HEAD-d34ac2b/results_problem_5"
img_id = 56000

model_chkp = torch.load(checkpoint_path, map_location="cpu")
model = transformer_econvviut_hires_multiloss_medium(depth=0, u_depth=9)
# print(model_chkp['state_dict'].keys())
model.load_state_dict(model_chkp['state_dict'], strict=False)

# Construct the main class and compute attention
svrt_model = SVRTVal(model, dataset_path)
sample = svrt_model.data_loader.dataset[img_id]
print(sample)
sample = sample[0].unsqueeze(0) # emulate batch size 1
sattn = svrt_model(sample)[0]

print("Self-attention shape:", sattn.shape)

shape = svrt_model.shape
# average over heads
interesting_head = (sattn[0, :, 0, 0].view(-1, 1, 1) != sattn[0, :, :, :]).any(1).any(1)
print(interesting_head)
sattn = sattn.mean(dim=1)


# reshape
sattn = sattn[0, 1:, 1:].reshape(shape + shape)
print("Reshaped self-attention:", sattn.shape)

Conv strides: [1, 2, 2, 2]; Num patches: 16 x 16
Conv Model: <class 'transformer.vit.EquivariantConvModel'>
Vit Depth: 0; U Transf Depth: 9
Using validation from /media/nicola/SSD/Datasets/svrt-HEAD-d34ac2b/results_problem_5
(tensor([[[2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         ...,
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489],
         [2.2489, 2.2489, 2.2489,  ..., 2.2489, 2.2489, 2.2489]],

        [[2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
         [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
         [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
         ...,
         [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
         [2.4286, 2.4286, 2.4286,  ..., 2.4286, 2.4286, 2.4286],
         [2.4286, 2.4286, 2.4

In [None]:
sattn[4, 4, :, :]

In [None]:
# downsampling factor for the CNN
fact = 128 // 8

# let's select 4 reference points for visualization
# idxs = [(64, 64), (64, 64), (64, 64), (64, 64),]
idxs = [(35, 80)] * 4

# here we create the canvas
fig = plt.figure(constrained_layout=True, figsize=(25 * 0.7, 8.5 * 0.7))
# and we add one plot per reference point
gs = fig.add_gridspec(2, 4)
axs = [
    fig.add_subplot(gs[0, 0]),
    fig.add_subplot(gs[1, 0]),
    fig.add_subplot(gs[0, -1]),
    fig.add_subplot(gs[1, -1]),
]

# for each one of the reference points, let's plot the self-attention
# for that point
for idx_o, ax in zip(idxs, axs):
    idx = (idx_o[0] // fact, idx_o[1] // fact)
    print(idx)
    ax.imshow(sattn[..., idx[0], idx[1]], cmap='cividis', interpolation='nearest')
    ax.axis('off')
    ax.set_title(f'self-attention{idx_o}')

# and now let's add the central image, with the reference points as red circles
fcenter_ax = fig.add_subplot(gs[:, 1:-1])
im = sample.squeeze(0).permute(1, 2, 0).numpy()
fcenter_ax.imshow(im)
for (y, x) in idxs:
    scale = 1 # im.height / img.shape[-2]
    x = ((x // fact) + 0.5) * fact
    y = ((y // fact) + 0.5) * fact
    fcenter_ax.add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
    fcenter_ax.axis('off')

In [45]:
# INTERACTIVE

class AttentionVisualizer:
    def __init__(self, model):
        self.model = model

        self.url = ""
        self.cur_url = None
        self.pil_img = None
        self.tensor_img = None

        self.conv_features = None
        self.enc_attn_weights = None
        self.dec_attn_weights = None

        self.setup_widgets()

    def setup_widgets(self):
        self.sliders = [
            widgets.Text(
                value='14',
                placeholder='Type something',
                description='Img ID:',
                disabled=False,
                continuous_update=False,
                layout=widgets.Layout(width='100%')
            ),
            widgets.FloatSlider(min=0, max=0.99,
                        step=0.02, description='X coordinate', value=0.72,
                        continuous_update=False,
                        layout=widgets.Layout(width='50%')
                        ),
            widgets.FloatSlider(min=0, max=0.99,
                        step=0.02, description='Y coordinate', value=0.40,
                        continuous_update=False,
                        layout=widgets.Layout(width='50%')),
            widgets.Checkbox(
              value=False,
              description='Direction of self attention',
              disabled=False,
              indent=False,
              layout=widgets.Layout(width='50%'),
          ),
            widgets.Checkbox(
              value=False,
              description='Show red dot in attention',
              disabled=False,
              indent=False,
              layout=widgets.Layout(width='50%'),
          )
        ]
        self.o = widgets.Output()

    def compute_features(self, img):
        model = self.model
        # propagate through the model
        img = img.unsqueeze(0) # emulate batch size 1
        sattns = model(img)

        p_sattns = []
        for sattn in sattns:
            shape = self.model.shape
            # average over heads
            # interesting_head = (sattn[0, :, 0, 0].view(-1, 1, 1) != sattn[0, :, :, :]).any(1).any(1)
            # print(interesting_head)
            sattn = sattn.mean(dim=1)

            # reshape
            sattn = sattn[0, 1:, 1:].reshape(shape + shape)
            # print("Reshaped self-attention:", sattn.shape)
            
            p_sattns.append(sattn)
            
        self.enc_attn_weights = p_sattns[-1]
        self.enc_attns_weights = p_sattns
    
    def compute_on_image(self, url):
        if url != self.url:
            self.url = url
            img_id = int(url)
            sample = self.model.data_loader.dataset[img_id]
            self.tensor_img = sample[0]
            out_image = self.tensor_img * torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) + torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
            print(self.tensor_img.shape)
            self.pil_img = Image.fromarray((out_image.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
            self.compute_features(self.tensor_img)
    
    def update_chart(self, change):
        with self.o:
            clear_output()

            # j and i are the x and y coordinates of where to look at
            # sattn_dir is which direction to consider in the self-attention matrix
            # sattn_dot displays a red dot or not in the self-attention map
            url, j, i, sattn_dir, sattn_dot = [s.value for s in self.sliders]

            fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(9, 4))
            self.compute_on_image(url)

            # convert reference point to absolute coordinates
            j = int(j * self.tensor_img.shape[-1])
            i = int(i * self.tensor_img.shape[-2])

            # how much was the original image upsampled before feeding it to the model
            scale = self.pil_img.height / self.tensor_img.shape[-2]

            # compute the downsampling factor for the model
            # it should be 32 for standard DETR and 16 for DC5
            sattn = self.enc_attn_weights
            fact = 2 ** round(math.log2(self.tensor_img.shape[-1] / sattn.shape[-1]))

            # round the position at the downsampling factor
            x = ((j // fact) + 0.5) * fact
            y = ((i // fact) + 0.5) * fact

            axs[0].imshow(self.pil_img)
            axs[0].axis('off')
            axs[0].add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))

            idx = (i // fact, j // fact)
            
            if sattn_dir:
                sattn_map = sattn[idx[0], idx[1], ...]
            else:
                sattn_map = sattn[..., idx[0], idx[1]]
            
            # axs[1].imshow(sattn_map, cmap='cividis', interpolation='nearest')
            #if sattn_dot:
                #axs[1].add_patch(plt.Circle((idx[1],idx[0]), 1, color='r'))
            #axs[1].axis('off')
            #axs[1].set_title(f'self-attention{(i, j)}')
            
            xmin, xmax, ymin, ymax = (0, self.pil_img.width, 0, self.pil_img.height)
            axs[1].imshow(sattn_map, cmap='cividis', interpolation='nearest', extent=(xmin, xmax, ymin, ymax))
            axs[1].imshow(invert(self.pil_img), alpha=0.3, extent=(xmin, xmax, ymin, ymax))
            if sattn_dot:
                axs[1].add_patch(plt.Circle((idx[1],idx[0]), 1, color='r'))
            axs[1].set_title(f'self-attention{(i, j)}')
            axs[1].axis('off')
            
            fig_anim, axs_anim = plt.subplots(figsize=(9, 4))
            
            def update(k):
                axs_anim.axis('off')
                sattn = self.enc_attns_weights[k]
                if sattn_dir:
                    sattn_map = sattn[idx[0], idx[1], ...]
                else:
                    sattn_map = sattn[..., idx[0], idx[1]]
                axs_anim.set_title(f'self-attention{(i, j)}-timestep{k}')
                axs_anim.imshow(sattn_map, cmap='cividis', interpolation='nearest', extent=(xmin, xmax, ymin, ymax))
                axs_anim.imshow(invert(self.pil_img), alpha=0.3, extent=(xmin, xmax, ymin, ymax))
                if sattn_dot:
                    axs_anim.add_patch(plt.Circle((idx[1],idx[0]), 1, color='r'))
                    
            anim = FuncAnimation(fig_anim, update, frames=np.arange(0, 10), interval=500)
            
            plt.show()
            # anim.save('line.gif', dpi=80, writer='imagemagick')
            
            # plots the attention over time-step in a single figure
            fig, axs = plt.subplots(ncols=6, nrows=1, figsize=(9, 4))
            axs[0].imshow(self.pil_img)
            axs[0].axis('off')
            axs[0].add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
            
            def draw(k):
                axs[k+1].axis('off')
                sattn = self.enc_attns_weights[k]
                if sattn_dir:
                    sattn_map = sattn[idx[0], idx[1], ...]
                else:
                    sattn_map = sattn[..., idx[0], idx[1]]
                axs[k+1].set_title(f't={k}')
                axs[k+1].imshow(sattn_map, cmap='cividis', interpolation='nearest', extent=(xmin, xmax, ymin, ymax))
                axs[k+1].imshow(invert(self.pil_img), alpha=0.3, extent=(xmin, xmax, ymin, ymax))
#                 if sattn_dot:
#                     axs[k+1].add_patch(plt.Circle((idx[1],idx[0]), 1, color='r'))
            
            for k in range(0, 5):
                draw(k)
                
            plt.show()
        
    def run(self):
      for s in self.sliders:
          s.observe(self.update_chart, 'value')
      self.update_chart(None)
      url, x, y, d, sattn_d = self.sliders
      res = widgets.VBox(
      [
          url,
          widgets.HBox([x, y]),
          widgets.HBox([d, sattn_d]),
          self.o
      ])
      return res

In [46]:
w = AttentionVisualizer(svrt_model)
w.run()

VBox(children=(Text(value='14', continuous_update=False, description='Img ID:', layout=Layout(width='100%'), p…