In [1]:
import torch
import torch.nn as nn

https://github.com/LKLQQ/ViG/blob/dea9ad27c2e5514ec85c3c2a082267d5d427e1e3/src/models/vig/gcn_lib.py#L32
https://github.com/LKLQQ/ViG/blob/dea9ad27c2e5514ec85c3c2a082267d5d427e1e3/src/models/vig/vig.py

https://github.com/lightaime/deep_gcns_torch/blob/master/gcn_lib/dense/torch_edge.py
https://github.com/lightaime/deep_gcns_torch/blob/751382aa2d25e25a2792c133cc99f8cfddae0657/gcn_lib/dense/torch_nn.py#L75

In [2]:
image = torch.randn(1,3,224,224)

In [3]:
H = W = 224
C = 3
P = 16
N=H*W//(P**2)

emb_dim = 768

### Patches

In [4]:
conv_proy = nn.Conv2d(C, emb_dim, kernel_size=P, stride=P)

print(conv_proy(image).shape)

batch = conv_proy(image).flatten(2).transpose(1,2)

torch.Size([1, 768, 14, 14])


In [5]:
batch.shape

torch.Size([1, 196, 768])

### KNN pairwise distance

In [6]:
def pairwise_distance(x):
    """
    Compute pairwise distance of a point cloud.
    Args:
        x: tensor (batch_size, num_points, num_dims)
    Returns:
        pairwise distance: (batch_size, num_points, num_points)
    """
    x_inner = -2*torch.matmul(x, x.transpose(2, 1))
    x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
    return torch.sqrt(x_square + x_inner + x_square.transpose(2, 1)) #SQRT((x - x.T)^2 = X^2 -2XX.T + X.T^2)

In [7]:
def pairwise_distance2(x):
    """
    Compute pairwise distance of a point cloud.
    Args:
        x: tensor (batch_size, num_points, num_dims)
    Returns:
        pairwise distance: (batch_size, num_points, num_points)
    """
    x_inner = -2*torch.matmul(x, x.T)
    x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
    return torch.sqrt(x_square + x_inner + x_square.T) #SQRT((x - x.T)^2 = X^2 -2XX.T + X.T^2)

In [8]:
def dense_knn_matrix(x, k=16):
    """Get KNN based on the pairwise distance.
    Args:
        x: (batch_size, num_dims, num_points, 1)
        k: int
    Returns:
        nearest neighbors: (batch_size, num_points ,k) (batch_size, num_points, k)
    """
    with torch.no_grad():
        x = x.transpose(2, 1).squeeze(-1)
        batch_size, n_points, n_dims = x.shape
        _, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k)
        center_idx = torch.arange(0, n_points, device=x.device).expand(batch_size, k, -1).transpose(2, 1)
    return torch.stack((nn_idx, center_idx), dim=0)

#### Pairwise distance

In [9]:
torch.arange(0, 192).expand(1, 4, -1).transpose(2, 1).shape

torch.Size([1, 192, 4])

In [10]:
x = batch

In [55]:
torch.sum(torch.mul(x, x), dim=-1, keepdim=True).shape

torch.Size([1, 196, 1])

In [56]:
(batch @ batch.transpose(1,2)).shape

torch.Size([1, 196, 196])

In [57]:
batch.shape

torch.Size([1, 196, 768])

In [58]:
pairwise_distance(batch).shape

torch.Size([1, 196, 196])

In [116]:
t = torch.tensor(
[[1,2], [2,4]]
)

In [117]:
t.shape

torch.Size([2, 2])

In [118]:
t

tensor([[1, 2],
        [2, 4]])

In [119]:
pairwise_distance(t)

tensor([[0.0000, 2.2361],
        [2.2361, 0.0000]])

In [123]:
(t@t.transpose(0,1))

tensor([[ 5, 10],
        [10, 20]])

In [124]:
t

tensor([[1, 2],
        [2, 4]])

In [125]:
t@t.transpose(0,1)

tensor([[ 5, 10],
        [10, 20]])

In [127]:
-2*torch.matmul(t, t.transpose(0,1))

tensor([[-10, -20],
        [-20, -40]])

In [128]:
torch.matmul(t, t.transpose(0,1))

tensor([[ 5, 10],
        [10, 20]])

### l2 distance KNN: EDGE INDEX

In [107]:
x.shape

torch.Size([1, 196, 768])

In [197]:
x = torch.randn(1,196,768)

