In [1]:
from patch_sampler import PatchSampler
from paper_patch_sampler import PatchSampleF
import torch

### Without MLP

In [2]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32)),
]

In [3]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('mps')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('mps'),
    use_mlp=False
)

In [4]:
from einops import rearrange
iters = 100

for _ in range(iters):
  with torch.no_grad():
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=False)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=False
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=cust_idx
    )

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      if p_cust.shape != p_paper.shape:
        p_paper = rearrange(p_paper, 'b n d -> (b n) d')
      print(p_cust.shape)
      print(p_paper.shape)
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


torch.Size([1, 64, 128])
torch.Size([1, 64, 128])
tensor([[[ 0.0270, -0.3414, -0.1048,  ...,  0.1136,  0.1116, -0.0014],
         [-0.0233,  0.1204, -0.0871,  ..., -0.0417,  0.0231, -0.1255],
         [ 0.1173,  0.1513, -0.0027,  ...,  0.2310,  0.1249,  0.0330],
         ...,
         [ 0.1136, -0.0540, -0.0473,  ...,  0.0663,  0.1141,  0.1562],
         [-0.1289, -0.0229,  0.0590,  ..., -0.1199, -0.0579, -0.1357],
         [ 0.0484, -0.0878,  0.0049,  ..., -0.1418, -0.1040, -0.1231]]],
       device='mps:0')
tensor([[[ 0.0270, -0.3414, -0.1048,  ...,  0.1136,  0.1116, -0.0014],
         [-0.0233,  0.1204, -0.0871,  ..., -0.0417,  0.0231, -0.1255],
         [ 0.1173,  0.1513, -0.0027,  ...,  0.2310,  0.1249,  0.0330],
         ...,
         [ 0.1136, -0.0540, -0.0473,  ...,  0.0663,  0.1141,  0.1562],
         [-0.1289, -0.0229,  0.0590,  ..., -0.1199, -0.0579, -0.1357],
         [ 0.0484, -0.0878,  0.0049,  ..., -0.1418, -0.1040, -0.1231]]],
       device='mps:0')
tensor([[[ 0.0270, -

### With MLP

In [5]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32))
]
cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
for layer_idx in range(len(cust_patches)):
  setattr(
      paper_patch_sampler,
      f'mlp_{layer_idx}',
      getattr(
          custom_patch_sampler,
          f'mlp_{layer_idx}',
      )
  )
paper_patch_sampler.mlp_init = True

In [6]:
iters = 100
for _ in range(iters):
  with torch.no_grad():
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=True
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=cust_idx
    )

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      if p_cust.shape != p_paper.shape:
        p_paper = rearrange(p_paper, 'b n d -> (b n) d')
      print(p_cust.shape)
      print(p_paper.shape)
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

torch.Size([1, 64, 256])
torch.Size([1, 64, 256])
tensor([[[ 0.0916, -0.0703,  0.2186,  ..., -0.0133,  0.2434, -0.1071],
         [ 0.3118,  0.0463,  0.0259,  ..., -0.1222,  0.1357, -0.0802],
         [ 0.0211,  0.2662,  0.1254,  ..., -0.2056, -0.0801,  0.0724],
         ...,
         [-0.0245,  0.0285,  0.0482,  ...,  0.0569, -0.0477, -0.1285],
         [-0.0717, -0.0592, -0.0422,  ..., -0.2130,  0.0051, -0.1489],
         [-0.2511,  0.1555,  0.0959,  ...,  0.1343, -0.0511, -0.0779]]])
tensor([[[ 0.0916, -0.0703,  0.2186,  ..., -0.0133,  0.2434, -0.1071],
         [ 0.3118,  0.0463,  0.0259,  ..., -0.1222,  0.1357, -0.0802],
         [ 0.0211,  0.2662,  0.1254,  ..., -0.2056, -0.0801,  0.0724],
         ...,
         [-0.0245,  0.0285,  0.0482,  ...,  0.0569, -0.0477, -0.1285],
         [-0.0717, -0.0592, -0.0422,  ..., -0.2130,  0.0051, -0.1489],
         [-0.2511,  0.1555,  0.0959,  ...,  0.1343, -0.0511, -0.0779]]])
tensor([[[ 0.0916, -0.0703,  0.2186,  ..., -0.0133,  0.2434, -0.10