In [1]:
from patch_sampler import PatchSampler
from paper_patch_sampler import PatchSampleF

In [2]:
import torch

batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 256

In [3]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)

with torch.no_grad():
  layer_outs = [
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  cust_patches, cust_idx = custom_patch_sampler(
      layer_outs,
      apply_mlp=True
  )
  for layer_idx in range(len(cust_patches)):
    setattr(
        paper_patch_sampler,
        f'mlp_{layer_idx}',
        getattr(
            custom_patch_sampler,
            f'mlp_{layer_idx}',
        )
    )
  paper_patch_sampler.mlp_init = True

### Selecting correct patches

In [4]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=False
)

In [5]:
from einops import rearrange
import torch.nn.functional as F

iters = 100
batch_size = 1

for _ in range(iters):
  layer_outs = [
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  with torch.no_grad():
    custom_patches, patches_idx = custom_patch_sampler(
        layer_outs,
        apply_mlp=False
    )

    for (layer_idx, layer_out) in enumerate(layer_outs):
      rearrange_layer_out = rearrange(layer_out, 'b c h w -> b (h w) c')

      for batch_idx in range(batch_size):
        expected_patches = torch.index_select(input=rearrange_layer_out[0],
                                              index=patches_idx[layer_idx][0], dim=0)
        expected_normalized_patches = F.normalize(expected_patches, p=2, dim=-1)

        assert torch.equal(expected_normalized_patches,
                           custom_patches[layer_idx][0])

In [6]:
from einops import rearrange
import torch.nn.functional as F

iters = 100
batch_size = 4

for _ in range(iters):
  layer_outs = [
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  with torch.no_grad():
    custom_patches, patches_idx = custom_patch_sampler(
        layer_outs,
        apply_mlp=False
    )

    for (layer_idx, layer_out) in enumerate(layer_outs):
      rearrange_layer_out = rearrange(layer_out, 'b c h w -> b (h w) c')

      for batch_idx in range(batch_size):
        expected_patches = torch.index_select(input=rearrange_layer_out[batch_idx],
                                              index=patches_idx[layer_idx][batch_idx], dim=0)
        expected_normalized_patches = F.normalize(expected_patches, p=2, dim=-1)

        assert torch.equal(expected_normalized_patches,
                           custom_patches[layer_idx][batch_idx])

## Direct comparison with paper implementation

### With MLP projection

In [7]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)

with torch.no_grad():
  layer_outs = [
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
  for layer_idx in range(len(cust_patches)):
    setattr(
        paper_patch_sampler,
        f'mlp_{layer_idx}',
        getattr(
            custom_patch_sampler,
            f'mlp_{layer_idx}',
        )
    )
  paper_patch_sampler.mlp_init = True

In [8]:
batch_size = 1
iters = 100

for _ in range(iters):
  layer_outs = [
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  with torch.no_grad():
    custom_patches, custom_idx = custom_patch_sampler(layer_outs)
    custom_patches_2, custom_idx_2 = custom_patch_sampler(
        layer_outs, custom_idx)
    paper_patches, paper_idx = paper_patch_sampler(
        feats=layer_outs,
        num_patches=num_patches_per_layer,
        patch_ids=[
            idx.flatten()
            for idx in custom_idx
        ]
    )
    paper_idx = [
        idx.reshape(batch_size, -1)
        for idx in paper_idx
    ]
    for i, j in zip(custom_idx, custom_idx_2):
      assert torch.equal(i, j)
    for i, j in zip(custom_idx_2, paper_idx):
      assert torch.equal(i, j)

    for p, q in zip(custom_patches, custom_patches_2):
      assert torch.equal(p, q)
    for p, q in zip(custom_patches_2, paper_patches):
      assert torch.allclose(p, q)

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


### Without MLP projection

In [9]:
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=False
)

In [10]:
iters = 100
batch_size = 1

for _ in range(iters):
  layer_outs = [
      torch.randn((batch_size, 128, 128, 128)),
      torch.randn((batch_size, 256, 64, 64)),
      torch.randn((batch_size, 512, 32, 32)),
  ]
  with torch.no_grad():
    custom_patches, custom_idx = custom_patch_sampler(
        layer_outs, apply_mlp=False)
    custom_patches_2, custom_idx_2 = custom_patch_sampler(
        layer_outs, custom_idx, apply_mlp=False)

    paper_patches, paper_idx = paper_patch_sampler(
        feats=layer_outs,
        num_patches=num_patches_per_layer,
        patch_ids=[
            idx.flatten()
            for idx in custom_idx
        ]
    )
    paper_idx = [
        idx.reshape(batch_size, -1)
        for idx in paper_idx
    ]
    for i, j in zip(custom_idx, custom_idx_2):
      assert torch.equal(i, j)
    for i, j in zip(custom_idx_2, paper_idx):
      assert torch.equal(i, j)

    for p, q in zip(custom_patches, custom_patches_2):
      assert torch.equal(p, q)
    for p, q in zip(custom_patches_2, paper_patches):
      assert torch.allclose(p, q)