In [1]:
from qsa_patch_sampler import QSAType
from qsa_patch_sampler import QSAPatchSampler
from paper_qsa_local_patch_sampler import PatchSampleF
from einops import rearrange

In [2]:
import torch

batch_size = 2
patch_embedding_dim = 256
num_patches_per_layer = 256

In [3]:
custom_patch_sampler = QSAPatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    qsa_type=QSAType.LOCAL,
    max_spatial_size=64*64,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=False
)

with torch.no_grad():
  layer_outs = [
      torch.randn((batch_size, 3, 256, 256)),
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  cust_patches, cust_idx, cust_attn_map = custom_patch_sampler(
      layer_outs,
      apply_mlp=False
  )
  for layer_idx in range(len(cust_patches)):
    setattr(
        paper_patch_sampler,
        f'mlp_{layer_idx}',
        getattr(
            custom_patch_sampler,
            f'mlp_{layer_idx}',
        )
    )
  paper_patch_sampler.mlp_init = True

### No MLP

In [4]:
iters = 10

for _ in range(iters):
  layer_outs = [
      torch.randn((batch_size, 3, 256, 256)),
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  with torch.no_grad():
    custom_patches, custom_idx, custom_attn = custom_patch_sampler(
        layer_outs,
        apply_mlp=False
    )
    custom_patches_2, custom_idx_2, custom_attn_2 = custom_patch_sampler(
        layer_outs,
        custom_idx,
        custom_attn,
        apply_mlp=False
    )
    paper_patches, paper_idx, paper_attn = paper_patch_sampler(
        feats=layer_outs,
        num_patches=num_patches_per_layer
    )

    for (layer_idx, layer_out) in enumerate(layer_outs):
      _, _, H, W = layer_out.shape

      if layer_idx >= 3:
        assert torch.equal(custom_attn[layer_idx], custom_attn_2[layer_idx])

        assert torch.equal(
            custom_attn_2[layer_idx],
            paper_attn[layer_idx].reshape(custom_attn_2[layer_idx].shape)
        )

        assert torch.equal(
            custom_patches[layer_idx],
            custom_patches_2[layer_idx]
        )
        assert torch.allclose(
            input=rearrange(custom_patches_2[layer_idx], 'b n d -> (b n) d'),
            other=paper_patches[layer_idx]
        )

### With MLP

In [5]:
custom_patch_sampler = QSAPatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    qsa_type=QSAType.LOCAL,
    max_spatial_size=64*64,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)

with torch.no_grad():
  layer_outs = [
      torch.randn((batch_size, 3, 256, 256)),
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  cust_patches, cust_idx, cust_attn_map = custom_patch_sampler(
      layer_outs,
      apply_mlp=True
  )
  for layer_idx in range(len(cust_patches)):
    setattr(
        paper_patch_sampler,
        f'mlp_{layer_idx}',
        getattr(
            custom_patch_sampler,
            f'mlp_{layer_idx}',
        )
    )
  paper_patch_sampler.mlp_init = True

In [6]:
iters = 10

for _ in range(iters):
  layer_outs = [
      torch.randn((batch_size, 3, 256, 256)),
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  with torch.no_grad():
    custom_patches, custom_idx, custom_attn = custom_patch_sampler(
        layer_outs,
        apply_mlp=True
    )
    custom_patches_2, custom_idx_2, custom_attn_2 = custom_patch_sampler(
        layer_outs,
        custom_idx,
        custom_attn,
        apply_mlp=True
    )
    paper_patches, paper_idx, paper_attn = paper_patch_sampler(
        feats=layer_outs,
        num_patches=num_patches_per_layer
    )

    for (layer_idx, layer_out) in enumerate(layer_outs):
      _, _, H, W = layer_out.shape

      if layer_idx >= 3:
        assert torch.equal(custom_attn[layer_idx], custom_attn_2[layer_idx])

        assert torch.equal(
            custom_attn_2[layer_idx],
            paper_attn[layer_idx].reshape(custom_attn_2[layer_idx].shape)
        )

        assert torch.equal(
            custom_patches[layer_idx],
            custom_patches_2[layer_idx]
        )
        assert torch.allclose(
            input=rearrange(custom_patches_2[layer_idx], 'b n d -> (b n) d'),
            other=paper_patches[layer_idx]
        )