In [273]:
x= torch.tensor([[[1,3],[1,2], [0,1]]])

In [274]:
def dense_knn_matrix(x, k=2):
    """Get KNN based on the pairwise distance.
    Args:
        x: (batch_size, num_dims, num_points, 1)
        k: int
    Returns:
        nearest neighbors: (batch_size, num_points ,k) (batch_size, num_points, k)
    """
    with torch.no_grad():
        #x = x.transpose(2, 1).squeeze(-1)
        
        batch_size, n_points, n_dims = x.shape
        #CONSIDERS SELF LOOP
        _, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k) #neighbors idx, xj
        
        center_idx = torch.arange(0, n_points, device=x.device).expand(batch_size, k, -1).transpose(2, 1) #curr node, xi
    return torch.stack((nn_idx, center_idx), dim=0) # 2 (neighbor-src), BATCH_SIZE, NUM_PATCHES, K

In [275]:
edge_index = dense_knn_matrix(x)

In [276]:
edge_index[0,0,1,:] #Los vecinos del batch[0], patch[1]

tensor([1, 0])

In [277]:
edge_index[1].shape

torch.Size([1, 3, 2])

### fetch nodes from edge index

In [278]:
def batched_index_select(x, idx):
    """
    Args:
        X: B, C, N, 1
        e_index: B, N, K
        
    Returns:
        B, C, N, k
    """

    
    
    r"""fetches neighbors features from a given neighbor idx
    Args:
        x (Tensor): input feature Tensor
                :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
        idx (Tensor): edge_idx
                :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
    Returns:
        Tensor: output neighbors features
            :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
    """
    batch_size, num_dims, num_vertices = x.shape[:3]
    k = idx.shape[-1]
    
    
    idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices
    idx = idx + idx_base
    idx = idx.contiguous().view(-1)

    x = x.transpose(2, 1)
    feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :]
    feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
    return feature #B, C, N, K

In [279]:
batched_index_select(x.transpose(1,2).unsqueeze(-1),edge_index[1]).shape

torch.Size([1, 2, 3, 2])

---

In [280]:
xi = batched_index_select(x.transpose(1,2).unsqueeze(-1),edge_index[1])
xj = batched_index_select(x.transpose(1,2).unsqueeze(-1),edge_index[0])

In [283]:
xi

tensor([[[[1, 1],
          [1, 1],
          [0, 0]],

         [[3, 3],
          [2, 2],
          [1, 1]]]])

In [284]:
xj

tensor([[[[1, 1],
          [1, 1],
          [0, 1]],

         [[3, 2],
          [2, 3],
          [1, 2]]]])

In [285]:
torch.max(xj - xi, -1, keepdim=True).values

tensor([[[[0],
          [0],
          [1]],

         [[0],
          [1],
          [1]]]])

In [311]:
def batched_index_select(x, idx):
    """
    Args:
        X: B, C, N, 1
        e_index: B, N, K
        
    Returns:
        B, C, N, k
    """

    
    
    r"""fetches neighbors features from a given neighbor idx
    Args:
        x (Tensor): input feature Tensor
                :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
        idx (Tensor): edge_idx
                :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
    Returns:
        Tensor: output neighbors features
            :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
    """
    batch_size, num_dims, num_vertices = x.shape[:3]
    k = idx.shape[-1]
    
    
    idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices
    idx = idx + idx_base
    idx = idx.contiguous().view(-1)

    x = x.transpose(2, 1)
    feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :]
    feature = feature.view(batch_size, num_vertices, k, num_dims).contiguous()
    return feature

In [312]:
xi = batched_index_select(x.transpose(1,2).unsqueeze(-1),edge_index[1])
xj = batched_index_select(x.transpose(1,2).unsqueeze(-1),edge_index[0])

In [313]:
xi.shape

torch.Size([1, 3, 2, 2])

In [314]:
x

tensor([[[1, 3],
         [1, 2],
         [0, 1]]])

In [315]:
xj

tensor([[[[1, 3],
          [1, 2]],

         [[1, 2],
          [1, 3]],

         [[0, 1],
          [1, 2]]]])

In [316]:
torch.max(xj,dim=2,keepdim=True).values

tensor([[[[1, 3]],

         [[1, 3]],

         [[1, 2]]]])

In [320]:
torch.max(xj - xi, 2, keepdim=True).values.permute(0, 3, 1, 2) #Same result

tensor([[[[0],
          [0],
          [1]],

         [[0],
          [1],
          [1]]]])