In [1]:
import torch

In [2]:
from paper_qsa_patch_sampler import PatchSampleF as PaperQSAPatchSampler
from qsa_patch_sampler import QSAPatchSampler

In [3]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.rand((batch_size, 128, 128, 128)),
    torch.rand((batch_size, 256, 64, 64)),
    torch.rand((batch_size, 512, 32, 32)),
]

In [4]:
from qsa_patch_sampler import QSAType


custom_patch_sampler = QSAPatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    qsa_type=QSAType.GLOBAL,
    max_spatial_size=64*64,
    device=torch.device('cpu')
)
paper_patch_sampler = PaperQSAPatchSampler(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=False
)

### Test reselection without MLP

In [5]:
iters = 1

for _ in range(iters):
  with torch.no_grad():
    layer_outs = [
        torch.rand((batch_size, 64, 256, 256)),
        torch.rand((batch_size, 128, 128, 128)),
        torch.rand((batch_size, 256, 64, 64)),
        torch.rand((batch_size, 512, 32, 32)),
    ]
    selected_patches, patch_idx, attn_map = custom_patch_sampler(
        layer_outs,
        apply_mlp=False
    )
    reselected_patches, reselected_patch_idx, reselected_attn_map = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=patch_idx,
        attn_map_per_layer=attn_map,
        apply_mlp=False
    )
    for layer_idx in range(0, len(layer_outs)):
      assert torch.equal(
          selected_patches[layer_idx],
          reselected_patches[layer_idx]
      )
      if patch_idx[layer_idx] != None:
        assert torch.equal(
            patch_idx[layer_idx],
            reselected_patch_idx[layer_idx]
        )
      else:
        assert torch.equal(
            attn_map[layer_idx],
            reselected_attn_map[layer_idx]
        )

### Test reselection with MLP

In [6]:
iters = 1

for _ in range(iters):
  with torch.no_grad():
    layer_outs = [
        torch.rand((batch_size, 64, 256, 256)),
        torch.rand((batch_size, 128, 128, 128)),
        torch.rand((batch_size, 256, 64, 64)),
        torch.rand((batch_size, 512, 32, 32)),
    ]
    selected_patches, patch_idx, attn_map = custom_patch_sampler(
        layer_outs,
        apply_mlp=True
    )
    reselected_patches, reselected_patch_idx, reselected_attn_map = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=patch_idx,
        attn_map_per_layer=attn_map,
        apply_mlp=True
    )
    for layer_idx in range(0, len(layer_outs)):
      assert torch.equal(
          selected_patches[layer_idx],
          reselected_patches[layer_idx]
      )
      if patch_idx[layer_idx] != None:
        assert torch.equal(
            patch_idx[layer_idx],
            reselected_patch_idx[layer_idx]
        )
      else:
        assert torch.equal(
            attn_map[layer_idx],
            reselected_attn_map[layer_idx]
        )

### Compare against paper patch sampler (without MLP)

In [7]:
paper_patch_sampler = PaperQSAPatchSampler(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=False
)
custom_patch_sampler = QSAPatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    qsa_type=QSAType.GLOBAL,
    max_spatial_size=64*64,
    device=torch.device('cpu')
)

In [8]:
from einops import rearrange

iters = 1

