In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from models.encoders import DinoViT_8
from models.data import SynthCardDataModule, JSRTDataModule

In [14]:
dinovit8 = DinoViT_8()

# data = SynthCardDataModule(batch_size=1, rate_maps=1.0, augmentation=False, cache=True)
data = JSRTDataModule(batch_size=5, augmentation=False)

Using cache found in /homes/bc1623/.cache/torch/hub/facebookresearch_dino_main


True


Loading Data: 100%|██████████| 210/210 [00:00<00:00, 72535.93it/s]
Loading Data: 100%|██████████| 12/12 [00:00<00:00, 43314.67it/s]
Loading Data: 100%|██████████| 25/25 [00:00<00:00, 57080.89it/s]


In [15]:
batch = next(iter(data.train_dataloader()))

In [6]:
with torch.no_grad():
    dinovit8(batch['image'])

In [9]:
def visualize_attention_maps(model, x, layers_to_visualize):
    # Forward pass through the model to get attention maps
    attention_maps = []

    def hook_fn(module, input, output):
        # Assuming the module output is the attention map
        # print(len(output)) # 2
        attention_maps.append(output[1])

    # Register hooks to capture attention maps from specified layers
    hooks = []
    for block in model.blocks:
        hook = block.attn.register_forward_hook(hook_fn)
        hooks.append(hook)

    # Forward pass to compute attention
    with torch.no_grad():
        _ = model(x)

    # Remove hooks after extraction
    for hook in hooks:
        hook.remove()

    # Visualize the attention maps
    for idx, att_map in enumerate(attention_maps):
        plt.figure(figsize=(10, 10))
        # plt.title(f'Layer {layers_to_visualize[idx]} Attention Map')
        plt.imshow(att_map.squeeze().sum(dim=0)[0, 1:].reshape(14, 14).cpu().numpy(), cmap='viridis')
        plt.colorbar()
        plt.show()

In [10]:
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
image = batch['image'][0].unsqueeze(0).repeat(1, 3, 1, 1)
visualize_attention_maps(model, image, [0, 1, 2, 3, 4, 5])

Using cache found in /homes/bc1623/.cache/torch/hub/facebookresearch_dino_main


RuntimeError: shape '[14, 14]' is invalid for input of size 784

<Figure size 1000x1000 with 0 Axes>

In [27]:
print(len(model.blocks))

12


In [11]:
# Copyright (c) Facebook, Inc. and its affiliates.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import random
import colorsys
import requests
from io import BytesIO

import skimage.io
from skimage.measure import find_contours
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms as pth_transforms
import numpy as np
from PIL import Image

import utils
import vision_transformer as vits


def apply_mask(image, mask, color, alpha=0.5):
    for c in range(3):
        image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
    return image


def random_colors(N, bright=True):
    """
    Generate random colors.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5):
    fig = plt.figure(figsize=figsize, frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax = plt.gca()

    N = 1
    mask = mask[None, :, :]
    # Generate random colors
    colors = random_colors(N)

    # Show area outside image boundaries.
    height, width = image.shape[:2]
    margin = 0
    ax.set_ylim(height + margin, -margin)
    ax.set_xlim(-margin, width + margin)
    ax.axis('off')
    masked_image = image.astype(np.uint32).copy()
    for i in range(N):
        color = colors[i]
        _mask = mask[i]
        # if blur:
        #     _mask = cv2.blur(_mask,(10,10))
        # Mask
        masked_image = apply_mask(masked_image, _mask, color, alpha)
        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        if contour:
            padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
            padded_mask[1:-1, 1:-1] = _mask
            contours = find_contours(padded_mask, 0.5)
            for verts in contours:
                # Subtract the padding and flip (y, x) to (x, y)
                verts = np.fliplr(verts) - 1
                p = Polygon(verts, facecolor="none", edgecolor=color)
                ax.add_patch(p)
    ax.imshow(masked_image.astype(np.uint8), aspect='auto')
    fig.savefig(fname)
    print(f"{fname} saved.")
    return

In [18]:
device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8', pretrained=True)
model.to(device)

img = batch['image'][1].unsqueeze(0).repeat(1, 3, 1, 1)
transform = pth_transforms.Compose([
    # pth_transforms.Resize(128),
    # pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img = transform(img)
# make the image divisible by the patch size
#w, h = img.shape[1] - img.shape[1] % 8, img.shape[2] - img.shape[2] % 8
# img = img[:, :, :w, :h]#.unsqueeze(0)
# print(w, h)

w_featmap = img.shape[-2] // 8
h_featmap = img.shape[-1] // 8

with torch.no_grad():
    attentions = model.get_last_selfattention(img.to(device))

nh = attentions.shape[1] # number of head
print(attentions.shape) # torch.Size([1, 6, 257, 257])
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

# if args.threshold is not None:
#     # we keep only a certain percentage of the mass
#     val, idx = torch.sort(attentions)
#     val /= torch.sum(val, dim=1, keepdim=True)
#     cumval = torch.cumsum(val, dim=1)
#     th_attn = cumval > (1 - None)
#     idx2 = torch.argsort(idx)
#     for head in range(nh):
#         th_attn[head] = th_attn[head][idx2[head]]
#     th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
#     # interpolate
#     th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=8, mode="nearest")[0].cpu().numpy()

attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=8, mode="nearest")[0].cpu().numpy()

# save attentions heatmaps
os.makedirs('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/', exist_ok=True)
torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/', "img.png"))
print(nh)
for j in range(nh):
    fname = os.path.join('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/', "attn-head" + str(j) + ".png")
    plt.imsave(fname=fname, arr=attentions[j], format='png')
    print(f"{fname} saved.")

# if args.threshold is not None:
#     image = skimage.io.imread(os.path.join('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/', "img.png"))
#     for j in range(nh):
#         display_instances(image, th_attn[j], fname=os.path.join('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/', "mask_th" + str(None) + "_head" + str(j) +".png"), blur=False)

Using cache found in /homes/bc1623/.cache/torch/hub/facebookresearch_dino_main


torch.Size([1, 12, 785, 785])
12
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head0.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head1.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head2.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head3.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head4.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head5.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head6.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head7.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_attention_maps/attn-head8.png saved.
/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dino_atte