In [1]:
from patch_sampler import PatchSampler
from paper_patch_sampler import PatchSampleF
import torch

### Without MLP

In [2]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32)),
]

In [3]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('mps')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('mps'),
    use_mlp=False
)

In [4]:
from einops import rearrange
iters = 100

for _ in range(iters):
  with torch.no_grad():
    layer_outs = [
        torch.randn((batch_size, 128, 128, 128)),
        torch.randn((batch_size, 256, 64, 64)),
        torch.randn((batch_size, 512, 32, 32)),
    ]
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=False)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=False
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=[
            idx.flatten()
            for idx in cust_idx
        ],
        num_patches=num_patches_per_layer
    )

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      p_cust = rearrange(p_cust, 'b n d -> (b n) d')
      p_cust_2 = rearrange(p_cust_2, 'b n d -> (b n) d')
      print(p_cust.shape)
      print(p_paper.shape)
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print(torch.linalg.norm(p_cust, dim=-1))
      print(torch.linalg.norm(p_paper, dim=-1))
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)
  print(torch.linalg.norm(p_cust, dim=-1))


torch.Size([64, 128])
torch.Size([64, 128])
tensor([[-0.0056,  0.0631, -0.0447,  ...,  0.0446, -0.0151, -0.1973],
        [-0.1099, -0.0311, -0.0499,  ..., -0.2154, -0.0131, -0.0282],
        [ 0.0741,  0.0007, -0.1228,  ..., -0.0893,  0.0356, -0.0698],
        ...,
        [-0.0442, -0.0687, -0.0746,  ...,  0.0802,  0.0865, -0.0413],
        [ 0.0321,  0.0288,  0.1202,  ..., -0.0682, -0.0215, -0.1044],
        [-0.0397, -0.0547, -0.1687,  ..., -0.0590, -0.1688,  0.0219]],
       device='mps:0')
tensor([[-0.0056,  0.0631, -0.0447,  ...,  0.0446, -0.0151, -0.1973],
        [-0.1099, -0.0311, -0.0499,  ..., -0.2154, -0.0131, -0.0282],
        [ 0.0741,  0.0007, -0.1228,  ..., -0.0893,  0.0356, -0.0698],
        ...,
        [-0.0442, -0.0687, -0.0746,  ...,  0.0802,  0.0865, -0.0413],
        [ 0.0321,  0.0288,  0.1202,  ..., -0.0682, -0.0215, -0.1044],
        [-0.0397, -0.0547, -0.1687,  ..., -0.0590, -0.1688,  0.0219]],
       device='mps:0')
tensor([[-0.0056,  0.0631, -0.0447,  ..., 

### With MLP

In [5]:
custom_patch_sampler = PatchSampler(
    patch_embedding_dim=patch_embedding_dim,
    num_patches_per_layer=num_patches_per_layer,
    device=torch.device('cpu')
)
paper_patch_sampler = PatchSampleF(
    nc=patch_embedding_dim,
    device=torch.device('cpu'),
    use_mlp=True
)
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 64

layer_outs = [
    torch.randn((batch_size, 128, 128, 128)),
    torch.randn((batch_size, 256, 64, 64)),
    torch.randn((batch_size, 512, 32, 32))
]
cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
for layer_idx in range(len(cust_patches)):
  setattr(
      paper_patch_sampler,
      f'mlp_{layer_idx}',
      getattr(
          custom_patch_sampler,
          f'mlp_{layer_idx}',
      )
  )
paper_patch_sampler.mlp_init = True

In [6]:
iters = 100
for _ in range(iters):
  with torch.no_grad():
    cust_patches, cust_idx = custom_patch_sampler(layer_outs, apply_mlp=True)
    cust_patches_reselected, _ = custom_patch_sampler(
        layer_outs,
        patch_idx_per_layer=cust_idx,
        apply_mlp=True
    )

    paper_patches, paper_idx = paper_patch_sampler(
        layer_outs,
        patch_ids=[
            idx.flatten()
            for idx in cust_idx
        ])

    for p_cust,  p_cust_2, p_paper in zip(cust_patches, cust_patches_reselected, paper_patches):
      p_cust = rearrange(p_cust, 'b n d -> (b n) d')
      p_cust_2 = rearrange(p_cust_2, 'b n d -> (b n) d')
      print(p_cust)
      print(p_cust_2)
      print(p_paper)
      print(torch.linalg.norm(p_cust, dim=-1))
      print(torch.linalg.norm(p_paper, dim=-1))
      print('--')
      assert (torch.allclose(p_cust.cpu(), p_paper.cpu()))

tensor([[ 0.1261,  0.0618,  0.0794,  ..., -0.0329, -0.0652,  0.0606],
        [ 0.0513,  0.0795,  0.1131,  ..., -0.0296, -0.0085, -0.0119],
        [ 0.0827,  0.0007,  0.1203,  ...,  0.1169, -0.0418, -0.0250],
        ...,
        [-0.0057,  0.0844,  0.1317,  ...,  0.0335,  0.0059, -0.0412],
        [-0.0037, -0.0071,  0.1378,  ...,  0.0343, -0.0238, -0.0635],
        [ 0.0120,  0.1050,  0.0892,  ...,  0.0484, -0.0064,  0.0303]])
tensor([[ 0.1261,  0.0618,  0.0794,  ..., -0.0329, -0.0652,  0.0606],
        [ 0.0513,  0.0795,  0.1131,  ..., -0.0296, -0.0085, -0.0119],
        [ 0.0827,  0.0007,  0.1203,  ...,  0.1169, -0.0418, -0.0250],
        ...,
        [-0.0057,  0.0844,  0.1317,  ...,  0.0335,  0.0059, -0.0412],
        [-0.0037, -0.0071,  0.1378,  ...,  0.0343, -0.0238, -0.0635],
        [ 0.0120,  0.1050,  0.0892,  ...,  0.0484, -0.0064,  0.0303]])
tensor([[ 0.1261,  0.0618,  0.0794,  ..., -0.0329, -0.0652,  0.0606],
        [ 0.0513,  0.0795,  0.1131,  ..., -0.0296, -0.0085, -0