In [1]:
import json
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image
import tabulate
import timm.models.swin_transformer as st
import torch
from einops import rearrange
from IPython.display import Image, display
from skimage.filters import gaussian
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.swin_transformer import swin_tiny_patch4_window7_224
from torchvision.transforms import CenterCrop, Normalize, ToTensor

torch.hub.set_dir(Path("~/torchhub").expanduser().resolve().as_posix())
p = Path("10-SwinTransformerExploration/")

IMAGENET_CLASSES = json.loads(Path.read_text(p / "imagenet_classes.json"))
IMAGENET_CLASSES = {int(k): v for k, v in IMAGENET_CLASSES.items()}

What combinations of architecture parameters are ok?

In [2]:
def foo(img_size, patch_size, window_size, embed_dim, num_heads):
    img = torch.rand(1, 3, img_size, img_size)
    model = st.SwinTransformer(
        img_size=img_size,
        patch_size=patch_size,
        num_classes=0,
        window_size=window_size,
        depths=(2, 2, 2, 2),
        num_heads=(num_heads, num_heads, num_heads, num_heads),
        embed_dim=embed_dim,
    )
    model.forward_features(img)


df = []
embed_dim = 128
num_heads = 4
for img_size, patch_size, window_size in product(
    [224, 256, 384],
    [4, 8, 16, 32],
    [7, 8, 12],
):
    num_patches = img_size / patch_size
    num_windows = num_patches / window_size
    ok = "ok "
    try:
        foo(img_size, patch_size, window_size, embed_dim, num_heads)
    except Exception as e:
        ok = str(e)
    df.append((img_size, patch_size, window_size, num_patches, num_windows, ok))

df = pd.DataFrame(
    df,
    columns=[
        "img_size",
        "patch_size",
        "window_size",
        "num_patches",
        "num_windows",
        "ok",
    ],
)
with pd.option_context("display.max_rows", None):
    display(df.round(1))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Unnamed: 0,img_size,patch_size,window_size,num_patches,num_windows,ok
0,224,4,7,56.0,8.0,ok
1,224,4,8,56.0,7.0,"shape '[1, 3, 8, 3, 8, 1]' is invalid for inpu..."
2,224,4,12,56.0,4.7,"shape '[1, 4, 12, 4, 12, 1]' is invalid for in..."
3,224,8,7,28.0,4.0,x size (7*7) are not even.
4,224,8,8,28.0,3.5,"shape '[1, 3, 8, 3, 8, 1]' is invalid for inpu..."
5,224,8,12,28.0,2.3,"shape '[1, 2, 12, 2, 12, 1]' is invalid for in..."
6,224,16,7,14.0,2.0,x size (7*7) are not even.
7,224,16,8,14.0,1.8,"shape '[1, 1, 8, 1, 8, 1]' is invalid for inpu..."
8,224,16,12,14.0,1.2,"shape '[1, 1, 12, 1, 12, 1]' is invalid for in..."
9,224,32,7,7.0,1.0,shift_size must in 0-window_size


Test image

In [3]:
img_pil = PIL.Image.open(p / "cat-dog.jpg")
img_pil = CenterCrop(min(img_pil.size))(img_pil)
img_pil = img_pil.resize((224, 224))
img = ToTensor()(img_pil)
img = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(img)
img.shape

torch.Size([3, 224, 224])

Hooks to collect intermediate features

In [4]:
model = swin_tiny_patch4_window7_224(pretrained=True)
model.eval()

feats = []
model.patch_embed.register_forward_hook(lambda m, i, o: feats.append(o.detach()))
for layer in model.layers:
    layer.blocks[-1].register_forward_hook(lambda m, i, o: feats.append(o.detach()))

with torch.no_grad():
    logits = model(img[None, :, :, :])[0]

table = []
table.append(("img", img_size, img_size, 3))
table.append(
    ("patch", np.sqrt(feats[0].shape[1]), np.sqrt(feats[0].shape[1]), feats[0].shape[2])
)
for i, f in enumerate(feats[1:]):
    B, HW, C = f.shape
    H = W = int(np.sqrt(HW))
    table.append((f"layer {i}", H, W, C))
print(tabulate.tabulate(table, headers=["", "H", "W", "C"], floatfmt=".0f"))

           H    W    C
-------  ---  ---  ---
img      384  384    3
patch     56   56   96
layer 0   56   56   96
layer 1   28   28  192
layer 2   14   14  384
layer 3    7    7  768


Gradient-based explanations (with some spatial gaussian smoothing)

In [5]:
model = swin_tiny_patch4_window7_224(pretrained=True)
model.eval()

img.requires_grad_()
logits = model(img[None, :, :, :])[0]
topk = logits.softmax(-1).topk(5)

