In [1]:
from patch_sampler import PatchSampler
from paper_patch_sampler import PatchSampleF
import torch

### Without MLP

In [2]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32)),
]

In [3]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('mps')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('mps'),
    use_mlp=False
)

In [4]:
from einops import rearrange
iters = 100

for _ in range(iters):
  with torch.no_grad():
    layer_outs = [
        torch.randn((batch_size, 128, 128, 128)),
        torch.randn((batch_size, 256, 64, 64)),
        torch.randn((batch_size, 512, 32, 32)),
    ]
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=False)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=False
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=[
            idx.flatten()
            for idx in cust_idx
        ],
        num_patches=num_patches_per_layer
    )

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      p_cust = rearrange(p_cust, 'b n d -> (b n) d')
      p_cust_2 = rearrange(p_cust_2, 'b n d -> (b n) d')
      print(p_cust.shape)
      print(p_paper.shape)
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print(torch.linalg.norm(p_cust, dim=-1))
      print(torch.linalg.norm(p_paper, dim=-1))
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)
  print(torch.linalg.norm(p_cust, dim=-1))


torch.Size([64, 128])
torch.Size([64, 128])
tensor([[ 0.0796,  0.1434, -0.1073,  ...,  0.0437, -0.0064,  0.1051],
        [ 0.1018,  0.0298,  0.1218,  ...,  0.0919, -0.0857, -0.0961],
        [-0.0949, -0.0416, -0.1090,  ...,  0.0374, -0.0377,  0.0285],
        ...,
        [-0.0015, -0.0062, -0.0474,  ..., -0.1694,  0.0164, -0.0524],
        [-0.0580,  0.1656, -0.0069,  ..., -0.0934, -0.1507, -0.1073],
        [ 0.0707, -0.0605, -0.0769,  ...,  0.1190, -0.0379, -0.0299]],
       device='mps:0')
tensor([[ 0.0796,  0.1434, -0.1073,  ...,  0.0437, -0.0064,  0.1051],
        [ 0.1018,  0.0298,  0.1218,  ...,  0.0919, -0.0857, -0.0961],
        [-0.0949, -0.0416, -0.1090,  ...,  0.0374, -0.0377,  0.0285],
        ...,
        [-0.0015, -0.0062, -0.0474,  ..., -0.1694,  0.0164, -0.0524],
        [-0.0580,  0.1656, -0.0069,  ..., -0.0934, -0.1507, -0.1073],
        [ 0.0707, -0.0605, -0.0769,  ...,  0.1190, -0.0379, -0.0299]],
       device='mps:0')
tensor([[ 0.0796,  0.1434, -0.1073,  ..., 

### With MLP

In [5]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32))
]
cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
for layer_idx in range(len(cust_patches)):
  setattr(
      paper_patch_sampler,
      f'mlp_{layer_idx}',
      getattr(
          custom_patch_sampler,
          f'mlp_{layer_idx}',
      )
  )
paper_patch_sampler.mlp_init = True

In [6]:
iters = 100
for _ in range(iters):
  with torch.no_grad():
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=True
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=[
            idx.flatten()
            for idx in cust_idx
        ])

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      p_cust = rearrange(p_cust, 'b n d -> (b n) d')
      p_cust_2 = rearrange(p_cust_2, 'b n d -> (b n) d')
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print(torch.linalg.norm(p_cust, dim=-1))
      print(torch.linalg.norm(p_paper, dim=-1))
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

tensor([[-0.0276, -0.0110,  0.0232,  ..., -0.0286, -0.1086, -0.0086],
        [-0.0637, -0.0783, -0.0016,  ...,  0.0504, -0.0233,  0.0763],
        [-0.0973,  0.0178, -0.0174,  ...,  0.1148, -0.0150,  0.0210],
        ...,
        [-0.0161, -0.0548,  0.0618,  ...,  0.0233,  0.0315,  0.0117],
        [-0.0771, -0.0804,  0.0199,  ...,  0.0250,  0.0721,  0.0872],
        [-0.0191,  0.0709,  0.0695,  ...,  0.0348, -0.0385,  0.0470]])
tensor([[-0.0276, -0.0110,  0.0232,  ..., -0.0286, -0.1086, -0.0086],
        [-0.0637, -0.0783, -0.0016,  ...,  0.0504, -0.0233,  0.0763],
        [-0.0973,  0.0178, -0.0174,  ...,  0.1148, -0.0150,  0.0210],
        ...,
        [-0.0161, -0.0548,  0.0618,  ...,  0.0233,  0.0315,  0.0117],
        [-0.0771, -0.0804,  0.0199,  ...,  0.0250,  0.0721,  0.0872],
        [-0.0191,  0.0709,  0.0695,  ...,  0.0348, -0.0385,  0.0470]])
tensor([[-0.0276, -0.0110,  0.0232,  ..., -0.0286, -0.1086, -0.0086],
        [-0.0637, -0.0783, -0.0016,  ...,  0.0504, -0.0233,  0