for _ in range(iters):
  with torch.no_grad():
    layer_outs = [
        torch.rand((batch_size, 64, 512, 512)),
        torch.rand((batch_size, 128, 256, 256)),
        torch.rand((batch_size, 256, 128, 128)),
        torch.rand((batch_size, 256, 64, 64)),
        torch.rand((batch_size, 256, 64, 64)),
    ]
    cust_patches, cust_patch_idx, cust_attn_map = custom_patch_sampler(
        layer_outs,
        apply_mlp=False
    )
    cust_patches_reselected, _, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_patch_idx,
        attn_map_per_layer=cust_attn_map,
        apply_mlp=False
    )
    paper_patches, paper_patch_idx, paper_attn_map = paper_patch_sampler(
        feats=layer_outs,
        num_patches=num_patches_per_layer,
    )

    for layer_idx, layer_selection in enumerate(
        zip(
            cust_patches,
            cust_patches,
            paper_patches
        )
    ):
      if layer_idx >= 3:
        p_cust, p_cust_reselected, p_paper_selected = layer_selection

        if p_cust.shape != p_paper_selected.shape:
          p_cust = rearrange(p_cust, 'b n d -> (b n) d')

        if cust_attn_map[layer_idx] != None:
          assert (
              torch.allclose(
                  cust_attn_map[layer_idx],
                  paper_attn_map[layer_idx]
              )
          )
        else:
          assert (
              torch.allclose(
                  cust_patch_idx[layer_idx],
                  paper_patch_idx[layer_idx]
              )
          )

        print(torch.linalg.norm(p_cust, dim=-1))
        print(torch.linalg.norm(p_paper_selected, dim=-1))
        print(p_cust.shape)
        print(p_paper_selected.shape)

        print(p_cust)
        print(p_cust_reselected)
        print(p_paper_selected)
        print('--')
        assert (torch.allclose(p_cust, p_paper_selected))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000,

### Compare against paper patch sampler (with MLP)

In [9]:
paper_patch_sampler = PaperQSAPatchSampler(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)
custom_patch_sampler = QSAPatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    qsa_type=QSAType.GLOBAL,
    max_spatial_size=64*64,
    device=torch.device('cpu')
)

layer_outs = [
    torch.rand((batch_size, 64, 512, 512)),
    torch.rand((batch_size, 128, 256, 256)),
    torch.rand((batch_size, 256, 128, 128)),
    torch.rand((batch_size, 256, 64, 64)),
    torch.rand((batch_size, 256, 64, 64)),
]
cust_patches, _, _ = custom_patch_sampler(layer_outs, apply_mlp=True)
for layer_idx in range(len(cust_patches)):
  setattr(
      paper_patch_sampler,
      f'mlp_{layer_idx}',
      getattr(
          custom_patch_sampler,
          f'mlp_{layer_idx}',
      )
  )
paper_patch_sampler.mlp_init = True

In [10]:
from einops import rearrange

iters = 100

for _ in range(iters):
  with torch.no_grad():
    layer_outs = [
        torch.rand((batch_size, 64, 512, 512)),
        torch.rand((batch_size, 128, 256, 256)),
        torch.rand((batch_size, 256, 128, 128)),
        torch.rand((batch_size, 256, 64, 64)),
        torch.rand((batch_size, 256, 64, 64)),
    ]
    cust_patches, cust_patch_idx, cust_attn_map = custom_patch_sampler(
        layer_outs,
        apply_mlp=True
    )
    cust_patches_reselected, _, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_patch_idx,
        attn_map_per_layer=cust_attn_map,
        apply_mlp=True
    )
    paper_patches, paper_patch_idx, paper_attn_map = paper_patch_sampler(
        feats=layer_outs,
        num_patches=num_patches_per_layer,
    )

    for layer_idx, layer_selection in enumerate(
        zip(
            cust_patches,
            cust_patches,
            paper_patches
        )
    ):
      if layer_idx >= 3:
        p_cust, p_cust_reselected, p_paper_selected = layer_selection

        if p_cust.shape != p_paper_selected.shape:
          p_cust = rearrange(p_cust, 'b n d -> (b n) d')

        if cust_attn_map[layer_idx] != None:
          assert (
              torch.allclose(
                  cust_attn_map[layer_idx],
                  paper_attn_map[layer_idx]
              )
          )
        else:
          assert (
              torch.allclose(
                  cust_patch_idx[layer_idx],
                  paper_patch_idx[layer_idx]
              )
          )

        print(torch.linalg.norm(p_cust, dim=-1))
        print(torch.linalg.norm(p_paper_selected, dim=-1))
        print(p_cust.shape)
        print(p_paper_selected.shape)

        print(p_cust)
        print(p_cust_reselected)
        print(p_paper_selected)
        print('--')
        assert (torch.allclose(p_cust, p_paper_selected))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000,