In [1]:
from patch_sampler import PatchSampler
from psample import PatchSampleF
import torch

### Without MLP

In [2]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32)),
]

In [3]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('mps')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('mps'),
    use_mlp=False
)

In [4]:
from einops import rearrange
iters = 100

for _ in range(iters):
  with torch.no_grad():
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=False)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=False
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=cust_idx
    )

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      if p_cust.shape != p_paper.shape:
        p_paper = rearrange(p_paper, 'b n d -> (b n) d')
      print(p_cust.shape)
      print(p_paper.shape)
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


torch.Size([1, 64, 128])
torch.Size([1, 64, 128])
tensor([[[ 0.0675,  0.2414, -0.0128,  ...,  0.0399,  0.0090,  0.0085],
         [-0.1192, -0.2386, -0.1461,  ...,  0.0713, -0.0326,  0.1385],
         [-0.1379, -0.0914,  0.0511,  ...,  0.1046, -0.0264, -0.1339],
         ...,
         [-0.0964,  0.0749, -0.0101,  ...,  0.0673, -0.0614,  0.1518],
         [-0.0304, -0.0056,  0.0061,  ..., -0.0957,  0.0130,  0.0378],
         [ 0.0369,  0.0005, -0.2460,  ...,  0.0648,  0.0296, -0.1856]]],
       device='mps:0')
tensor([[[ 0.0675,  0.2414, -0.0128,  ...,  0.0399,  0.0090,  0.0085],
         [-0.1192, -0.2386, -0.1461,  ...,  0.0713, -0.0326,  0.1385],
         [-0.1379, -0.0914,  0.0511,  ...,  0.1046, -0.0264, -0.1339],
         ...,
         [-0.0964,  0.0749, -0.0101,  ...,  0.0673, -0.0614,  0.1518],
         [-0.0304, -0.0056,  0.0061,  ..., -0.0957,  0.0130,  0.0378],
         [ 0.0369,  0.0005, -0.2460,  ...,  0.0648,  0.0296, -0.1856]]],
       device='mps:0')
tensor([[[ 0.0675,  

### With MLP

In [5]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32))
]
cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
for layer_idx in range(len(cust_patches)):
  setattr(
      paper_patch_sampler,
      f'mlp_{layer_idx}',
      getattr(
          custom_patch_sampler,
          f'mlp_{layer_idx}',
      )
  )
paper_patch_sampler.mlp_init = True

In [6]:
iters = 100
for _ in range(iters):
  with torch.no_grad():
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=True
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=cust_idx
    )

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      if p_cust.shape != p_paper.shape:
        p_paper = rearrange(p_paper, 'b n d -> (b n) d')
      print(p_cust.shape)
      print(p_paper.shape)
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

torch.Size([1, 64, 256])
torch.Size([1, 64, 256])
tensor([[[ 0.1255,  0.3112, -0.0393,  ...,  0.2481,  0.1130, -0.0625],
         [-0.0342,  0.1184,  0.1945,  ...,  0.1095, -0.0822, -0.0399],
         [-0.0660,  0.0413, -0.1335,  ...,  0.0986,  0.1469, -0.3254],
         ...,
         [ 0.2583,  0.1981, -0.0987,  ..., -0.1915, -0.0083,  0.3108],
         [-0.0103,  0.0261, -0.0055,  ...,  0.1353,  0.0624, -0.1268],
         [-0.1417,  0.1236,  0.1100,  ...,  0.2465,  0.0427,  0.0736]]])
tensor([[[ 0.1255,  0.3112, -0.0393,  ...,  0.2481,  0.1130, -0.0625],
         [-0.0342,  0.1184,  0.1945,  ...,  0.1095, -0.0822, -0.0399],
         [-0.0660,  0.0413, -0.1335,  ...,  0.0986,  0.1469, -0.3254],
         ...,
         [ 0.2583,  0.1981, -0.0987,  ..., -0.1915, -0.0083,  0.3108],
         [-0.0103,  0.0261, -0.0055,  ...,  0.1353,  0.0624, -0.1268],
         [-0.1417,  0.1236,  0.1100,  ...,  0.2465,  0.0427,  0.0736]]])
tensor([[[ 0.1255,  0.3112, -0.0393,  ...,  0.2481,  0.1130, -0.06