for i in range(5):
    grad, *_ = torch.autograd.grad(topk.values[i], img, retain_graph=True)
    grad = grad.norm(dim=0).numpy()
    grad = gaussian(grad, sigma=3)

    fig, axs = plt.subplots(1, 2, figsize=(6, 3), sharex=True, sharey=True)
    axs[0].set_title(f"{topk.values[i]:.1%} {IMAGENET_CLASSES[int(topk.indices[i])]}")
    axs[0].imshow(img_pil)
    axs[1].set_title("Input gradient")
    axs[1].imshow(grad, interpolation="none", vmin=0)

    fig.tight_layout()
    fig.set_facecolor("white")
    fig.savefig(p / f"gradient-{i}.png")
    plt.close(fig)
    display(Image(url=p / f"gradient-{i}.png"))

Hooks to collect attention maps

In [6]:
attns = {}


def make_hook(name, res, win, shift):
    def hook(m, i, o):
        attns[name] = (res, win, shift, o.detach())

    return hook


model = st.swin_tiny_patch4_window7_224(pretrained=True)
model.eval()

for name, module in model.named_modules():
    if not isinstance(module, st.SwinTransformerBlock):
        continue
    module.attn.softmax.register_forward_hook(
        make_hook(name, module.input_resolution, module.window_size, module.shift_size)
    )

_ = model(img[None, :, :, :])

print(
    tabulate.tabulate(
        [
            (name, res, win, shift, tuple(attn.shape))
            for name, (res, win, shift, attn) in attns.items()
        ],
        headers=["name", "resolution", "window", "shift", "attn"],
    )
)

name               resolution      window    shift  attn
-----------------  ------------  --------  -------  ---------------
layers.0.blocks.0  (56, 56)             7        0  (64, 3, 49, 49)
layers.0.blocks.1  (56, 56)             7        3  (64, 3, 49, 49)
layers.1.blocks.0  (28, 28)             7        0  (16, 6, 49, 49)
layers.1.blocks.1  (28, 28)             7        3  (16, 6, 49, 49)
layers.2.blocks.0  (14, 14)             7        0  (4, 12, 49, 49)
layers.2.blocks.1  (14, 14)             7        3  (4, 12, 49, 49)
layers.2.blocks.2  (14, 14)             7        0  (4, 12, 49, 49)
layers.2.blocks.3  (14, 14)             7        3  (4, 12, 49, 49)
layers.2.blocks.4  (14, 14)             7        0  (4, 12, 49, 49)
layers.2.blocks.5  (14, 14)             7        3  (4, 12, 49, 49)
layers.3.blocks.0  (7, 7)               7        0  (1, 24, 49, 49)
layers.3.blocks.1  (7, 7)               7        0  (1, 24, 49, 49)


Attention visualization: for each layer, average over the heads, undo shift and merge windows

In [7]:
for name, (res, win, shift, attn) in attns.items():
    print(name)
    print(f"{res=} {win=} {shift=}")

    # attn: [win_h*win_w, heads, patch_h*patch_w, patch_h*patch_w]
    #                               ^queries^         ^keys^
    res_h, res_w = res
    win_h = win_w = int(np.sqrt(attn.shape[0]))
    q_patch_h = q_patch_w = int(np.sqrt(attn.shape[2]))
    k_patch_h = k_patch_w = int(np.sqrt(attn.shape[3]))
    px_per_patch = np.array(model.patch_embed.img_size) // np.array(res)
    px_per_window = px_per_patch * np.array([win_h, win_w])
    assert attn.shape[2] == attn.shape[3]
    assert res_h == q_patch_h * win_h
    assert res_w == q_patch_w * win_w
    print("attn", tuple(attn.shape))
    print(f"num windows: {win_h}x{win_w} = {win_h*win_w}")
    print(f"num patches per window: {q_patch_h}x{q_patch_w} = {q_patch_h*q_patch_w}")
    print(
        f"pixels per patch {px_per_patch[0]}x{px_per_patch[1]} = {np.prod(px_per_patch)}"
    )
    print(
        f"pixels per window {px_per_window[0]}x{px_per_window[1]} = {np.prod(px_per_window)}"
    )

    # Reduce heads by averaging
    # attn: [win_h*win_w, queries_patch_h*queries_patch_w, keys_patch_h*keys_patch_w]
    attn = attn.mean(1)

    # Undo windowing -> Creates block-sparse matrix
    # attn: [(win_h*win_w)*(patch_h*patch_w), (win_h*win_w)*(patch_h*patch_w)]
    #                 ^queries^                          ^keys^
    attn = torch.block_diag(*attn)

    # Merge windows into height and width
    attn = rearrange(
        attn,
        "(q_win_h q_win_w q_patch_h q_patch_w) (k_win_h k_win_w k_patch_h k_patch_w)"
        "->"
        "(q_win_h q_patch_h) (q_win_w q_patch_w) (k_win_h k_patch_h) (k_win_w k_patch_w)",
        q_win_h=win_h,
        q_win_w=win_w,
        k_win_h=win_h,
        k_win_w=win_w,
        q_patch_h=q_patch_h,
        q_patch_w=q_patch_w,
        k_patch_h=k_patch_h,
        k_patch_w=k_patch_w,
    )
    assert attn.shape == (res_h, res_w, res_h, res_w)

    # Undo roll of *_patch_* dimensions
    if shift != 0:
        attn = torch.roll(attn, (shift, shift, shift, shift), dims=(0, 1, 2, 3))

    fig, axs = plt.subplots(1, 2 + 5, figsize=5 * np.array([2 + 5, 1]))
    axs[0].imshow(img_pil)
    axs[0].set_title("Input Image")

    axs[1].imshow(attn.mean((0, 1)), interpolation="none", vmin=0)
    axs[1].set_title("Avg attn received by each patch")

    for ax, q_hw in zip(
        axs[2:], [(0, 0), (-1, 0), (res_h // 2, res_w // 2), (0, -1), (-1, -1)]
    ):
        ax.imshow(attn[q_hw[0], q_hw[1], :, :], interpolation="none", vmin=0)
        ax.set_title(f"Where does query patch {q_hw} attend?")

    fig.tight_layout()
    fig.set_facecolor("white")
    fig.savefig(p / f"attn.{name}.png")
    plt.close(fig)
    display(Image(url=p / f"attn.{name}.png"))

layers.0.blocks.0
res=(56, 56) win=7 shift=0
attn (64, 3, 49, 49)
num windows: 8x8 = 64
num patches per window: 7x7 = 49
pixels per patch 4x4 = 16
pixels per window 32x32 = 1024


layers.0.blocks.1
res=(56, 56) win=7 shift=3
attn (64, 3, 49, 49)
num windows: 8x8 = 64
num patches per window: 7x7 = 49
pixels per patch 4x4 = 16
pixels per window 32x32 = 1024


layers.1.blocks.0
res=(28, 28) win=7 shift=0
attn (16, 6, 49, 49)
num windows: 4x4 = 16
num patches per window: 7x7 = 49
pixels per patch 8x8 = 64
pixels per window 32x32 = 1024


layers.1.blocks.1
res=(28, 28) win=7 shift=3
attn (16, 6, 49, 49)
num windows: 4x4 = 16
num patches per window: 7x7 = 49
pixels per patch 8x8 = 64
pixels per window 32x32 = 1024


layers.2.blocks.0
res=(14, 14) win=7 shift=0
attn (4, 12, 49, 49)
num windows: 2x2 = 4
num patches per window: 7x7 = 49
pixels per patch 16x16 = 256
pixels per window 32x32 = 1024


layers.2.blocks.1
res=(14, 14) win=7 shift=3
attn (4, 12, 49, 49)
num windows: 2x2 = 4
num patches per window: 7x7 = 49
pixels per patch 16x16 = 256
pixels per window 32x32 = 1024


layers.2.blocks.2
res=(14, 14) win=7 shift=0
attn (4, 12, 49, 49)
num windows: 2x2 = 4
num patches per window: 7x7 = 49
pixels per patch 16x16 = 256
pixels per window 32x32 = 1024


layers.2.blocks.3
res=(14, 14) win=7 shift=3
attn (4, 12, 49, 49)
num windows: 2x2 = 4
num patches per window: 7x7 = 49
pixels per patch 16x16 = 256
pixels per window 32x32 = 1024


layers.2.blocks.4
res=(14, 14) win=7 shift=0
attn (4, 12, 49, 49)
num windows: 2x2 = 4
num patches per window: 7x7 = 49
pixels per patch 16x16 = 256
pixels per window 32x32 = 1024


layers.2.blocks.5
res=(14, 14) win=7 shift=3
attn (4, 12, 49, 49)
num windows: 2x2 = 4
num patches per window: 7x7 = 49
pixels per patch 16x16 = 256
pixels per window 32x32 = 1024


layers.3.blocks.0
res=(7, 7) win=7 shift=0
attn (1, 24, 49, 49)
num windows: 1x1 = 1
num patches per window: 7x7 = 49
pixels per patch 32x32 = 1024
pixels per window 32x32 = 1024


layers.3.blocks.1
res=(7, 7) win=7 shift=0
attn (1, 24, 49, 49)
num windows: 1x1 = 1
num patches per window: 7x7 = 49
pixels per patch 32x32 = 1024
pixels per window 32x32 = 1024
